pyo3_stub_gen/generate/
stub_info.rs

1use crate::{
2    generate::*,
3    pyproject::{PyProject, StubGenConfig},
4    type_info::*,
5};
6use anyhow::{Context, Result};
7use std::{
8    collections::{BTreeMap, BTreeSet},
9    fs,
10    path::*,
11};
12
13#[derive(Debug, Clone, PartialEq)]
14pub struct StubInfo {
15    pub modules: BTreeMap<String, Module>,
16    pub python_root: PathBuf,
17    /// Whether this is a mixed Python/Rust layout (has `python-source` in pyproject.toml)
18    pub is_mixed_layout: bool,
19    /// Configuration options for stub generation
20    pub config: StubGenConfig,
21}
22
23impl StubInfo {
24    /// Initialize [StubInfo] from a `pyproject.toml` file in `CARGO_MANIFEST_DIR`.
25    /// This is automatically set up by the [crate::define_stub_info_gatherer] macro.
26    pub fn from_pyproject_toml(path: impl AsRef<Path>) -> Result<Self> {
27        let pyproject = PyProject::parse_toml(path)?;
28        let config = pyproject.stub_gen_config();
29        StubInfoBuilder::from_pyproject_toml(pyproject, config).build()
30    }
31
32    /// Initialize [StubInfo] with a specific module name, project root, and configuration.
33    /// This must be placed in your PyO3 library crate, i.e. the same crate where [inventory::submit]ted,
34    /// not in the `gen_stub` executables due to [inventory]'s mechanism.
35    pub fn from_project_root(
36        default_module_name: String,
37        project_root: PathBuf,
38        is_mixed_layout: bool,
39        config: StubGenConfig,
40    ) -> Result<Self> {
41        StubInfoBuilder::from_project_root(
42            default_module_name,
43            project_root,
44            is_mixed_layout,
45            config,
46        )
47        .build()
48    }
49
50    pub fn generate(&self) -> Result<()> {
51        // Validate: Pure Rust layout can only have a single module
52        if !self.is_mixed_layout && self.modules.len() > 1 {
53            let module_names: Vec<_> = self.modules.keys().collect();
54            anyhow::bail!(
55                "Pure Rust layout does not support multiple modules or submodules. Found {} modules: {}. \
56                 Please use mixed Python/Rust layout (add `python-source` to [tool.maturin] in pyproject.toml) \
57                 if you need multiple modules or submodules.",
58                self.modules.len(),
59                module_names.iter().map(|s| format!("'{}'", s)).collect::<Vec<_>>().join(", ")
60            );
61        }
62
63        for (name, module) in self.modules.iter() {
64            // Convert dashes to underscores for Python compatibility
65            let normalized_name = name.replace("-", "_");
66            let path = normalized_name.replace(".", "/");
67
68            // Determine destination path based solely on layout type
69            let dest = if self.is_mixed_layout {
70                // Mixed Python/Rust: Always use directory-based structure
71                self.python_root.join(&path).join("__init__.pyi")
72            } else {
73                // Pure Rust: Always use single file at root (use first segment of module name)
74                let package_name = normalized_name
75                    .split('.')
76                    .next()
77                    .filter(|s| !s.is_empty())
78                    .ok_or_else(|| {
79                        anyhow::anyhow!(
80                            "Module name is empty after normalization: original name was `{name}`"
81                        )
82                    })?;
83                self.python_root.join(format!("{package_name}.pyi"))
84            };
85
86            let dir = dest.parent().context("Cannot get parent directory")?;
87            if !dir.exists() {
88                fs::create_dir_all(dir)?;
89            }
90
91            let content = module.format_with_config(self.config.use_type_statement);
92            fs::write(&dest, content)?;
93            log::info!(
94                "Generate stub file of a module `{name}` at {dest}",
95                dest = dest.display()
96            );
97        }
98        Ok(())
99    }
100}
101
102struct StubInfoBuilder {
103    modules: BTreeMap<String, Module>,
104    default_module_name: String,
105    python_root: PathBuf,
106    is_mixed_layout: bool,
107    config: StubGenConfig,
108}
109
110impl StubInfoBuilder {
111    fn from_pyproject_toml(pyproject: PyProject, config: StubGenConfig) -> Self {
112        let is_mixed_layout = pyproject.python_source().is_some();
113        let python_root = pyproject
114            .python_source()
115            .unwrap_or(PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").unwrap()));
116
117        Self {
118            modules: BTreeMap::new(),
119            default_module_name: pyproject.module_name().to_string(),
120            python_root,
121            is_mixed_layout,
122            config,
123        }
124    }
125
126    fn from_project_root(
127        default_module_name: String,
128        project_root: PathBuf,
129        is_mixed_layout: bool,
130        config: StubGenConfig,
131    ) -> Self {
132        Self {
133            modules: BTreeMap::new(),
134            default_module_name,
135            python_root: project_root,
136            is_mixed_layout,
137            config,
138        }
139    }
140
141    fn get_module(&mut self, name: Option<&str>) -> &mut Module {
142        let name = name.unwrap_or(&self.default_module_name).to_string();
143        let module = self.modules.entry(name.clone()).or_default();
144        module.name = name;
145        module.default_module_name = self.default_module_name.clone();
146        module
147    }
148
149    fn register_submodules(&mut self) {
150        let mut all_parent_child_pairs: Vec<(String, String)> = Vec::new();
151
152        // For each existing module, collect all parent-child relationships
153        for module in self.modules.keys() {
154            let path = module.split('.').collect::<Vec<_>>();
155
156            // Generate all parent paths and their immediate children
157            for i in 1..path.len() {
158                let parent = path[..i].join(".");
159                let child = path[i].to_string();
160                all_parent_child_pairs.push((parent, child));
161            }
162        }
163
164        // Group children by parent
165        let mut parent_to_children: BTreeMap<String, BTreeSet<String>> = BTreeMap::new();
166        for (parent, child) in all_parent_child_pairs {
167            parent_to_children.entry(parent).or_default().insert(child);
168        }
169
170        // Create or update all parent modules
171        for (parent, children) in parent_to_children {
172            let module = self.modules.entry(parent.clone()).or_default();
173            module.name = parent;
174            module.default_module_name = self.default_module_name.clone();
175            module.submodules.extend(children);
176        }
177    }
178
179    fn add_class(&mut self, info: &PyClassInfo) {
180        let mut class_def = ClassDef::from(info);
181        class_def.resolve_default_modules(&self.default_module_name);
182        self.get_module(info.module)
183            .class
184            .insert((info.struct_id)(), class_def);
185    }
186
187    fn add_complex_enum(&mut self, info: &PyComplexEnumInfo) {
188        let mut class_def = ClassDef::from(info);
189        class_def.resolve_default_modules(&self.default_module_name);
190        self.get_module(info.module)
191            .class
192            .insert((info.enum_id)(), class_def);
193    }
194
195    fn add_enum(&mut self, info: &PyEnumInfo) {
196        self.get_module(info.module)
197            .enum_
198            .insert((info.enum_id)(), EnumDef::from(info));
199    }
200
201    fn add_function(&mut self, info: &PyFunctionInfo) -> Result<()> {
202        // Clone default_module_name to avoid borrow checker issues
203        let default_module_name = self.default_module_name.clone();
204
205        let target = self
206            .get_module(info.module)
207            .function
208            .entry(info.name)
209            .or_default();
210
211        // Validation: Check for multiple non-overload functions
212        let mut new_func = FunctionDef::from(info);
213        new_func.resolve_default_modules(&default_module_name);
214
215        if !new_func.is_overload {
216            let non_overload_count = target.iter().filter(|f| !f.is_overload).count();
217            if non_overload_count > 0 {
218                anyhow::bail!(
219                    "Multiple functions with name '{}' found without @overload decorator. \
220                     Please add @overload decorator to all variants.",
221                    info.name
222                );
223            }
224        }
225
226        target.push(new_func);
227        Ok(())
228    }
229
230    fn add_variable(&mut self, info: &PyVariableInfo) {
231        self.get_module(Some(info.module))
232            .variables
233            .insert(info.name, VariableDef::from(info));
234    }
235
236    fn add_type_alias(&mut self, info: &TypeAliasInfo) {
237        self.get_module(Some(info.module))
238            .type_aliases
239            .insert(info.name, TypeAliasDef::from(info));
240    }
241
242    fn add_module_doc(&mut self, info: &ModuleDocInfo) {
243        self.get_module(Some(info.module)).doc = (info.doc)();
244    }
245
246    fn add_module_export(&mut self, info: &ReexportModuleMembers) {
247        let use_wildcard = info.items.is_none();
248        let items = info
249            .items
250            .map(|items| items.iter().map(|s| s.to_string()).collect())
251            .unwrap_or_default();
252
253        self.get_module(Some(info.target_module))
254            .module_re_exports
255            .push(ModuleReExport {
256                source_module: info.source_module.to_string(),
257                items,
258                use_wildcard_import: use_wildcard,
259            });
260    }
261
262    fn add_verbatim_export(&mut self, info: &ExportVerbatim) {
263        self.get_module(Some(info.target_module))
264            .verbatim_all_entries
265            .insert(info.name.to_string());
266    }
267
268    fn add_exclude(&mut self, info: &ExcludeFromAll) {
269        self.get_module(Some(info.target_module))
270            .excluded_all_entries
271            .insert(info.name.to_string());
272    }
273
274    fn resolve_wildcard_re_exports(&mut self) -> Result<()> {
275        // Collect wildcard re-exports and their resolved items for __all__
276        let mut resolutions: Vec<(String, usize, Vec<String>)> = Vec::new();
277
278        for (module_name, module) in &self.modules {
279            for (idx, re_export) in module.module_re_exports.iter().enumerate() {
280                if re_export.use_wildcard_import && re_export.items.is_empty() {
281                    // Wildcard - resolve items for __all__
282                    if let Some(source_mod) = self.modules.get(&re_export.source_module) {
283                        // Internal module - collect all public items that would be in __all__
284                        let mut items = Vec::new();
285                        for class in source_mod.class.values() {
286                            if !class.name.starts_with('_') {
287                                items.push(class.name.to_string());
288                            }
289                        }
290                        for enum_ in source_mod.enum_.values() {
291                            if !enum_.name.starts_with('_') {
292                                items.push(enum_.name.to_string());
293                            }
294                        }
295                        for func_name in source_mod.function.keys() {
296                            if !func_name.starts_with('_') {
297                                items.push(func_name.to_string());
298                            }
299                        }
300                        for var_name in source_mod.variables.keys() {
301                            if !var_name.starts_with('_') {
302                                items.push(var_name.to_string());
303                            }
304                        }
305                        for alias_name in source_mod.type_aliases.keys() {
306                            if !alias_name.starts_with('_') {
307                                items.push(alias_name.to_string());
308                            }
309                        }
310                        // FIX: Add underscore filtering for submodules in wildcard resolution
311                        for submod in &source_mod.submodules {
312                            if !submod.starts_with('_') {
313                                items.push(submod.to_string());
314                            }
315                        }
316                        resolutions.push((module_name.clone(), idx, items));
317                    } else {
318                        // External module - cannot resolve, error
319                        anyhow::bail!(
320                            "Cannot resolve wildcard re-export in module '{}': source module '{}' not found. \
321                             Wildcard re-exports only work with internal modules.",
322                            module_name,
323                            re_export.source_module
324                        );
325                    }
326                }
327            }
328        }
329
330        // Apply resolutions (populate items for wildcard imports)
331        for (module_name, idx, items) in resolutions {
332            if let Some(module) = self.modules.get_mut(&module_name) {
333                module.module_re_exports[idx].items = items;
334            }
335        }
336
337        Ok(())
338    }
339
340    fn add_methods(&mut self, info: &PyMethodsInfo) -> Result<()> {
341        let struct_id = (info.struct_id)();
342        for module in self.modules.values_mut() {
343            if let Some(entry) = module.class.get_mut(&struct_id) {
344                for attr in info.attrs {
345                    entry.attrs.push(MemberDef {
346                        name: attr.name,
347                        r#type: (attr.r#type)(),
348                        doc: attr.doc,
349                        default: attr.default.map(|f| f()),
350                        deprecated: attr.deprecated.clone(),
351                    });
352                }
353                for getter in info.getters {
354                    entry
355                        .getter_setters
356                        .entry(getter.name.to_string())
357                        .or_default()
358                        .0 = Some(MemberDef {
359                        name: getter.name,
360                        r#type: (getter.r#type)(),
361                        doc: getter.doc,
362                        default: getter.default.map(|f| f()),
363                        deprecated: getter.deprecated.clone(),
364                    });
365                }
366                for setter in info.setters {
367                    entry
368                        .getter_setters
369                        .entry(setter.name.to_string())
370                        .or_default()
371                        .1 = Some(MemberDef {
372                        name: setter.name,
373                        r#type: (setter.r#type)(),
374                        doc: setter.doc,
375                        default: setter.default.map(|f| f()),
376                        deprecated: setter.deprecated.clone(),
377                    });
378                }
379                for method in info.methods {
380                    let entries = entry.methods.entry(method.name.to_string()).or_default();
381
382                    // Validation: Check for multiple non-overload methods
383                    let new_method = MethodDef::from(method);
384                    if !new_method.is_overload {
385                        let non_overload_count = entries.iter().filter(|m| !m.is_overload).count();
386                        if non_overload_count > 0 {
387                            anyhow::bail!(
388                                "Multiple methods with name '{}' in class '{}' found without @overload decorator. \
389                                 Please add @overload decorator to all variants.",
390                                method.name, entry.name
391                            );
392                        }
393                    }
394
395                    entries.push(new_method);
396                }
397                return Ok(());
398            } else if let Some(entry) = module.enum_.get_mut(&struct_id) {
399                for attr in info.attrs {
400                    entry.attrs.push(MemberDef {
401                        name: attr.name,
402                        r#type: (attr.r#type)(),
403                        doc: attr.doc,
404                        default: attr.default.map(|f| f()),
405                        deprecated: attr.deprecated.clone(),
406                    });
407                }
408                for getter in info.getters {
409                    entry.getters.push(MemberDef {
410                        name: getter.name,
411                        r#type: (getter.r#type)(),
412                        doc: getter.doc,
413                        default: getter.default.map(|f| f()),
414                        deprecated: getter.deprecated.clone(),
415                    });
416                }
417                for setter in info.setters {
418                    entry.setters.push(MemberDef {
419                        name: setter.name,
420                        r#type: (setter.r#type)(),
421                        doc: setter.doc,
422                        default: setter.default.map(|f| f()),
423                        deprecated: setter.deprecated.clone(),
424                    });
425                }
426                for method in info.methods {
427                    // Validation: Check for multiple non-overload methods
428                    let new_method = MethodDef::from(method);
429                    if !new_method.is_overload {
430                        let non_overload_count = entry
431                            .methods
432                            .iter()
433                            .filter(|m| m.name == method.name && !m.is_overload)
434                            .count();
435                        if non_overload_count > 0 {
436                            anyhow::bail!(
437                                "Multiple methods with name '{}' in enum '{}' found without @overload decorator. \
438                                 Please add @overload decorator to all variants.",
439                                method.name, entry.name
440                            );
441                        }
442                    }
443
444                    entry.methods.push(new_method);
445                }
446                return Ok(());
447            }
448        }
449        unreachable!("Missing struct_id/enum_id = {:?}", struct_id);
450    }
451
452    fn build(mut self) -> Result<StubInfo> {
453        for info in inventory::iter::<PyClassInfo> {
454            self.add_class(info);
455        }
456        for info in inventory::iter::<PyComplexEnumInfo> {
457            self.add_complex_enum(info);
458        }
459        for info in inventory::iter::<PyEnumInfo> {
460            self.add_enum(info);
461        }
462        for info in inventory::iter::<PyFunctionInfo> {
463            self.add_function(info)?;
464        }
465        for info in inventory::iter::<PyVariableInfo> {
466            self.add_variable(info);
467        }
468        for info in inventory::iter::<TypeAliasInfo> {
469            self.add_type_alias(info);
470        }
471        for info in inventory::iter::<ModuleDocInfo> {
472            self.add_module_doc(info);
473        }
474        // Sort PyMethodsInfo by source location for deterministic IndexMap insertion order
475        let mut methods_infos: Vec<&PyMethodsInfo> = inventory::iter::<PyMethodsInfo>().collect();
476        methods_infos.sort_by_key(|info| (info.file, info.line, info.column));
477        for info in methods_infos {
478            self.add_methods(info)?;
479        }
480        // Collect __all__ export directives
481        for info in inventory::iter::<ReexportModuleMembers> {
482            self.add_module_export(info);
483        }
484        for info in inventory::iter::<ExportVerbatim> {
485            self.add_verbatim_export(info);
486        }
487        for info in inventory::iter::<ExcludeFromAll> {
488            self.add_exclude(info);
489        }
490        self.register_submodules();
491
492        // Resolve wildcard re-exports
493        self.resolve_wildcard_re_exports()?;
494
495        Ok(StubInfo {
496            modules: self.modules,
497            python_root: self.python_root,
498            is_mixed_layout: self.is_mixed_layout,
499            config: self.config,
500        })
501    }
502}
503
504#[cfg(test)]
505mod tests {
506    use super::*;
507
508    #[test]
509    fn test_register_submodules_creates_empty_parent_modules() {
510        let mut builder = StubInfoBuilder::from_project_root(
511            "test_module".to_string(),
512            "/tmp".into(),
513            false,
514            StubGenConfig::default(),
515        );
516
517        // Simulate a module with only submodules
518        builder.modules.insert(
519            "test_module.sub_mod".to_string(),
520            Module {
521                name: "test_module.sub_mod".to_string(),
522                default_module_name: "test_module".to_string(),
523                ..Default::default()
524            },
525        );
526
527        builder.register_submodules();
528
529        // Check that the empty parent module was created
530        assert!(builder.modules.contains_key("test_module"));
531        let parent_module = &builder.modules["test_module"];
532        assert_eq!(parent_module.name, "test_module");
533        assert!(parent_module.submodules.contains("sub_mod"));
534
535        // Verify the submodule still exists
536        assert!(builder.modules.contains_key("test_module.sub_mod"));
537    }
538
539    #[test]
540    fn test_register_submodules_with_multiple_levels() {
541        let mut builder = StubInfoBuilder::from_project_root(
542            "root".to_string(),
543            "/tmp".into(),
544            false,
545            StubGenConfig::default(),
546        );
547
548        // Simulate deeply nested modules
549        builder.modules.insert(
550            "root.level1.level2.deep_mod".to_string(),
551            Module {
552                name: "root.level1.level2.deep_mod".to_string(),
553                default_module_name: "root".to_string(),
554                ..Default::default()
555            },
556        );
557
558        builder.register_submodules();
559
560        // Check that all intermediate parent modules were created
561        assert!(builder.modules.contains_key("root"));
562        assert!(builder.modules.contains_key("root.level1"));
563        assert!(builder.modules.contains_key("root.level1.level2"));
564        assert!(builder.modules.contains_key("root.level1.level2.deep_mod"));
565
566        // Check submodule relationships
567        assert!(builder.modules["root"].submodules.contains("level1"));
568        assert!(builder.modules["root.level1"].submodules.contains("level2"));
569        assert!(builder.modules["root.level1.level2"]
570            .submodules
571            .contains("deep_mod"));
572    }
573
574    #[test]
575    fn test_pure_layout_rejects_multiple_modules() {
576        // Pure Rust layout should reject multiple modules (whether submodules or top-level)
577        let stub_info = StubInfo {
578            modules: {
579                let mut map = BTreeMap::new();
580                map.insert(
581                    "mymodule".to_string(),
582                    Module {
583                        name: "mymodule".to_string(),
584                        default_module_name: "mymodule".to_string(),
585                        ..Default::default()
586                    },
587                );
588                map.insert(
589                    "mymodule.sub".to_string(),
590                    Module {
591                        name: "mymodule.sub".to_string(),
592                        default_module_name: "mymodule".to_string(),
593                        ..Default::default()
594                    },
595                );
596                map
597            },
598            python_root: PathBuf::from("/tmp"),
599            is_mixed_layout: false,
600            config: StubGenConfig::default(),
601        };
602
603        let result = stub_info.generate();
604        assert!(result.is_err());
605        let err_msg = result.unwrap_err().to_string();
606        assert!(
607            err_msg.contains("Pure Rust layout does not support multiple modules or submodules")
608        );
609    }
610}