1use crate::{
2 generate::{docstring::normalize_docstring, *},
3 pyproject::{PyProject, StubGenConfig},
4 type_info::*,
5};
6use anyhow::{Context, Result};
7use std::{
8 collections::{BTreeMap, BTreeSet},
9 fs,
10 path::*,
11};
12
13#[derive(Debug, Clone, PartialEq)]
14pub struct StubInfo {
15 pub modules: BTreeMap<String, Module>,
16 pub python_root: PathBuf,
17 pub is_mixed_layout: bool,
19 pub config: StubGenConfig,
21 pub pyproject_dir: Option<PathBuf>,
23 pub default_module_name: String,
25 pub project_name: String,
28}
29
30impl StubInfo {
31 pub fn from_pyproject_toml(path: impl AsRef<Path>) -> Result<Self> {
34 let path = path.as_ref();
35 let pyproject = PyProject::parse_toml(path)?;
36 let mut config = pyproject.stub_gen_config();
37
38 if let Some(resolved_doc_gen) = pyproject.doc_gen_config_resolved() {
40 config.doc_gen = Some(resolved_doc_gen);
41 }
42
43 let pyproject_dir = path.parent().map(|p| p.to_path_buf());
44
45 let mut stub_info = StubInfoBuilder::from_pyproject_toml(pyproject, config).build()?;
46 stub_info.pyproject_dir = pyproject_dir;
47 Ok(stub_info)
48 }
49
50 pub fn from_project_root(
54 default_module_name: String,
55 project_root: PathBuf,
56 is_mixed_layout: bool,
57 config: StubGenConfig,
58 ) -> Result<Self> {
59 StubInfoBuilder::from_project_root(
60 default_module_name,
61 project_root,
62 is_mixed_layout,
63 config,
64 )
65 .build()
66 }
67
68 pub fn generate(&self) -> Result<()> {
69 if !self.is_mixed_layout && self.modules.len() > 1 {
71 let module_names: Vec<_> = self.modules.keys().collect();
72 anyhow::bail!(
73 "Pure Rust layout does not support multiple modules or submodules. Found {} modules: {}. \
74 Please use mixed Python/Rust layout (add `python-source` to [tool.maturin] in pyproject.toml) \
75 if you need multiple modules or submodules.",
76 self.modules.len(),
77 module_names.iter().map(|s| format!("'{}'", s)).collect::<Vec<_>>().join(", ")
78 );
79 }
80
81 for (name, module) in self.modules.iter() {
82 if module.is_empty() {
84 continue;
85 }
86
87 let normalized_name = name.replace("-", "_");
89 let path = normalized_name.replace(".", "/");
90
91 if self.is_pyo3_generated(name) {
92 let dest = if self.is_mixed_layout {
94 self.python_root.join(&path).join("__init__.pyi")
95 } else {
96 let package_name = normalized_name
98 .split('.')
99 .next()
100 .filter(|s| !s.is_empty())
101 .ok_or_else(|| {
102 anyhow::anyhow!(
103 "Module name is empty after normalization: original name was `{name}`"
104 )
105 })?;
106 self.python_root.join(format!("{package_name}.pyi"))
107 };
108
109 self.write_stub_file(&dest, module)?;
110 } else {
111 if !module.is_init_py_compatible() {
113 anyhow::bail!(
115 "Module '{}' has PyO3 items (classes, functions, etc.) but is not under \
116 the PyO3 module path '{}'. Either move these items to a module under '{}', \
117 or check your module path configuration.",
118 name,
119 self.default_module_name,
120 self.default_module_name
121 );
122 }
123
124 if !self.config.generate_init_py.is_enabled_for(name) {
125 anyhow::bail!(
126 "Module '{}' is not a PyO3 module and requires `generate-init-py` to be enabled. \
127 Add `generate-init-py = true` or `generate-init-py = [\"{}\"]` to \
128 [tool.pyo3-stub-gen] in pyproject.toml.",
129 name,
130 name
131 );
132 }
133
134 let dir = self.python_root.join(&path);
136 if !dir.exists() {
137 fs::create_dir_all(&dir)?;
138 }
139
140 let init_py_dest = dir.join("__init__.py");
141 let init_py_content = module.format_init_py();
142 fs::write(&init_py_dest, init_py_content)?;
143 log::info!(
144 "Generate __init__.py for module `{name}` at {dest}",
145 dest = init_py_dest.display()
146 );
147 }
148 }
149
150 if let Some(doc_config) = &self.config.doc_gen {
152 self.generate_docs(doc_config)?;
153 }
154
155 Ok(())
156 }
157
158 fn write_stub_file(&self, dest: &std::path::Path, module: &module::Module) -> Result<()> {
159 let dir = dest.parent().context("Cannot get parent directory")?;
160 if !dir.exists() {
161 fs::create_dir_all(dir)?;
162 }
163
164 let content = module.format_with_config(self.config.use_type_statement);
165 fs::write(dest, content)?;
166 log::info!(
167 "Generate stub file of a module `{}` at {dest}",
168 module.name,
169 dest = dest.display()
170 );
171 Ok(())
172 }
173
174 fn is_pyo3_generated(&self, module: &str) -> bool {
179 if !self.is_mixed_layout {
181 return true;
182 }
183
184 let normalized_module = module.replace("-", "_");
186 let normalized_module_name = self.default_module_name.replace("-", "_");
187
188 normalized_module == normalized_module_name
189 || normalized_module.starts_with(&format!("{}.", normalized_module_name))
190 }
191
192 fn generate_docs(&self, config: &crate::docgen::DocGenConfig) -> Result<()> {
193 log::info!("Generating API documentation...");
194
195 let doc_package = crate::docgen::builder::DocPackageBuilder::new(self).build()?;
197
198 let json_output = crate::docgen::render::render_to_json(&doc_package)?;
200
201 fs::create_dir_all(&config.output_dir)?;
203 fs::write(config.output_dir.join(&config.json_output), json_output)?;
204
205 crate::docgen::render::copy_sphinx_extension(&config.output_dir)?;
207
208 if config.separate_pages {
210 crate::docgen::render::generate_module_pages(&doc_package, &config.output_dir)?;
211 crate::docgen::render::generate_index_rst(&doc_package, &config.output_dir, config)?;
212 log::info!("Generated separate .rst pages for each module");
213 }
214
215 log::info!("Generated API docs at {:?}", config.output_dir);
216 Ok(())
217 }
218}
219
220struct StubInfoBuilder {
221 modules: BTreeMap<String, Module>,
222 default_module_name: String,
223 project_name: String,
224 python_root: PathBuf,
225 is_mixed_layout: bool,
226 config: StubGenConfig,
227}
228
229impl StubInfoBuilder {
230 fn from_pyproject_toml(pyproject: PyProject, config: StubGenConfig) -> Self {
231 let is_mixed_layout = pyproject.python_source().is_some();
232 let python_root = pyproject
233 .python_source()
234 .unwrap_or(PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").unwrap()));
235
236 Self {
237 modules: BTreeMap::new(),
238 default_module_name: pyproject.module_name().to_string(),
239 project_name: pyproject.project.name.clone(),
240 python_root,
241 is_mixed_layout,
242 config,
243 }
244 }
245
246 fn from_project_root(
247 default_module_name: String,
248 project_root: PathBuf,
249 is_mixed_layout: bool,
250 config: StubGenConfig,
251 ) -> Self {
252 let project_name = default_module_name
254 .split('.')
255 .next()
256 .unwrap_or(&default_module_name)
257 .to_string();
258
259 Self {
260 modules: BTreeMap::new(),
261 default_module_name,
262 project_name,
263 python_root: project_root,
264 is_mixed_layout,
265 config,
266 }
267 }
268
269 fn get_module(&mut self, name: Option<&str>) -> &mut Module {
270 let name = name.unwrap_or(&self.default_module_name).to_string();
271 let module = self.modules.entry(name.clone()).or_default();
272 module.name = name;
273 module.default_module_name = self.default_module_name.clone();
274 module
275 }
276
277 fn register_submodules(&mut self) {
278 let mut parent_to_children: BTreeMap<String, BTreeSet<String>> = BTreeMap::new();
287
288 for module in self.modules.keys() {
290 let path = module.split('.').collect::<Vec<_>>();
291
292 for i in 1..path.len() {
294 let parent = path[..i].join(".");
295
296 if self.is_pyo3_generated(&parent) {
298 let child = path[i].to_string();
299 parent_to_children.entry(parent).or_default().insert(child);
300 }
301 }
302 }
303
304 for (parent, children) in parent_to_children {
306 let module = self.modules.entry(parent.clone()).or_default();
307 module.name = parent;
308 module.default_module_name = self.default_module_name.clone();
309 module.submodules.extend(children);
310 }
311 }
312
313 fn is_pyo3_generated(&self, module: &str) -> bool {
322 if !self.is_mixed_layout {
324 return true;
325 }
326
327 let normalized_module = module.replace("-", "_");
329 let normalized_module_name = self.default_module_name.replace("-", "_");
330
331 normalized_module == normalized_module_name
332 || normalized_module.starts_with(&format!("{}.", normalized_module_name))
333 }
334
335 fn add_class(&mut self, info: &PyClassInfo) {
336 let mut class_def = ClassDef::from(info);
337 class_def.resolve_default_modules(&self.default_module_name);
338 self.get_module(info.module)
339 .class
340 .insert((info.struct_id)(), class_def);
341 }
342
343 fn add_complex_enum(&mut self, info: &PyComplexEnumInfo) {
344 let mut class_def = ClassDef::from(info);
345 class_def.resolve_default_modules(&self.default_module_name);
346 self.get_module(info.module)
347 .class
348 .insert((info.enum_id)(), class_def);
349 }
350
351 fn add_enum(&mut self, info: &PyEnumInfo) {
352 self.get_module(info.module)
353 .enum_
354 .insert((info.enum_id)(), EnumDef::from(info));
355 }
356
357 fn add_function(&mut self, info: &PyFunctionInfo) -> Result<()> {
358 let default_module_name = self.default_module_name.clone();
360
361 let target = self
362 .get_module(info.module)
363 .function
364 .entry(info.name)
365 .or_default();
366
367 let mut new_func = FunctionDef::from(info);
369 new_func.resolve_default_modules(&default_module_name);
370
371 if !new_func.is_overload {
372 let non_overload_count = target.iter().filter(|f| !f.is_overload).count();
373 if non_overload_count > 0 {
374 anyhow::bail!(
375 "Multiple functions with name '{}' found without @overload decorator. \
376 Please add @overload decorator to all variants.",
377 info.name
378 );
379 }
380 }
381
382 target.push(new_func);
383 Ok(())
384 }
385
386 fn add_variable(&mut self, info: &PyVariableInfo) {
387 self.get_module(Some(info.module))
388 .variables
389 .insert(info.name, VariableDef::from(info));
390 }
391
392 fn add_type_alias(&mut self, info: &TypeAliasInfo) {
393 self.get_module(Some(info.module))
394 .type_aliases
395 .insert(info.name, TypeAliasDef::from(info));
396 }
397
398 fn add_module_doc(&mut self, info: &ModuleDocInfo) {
399 let raw_doc = (info.doc)();
400 self.get_module(Some(info.module)).doc = normalize_docstring(&raw_doc);
401 }
402
403 fn add_module_export(&mut self, info: &ReexportModuleMembers) {
404 use crate::type_info::ReexportItems;
405
406 let (items, additional_items) = match info.items {
407 ReexportItems::Wildcard => (Vec::new(), Vec::new()),
408 ReexportItems::Explicit(items) => {
409 (items.iter().map(|s| s.to_string()).collect(), Vec::new())
410 }
411 ReexportItems::WildcardPlus(additional) => (
412 Vec::new(),
413 additional.iter().map(|s| s.to_string()).collect(),
414 ),
415 };
416
417 self.get_module(Some(info.target_module))
418 .module_re_exports
419 .push(ModuleReExport {
420 source_module: info.source_module.to_string(),
421 items,
422 additional_items,
423 });
424 }
425
426 fn add_verbatim_export(&mut self, info: &ExportVerbatim) {
427 self.get_module(Some(info.target_module))
428 .verbatim_all_entries
429 .insert(info.name.to_string());
430 }
431
432 fn add_exclude(&mut self, info: &ExcludeFromAll) {
433 self.get_module(Some(info.target_module))
434 .excluded_all_entries
435 .insert(info.name.to_string());
436 }
437
438 fn resolve_wildcard_re_exports(&mut self) -> Result<()> {
439 let mut resolutions: Vec<(String, usize, Vec<String>, Vec<String>)> = Vec::new();
442
443 for (module_name, module) in &self.modules {
444 for (idx, re_export) in module.module_re_exports.iter().enumerate() {
445 if re_export.items.is_empty() {
446 if let Some(source_mod) = self.modules.get(&re_export.source_module) {
448 let mut items = Vec::new();
450 for class in source_mod.class.values() {
451 if !class.name.starts_with('_') {
452 items.push(class.name.to_string());
453 }
454 }
455 for enum_ in source_mod.enum_.values() {
456 if !enum_.name.starts_with('_') {
457 items.push(enum_.name.to_string());
458 }
459 }
460 for func_name in source_mod.function.keys() {
461 if !func_name.starts_with('_') {
462 items.push(func_name.to_string());
463 }
464 }
465 for var_name in source_mod.variables.keys() {
466 if !var_name.starts_with('_') {
467 items.push(var_name.to_string());
468 }
469 }
470 for alias_name in source_mod.type_aliases.keys() {
471 if !alias_name.starts_with('_') {
472 items.push(alias_name.to_string());
473 }
474 }
475 for submod in &source_mod.submodules {
476 if !submod.starts_with('_') {
477 items.push(submod.to_string());
478 }
479 }
480 let additional = re_export.additional_items.clone();
482 resolutions.push((module_name.clone(), idx, items, additional));
483 } else {
484 anyhow::bail!(
486 "Cannot resolve wildcard re-export in module '{}': source module '{}' not found. \
487 Wildcard re-exports only work with internal modules.",
488 module_name,
489 re_export.source_module
490 );
491 }
492 }
493 }
494 }
495
496 for (module_name, idx, mut items, additional) in resolutions {
498 items.extend(additional);
500 let mut seen = BTreeSet::new();
502 items.retain(|item| seen.insert(item.clone()));
503 if let Some(module) = self.modules.get_mut(&module_name) {
504 module.module_re_exports[idx].items = items;
505 module.module_re_exports[idx].additional_items.clear();
506 }
507 }
508
509 Ok(())
510 }
511
512 fn add_methods(&mut self, info: &PyMethodsInfo) -> Result<()> {
513 let struct_id = (info.struct_id)();
514 for module in self.modules.values_mut() {
515 if let Some(entry) = module.class.get_mut(&struct_id) {
516 for attr in info.attrs {
517 entry.attrs.push(MemberDef {
518 name: attr.name,
519 r#type: (attr.r#type)(),
520 doc: attr.doc,
521 default: attr.default.map(|f| f()),
522 deprecated: attr.deprecated.clone(),
523 });
524 }
525 for getter in info.getters {
526 entry
527 .getter_setters
528 .entry(getter.name.to_string())
529 .or_default()
530 .0 = Some(MemberDef {
531 name: getter.name,
532 r#type: (getter.r#type)(),
533 doc: getter.doc,
534 default: getter.default.map(|f| f()),
535 deprecated: getter.deprecated.clone(),
536 });
537 }
538 for setter in info.setters {
539 entry
540 .getter_setters
541 .entry(setter.name.to_string())
542 .or_default()
543 .1 = Some(MemberDef {
544 name: setter.name,
545 r#type: (setter.r#type)(),
546 doc: setter.doc,
547 default: setter.default.map(|f| f()),
548 deprecated: setter.deprecated.clone(),
549 });
550 }
551 for method in info.methods {
552 let entries = entry.methods.entry(method.name.to_string()).or_default();
553
554 let new_method = MethodDef::from(method);
556 if !new_method.is_overload {
557 let non_overload_count = entries.iter().filter(|m| !m.is_overload).count();
558 if non_overload_count > 0 {
559 anyhow::bail!(
560 "Multiple methods with name '{}' in class '{}' found without @overload decorator. \
561 Please add @overload decorator to all variants.",
562 method.name, entry.name
563 );
564 }
565 }
566
567 entries.push(new_method);
568 }
569 return Ok(());
570 } else if let Some(entry) = module.enum_.get_mut(&struct_id) {
571 for attr in info.attrs {
572 entry.attrs.push(MemberDef {
573 name: attr.name,
574 r#type: (attr.r#type)(),
575 doc: attr.doc,
576 default: attr.default.map(|f| f()),
577 deprecated: attr.deprecated.clone(),
578 });
579 }
580 for getter in info.getters {
581 entry.getters.push(MemberDef {
582 name: getter.name,
583 r#type: (getter.r#type)(),
584 doc: getter.doc,
585 default: getter.default.map(|f| f()),
586 deprecated: getter.deprecated.clone(),
587 });
588 }
589 for setter in info.setters {
590 entry.setters.push(MemberDef {
591 name: setter.name,
592 r#type: (setter.r#type)(),
593 doc: setter.doc,
594 default: setter.default.map(|f| f()),
595 deprecated: setter.deprecated.clone(),
596 });
597 }
598 for method in info.methods {
599 let new_method = MethodDef::from(method);
601 if !new_method.is_overload {
602 let non_overload_count = entry
603 .methods
604 .iter()
605 .filter(|m| m.name == method.name && !m.is_overload)
606 .count();
607 if non_overload_count > 0 {
608 anyhow::bail!(
609 "Multiple methods with name '{}' in enum '{}' found without @overload decorator. \
610 Please add @overload decorator to all variants.",
611 method.name, entry.name
612 );
613 }
614 }
615
616 entry.methods.push(new_method);
617 }
618 return Ok(());
619 }
620 }
621 unreachable!("Missing struct_id/enum_id = {:?}", struct_id);
622 }
623
624 fn build(mut self) -> Result<StubInfo> {
625 for info in inventory::iter::<PyClassInfo> {
626 self.add_class(info);
627 }
628 for info in inventory::iter::<PyComplexEnumInfo> {
629 self.add_complex_enum(info);
630 }
631 for info in inventory::iter::<PyEnumInfo> {
632 self.add_enum(info);
633 }
634 for info in inventory::iter::<PyFunctionInfo> {
635 self.add_function(info)?;
636 }
637 for info in inventory::iter::<PyVariableInfo> {
638 self.add_variable(info);
639 }
640 for info in inventory::iter::<TypeAliasInfo> {
641 self.add_type_alias(info);
642 }
643 for info in inventory::iter::<ModuleDocInfo> {
644 self.add_module_doc(info);
645 }
646 let mut methods_infos: Vec<&PyMethodsInfo> = inventory::iter::<PyMethodsInfo>().collect();
648 methods_infos.sort_by_key(|info| (info.file, info.line, info.column));
649 for info in methods_infos {
650 self.add_methods(info)?;
651 }
652 for info in inventory::iter::<ReexportModuleMembers> {
654 self.add_module_export(info);
655 }
656 for info in inventory::iter::<ExportVerbatim> {
657 self.add_verbatim_export(info);
658 }
659 for info in inventory::iter::<ExcludeFromAll> {
660 self.add_exclude(info);
661 }
662 self.register_submodules();
663
664 self.resolve_wildcard_re_exports()?;
666
667 Ok(StubInfo {
668 modules: self.modules,
669 python_root: self.python_root,
670 is_mixed_layout: self.is_mixed_layout,
671 config: self.config,
672 pyproject_dir: None, default_module_name: self.default_module_name,
674 project_name: self.project_name,
675 })
676 }
677}
678
679#[cfg(test)]
680mod tests {
681 use super::*;
682
683 #[test]
684 fn test_register_submodules_creates_empty_parent_modules() {
685 let mut builder = StubInfoBuilder::from_project_root(
686 "test_module".to_string(),
687 "/tmp".into(),
688 false,
689 StubGenConfig::default(),
690 );
691
692 builder.modules.insert(
694 "test_module.sub_mod".to_string(),
695 Module {
696 name: "test_module.sub_mod".to_string(),
697 default_module_name: "test_module".to_string(),
698 ..Default::default()
699 },
700 );
701
702 builder.register_submodules();
703
704 assert!(builder.modules.contains_key("test_module"));
706 let parent_module = &builder.modules["test_module"];
707 assert_eq!(parent_module.name, "test_module");
708 assert!(parent_module.submodules.contains("sub_mod"));
709
710 assert!(builder.modules.contains_key("test_module.sub_mod"));
712 }
713
714 #[test]
715 fn test_register_submodules_with_multiple_levels() {
716 let mut builder = StubInfoBuilder::from_project_root(
717 "root".to_string(),
718 "/tmp".into(),
719 false,
720 StubGenConfig::default(),
721 );
722
723 builder.modules.insert(
725 "root.level1.level2.deep_mod".to_string(),
726 Module {
727 name: "root.level1.level2.deep_mod".to_string(),
728 default_module_name: "root".to_string(),
729 ..Default::default()
730 },
731 );
732
733 builder.register_submodules();
734
735 assert!(builder.modules.contains_key("root"));
737 assert!(builder.modules.contains_key("root.level1"));
738 assert!(builder.modules.contains_key("root.level1.level2"));
739 assert!(builder.modules.contains_key("root.level1.level2.deep_mod"));
740
741 assert!(builder.modules["root"].submodules.contains("level1"));
743 assert!(builder.modules["root.level1"].submodules.contains("level2"));
744 assert!(builder.modules["root.level1.level2"]
745 .submodules
746 .contains("deep_mod"));
747 }
748
749 #[test]
750 fn test_pure_layout_rejects_multiple_modules() {
751 let stub_info = StubInfo {
753 modules: {
754 let mut map = BTreeMap::new();
755 map.insert(
756 "mymodule".to_string(),
757 Module {
758 name: "mymodule".to_string(),
759 default_module_name: "mymodule".to_string(),
760 doc: "Test module".to_string(),
762 ..Default::default()
763 },
764 );
765 map.insert(
766 "mymodule.sub".to_string(),
767 Module {
768 name: "mymodule.sub".to_string(),
769 default_module_name: "mymodule".to_string(),
770 doc: "Test submodule".to_string(),
772 ..Default::default()
773 },
774 );
775 map
776 },
777 python_root: PathBuf::from("/tmp"),
778 is_mixed_layout: false,
779 config: StubGenConfig::default(),
780 pyproject_dir: None,
781 default_module_name: "mymodule".to_string(),
782 project_name: "mymodule".to_string(),
783 };
784
785 let result = stub_info.generate();
786 assert!(result.is_err());
787 let err_msg = result.unwrap_err().to_string();
788 assert!(
789 err_msg.contains("Pure Rust layout does not support multiple modules or submodules")
790 );
791 }
792}