pyo3_stub_gen/generate/
stub_info.rs

1use crate::{generate::*, pyproject::PyProject, type_info::*};
2use anyhow::{Context, Result};
3use std::{
4    collections::{BTreeMap, BTreeSet},
5    fs,
6    io::Write,
7    path::*,
8};
9
10#[derive(Debug, Clone, PartialEq)]
11pub struct StubInfo {
12    pub modules: BTreeMap<String, Module>,
13    pub python_root: PathBuf,
14}
15
16impl StubInfo {
17    /// Initialize [StubInfo] from a `pyproject.toml` file in `CARGO_MANIFEST_DIR`.
18    /// This is automatically set up by the [crate::define_stub_info_gatherer] macro.
19    pub fn from_pyproject_toml(path: impl AsRef<Path>) -> Result<Self> {
20        let pyproject = PyProject::parse_toml(path)?;
21        StubInfoBuilder::from_pyproject_toml(pyproject).build()
22    }
23
24    /// Initialize [StubInfo] with a specific module name and project root.
25    /// This must be placed in your PyO3 library crate, i.e. the same crate where [inventory::submit]ted,
26    /// not in the `gen_stub` executables due to [inventory]'s mechanism.
27    pub fn from_project_root(default_module_name: String, project_root: PathBuf) -> Result<Self> {
28        StubInfoBuilder::from_project_root(default_module_name, project_root).build()
29    }
30
31    pub fn generate(&self) -> Result<()> {
32        for (name, module) in self.modules.iter() {
33            // Convert dashes to underscores for Python compatibility
34            let normalized_name = name.replace("-", "_");
35            let path = normalized_name.replace(".", "/");
36            let dest = if module.submodules.is_empty() {
37                self.python_root.join(format!("{path}.pyi"))
38            } else {
39                self.python_root.join(path).join("__init__.pyi")
40            };
41
42            let dir = dest.parent().context("Cannot get parent directory")?;
43            if !dir.exists() {
44                fs::create_dir_all(dir)?;
45            }
46
47            let mut f = fs::File::create(&dest)?;
48            write!(f, "{module}")?;
49            log::info!(
50                "Generate stub file of a module `{name}` at {dest}",
51                dest = dest.display()
52            );
53        }
54        Ok(())
55    }
56}
57
58struct StubInfoBuilder {
59    modules: BTreeMap<String, Module>,
60    default_module_name: String,
61    python_root: PathBuf,
62}
63
64impl StubInfoBuilder {
65    fn from_pyproject_toml(pyproject: PyProject) -> Self {
66        StubInfoBuilder::from_project_root(
67            pyproject.module_name().to_string(),
68            pyproject
69                .python_source()
70                .unwrap_or(PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").unwrap())),
71        )
72    }
73
74    fn from_project_root(default_module_name: String, project_root: PathBuf) -> Self {
75        Self {
76            modules: BTreeMap::new(),
77            default_module_name,
78            python_root: project_root,
79        }
80    }
81
82    fn get_module(&mut self, name: Option<&str>) -> &mut Module {
83        let name = name.unwrap_or(&self.default_module_name).to_string();
84        let module = self.modules.entry(name.clone()).or_default();
85        module.name = name;
86        module.default_module_name = self.default_module_name.clone();
87        module
88    }
89
90    fn register_submodules(&mut self) {
91        let mut map: BTreeMap<String, BTreeSet<String>> = BTreeMap::new();
92        for module in self.modules.keys() {
93            let path = module.split('.').collect::<Vec<_>>();
94            let n = path.len();
95            if n <= 1 {
96                continue;
97            }
98            map.entry(path[..n - 1].join("."))
99                .or_default()
100                .insert(path[n - 1].to_string());
101        }
102        for (parent, children) in map {
103            if let Some(module) = self.modules.get_mut(&parent) {
104                module.submodules.extend(children);
105            }
106        }
107    }
108
109    fn add_class(&mut self, info: &PyClassInfo) {
110        self.get_module(info.module)
111            .class
112            .insert((info.struct_id)(), ClassDef::from(info));
113    }
114
115    fn add_complex_enum(&mut self, info: &PyComplexEnumInfo) {
116        self.get_module(info.module)
117            .class
118            .insert((info.enum_id)(), ClassDef::from(info));
119    }
120
121    fn add_enum(&mut self, info: &PyEnumInfo) {
122        self.get_module(info.module)
123            .enum_
124            .insert((info.enum_id)(), EnumDef::from(info));
125    }
126
127    fn add_function(&mut self, info: &PyFunctionInfo) -> Result<()> {
128        let target = self
129            .get_module(info.module)
130            .function
131            .entry(info.name)
132            .or_default();
133
134        // Validation: Check for multiple non-overload functions
135        let new_func = FunctionDef::from(info);
136        if !new_func.is_overload {
137            let non_overload_count = target.iter().filter(|f| !f.is_overload).count();
138            if non_overload_count > 0 {
139                anyhow::bail!(
140                    "Multiple functions with name '{}' found without @overload decorator. \
141                     Please add @overload decorator to all variants.",
142                    info.name
143                );
144            }
145        }
146
147        target.push(new_func);
148        Ok(())
149    }
150
151    fn add_variable(&mut self, info: &PyVariableInfo) {
152        self.get_module(Some(info.module))
153            .variables
154            .insert(info.name, VariableDef::from(info));
155    }
156
157    fn add_module_doc(&mut self, info: &ModuleDocInfo) {
158        self.get_module(Some(info.module)).doc = (info.doc)();
159    }
160
161    fn add_methods(&mut self, info: &PyMethodsInfo) -> Result<()> {
162        let struct_id = (info.struct_id)();
163        for module in self.modules.values_mut() {
164            if let Some(entry) = module.class.get_mut(&struct_id) {
165                for attr in info.attrs {
166                    entry.attrs.push(MemberDef {
167                        name: attr.name,
168                        r#type: (attr.r#type)(),
169                        doc: attr.doc,
170                        default: attr.default.map(|f| f()),
171                        deprecated: attr.deprecated.clone(),
172                    });
173                }
174                for getter in info.getters {
175                    entry
176                        .getter_setters
177                        .entry(getter.name.to_string())
178                        .or_default()
179                        .0 = Some(MemberDef {
180                        name: getter.name,
181                        r#type: (getter.r#type)(),
182                        doc: getter.doc,
183                        default: getter.default.map(|f| f()),
184                        deprecated: getter.deprecated.clone(),
185                    });
186                }
187                for setter in info.setters {
188                    entry
189                        .getter_setters
190                        .entry(setter.name.to_string())
191                        .or_default()
192                        .1 = Some(MemberDef {
193                        name: setter.name,
194                        r#type: (setter.r#type)(),
195                        doc: setter.doc,
196                        default: setter.default.map(|f| f()),
197                        deprecated: setter.deprecated.clone(),
198                    });
199                }
200                for method in info.methods {
201                    let entries = entry.methods.entry(method.name.to_string()).or_default();
202
203                    // Validation: Check for multiple non-overload methods
204                    let new_method = MethodDef::from(method);
205                    if !new_method.is_overload {
206                        let non_overload_count = entries.iter().filter(|m| !m.is_overload).count();
207                        if non_overload_count > 0 {
208                            anyhow::bail!(
209                                "Multiple methods with name '{}' in class '{}' found without @overload decorator. \
210                                 Please add @overload decorator to all variants.",
211                                method.name, entry.name
212                            );
213                        }
214                    }
215
216                    entries.push(new_method);
217                }
218                return Ok(());
219            } else if let Some(entry) = module.enum_.get_mut(&struct_id) {
220                for attr in info.attrs {
221                    entry.attrs.push(MemberDef {
222                        name: attr.name,
223                        r#type: (attr.r#type)(),
224                        doc: attr.doc,
225                        default: attr.default.map(|f| f()),
226                        deprecated: attr.deprecated.clone(),
227                    });
228                }
229                for getter in info.getters {
230                    entry.getters.push(MemberDef {
231                        name: getter.name,
232                        r#type: (getter.r#type)(),
233                        doc: getter.doc,
234                        default: getter.default.map(|f| f()),
235                        deprecated: getter.deprecated.clone(),
236                    });
237                }
238                for setter in info.setters {
239                    entry.setters.push(MemberDef {
240                        name: setter.name,
241                        r#type: (setter.r#type)(),
242                        doc: setter.doc,
243                        default: setter.default.map(|f| f()),
244                        deprecated: setter.deprecated.clone(),
245                    });
246                }
247                for method in info.methods {
248                    // Validation: Check for multiple non-overload methods
249                    let new_method = MethodDef::from(method);
250                    if !new_method.is_overload {
251                        let non_overload_count = entry
252                            .methods
253                            .iter()
254                            .filter(|m| m.name == method.name && !m.is_overload)
255                            .count();
256                        if non_overload_count > 0 {
257                            anyhow::bail!(
258                                "Multiple methods with name '{}' in enum '{}' found without @overload decorator. \
259                                 Please add @overload decorator to all variants.",
260                                method.name, entry.name
261                            );
262                        }
263                    }
264
265                    entry.methods.push(new_method);
266                }
267                return Ok(());
268            }
269        }
270        unreachable!("Missing struct_id/enum_id = {:?}", struct_id);
271    }
272
273    fn build(mut self) -> Result<StubInfo> {
274        for info in inventory::iter::<PyClassInfo> {
275            self.add_class(info);
276        }
277        for info in inventory::iter::<PyComplexEnumInfo> {
278            self.add_complex_enum(info);
279        }
280        for info in inventory::iter::<PyEnumInfo> {
281            self.add_enum(info);
282        }
283        for info in inventory::iter::<PyFunctionInfo> {
284            self.add_function(info)?;
285        }
286        for info in inventory::iter::<PyVariableInfo> {
287            self.add_variable(info);
288        }
289        for info in inventory::iter::<ModuleDocInfo> {
290            self.add_module_doc(info);
291        }
292        // Sort PyMethodsInfo by source location for deterministic IndexMap insertion order
293        let mut methods_infos: Vec<&PyMethodsInfo> = inventory::iter::<PyMethodsInfo>().collect();
294        methods_infos.sort_by_key(|info| (info.file, info.line, info.column));
295        for info in methods_infos {
296            self.add_methods(info)?;
297        }
298        self.register_submodules();
299        Ok(StubInfo {
300            modules: self.modules,
301            python_root: self.python_root,
302        })
303    }
304}