pyo3_stub_gen/generate/
stub_info.rs

1use crate::{
2    generate::{docstring::normalize_docstring, *},
3    pyproject::{PyProject, StubGenConfig},
4    type_info::*,
5};
6use anyhow::{Context, Result};
7use std::{
8    collections::{BTreeMap, BTreeSet},
9    fs,
10    path::*,
11};
12
13#[derive(Debug, Clone, PartialEq)]
14pub struct StubInfo {
15    pub modules: BTreeMap<String, Module>,
16    pub python_root: PathBuf,
17    /// Whether this is a mixed Python/Rust layout (has `python-source` in pyproject.toml)
18    pub is_mixed_layout: bool,
19    /// Configuration options for stub generation
20    pub config: StubGenConfig,
21    /// Directory containing pyproject.toml (for relative path calculations)
22    pub pyproject_dir: Option<PathBuf>,
23}
24
25impl StubInfo {
26    /// Initialize [StubInfo] from a `pyproject.toml` file in `CARGO_MANIFEST_DIR`.
27    /// This is automatically set up by the [crate::define_stub_info_gatherer] macro.
28    pub fn from_pyproject_toml(path: impl AsRef<Path>) -> Result<Self> {
29        let path = path.as_ref();
30        let pyproject = PyProject::parse_toml(path)?;
31        let mut config = pyproject.stub_gen_config();
32
33        // Resolve doc_gen paths relative to pyproject.toml location
34        if let Some(resolved_doc_gen) = pyproject.doc_gen_config_resolved() {
35            config.doc_gen = Some(resolved_doc_gen);
36        }
37
38        let pyproject_dir = path.parent().map(|p| p.to_path_buf());
39
40        let mut stub_info = StubInfoBuilder::from_pyproject_toml(pyproject, config).build()?;
41        stub_info.pyproject_dir = pyproject_dir;
42        Ok(stub_info)
43    }
44
45    /// Initialize [StubInfo] with a specific module name, project root, and configuration.
46    /// This must be placed in your PyO3 library crate, i.e. the same crate where [inventory::submit]ted,
47    /// not in the `gen_stub` executables due to [inventory]'s mechanism.
48    pub fn from_project_root(
49        default_module_name: String,
50        project_root: PathBuf,
51        is_mixed_layout: bool,
52        config: StubGenConfig,
53    ) -> Result<Self> {
54        StubInfoBuilder::from_project_root(
55            default_module_name,
56            project_root,
57            is_mixed_layout,
58            config,
59        )
60        .build()
61    }
62
63    pub fn generate(&self) -> Result<()> {
64        // Validate: Pure Rust layout can only have a single module
65        if !self.is_mixed_layout && self.modules.len() > 1 {
66            let module_names: Vec<_> = self.modules.keys().collect();
67            anyhow::bail!(
68                "Pure Rust layout does not support multiple modules or submodules. Found {} modules: {}. \
69                 Please use mixed Python/Rust layout (add `python-source` to [tool.maturin] in pyproject.toml) \
70                 if you need multiple modules or submodules.",
71                self.modules.len(),
72                module_names.iter().map(|s| format!("'{}'", s)).collect::<Vec<_>>().join(", ")
73            );
74        }
75
76        // Generate stub files
77        for (name, module) in self.modules.iter() {
78            // Convert dashes to underscores for Python compatibility
79            let normalized_name = name.replace("-", "_");
80            let path = normalized_name.replace(".", "/");
81
82            // Determine destination path based solely on layout type
83            let dest = if self.is_mixed_layout {
84                // Mixed Python/Rust: Always use directory-based structure
85                self.python_root.join(&path).join("__init__.pyi")
86            } else {
87                // Pure Rust: Always use single file at root (use first segment of module name)
88                let package_name = normalized_name
89                    .split('.')
90                    .next()
91                    .filter(|s| !s.is_empty())
92                    .ok_or_else(|| {
93                        anyhow::anyhow!(
94                            "Module name is empty after normalization: original name was `{name}`"
95                        )
96                    })?;
97                self.python_root.join(format!("{package_name}.pyi"))
98            };
99
100            let dir = dest.parent().context("Cannot get parent directory")?;
101            if !dir.exists() {
102                fs::create_dir_all(dir)?;
103            }
104
105            let content = module.format_with_config(self.config.use_type_statement);
106            fs::write(&dest, content)?;
107            log::info!(
108                "Generate stub file of a module `{name}` at {dest}",
109                dest = dest.display()
110            );
111        }
112
113        // Generate documentation if configured
114        if let Some(doc_config) = &self.config.doc_gen {
115            self.generate_docs(doc_config)?;
116        }
117
118        Ok(())
119    }
120
121    fn generate_docs(&self, config: &crate::docgen::DocGenConfig) -> Result<()> {
122        log::info!("Generating API documentation...");
123
124        // 1. Build DocPackage IR
125        let doc_package = crate::docgen::builder::DocPackageBuilder::new(self).build()?;
126
127        // 2. Render to JSON
128        let json_output = crate::docgen::render::render_to_json(&doc_package)?;
129
130        // 3. Write files
131        fs::create_dir_all(&config.output_dir)?;
132        fs::write(config.output_dir.join(&config.json_output), json_output)?;
133
134        // 4. Copy Sphinx extension
135        crate::docgen::render::copy_sphinx_extension(&config.output_dir)?;
136
137        // 5. Generate RST files
138        if config.separate_pages {
139            crate::docgen::render::generate_module_pages(&doc_package, &config.output_dir)?;
140            crate::docgen::render::generate_index_rst(&doc_package, &config.output_dir, config)?;
141            log::info!("Generated separate .rst pages for each module");
142        }
143
144        log::info!("Generated API docs at {:?}", config.output_dir);
145        Ok(())
146    }
147}
148
149struct StubInfoBuilder {
150    modules: BTreeMap<String, Module>,
151    default_module_name: String,
152    python_root: PathBuf,
153    is_mixed_layout: bool,
154    config: StubGenConfig,
155}
156
157impl StubInfoBuilder {
158    fn from_pyproject_toml(pyproject: PyProject, config: StubGenConfig) -> Self {
159        let is_mixed_layout = pyproject.python_source().is_some();
160        let python_root = pyproject
161            .python_source()
162            .unwrap_or(PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").unwrap()));
163
164        Self {
165            modules: BTreeMap::new(),
166            default_module_name: pyproject.module_name().to_string(),
167            python_root,
168            is_mixed_layout,
169            config,
170        }
171    }
172
173    fn from_project_root(
174        default_module_name: String,
175        project_root: PathBuf,
176        is_mixed_layout: bool,
177        config: StubGenConfig,
178    ) -> Self {
179        Self {
180            modules: BTreeMap::new(),
181            default_module_name,
182            python_root: project_root,
183            is_mixed_layout,
184            config,
185        }
186    }
187
188    fn get_module(&mut self, name: Option<&str>) -> &mut Module {
189        let name = name.unwrap_or(&self.default_module_name).to_string();
190        let module = self.modules.entry(name.clone()).or_default();
191        module.name = name;
192        module.default_module_name = self.default_module_name.clone();
193        module
194    }
195
196    fn register_submodules(&mut self) {
197        let mut all_parent_child_pairs: Vec<(String, String)> = Vec::new();
198
199        // For each existing module, collect all parent-child relationships
200        for module in self.modules.keys() {
201            let path = module.split('.').collect::<Vec<_>>();
202
203            // Generate all parent paths and their immediate children
204            for i in 1..path.len() {
205                let parent = path[..i].join(".");
206                let child = path[i].to_string();
207                all_parent_child_pairs.push((parent, child));
208            }
209        }
210
211        // Group children by parent
212        let mut parent_to_children: BTreeMap<String, BTreeSet<String>> = BTreeMap::new();
213        for (parent, child) in all_parent_child_pairs {
214            parent_to_children.entry(parent).or_default().insert(child);
215        }
216
217        // Create or update all parent modules
218        for (parent, children) in parent_to_children {
219            let module = self.modules.entry(parent.clone()).or_default();
220            module.name = parent;
221            module.default_module_name = self.default_module_name.clone();
222            module.submodules.extend(children);
223        }
224    }
225
226    fn add_class(&mut self, info: &PyClassInfo) {
227        let mut class_def = ClassDef::from(info);
228        class_def.resolve_default_modules(&self.default_module_name);
229        self.get_module(info.module)
230            .class
231            .insert((info.struct_id)(), class_def);
232    }
233
234    fn add_complex_enum(&mut self, info: &PyComplexEnumInfo) {
235        let mut class_def = ClassDef::from(info);
236        class_def.resolve_default_modules(&self.default_module_name);
237        self.get_module(info.module)
238            .class
239            .insert((info.enum_id)(), class_def);
240    }
241
242    fn add_enum(&mut self, info: &PyEnumInfo) {
243        self.get_module(info.module)
244            .enum_
245            .insert((info.enum_id)(), EnumDef::from(info));
246    }
247
248    fn add_function(&mut self, info: &PyFunctionInfo) -> Result<()> {
249        // Clone default_module_name to avoid borrow checker issues
250        let default_module_name = self.default_module_name.clone();
251
252        let target = self
253            .get_module(info.module)
254            .function
255            .entry(info.name)
256            .or_default();
257
258        // Validation: Check for multiple non-overload functions
259        let mut new_func = FunctionDef::from(info);
260        new_func.resolve_default_modules(&default_module_name);
261
262        if !new_func.is_overload {
263            let non_overload_count = target.iter().filter(|f| !f.is_overload).count();
264            if non_overload_count > 0 {
265                anyhow::bail!(
266                    "Multiple functions with name '{}' found without @overload decorator. \
267                     Please add @overload decorator to all variants.",
268                    info.name
269                );
270            }
271        }
272
273        target.push(new_func);
274        Ok(())
275    }
276
277    fn add_variable(&mut self, info: &PyVariableInfo) {
278        self.get_module(Some(info.module))
279            .variables
280            .insert(info.name, VariableDef::from(info));
281    }
282
283    fn add_type_alias(&mut self, info: &TypeAliasInfo) {
284        self.get_module(Some(info.module))
285            .type_aliases
286            .insert(info.name, TypeAliasDef::from(info));
287    }
288
289    fn add_module_doc(&mut self, info: &ModuleDocInfo) {
290        let raw_doc = (info.doc)();
291        self.get_module(Some(info.module)).doc = normalize_docstring(&raw_doc);
292    }
293
294    fn add_module_export(&mut self, info: &ReexportModuleMembers) {
295        let use_wildcard = info.items.is_none();
296        let items = info
297            .items
298            .map(|items| items.iter().map(|s| s.to_string()).collect())
299            .unwrap_or_default();
300
301        self.get_module(Some(info.target_module))
302            .module_re_exports
303            .push(ModuleReExport {
304                source_module: info.source_module.to_string(),
305                items,
306                use_wildcard_import: use_wildcard,
307            });
308    }
309
310    fn add_verbatim_export(&mut self, info: &ExportVerbatim) {
311        self.get_module(Some(info.target_module))
312            .verbatim_all_entries
313            .insert(info.name.to_string());
314    }
315
316    fn add_exclude(&mut self, info: &ExcludeFromAll) {
317        self.get_module(Some(info.target_module))
318            .excluded_all_entries
319            .insert(info.name.to_string());
320    }
321
322    fn resolve_wildcard_re_exports(&mut self) -> Result<()> {
323        // Collect wildcard re-exports and their resolved items for __all__
324        let mut resolutions: Vec<(String, usize, Vec<String>)> = Vec::new();
325
326        for (module_name, module) in &self.modules {
327            for (idx, re_export) in module.module_re_exports.iter().enumerate() {
328                if re_export.use_wildcard_import && re_export.items.is_empty() {
329                    // Wildcard - resolve items for __all__
330                    if let Some(source_mod) = self.modules.get(&re_export.source_module) {
331                        // Internal module - collect all public items that would be in __all__
332                        let mut items = Vec::new();
333                        for class in source_mod.class.values() {
334                            if !class.name.starts_with('_') {
335                                items.push(class.name.to_string());
336                            }
337                        }
338                        for enum_ in source_mod.enum_.values() {
339                            if !enum_.name.starts_with('_') {
340                                items.push(enum_.name.to_string());
341                            }
342                        }
343                        for func_name in source_mod.function.keys() {
344                            if !func_name.starts_with('_') {
345                                items.push(func_name.to_string());
346                            }
347                        }
348                        for var_name in source_mod.variables.keys() {
349                            if !var_name.starts_with('_') {
350                                items.push(var_name.to_string());
351                            }
352                        }
353                        for alias_name in source_mod.type_aliases.keys() {
354                            if !alias_name.starts_with('_') {
355                                items.push(alias_name.to_string());
356                            }
357                        }
358                        // FIX: Add underscore filtering for submodules in wildcard resolution
359                        for submod in &source_mod.submodules {
360                            if !submod.starts_with('_') {
361                                items.push(submod.to_string());
362                            }
363                        }
364                        resolutions.push((module_name.clone(), idx, items));
365                    } else {
366                        // External module - cannot resolve, error
367                        anyhow::bail!(
368                            "Cannot resolve wildcard re-export in module '{}': source module '{}' not found. \
369                             Wildcard re-exports only work with internal modules.",
370                            module_name,
371                            re_export.source_module
372                        );
373                    }
374                }
375            }
376        }
377
378        // Apply resolutions (populate items for wildcard imports)
379        for (module_name, idx, items) in resolutions {
380            if let Some(module) = self.modules.get_mut(&module_name) {
381                module.module_re_exports[idx].items = items;
382            }
383        }
384
385        Ok(())
386    }
387
388    fn add_methods(&mut self, info: &PyMethodsInfo) -> Result<()> {
389        let struct_id = (info.struct_id)();
390        for module in self.modules.values_mut() {
391            if let Some(entry) = module.class.get_mut(&struct_id) {
392                for attr in info.attrs {
393                    entry.attrs.push(MemberDef {
394                        name: attr.name,
395                        r#type: (attr.r#type)(),
396                        doc: attr.doc,
397                        default: attr.default.map(|f| f()),
398                        deprecated: attr.deprecated.clone(),
399                    });
400                }
401                for getter in info.getters {
402                    entry
403                        .getter_setters
404                        .entry(getter.name.to_string())
405                        .or_default()
406                        .0 = Some(MemberDef {
407                        name: getter.name,
408                        r#type: (getter.r#type)(),
409                        doc: getter.doc,
410                        default: getter.default.map(|f| f()),
411                        deprecated: getter.deprecated.clone(),
412                    });
413                }
414                for setter in info.setters {
415                    entry
416                        .getter_setters
417                        .entry(setter.name.to_string())
418                        .or_default()
419                        .1 = Some(MemberDef {
420                        name: setter.name,
421                        r#type: (setter.r#type)(),
422                        doc: setter.doc,
423                        default: setter.default.map(|f| f()),
424                        deprecated: setter.deprecated.clone(),
425                    });
426                }
427                for method in info.methods {
428                    let entries = entry.methods.entry(method.name.to_string()).or_default();
429
430                    // Validation: Check for multiple non-overload methods
431                    let new_method = MethodDef::from(method);
432                    if !new_method.is_overload {
433                        let non_overload_count = entries.iter().filter(|m| !m.is_overload).count();
434                        if non_overload_count > 0 {
435                            anyhow::bail!(
436                                "Multiple methods with name '{}' in class '{}' found without @overload decorator. \
437                                 Please add @overload decorator to all variants.",
438                                method.name, entry.name
439                            );
440                        }
441                    }
442
443                    entries.push(new_method);
444                }
445                return Ok(());
446            } else if let Some(entry) = module.enum_.get_mut(&struct_id) {
447                for attr in info.attrs {
448                    entry.attrs.push(MemberDef {
449                        name: attr.name,
450                        r#type: (attr.r#type)(),
451                        doc: attr.doc,
452                        default: attr.default.map(|f| f()),
453                        deprecated: attr.deprecated.clone(),
454                    });
455                }
456                for getter in info.getters {
457                    entry.getters.push(MemberDef {
458                        name: getter.name,
459                        r#type: (getter.r#type)(),
460                        doc: getter.doc,
461                        default: getter.default.map(|f| f()),
462                        deprecated: getter.deprecated.clone(),
463                    });
464                }
465                for setter in info.setters {
466                    entry.setters.push(MemberDef {
467                        name: setter.name,
468                        r#type: (setter.r#type)(),
469                        doc: setter.doc,
470                        default: setter.default.map(|f| f()),
471                        deprecated: setter.deprecated.clone(),
472                    });
473                }
474                for method in info.methods {
475                    // Validation: Check for multiple non-overload methods
476                    let new_method = MethodDef::from(method);
477                    if !new_method.is_overload {
478                        let non_overload_count = entry
479                            .methods
480                            .iter()
481                            .filter(|m| m.name == method.name && !m.is_overload)
482                            .count();
483                        if non_overload_count > 0 {
484                            anyhow::bail!(
485                                "Multiple methods with name '{}' in enum '{}' found without @overload decorator. \
486                                 Please add @overload decorator to all variants.",
487                                method.name, entry.name
488                            );
489                        }
490                    }
491
492                    entry.methods.push(new_method);
493                }
494                return Ok(());
495            }
496        }
497        unreachable!("Missing struct_id/enum_id = {:?}", struct_id);
498    }
499
500    fn build(mut self) -> Result<StubInfo> {
501        for info in inventory::iter::<PyClassInfo> {
502            self.add_class(info);
503        }
504        for info in inventory::iter::<PyComplexEnumInfo> {
505            self.add_complex_enum(info);
506        }
507        for info in inventory::iter::<PyEnumInfo> {
508            self.add_enum(info);
509        }
510        for info in inventory::iter::<PyFunctionInfo> {
511            self.add_function(info)?;
512        }
513        for info in inventory::iter::<PyVariableInfo> {
514            self.add_variable(info);
515        }
516        for info in inventory::iter::<TypeAliasInfo> {
517            self.add_type_alias(info);
518        }
519        for info in inventory::iter::<ModuleDocInfo> {
520            self.add_module_doc(info);
521        }
522        // Sort PyMethodsInfo by source location for deterministic IndexMap insertion order
523        let mut methods_infos: Vec<&PyMethodsInfo> = inventory::iter::<PyMethodsInfo>().collect();
524        methods_infos.sort_by_key(|info| (info.file, info.line, info.column));
525        for info in methods_infos {
526            self.add_methods(info)?;
527        }
528        // Collect __all__ export directives
529        for info in inventory::iter::<ReexportModuleMembers> {
530            self.add_module_export(info);
531        }
532        for info in inventory::iter::<ExportVerbatim> {
533            self.add_verbatim_export(info);
534        }
535        for info in inventory::iter::<ExcludeFromAll> {
536            self.add_exclude(info);
537        }
538        self.register_submodules();
539
540        // Resolve wildcard re-exports
541        self.resolve_wildcard_re_exports()?;
542
543        Ok(StubInfo {
544            modules: self.modules,
545            python_root: self.python_root,
546            is_mixed_layout: self.is_mixed_layout,
547            config: self.config,
548            pyproject_dir: None, // Will be set by from_pyproject_toml()
549        })
550    }
551}
552
553#[cfg(test)]
554mod tests {
555    use super::*;
556
557    #[test]
558    fn test_register_submodules_creates_empty_parent_modules() {
559        let mut builder = StubInfoBuilder::from_project_root(
560            "test_module".to_string(),
561            "/tmp".into(),
562            false,
563            StubGenConfig::default(),
564        );
565
566        // Simulate a module with only submodules
567        builder.modules.insert(
568            "test_module.sub_mod".to_string(),
569            Module {
570                name: "test_module.sub_mod".to_string(),
571                default_module_name: "test_module".to_string(),
572                ..Default::default()
573            },
574        );
575
576        builder.register_submodules();
577
578        // Check that the empty parent module was created
579        assert!(builder.modules.contains_key("test_module"));
580        let parent_module = &builder.modules["test_module"];
581        assert_eq!(parent_module.name, "test_module");
582        assert!(parent_module.submodules.contains("sub_mod"));
583
584        // Verify the submodule still exists
585        assert!(builder.modules.contains_key("test_module.sub_mod"));
586    }
587
588    #[test]
589    fn test_register_submodules_with_multiple_levels() {
590        let mut builder = StubInfoBuilder::from_project_root(
591            "root".to_string(),
592            "/tmp".into(),
593            false,
594            StubGenConfig::default(),
595        );
596
597        // Simulate deeply nested modules
598        builder.modules.insert(
599            "root.level1.level2.deep_mod".to_string(),
600            Module {
601                name: "root.level1.level2.deep_mod".to_string(),
602                default_module_name: "root".to_string(),
603                ..Default::default()
604            },
605        );
606
607        builder.register_submodules();
608
609        // Check that all intermediate parent modules were created
610        assert!(builder.modules.contains_key("root"));
611        assert!(builder.modules.contains_key("root.level1"));
612        assert!(builder.modules.contains_key("root.level1.level2"));
613        assert!(builder.modules.contains_key("root.level1.level2.deep_mod"));
614
615        // Check submodule relationships
616        assert!(builder.modules["root"].submodules.contains("level1"));
617        assert!(builder.modules["root.level1"].submodules.contains("level2"));
618        assert!(builder.modules["root.level1.level2"]
619            .submodules
620            .contains("deep_mod"));
621    }
622
623    #[test]
624    fn test_pure_layout_rejects_multiple_modules() {
625        // Pure Rust layout should reject multiple modules (whether submodules or top-level)
626        let stub_info = StubInfo {
627            modules: {
628                let mut map = BTreeMap::new();
629                map.insert(
630                    "mymodule".to_string(),
631                    Module {
632                        name: "mymodule".to_string(),
633                        default_module_name: "mymodule".to_string(),
634                        ..Default::default()
635                    },
636                );
637                map.insert(
638                    "mymodule.sub".to_string(),
639                    Module {
640                        name: "mymodule.sub".to_string(),
641                        default_module_name: "mymodule".to_string(),
642                        ..Default::default()
643                    },
644                );
645                map
646            },
647            python_root: PathBuf::from("/tmp"),
648            is_mixed_layout: false,
649            config: StubGenConfig::default(),
650            pyproject_dir: None,
651        };
652
653        let result = stub_info.generate();
654        assert!(result.is_err());
655        let err_msg = result.unwrap_err().to_string();
656        assert!(
657            err_msg.contains("Pure Rust layout does not support multiple modules or submodules")
658        );
659    }
660}