pyo3_stub_gen/generate/
stub_info.rs

1use crate::{generate::*, pyproject::PyProject, type_info::*};
2use anyhow::{Context, Result};
3use std::{
4    collections::{BTreeMap, BTreeSet},
5    fs,
6    io::Write,
7    path::*,
8};
9
10#[derive(Debug, Clone, PartialEq)]
11pub struct StubInfo {
12    pub modules: BTreeMap<String, Module>,
13    pub python_root: PathBuf,
14    /// Whether this is a mixed Python/Rust layout (has `python-source` in pyproject.toml)
15    pub is_mixed_layout: bool,
16}
17
18impl StubInfo {
19    /// Initialize [StubInfo] from a `pyproject.toml` file in `CARGO_MANIFEST_DIR`.
20    /// This is automatically set up by the [crate::define_stub_info_gatherer] macro.
21    pub fn from_pyproject_toml(path: impl AsRef<Path>) -> Result<Self> {
22        let pyproject = PyProject::parse_toml(path)?;
23        StubInfoBuilder::from_pyproject_toml(pyproject).build()
24    }
25
26    /// Initialize [StubInfo] with a specific module name and project root.
27    /// This must be placed in your PyO3 library crate, i.e. the same crate where [inventory::submit]ted,
28    /// not in the `gen_stub` executables due to [inventory]'s mechanism.
29    pub fn from_project_root(
30        default_module_name: String,
31        project_root: PathBuf,
32        is_mixed_layout: bool,
33    ) -> Result<Self> {
34        StubInfoBuilder::from_project_root(default_module_name, project_root, is_mixed_layout)
35            .build()
36    }
37
38    pub fn generate(&self) -> Result<()> {
39        // Validate: Pure Rust layout can only have a single module
40        if !self.is_mixed_layout && self.modules.len() > 1 {
41            let module_names: Vec<_> = self.modules.keys().collect();
42            anyhow::bail!(
43                "Pure Rust layout does not support multiple modules or submodules. Found {} modules: {}. \
44                 Please use mixed Python/Rust layout (add `python-source` to [tool.maturin] in pyproject.toml) \
45                 if you need multiple modules or submodules.",
46                self.modules.len(),
47                module_names.iter().map(|s| format!("'{}'", s)).collect::<Vec<_>>().join(", ")
48            );
49        }
50
51        for (name, module) in self.modules.iter() {
52            // Convert dashes to underscores for Python compatibility
53            let normalized_name = name.replace("-", "_");
54            let path = normalized_name.replace(".", "/");
55
56            // Determine destination path based solely on layout type
57            let dest = if self.is_mixed_layout {
58                // Mixed Python/Rust: Always use directory-based structure
59                self.python_root.join(&path).join("__init__.pyi")
60            } else {
61                // Pure Rust: Always use single file at root (use first segment of module name)
62                let package_name = normalized_name
63                    .split('.')
64                    .next()
65                    .filter(|s| !s.is_empty())
66                    .ok_or_else(|| {
67                        anyhow::anyhow!(
68                            "Module name is empty after normalization: original name was `{name}`"
69                        )
70                    })?;
71                self.python_root.join(format!("{package_name}.pyi"))
72            };
73
74            let dir = dest.parent().context("Cannot get parent directory")?;
75            if !dir.exists() {
76                fs::create_dir_all(dir)?;
77            }
78
79            let mut f = fs::File::create(&dest)?;
80            write!(f, "{module}")?;
81            log::info!(
82                "Generate stub file of a module `{name}` at {dest}",
83                dest = dest.display()
84            );
85        }
86        Ok(())
87    }
88}
89
90struct StubInfoBuilder {
91    modules: BTreeMap<String, Module>,
92    default_module_name: String,
93    python_root: PathBuf,
94    is_mixed_layout: bool,
95}
96
97impl StubInfoBuilder {
98    fn from_pyproject_toml(pyproject: PyProject) -> Self {
99        let is_mixed_layout = pyproject.python_source().is_some();
100        let python_root = pyproject
101            .python_source()
102            .unwrap_or(PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").unwrap()));
103
104        Self {
105            modules: BTreeMap::new(),
106            default_module_name: pyproject.module_name().to_string(),
107            python_root,
108            is_mixed_layout,
109        }
110    }
111
112    fn from_project_root(
113        default_module_name: String,
114        project_root: PathBuf,
115        is_mixed_layout: bool,
116    ) -> Self {
117        Self {
118            modules: BTreeMap::new(),
119            default_module_name,
120            python_root: project_root,
121            is_mixed_layout,
122        }
123    }
124
125    fn get_module(&mut self, name: Option<&str>) -> &mut Module {
126        let name = name.unwrap_or(&self.default_module_name).to_string();
127        let module = self.modules.entry(name.clone()).or_default();
128        module.name = name;
129        module.default_module_name = self.default_module_name.clone();
130        module
131    }
132
133    fn register_submodules(&mut self) {
134        let mut all_parent_child_pairs: Vec<(String, String)> = Vec::new();
135
136        // For each existing module, collect all parent-child relationships
137        for module in self.modules.keys() {
138            let path = module.split('.').collect::<Vec<_>>();
139
140            // Generate all parent paths and their immediate children
141            for i in 1..path.len() {
142                let parent = path[..i].join(".");
143                let child = path[i].to_string();
144                all_parent_child_pairs.push((parent, child));
145            }
146        }
147
148        // Group children by parent
149        let mut parent_to_children: BTreeMap<String, BTreeSet<String>> = BTreeMap::new();
150        for (parent, child) in all_parent_child_pairs {
151            parent_to_children.entry(parent).or_default().insert(child);
152        }
153
154        // Create or update all parent modules
155        for (parent, children) in parent_to_children {
156            let module = self.modules.entry(parent.clone()).or_default();
157            module.name = parent;
158            module.default_module_name = self.default_module_name.clone();
159            module.submodules.extend(children);
160        }
161    }
162
163    fn add_class(&mut self, info: &PyClassInfo) {
164        self.get_module(info.module)
165            .class
166            .insert((info.struct_id)(), ClassDef::from(info));
167    }
168
169    fn add_complex_enum(&mut self, info: &PyComplexEnumInfo) {
170        self.get_module(info.module)
171            .class
172            .insert((info.enum_id)(), ClassDef::from(info));
173    }
174
175    fn add_enum(&mut self, info: &PyEnumInfo) {
176        self.get_module(info.module)
177            .enum_
178            .insert((info.enum_id)(), EnumDef::from(info));
179    }
180
181    fn add_function(&mut self, info: &PyFunctionInfo) -> Result<()> {
182        let target = self
183            .get_module(info.module)
184            .function
185            .entry(info.name)
186            .or_default();
187
188        // Validation: Check for multiple non-overload functions
189        let new_func = FunctionDef::from(info);
190        if !new_func.is_overload {
191            let non_overload_count = target.iter().filter(|f| !f.is_overload).count();
192            if non_overload_count > 0 {
193                anyhow::bail!(
194                    "Multiple functions with name '{}' found without @overload decorator. \
195                     Please add @overload decorator to all variants.",
196                    info.name
197                );
198            }
199        }
200
201        target.push(new_func);
202        Ok(())
203    }
204
205    fn add_variable(&mut self, info: &PyVariableInfo) {
206        self.get_module(Some(info.module))
207            .variables
208            .insert(info.name, VariableDef::from(info));
209    }
210
211    fn add_module_doc(&mut self, info: &ModuleDocInfo) {
212        self.get_module(Some(info.module)).doc = (info.doc)();
213    }
214
215    fn add_methods(&mut self, info: &PyMethodsInfo) -> Result<()> {
216        let struct_id = (info.struct_id)();
217        for module in self.modules.values_mut() {
218            if let Some(entry) = module.class.get_mut(&struct_id) {
219                for attr in info.attrs {
220                    entry.attrs.push(MemberDef {
221                        name: attr.name,
222                        r#type: (attr.r#type)(),
223                        doc: attr.doc,
224                        default: attr.default.map(|f| f()),
225                        deprecated: attr.deprecated.clone(),
226                    });
227                }
228                for getter in info.getters {
229                    entry
230                        .getter_setters
231                        .entry(getter.name.to_string())
232                        .or_default()
233                        .0 = Some(MemberDef {
234                        name: getter.name,
235                        r#type: (getter.r#type)(),
236                        doc: getter.doc,
237                        default: getter.default.map(|f| f()),
238                        deprecated: getter.deprecated.clone(),
239                    });
240                }
241                for setter in info.setters {
242                    entry
243                        .getter_setters
244                        .entry(setter.name.to_string())
245                        .or_default()
246                        .1 = Some(MemberDef {
247                        name: setter.name,
248                        r#type: (setter.r#type)(),
249                        doc: setter.doc,
250                        default: setter.default.map(|f| f()),
251                        deprecated: setter.deprecated.clone(),
252                    });
253                }
254                for method in info.methods {
255                    let entries = entry.methods.entry(method.name.to_string()).or_default();
256
257                    // Validation: Check for multiple non-overload methods
258                    let new_method = MethodDef::from(method);
259                    if !new_method.is_overload {
260                        let non_overload_count = entries.iter().filter(|m| !m.is_overload).count();
261                        if non_overload_count > 0 {
262                            anyhow::bail!(
263                                "Multiple methods with name '{}' in class '{}' found without @overload decorator. \
264                                 Please add @overload decorator to all variants.",
265                                method.name, entry.name
266                            );
267                        }
268                    }
269
270                    entries.push(new_method);
271                }
272                return Ok(());
273            } else if let Some(entry) = module.enum_.get_mut(&struct_id) {
274                for attr in info.attrs {
275                    entry.attrs.push(MemberDef {
276                        name: attr.name,
277                        r#type: (attr.r#type)(),
278                        doc: attr.doc,
279                        default: attr.default.map(|f| f()),
280                        deprecated: attr.deprecated.clone(),
281                    });
282                }
283                for getter in info.getters {
284                    entry.getters.push(MemberDef {
285                        name: getter.name,
286                        r#type: (getter.r#type)(),
287                        doc: getter.doc,
288                        default: getter.default.map(|f| f()),
289                        deprecated: getter.deprecated.clone(),
290                    });
291                }
292                for setter in info.setters {
293                    entry.setters.push(MemberDef {
294                        name: setter.name,
295                        r#type: (setter.r#type)(),
296                        doc: setter.doc,
297                        default: setter.default.map(|f| f()),
298                        deprecated: setter.deprecated.clone(),
299                    });
300                }
301                for method in info.methods {
302                    // Validation: Check for multiple non-overload methods
303                    let new_method = MethodDef::from(method);
304                    if !new_method.is_overload {
305                        let non_overload_count = entry
306                            .methods
307                            .iter()
308                            .filter(|m| m.name == method.name && !m.is_overload)
309                            .count();
310                        if non_overload_count > 0 {
311                            anyhow::bail!(
312                                "Multiple methods with name '{}' in enum '{}' found without @overload decorator. \
313                                 Please add @overload decorator to all variants.",
314                                method.name, entry.name
315                            );
316                        }
317                    }
318
319                    entry.methods.push(new_method);
320                }
321                return Ok(());
322            }
323        }
324        unreachable!("Missing struct_id/enum_id = {:?}", struct_id);
325    }
326
327    fn build(mut self) -> Result<StubInfo> {
328        for info in inventory::iter::<PyClassInfo> {
329            self.add_class(info);
330        }
331        for info in inventory::iter::<PyComplexEnumInfo> {
332            self.add_complex_enum(info);
333        }
334        for info in inventory::iter::<PyEnumInfo> {
335            self.add_enum(info);
336        }
337        for info in inventory::iter::<PyFunctionInfo> {
338            self.add_function(info)?;
339        }
340        for info in inventory::iter::<PyVariableInfo> {
341            self.add_variable(info);
342        }
343        for info in inventory::iter::<ModuleDocInfo> {
344            self.add_module_doc(info);
345        }
346        // Sort PyMethodsInfo by source location for deterministic IndexMap insertion order
347        let mut methods_infos: Vec<&PyMethodsInfo> = inventory::iter::<PyMethodsInfo>().collect();
348        methods_infos.sort_by_key(|info| (info.file, info.line, info.column));
349        for info in methods_infos {
350            self.add_methods(info)?;
351        }
352        self.register_submodules();
353        Ok(StubInfo {
354            modules: self.modules,
355            python_root: self.python_root,
356            is_mixed_layout: self.is_mixed_layout,
357        })
358    }
359}
360
361#[cfg(test)]
362mod tests {
363    use super::*;
364
365    #[test]
366    fn test_register_submodules_creates_empty_parent_modules() {
367        let mut builder =
368            StubInfoBuilder::from_project_root("test_module".to_string(), "/tmp".into(), false);
369
370        // Simulate a module with only submodules
371        builder.modules.insert(
372            "test_module.sub_mod".to_string(),
373            Module {
374                name: "test_module.sub_mod".to_string(),
375                default_module_name: "test_module".to_string(),
376                ..Default::default()
377            },
378        );
379
380        builder.register_submodules();
381
382        // Check that the empty parent module was created
383        assert!(builder.modules.contains_key("test_module"));
384        let parent_module = &builder.modules["test_module"];
385        assert_eq!(parent_module.name, "test_module");
386        assert!(parent_module.submodules.contains("sub_mod"));
387
388        // Verify the submodule still exists
389        assert!(builder.modules.contains_key("test_module.sub_mod"));
390    }
391
392    #[test]
393    fn test_register_submodules_with_multiple_levels() {
394        let mut builder =
395            StubInfoBuilder::from_project_root("root".to_string(), "/tmp".into(), false);
396
397        // Simulate deeply nested modules
398        builder.modules.insert(
399            "root.level1.level2.deep_mod".to_string(),
400            Module {
401                name: "root.level1.level2.deep_mod".to_string(),
402                default_module_name: "root".to_string(),
403                ..Default::default()
404            },
405        );
406
407        builder.register_submodules();
408
409        // Check that all intermediate parent modules were created
410        assert!(builder.modules.contains_key("root"));
411        assert!(builder.modules.contains_key("root.level1"));
412        assert!(builder.modules.contains_key("root.level1.level2"));
413        assert!(builder.modules.contains_key("root.level1.level2.deep_mod"));
414
415        // Check submodule relationships
416        assert!(builder.modules["root"].submodules.contains("level1"));
417        assert!(builder.modules["root.level1"].submodules.contains("level2"));
418        assert!(builder.modules["root.level1.level2"]
419            .submodules
420            .contains("deep_mod"));
421    }
422
423    #[test]
424    fn test_pure_layout_rejects_multiple_modules() {
425        // Pure Rust layout should reject multiple modules (whether submodules or top-level)
426        let stub_info = StubInfo {
427            modules: {
428                let mut map = BTreeMap::new();
429                map.insert(
430                    "mymodule".to_string(),
431                    Module {
432                        name: "mymodule".to_string(),
433                        default_module_name: "mymodule".to_string(),
434                        ..Default::default()
435                    },
436                );
437                map.insert(
438                    "mymodule.sub".to_string(),
439                    Module {
440                        name: "mymodule.sub".to_string(),
441                        default_module_name: "mymodule".to_string(),
442                        ..Default::default()
443                    },
444                );
445                map
446            },
447            python_root: PathBuf::from("/tmp"),
448            is_mixed_layout: false,
449        };
450
451        let result = stub_info.generate();
452        assert!(result.is_err());
453        let err_msg = result.unwrap_err().to_string();
454        assert!(
455            err_msg.contains("Pure Rust layout does not support multiple modules or submodules")
456        );
457    }
458}