pyo3_stub_gen/generate/
stub_info.rs

1use crate::{generate::*, pyproject::PyProject, type_info::*};
2use anyhow::{Context, Result};
3use std::{
4    collections::{BTreeMap, BTreeSet},
5    fs,
6    io::Write,
7    path::*,
8};
9
10#[derive(Debug, Clone, PartialEq)]
11pub struct StubInfo {
12    pub modules: BTreeMap<String, Module>,
13    pub python_root: PathBuf,
14}
15
16impl StubInfo {
17    /// Initialize [StubInfo] from a `pyproject.toml` file in `CARGO_MANIFEST_DIR`.
18    /// This is automatically set up by the [crate::define_stub_info_gatherer] macro.
19    pub fn from_pyproject_toml(path: impl AsRef<Path>) -> Result<Self> {
20        let pyproject = PyProject::parse_toml(path)?;
21        Ok(StubInfoBuilder::from_pyproject_toml(pyproject).build())
22    }
23
24    /// Initialize [StubInfo] with a specific module name and project root.
25    /// This must be placed in your PyO3 library crate, i.e. the same crate where [inventory::submit]ted,
26    /// not in the `gen_stub` executables due to [inventory]'s mechanism.
27    pub fn from_project_root(default_module_name: String, project_root: PathBuf) -> Result<Self> {
28        Ok(StubInfoBuilder::from_project_root(default_module_name, project_root).build())
29    }
30
31    pub fn generate(&self) -> Result<()> {
32        for (name, module) in self.modules.iter() {
33            // Convert dashes to underscores for Python compatibility
34            let normalized_name = name.replace("-", "_");
35            let path = normalized_name.replace(".", "/");
36            let dest = if module.submodules.is_empty() {
37                self.python_root.join(format!("{path}.pyi"))
38            } else {
39                self.python_root.join(path).join("__init__.pyi")
40            };
41
42            let dir = dest.parent().context("Cannot get parent directory")?;
43            if !dir.exists() {
44                fs::create_dir_all(dir)?;
45            }
46
47            let mut f = fs::File::create(&dest)?;
48            write!(f, "{module}")?;
49            log::info!(
50                "Generate stub file of a module `{name}` at {dest}",
51                dest = dest.display()
52            );
53        }
54        Ok(())
55    }
56}
57
58struct StubInfoBuilder {
59    modules: BTreeMap<String, Module>,
60    default_module_name: String,
61    python_root: PathBuf,
62}
63
64impl StubInfoBuilder {
65    fn from_pyproject_toml(pyproject: PyProject) -> Self {
66        StubInfoBuilder::from_project_root(
67            pyproject.module_name().to_string(),
68            pyproject
69                .python_source()
70                .unwrap_or(PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").unwrap())),
71        )
72    }
73
74    fn from_project_root(default_module_name: String, project_root: PathBuf) -> Self {
75        Self {
76            modules: BTreeMap::new(),
77            default_module_name,
78            python_root: project_root,
79        }
80    }
81
82    fn get_module(&mut self, name: Option<&str>) -> &mut Module {
83        let name = name.unwrap_or(&self.default_module_name).to_string();
84        let module = self.modules.entry(name.clone()).or_default();
85        module.name = name;
86        module.default_module_name = self.default_module_name.clone();
87        module
88    }
89
90    fn register_submodules(&mut self) {
91        let mut map: BTreeMap<String, BTreeSet<String>> = BTreeMap::new();
92        for module in self.modules.keys() {
93            let path = module.split('.').collect::<Vec<_>>();
94            let n = path.len();
95            if n <= 1 {
96                continue;
97            }
98            map.entry(path[..n - 1].join("."))
99                .or_default()
100                .insert(path[n - 1].to_string());
101        }
102        for (parent, children) in map {
103            if let Some(module) = self.modules.get_mut(&parent) {
104                module.submodules.extend(children);
105            }
106        }
107    }
108
109    fn add_class(&mut self, info: &PyClassInfo) {
110        self.get_module(info.module)
111            .class
112            .insert((info.struct_id)(), ClassDef::from(info));
113    }
114
115    fn add_complex_enum(&mut self, info: &PyComplexEnumInfo) {
116        self.get_module(info.module)
117            .class
118            .insert((info.enum_id)(), ClassDef::from(info));
119    }
120
121    fn add_enum(&mut self, info: &PyEnumInfo) {
122        self.get_module(info.module)
123            .enum_
124            .insert((info.enum_id)(), EnumDef::from(info));
125    }
126
127    fn add_function(&mut self, info: &PyFunctionInfo) {
128        let target = self
129            .get_module(info.module)
130            .function
131            .entry(info.name)
132            .or_default();
133        target.push(FunctionDef::from(info));
134    }
135
136    fn add_error(&mut self, info: &PyErrorInfo) {
137        self.get_module(Some(info.module))
138            .error
139            .insert(info.name, ErrorDef::from(info));
140    }
141
142    fn add_variable(&mut self, info: &PyVariableInfo) {
143        self.get_module(Some(info.module))
144            .variables
145            .insert(info.name, VariableDef::from(info));
146    }
147
148    fn add_methods(&mut self, info: &PyMethodsInfo) {
149        let struct_id = (info.struct_id)();
150        for module in self.modules.values_mut() {
151            if let Some(entry) = module.class.get_mut(&struct_id) {
152                for attr in info.attrs {
153                    entry.attrs.push(MemberDef {
154                        name: attr.name,
155                        r#type: (attr.r#type)(),
156                        doc: attr.doc,
157                        default: attr.default.map(|s| s.as_str()),
158                    });
159                }
160                for getter in info.getters {
161                    entry.getters.push(MemberDef {
162                        name: getter.name,
163                        r#type: (getter.r#type)(),
164                        doc: getter.doc,
165                        default: getter.default.map(|s| s.as_str()),
166                    });
167                }
168                for setter in info.setters {
169                    entry.setters.push(MemberDef {
170                        name: setter.name,
171                        r#type: (setter.r#type)(),
172                        doc: setter.doc,
173                        default: setter.default.map(|s| s.as_str()),
174                    });
175                }
176                for method in info.methods {
177                    let entries = entry.methods.entry(method.name.to_string()).or_default();
178                    entries.push(MethodDef::from(method));
179                }
180                return;
181            } else if let Some(entry) = module.enum_.get_mut(&struct_id) {
182                for attr in info.attrs {
183                    entry.attrs.push(MemberDef {
184                        name: attr.name,
185                        r#type: (attr.r#type)(),
186                        doc: attr.doc,
187                        default: attr.default.map(|s| s.as_str()),
188                    });
189                }
190                for getter in info.getters {
191                    entry.getters.push(MemberDef {
192                        name: getter.name,
193                        r#type: (getter.r#type)(),
194                        doc: getter.doc,
195                        default: getter.default.map(|s| s.as_str()),
196                    });
197                }
198                for setter in info.setters {
199                    entry.setters.push(MemberDef {
200                        name: setter.name,
201                        r#type: (setter.r#type)(),
202                        doc: setter.doc,
203                        default: setter.default.map(|s| s.as_str()),
204                    });
205                }
206                for method in info.methods {
207                    entry.methods.push(MethodDef::from(method))
208                }
209                return;
210            }
211        }
212        unreachable!("Missing struct_id/enum_id = {:?}", struct_id);
213    }
214
215    fn build(mut self) -> StubInfo {
216        for info in inventory::iter::<PyClassInfo> {
217            self.add_class(info);
218        }
219        for info in inventory::iter::<PyComplexEnumInfo> {
220            self.add_complex_enum(info);
221        }
222        for info in inventory::iter::<PyEnumInfo> {
223            self.add_enum(info);
224        }
225        for info in inventory::iter::<PyFunctionInfo> {
226            self.add_function(info);
227        }
228        for info in inventory::iter::<PyErrorInfo> {
229            self.add_error(info);
230        }
231        for info in inventory::iter::<PyVariableInfo> {
232            self.add_variable(info);
233        }
234        for info in inventory::iter::<PyMethodsInfo> {
235            self.add_methods(info);
236        }
237        self.register_submodules();
238        StubInfo {
239            modules: self.modules,
240            python_root: self.python_root,
241        }
242    }
243}