1use crate::{
2 generate::{docstring::normalize_docstring, *},
3 pyproject::{PyProject, StubGenConfig},
4 type_info::*,
5};
6use anyhow::{Context, Result};
7use std::{
8 collections::{BTreeMap, BTreeSet},
9 fs,
10 path::*,
11};
12
13#[derive(Debug, Clone, PartialEq)]
14pub struct StubInfo {
15 pub modules: BTreeMap<String, Module>,
16 pub python_root: PathBuf,
17 pub is_mixed_layout: bool,
19 pub config: StubGenConfig,
21 pub pyproject_dir: Option<PathBuf>,
23 pub default_module_name: String,
25 pub project_name: String,
28}
29
30impl StubInfo {
31 pub fn from_pyproject_toml(path: impl AsRef<Path>) -> Result<Self> {
34 let path = path.as_ref();
35 let pyproject = PyProject::parse_toml(path)?;
36 let mut config = pyproject.stub_gen_config();
37
38 if let Some(resolved_doc_gen) = pyproject.doc_gen_config_resolved() {
40 config.doc_gen = Some(resolved_doc_gen);
41 }
42
43 let pyproject_dir = path.parent().map(|p| p.to_path_buf());
44
45 let mut stub_info = StubInfoBuilder::from_pyproject_toml(pyproject, config).build()?;
46 stub_info.pyproject_dir = pyproject_dir;
47 Ok(stub_info)
48 }
49
50 pub fn from_project_root(
54 default_module_name: String,
55 project_root: PathBuf,
56 is_mixed_layout: bool,
57 config: StubGenConfig,
58 ) -> Result<Self> {
59 StubInfoBuilder::from_project_root(
60 default_module_name,
61 project_root,
62 is_mixed_layout,
63 config,
64 )
65 .build()
66 }
67
68 pub fn generate(&self) -> Result<()> {
69 if !self.is_mixed_layout && self.modules.len() > 1 {
71 let module_names: Vec<_> = self.modules.keys().collect();
72 anyhow::bail!(
73 "Pure Rust layout does not support multiple modules or submodules. Found {} modules: {}. \
74 Please use mixed Python/Rust layout (add `python-source` to [tool.maturin] in pyproject.toml) \
75 if you need multiple modules or submodules.",
76 self.modules.len(),
77 module_names.iter().map(|s| format!("'{}'", s)).collect::<Vec<_>>().join(", ")
78 );
79 }
80
81 for (name, module) in self.modules.iter() {
82 if module.is_empty() {
84 continue;
85 }
86
87 let normalized_name = name.replace("-", "_");
89 let path = normalized_name.replace(".", "/");
90
91 if self.is_pyo3_generated(name) {
92 let dest = if self.is_mixed_layout {
94 self.python_root.join(&path).join("__init__.pyi")
95 } else {
96 let package_name = normalized_name
98 .split('.')
99 .next()
100 .filter(|s| !s.is_empty())
101 .ok_or_else(|| {
102 anyhow::anyhow!(
103 "Module name is empty after normalization: original name was `{name}`"
104 )
105 })?;
106 self.python_root.join(format!("{package_name}.pyi"))
107 };
108
109 self.write_stub_file(&dest, module)?;
110 } else {
111 if !module.is_init_py_compatible() {
113 anyhow::bail!(
115 "Module '{}' has PyO3 items (classes, functions, etc.) but is not under \
116 the PyO3 module path '{}'. Either move these items to a module under '{}', \
117 or check your module path configuration.",
118 name,
119 self.default_module_name,
120 self.default_module_name
121 );
122 }
123
124 if !self.config.generate_init_py.is_enabled_for(name) {
125 anyhow::bail!(
126 "Module '{}' is not a PyO3 module and requires `generate-init-py` to be enabled. \
127 Add `generate-init-py = true` or `generate-init-py = [\"{}\"]` to \
128 [tool.pyo3-stub-gen] in pyproject.toml.",
129 name,
130 name
131 );
132 }
133
134 let dir = self.python_root.join(&path);
136 if !dir.exists() {
137 fs::create_dir_all(&dir)?;
138 }
139
140 let init_py_dest = dir.join("__init__.py");
141 let init_py_content = module.format_init_py();
142 fs::write(&init_py_dest, init_py_content)?;
143 log::info!(
144 "Generate __init__.py for module `{name}` at {dest}",
145 dest = init_py_dest.display()
146 );
147 }
148 }
149
150 if let Some(doc_config) = &self.config.doc_gen {
152 self.generate_docs(doc_config)?;
153 }
154
155 Ok(())
156 }
157
158 fn write_stub_file(&self, dest: &std::path::Path, module: &module::Module) -> Result<()> {
159 let dir = dest.parent().context("Cannot get parent directory")?;
160 if !dir.exists() {
161 fs::create_dir_all(dir)?;
162 }
163
164 let content = module.format_with_config(self.config.use_type_statement);
165 fs::write(dest, content)?;
166 log::info!(
167 "Generate stub file of a module `{}` at {dest}",
168 module.name,
169 dest = dest.display()
170 );
171 Ok(())
172 }
173
174 fn is_pyo3_generated(&self, module: &str) -> bool {
179 if !self.is_mixed_layout {
181 return true;
182 }
183
184 let normalized_module = module.replace("-", "_");
186 let normalized_module_name = self.default_module_name.replace("-", "_");
187
188 normalized_module == normalized_module_name
189 || normalized_module.starts_with(&format!("{}.", normalized_module_name))
190 }
191
192 fn generate_docs(&self, config: &crate::docgen::DocGenConfig) -> Result<()> {
193 config.validate()?;
194
195 log::info!("Generating API documentation...");
196
197 let doc_package = crate::docgen::builder::DocPackageBuilder::new(self).build()?;
199
200 let json_output = crate::docgen::render::render_to_json(&doc_package)?;
202
203 fs::create_dir_all(&config.output_dir)?;
205 fs::write(config.output_dir.join(&config.json_output), json_output)?;
206
207 crate::docgen::render::copy_sphinx_extension(&config.output_dir)?;
209
210 if config.separate_pages {
212 crate::docgen::render::generate_module_pages(&doc_package, &config.output_dir, config)?;
213 if config.generate_index {
214 crate::docgen::render::generate_index_rst(
215 &doc_package,
216 &config.output_dir,
217 config,
218 )?;
219 }
220 if config.separate_items {
221 crate::docgen::render::generate_item_pages(&doc_package, &config.output_dir)?;
222 log::info!("Generated separate .rst pages for each item");
223 }
224 log::info!("Generated separate .rst pages for each module");
225 }
226
227 log::info!("Generated API docs at {:?}", config.output_dir);
228 Ok(())
229 }
230}
231
232struct StubInfoBuilder {
233 modules: BTreeMap<String, Module>,
234 default_module_name: String,
235 project_name: String,
236 python_root: PathBuf,
237 is_mixed_layout: bool,
238 config: StubGenConfig,
239}
240
241impl StubInfoBuilder {
242 fn from_pyproject_toml(pyproject: PyProject, config: StubGenConfig) -> Self {
243 let is_mixed_layout = pyproject.python_source().is_some();
244 let python_root = pyproject
245 .python_source()
246 .unwrap_or(PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").unwrap()));
247
248 Self {
249 modules: BTreeMap::new(),
250 default_module_name: pyproject.module_name().to_string(),
251 project_name: pyproject.project.name.clone(),
252 python_root,
253 is_mixed_layout,
254 config,
255 }
256 }
257
258 fn from_project_root(
259 default_module_name: String,
260 project_root: PathBuf,
261 is_mixed_layout: bool,
262 config: StubGenConfig,
263 ) -> Self {
264 let project_name = default_module_name
266 .split('.')
267 .next()
268 .unwrap_or(&default_module_name)
269 .to_string();
270
271 Self {
272 modules: BTreeMap::new(),
273 default_module_name,
274 project_name,
275 python_root: project_root,
276 is_mixed_layout,
277 config,
278 }
279 }
280
281 fn get_module(&mut self, name: Option<&str>) -> &mut Module {
282 let name = name.unwrap_or(&self.default_module_name).to_string();
283 let module = self.modules.entry(name.clone()).or_default();
284 module.name = name;
285 module.default_module_name = self.default_module_name.clone();
286 module
287 }
288
289 fn register_submodules(&mut self) {
290 let mut parent_to_children: BTreeMap<String, BTreeSet<String>> = BTreeMap::new();
299
300 for module in self.modules.keys() {
302 let path = module.split('.').collect::<Vec<_>>();
303
304 for i in 1..path.len() {
306 let parent = path[..i].join(".");
307
308 if self.is_pyo3_generated(&parent) {
310 let child = path[i].to_string();
311 parent_to_children.entry(parent).or_default().insert(child);
312 }
313 }
314 }
315
316 for (parent, children) in parent_to_children {
318 let module = self.modules.entry(parent.clone()).or_default();
319 module.name = parent;
320 module.default_module_name = self.default_module_name.clone();
321 module.submodules.extend(children);
322 }
323 }
324
325 fn is_pyo3_generated(&self, module: &str) -> bool {
334 if !self.is_mixed_layout {
336 return true;
337 }
338
339 let normalized_module = module.replace("-", "_");
341 let normalized_module_name = self.default_module_name.replace("-", "_");
342
343 normalized_module == normalized_module_name
344 || normalized_module.starts_with(&format!("{}.", normalized_module_name))
345 }
346
347 fn add_class(&mut self, info: &PyClassInfo) {
348 let mut class_def = ClassDef::from(info);
349 class_def.resolve_default_modules(&self.default_module_name);
350 self.get_module(info.module)
351 .class
352 .insert((info.struct_id)(), class_def);
353 }
354
355 fn add_complex_enum(&mut self, info: &PyComplexEnumInfo) {
356 let mut class_def = ClassDef::from(info);
357 class_def.resolve_default_modules(&self.default_module_name);
358 self.get_module(info.module)
359 .class
360 .insert((info.enum_id)(), class_def);
361 }
362
363 fn add_enum(&mut self, info: &PyEnumInfo) {
364 self.get_module(info.module)
365 .enum_
366 .insert((info.enum_id)(), EnumDef::from(info));
367 }
368
369 fn add_function(&mut self, info: &PyFunctionInfo) -> Result<()> {
370 let default_module_name = self.default_module_name.clone();
372
373 let target = self
374 .get_module(info.module)
375 .function
376 .entry(info.name)
377 .or_default();
378
379 let mut new_func = FunctionDef::from(info);
381 new_func.resolve_default_modules(&default_module_name);
382
383 if !new_func.is_overload {
384 let non_overload_count = target.iter().filter(|f| !f.is_overload).count();
385 if non_overload_count > 0 {
386 anyhow::bail!(
387 "Multiple functions with name '{}' found without @overload decorator. \
388 Please add @overload decorator to all variants.",
389 info.name
390 );
391 }
392 }
393
394 target.push(new_func);
395 Ok(())
396 }
397
398 fn add_variable(&mut self, info: &PyVariableInfo) {
399 self.get_module(Some(info.module))
400 .variables
401 .insert(info.name, VariableDef::from(info));
402 }
403
404 fn add_type_alias(&mut self, info: &TypeAliasInfo) {
405 self.get_module(Some(info.module))
406 .type_aliases
407 .insert(info.name, TypeAliasDef::from(info));
408 }
409
410 fn add_module_doc(&mut self, info: &ModuleDocInfo) {
411 let raw_doc = (info.doc)();
412 self.get_module(Some(info.module)).doc = normalize_docstring(&raw_doc);
413 }
414
415 fn add_module_export(&mut self, info: &ReexportModuleMembers) {
416 use crate::type_info::ReexportItems;
417
418 let (items, additional_items) = match info.items {
419 ReexportItems::Wildcard => (Vec::new(), Vec::new()),
420 ReexportItems::Explicit(items) => {
421 (items.iter().map(|s| s.to_string()).collect(), Vec::new())
422 }
423 ReexportItems::WildcardPlus(additional) => (
424 Vec::new(),
425 additional.iter().map(|s| s.to_string()).collect(),
426 ),
427 };
428
429 self.get_module(Some(info.target_module))
430 .module_re_exports
431 .push(ModuleReExport {
432 source_module: info.source_module.to_string(),
433 items,
434 additional_items,
435 });
436 }
437
438 fn add_verbatim_export(&mut self, info: &ExportVerbatim) {
439 self.get_module(Some(info.target_module))
440 .verbatim_all_entries
441 .insert(info.name.to_string());
442 }
443
444 fn add_exclude(&mut self, info: &ExcludeFromAll) {
445 self.get_module(Some(info.target_module))
446 .excluded_all_entries
447 .insert(info.name.to_string());
448 }
449
450 fn resolve_wildcard_re_exports(&mut self) -> Result<()> {
451 let mut resolutions: Vec<(String, usize, Vec<String>, Vec<String>)> = Vec::new();
454
455 for (module_name, module) in &self.modules {
456 for (idx, re_export) in module.module_re_exports.iter().enumerate() {
457 if re_export.items.is_empty() {
458 if let Some(source_mod) = self.modules.get(&re_export.source_module) {
460 let mut items = Vec::new();
462 for class in source_mod.class.values() {
463 if !class.name.starts_with('_') {
464 items.push(class.name.to_string());
465 }
466 }
467 for enum_ in source_mod.enum_.values() {
468 if !enum_.name.starts_with('_') {
469 items.push(enum_.name.to_string());
470 }
471 }
472 for func_name in source_mod.function.keys() {
473 if !func_name.starts_with('_') {
474 items.push(func_name.to_string());
475 }
476 }
477 for var_name in source_mod.variables.keys() {
478 if !var_name.starts_with('_') {
479 items.push(var_name.to_string());
480 }
481 }
482 for alias_name in source_mod.type_aliases.keys() {
483 if !alias_name.starts_with('_') {
484 items.push(alias_name.to_string());
485 }
486 }
487 for submod in &source_mod.submodules {
488 if !submod.starts_with('_') {
489 items.push(submod.to_string());
490 }
491 }
492 let additional = re_export.additional_items.clone();
494 resolutions.push((module_name.clone(), idx, items, additional));
495 } else {
496 anyhow::bail!(
498 "Cannot resolve wildcard re-export in module '{}': source module '{}' not found. \
499 Wildcard re-exports only work with internal modules.",
500 module_name,
501 re_export.source_module
502 );
503 }
504 }
505 }
506 }
507
508 for (module_name, idx, mut items, additional) in resolutions {
510 items.extend(additional);
512 let mut seen = BTreeSet::new();
514 items.retain(|item| seen.insert(item.clone()));
515 if let Some(module) = self.modules.get_mut(&module_name) {
516 module.module_re_exports[idx].items = items;
517 module.module_re_exports[idx].additional_items.clear();
518 }
519 }
520
521 Ok(())
522 }
523
524 fn add_methods(&mut self, info: &PyMethodsInfo) -> Result<()> {
525 let struct_id = (info.struct_id)();
526 for module in self.modules.values_mut() {
527 if let Some(entry) = module.class.get_mut(&struct_id) {
528 for attr in info.attrs {
529 entry.attrs.push(MemberDef {
530 name: attr.name,
531 r#type: (attr.r#type)(),
532 doc: attr.doc,
533 default: attr.default.map(|f| f()),
534 deprecated: attr.deprecated.clone(),
535 });
536 }
537 for getter in info.getters {
538 entry
539 .getter_setters
540 .entry(getter.name.to_string())
541 .or_default()
542 .0 = Some(MemberDef {
543 name: getter.name,
544 r#type: (getter.r#type)(),
545 doc: getter.doc,
546 default: getter.default.map(|f| f()),
547 deprecated: getter.deprecated.clone(),
548 });
549 }
550 for setter in info.setters {
551 entry
552 .getter_setters
553 .entry(setter.name.to_string())
554 .or_default()
555 .1 = Some(MemberDef {
556 name: setter.name,
557 r#type: (setter.r#type)(),
558 doc: setter.doc,
559 default: setter.default.map(|f| f()),
560 deprecated: setter.deprecated.clone(),
561 });
562 }
563 for method in info.methods {
564 let entries = entry.methods.entry(method.name.to_string()).or_default();
565
566 let new_method = MethodDef::from(method);
568 if !new_method.is_overload {
569 let non_overload_count = entries.iter().filter(|m| !m.is_overload).count();
570 if non_overload_count > 0 {
571 anyhow::bail!(
572 "Multiple methods with name '{}' in class '{}' found without @overload decorator. \
573 Please add @overload decorator to all variants.",
574 method.name, entry.name
575 );
576 }
577 }
578
579 entries.push(new_method);
580 }
581 return Ok(());
582 } else if let Some(entry) = module.enum_.get_mut(&struct_id) {
583 for attr in info.attrs {
584 entry.attrs.push(MemberDef {
585 name: attr.name,
586 r#type: (attr.r#type)(),
587 doc: attr.doc,
588 default: attr.default.map(|f| f()),
589 deprecated: attr.deprecated.clone(),
590 });
591 }
592 for getter in info.getters {
593 entry.getters.push(MemberDef {
594 name: getter.name,
595 r#type: (getter.r#type)(),
596 doc: getter.doc,
597 default: getter.default.map(|f| f()),
598 deprecated: getter.deprecated.clone(),
599 });
600 }
601 for setter in info.setters {
602 entry.setters.push(MemberDef {
603 name: setter.name,
604 r#type: (setter.r#type)(),
605 doc: setter.doc,
606 default: setter.default.map(|f| f()),
607 deprecated: setter.deprecated.clone(),
608 });
609 }
610 for method in info.methods {
611 let new_method = MethodDef::from(method);
613 if !new_method.is_overload {
614 let non_overload_count = entry
615 .methods
616 .iter()
617 .filter(|m| m.name == method.name && !m.is_overload)
618 .count();
619 if non_overload_count > 0 {
620 anyhow::bail!(
621 "Multiple methods with name '{}' in enum '{}' found without @overload decorator. \
622 Please add @overload decorator to all variants.",
623 method.name, entry.name
624 );
625 }
626 }
627
628 entry.methods.push(new_method);
629 }
630 return Ok(());
631 }
632 }
633 unreachable!("Missing struct_id/enum_id = {:?}", struct_id);
634 }
635
636 fn build(mut self) -> Result<StubInfo> {
637 for info in inventory::iter::<PyClassInfo> {
638 self.add_class(info);
639 }
640 for info in inventory::iter::<PyComplexEnumInfo> {
641 self.add_complex_enum(info);
642 }
643 for info in inventory::iter::<PyEnumInfo> {
644 self.add_enum(info);
645 }
646 for info in inventory::iter::<PyFunctionInfo> {
647 self.add_function(info)?;
648 }
649 for info in inventory::iter::<PyVariableInfo> {
650 self.add_variable(info);
651 }
652 for info in inventory::iter::<TypeAliasInfo> {
653 self.add_type_alias(info);
654 }
655 for info in inventory::iter::<ModuleDocInfo> {
656 self.add_module_doc(info);
657 }
658 let mut methods_infos: Vec<&PyMethodsInfo> = inventory::iter::<PyMethodsInfo>().collect();
660 methods_infos.sort_by_key(|info| (info.file, info.line, info.column));
661 for info in methods_infos {
662 self.add_methods(info)?;
663 }
664 for info in inventory::iter::<ReexportModuleMembers> {
666 self.add_module_export(info);
667 }
668 for info in inventory::iter::<ExportVerbatim> {
669 self.add_verbatim_export(info);
670 }
671 for info in inventory::iter::<ExcludeFromAll> {
672 self.add_exclude(info);
673 }
674 self.register_submodules();
675
676 self.resolve_wildcard_re_exports()?;
678
679 Ok(StubInfo {
680 modules: self.modules,
681 python_root: self.python_root,
682 is_mixed_layout: self.is_mixed_layout,
683 config: self.config,
684 pyproject_dir: None, default_module_name: self.default_module_name,
686 project_name: self.project_name,
687 })
688 }
689}
690
691#[cfg(test)]
692mod tests {
693 use super::*;
694
695 #[test]
696 fn test_register_submodules_creates_empty_parent_modules() {
697 let mut builder = StubInfoBuilder::from_project_root(
698 "test_module".to_string(),
699 "/tmp".into(),
700 false,
701 StubGenConfig::default(),
702 );
703
704 builder.modules.insert(
706 "test_module.sub_mod".to_string(),
707 Module {
708 name: "test_module.sub_mod".to_string(),
709 default_module_name: "test_module".to_string(),
710 ..Default::default()
711 },
712 );
713
714 builder.register_submodules();
715
716 assert!(builder.modules.contains_key("test_module"));
718 let parent_module = &builder.modules["test_module"];
719 assert_eq!(parent_module.name, "test_module");
720 assert!(parent_module.submodules.contains("sub_mod"));
721
722 assert!(builder.modules.contains_key("test_module.sub_mod"));
724 }
725
726 #[test]
727 fn test_register_submodules_with_multiple_levels() {
728 let mut builder = StubInfoBuilder::from_project_root(
729 "root".to_string(),
730 "/tmp".into(),
731 false,
732 StubGenConfig::default(),
733 );
734
735 builder.modules.insert(
737 "root.level1.level2.deep_mod".to_string(),
738 Module {
739 name: "root.level1.level2.deep_mod".to_string(),
740 default_module_name: "root".to_string(),
741 ..Default::default()
742 },
743 );
744
745 builder.register_submodules();
746
747 assert!(builder.modules.contains_key("root"));
749 assert!(builder.modules.contains_key("root.level1"));
750 assert!(builder.modules.contains_key("root.level1.level2"));
751 assert!(builder.modules.contains_key("root.level1.level2.deep_mod"));
752
753 assert!(builder.modules["root"].submodules.contains("level1"));
755 assert!(builder.modules["root.level1"].submodules.contains("level2"));
756 assert!(builder.modules["root.level1.level2"]
757 .submodules
758 .contains("deep_mod"));
759 }
760
761 #[test]
762 fn test_pure_layout_rejects_multiple_modules() {
763 let stub_info = StubInfo {
765 modules: {
766 let mut map = BTreeMap::new();
767 map.insert(
768 "mymodule".to_string(),
769 Module {
770 name: "mymodule".to_string(),
771 default_module_name: "mymodule".to_string(),
772 doc: "Test module".to_string(),
774 ..Default::default()
775 },
776 );
777 map.insert(
778 "mymodule.sub".to_string(),
779 Module {
780 name: "mymodule.sub".to_string(),
781 default_module_name: "mymodule".to_string(),
782 doc: "Test submodule".to_string(),
784 ..Default::default()
785 },
786 );
787 map
788 },
789 python_root: PathBuf::from("/tmp"),
790 is_mixed_layout: false,
791 config: StubGenConfig::default(),
792 pyproject_dir: None,
793 default_module_name: "mymodule".to_string(),
794 project_name: "mymodule".to_string(),
795 };
796
797 let result = stub_info.generate();
798 assert!(result.is_err());
799 let err_msg = result.unwrap_err().to_string();
800 assert!(
801 err_msg.contains("Pure Rust layout does not support multiple modules or submodules")
802 );
803 }
804}