1use crate::{
2 generate::{docstring::normalize_docstring, *},
3 pyproject::{PyProject, StubGenConfig},
4 type_info::*,
5};
6use anyhow::{Context, Result};
7use std::{
8 collections::{BTreeMap, BTreeSet},
9 fs,
10 path::*,
11};
12
13#[derive(Debug, Clone, PartialEq)]
14pub struct StubInfo {
15 pub modules: BTreeMap<String, Module>,
16 pub python_root: PathBuf,
17 pub is_mixed_layout: bool,
19 pub config: StubGenConfig,
21 pub pyproject_dir: Option<PathBuf>,
23}
24
25impl StubInfo {
26 pub fn from_pyproject_toml(path: impl AsRef<Path>) -> Result<Self> {
29 let path = path.as_ref();
30 let pyproject = PyProject::parse_toml(path)?;
31 let mut config = pyproject.stub_gen_config();
32
33 if let Some(resolved_doc_gen) = pyproject.doc_gen_config_resolved() {
35 config.doc_gen = Some(resolved_doc_gen);
36 }
37
38 let pyproject_dir = path.parent().map(|p| p.to_path_buf());
39
40 let mut stub_info = StubInfoBuilder::from_pyproject_toml(pyproject, config).build()?;
41 stub_info.pyproject_dir = pyproject_dir;
42 Ok(stub_info)
43 }
44
45 pub fn from_project_root(
49 default_module_name: String,
50 project_root: PathBuf,
51 is_mixed_layout: bool,
52 config: StubGenConfig,
53 ) -> Result<Self> {
54 StubInfoBuilder::from_project_root(
55 default_module_name,
56 project_root,
57 is_mixed_layout,
58 config,
59 )
60 .build()
61 }
62
63 pub fn generate(&self) -> Result<()> {
64 if !self.is_mixed_layout && self.modules.len() > 1 {
66 let module_names: Vec<_> = self.modules.keys().collect();
67 anyhow::bail!(
68 "Pure Rust layout does not support multiple modules or submodules. Found {} modules: {}. \
69 Please use mixed Python/Rust layout (add `python-source` to [tool.maturin] in pyproject.toml) \
70 if you need multiple modules or submodules.",
71 self.modules.len(),
72 module_names.iter().map(|s| format!("'{}'", s)).collect::<Vec<_>>().join(", ")
73 );
74 }
75
76 for (name, module) in self.modules.iter() {
78 let normalized_name = name.replace("-", "_");
80 let path = normalized_name.replace(".", "/");
81
82 let dest = if self.is_mixed_layout {
84 self.python_root.join(&path).join("__init__.pyi")
86 } else {
87 let package_name = normalized_name
89 .split('.')
90 .next()
91 .filter(|s| !s.is_empty())
92 .ok_or_else(|| {
93 anyhow::anyhow!(
94 "Module name is empty after normalization: original name was `{name}`"
95 )
96 })?;
97 self.python_root.join(format!("{package_name}.pyi"))
98 };
99
100 let dir = dest.parent().context("Cannot get parent directory")?;
101 if !dir.exists() {
102 fs::create_dir_all(dir)?;
103 }
104
105 let content = module.format_with_config(self.config.use_type_statement);
106 fs::write(&dest, content)?;
107 log::info!(
108 "Generate stub file of a module `{name}` at {dest}",
109 dest = dest.display()
110 );
111 }
112
113 if let Some(doc_config) = &self.config.doc_gen {
115 self.generate_docs(doc_config)?;
116 }
117
118 Ok(())
119 }
120
121 fn generate_docs(&self, config: &crate::docgen::DocGenConfig) -> Result<()> {
122 log::info!("Generating API documentation...");
123
124 let doc_package = crate::docgen::builder::DocPackageBuilder::new(self).build()?;
126
127 let json_output = crate::docgen::render::render_to_json(&doc_package)?;
129
130 fs::create_dir_all(&config.output_dir)?;
132 fs::write(config.output_dir.join(&config.json_output), json_output)?;
133
134 crate::docgen::render::copy_sphinx_extension(&config.output_dir)?;
136
137 if config.separate_pages {
139 crate::docgen::render::generate_module_pages(&doc_package, &config.output_dir)?;
140 crate::docgen::render::generate_index_rst(&doc_package, &config.output_dir, config)?;
141 log::info!("Generated separate .rst pages for each module");
142 }
143
144 log::info!("Generated API docs at {:?}", config.output_dir);
145 Ok(())
146 }
147}
148
149struct StubInfoBuilder {
150 modules: BTreeMap<String, Module>,
151 default_module_name: String,
152 python_root: PathBuf,
153 is_mixed_layout: bool,
154 config: StubGenConfig,
155}
156
157impl StubInfoBuilder {
158 fn from_pyproject_toml(pyproject: PyProject, config: StubGenConfig) -> Self {
159 let is_mixed_layout = pyproject.python_source().is_some();
160 let python_root = pyproject
161 .python_source()
162 .unwrap_or(PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").unwrap()));
163
164 Self {
165 modules: BTreeMap::new(),
166 default_module_name: pyproject.module_name().to_string(),
167 python_root,
168 is_mixed_layout,
169 config,
170 }
171 }
172
173 fn from_project_root(
174 default_module_name: String,
175 project_root: PathBuf,
176 is_mixed_layout: bool,
177 config: StubGenConfig,
178 ) -> Self {
179 Self {
180 modules: BTreeMap::new(),
181 default_module_name,
182 python_root: project_root,
183 is_mixed_layout,
184 config,
185 }
186 }
187
188 fn get_module(&mut self, name: Option<&str>) -> &mut Module {
189 let name = name.unwrap_or(&self.default_module_name).to_string();
190 let module = self.modules.entry(name.clone()).or_default();
191 module.name = name;
192 module.default_module_name = self.default_module_name.clone();
193 module
194 }
195
196 fn register_submodules(&mut self) {
197 let mut all_parent_child_pairs: Vec<(String, String)> = Vec::new();
198
199 for module in self.modules.keys() {
201 let path = module.split('.').collect::<Vec<_>>();
202
203 for i in 1..path.len() {
205 let parent = path[..i].join(".");
206 let child = path[i].to_string();
207 all_parent_child_pairs.push((parent, child));
208 }
209 }
210
211 let mut parent_to_children: BTreeMap<String, BTreeSet<String>> = BTreeMap::new();
213 for (parent, child) in all_parent_child_pairs {
214 parent_to_children.entry(parent).or_default().insert(child);
215 }
216
217 for (parent, children) in parent_to_children {
219 let module = self.modules.entry(parent.clone()).or_default();
220 module.name = parent;
221 module.default_module_name = self.default_module_name.clone();
222 module.submodules.extend(children);
223 }
224 }
225
226 fn add_class(&mut self, info: &PyClassInfo) {
227 let mut class_def = ClassDef::from(info);
228 class_def.resolve_default_modules(&self.default_module_name);
229 self.get_module(info.module)
230 .class
231 .insert((info.struct_id)(), class_def);
232 }
233
234 fn add_complex_enum(&mut self, info: &PyComplexEnumInfo) {
235 let mut class_def = ClassDef::from(info);
236 class_def.resolve_default_modules(&self.default_module_name);
237 self.get_module(info.module)
238 .class
239 .insert((info.enum_id)(), class_def);
240 }
241
242 fn add_enum(&mut self, info: &PyEnumInfo) {
243 self.get_module(info.module)
244 .enum_
245 .insert((info.enum_id)(), EnumDef::from(info));
246 }
247
248 fn add_function(&mut self, info: &PyFunctionInfo) -> Result<()> {
249 let default_module_name = self.default_module_name.clone();
251
252 let target = self
253 .get_module(info.module)
254 .function
255 .entry(info.name)
256 .or_default();
257
258 let mut new_func = FunctionDef::from(info);
260 new_func.resolve_default_modules(&default_module_name);
261
262 if !new_func.is_overload {
263 let non_overload_count = target.iter().filter(|f| !f.is_overload).count();
264 if non_overload_count > 0 {
265 anyhow::bail!(
266 "Multiple functions with name '{}' found without @overload decorator. \
267 Please add @overload decorator to all variants.",
268 info.name
269 );
270 }
271 }
272
273 target.push(new_func);
274 Ok(())
275 }
276
277 fn add_variable(&mut self, info: &PyVariableInfo) {
278 self.get_module(Some(info.module))
279 .variables
280 .insert(info.name, VariableDef::from(info));
281 }
282
283 fn add_type_alias(&mut self, info: &TypeAliasInfo) {
284 self.get_module(Some(info.module))
285 .type_aliases
286 .insert(info.name, TypeAliasDef::from(info));
287 }
288
289 fn add_module_doc(&mut self, info: &ModuleDocInfo) {
290 let raw_doc = (info.doc)();
291 self.get_module(Some(info.module)).doc = normalize_docstring(&raw_doc);
292 }
293
294 fn add_module_export(&mut self, info: &ReexportModuleMembers) {
295 let use_wildcard = info.items.is_none();
296 let items = info
297 .items
298 .map(|items| items.iter().map(|s| s.to_string()).collect())
299 .unwrap_or_default();
300
301 self.get_module(Some(info.target_module))
302 .module_re_exports
303 .push(ModuleReExport {
304 source_module: info.source_module.to_string(),
305 items,
306 use_wildcard_import: use_wildcard,
307 });
308 }
309
310 fn add_verbatim_export(&mut self, info: &ExportVerbatim) {
311 self.get_module(Some(info.target_module))
312 .verbatim_all_entries
313 .insert(info.name.to_string());
314 }
315
316 fn add_exclude(&mut self, info: &ExcludeFromAll) {
317 self.get_module(Some(info.target_module))
318 .excluded_all_entries
319 .insert(info.name.to_string());
320 }
321
322 fn resolve_wildcard_re_exports(&mut self) -> Result<()> {
323 let mut resolutions: Vec<(String, usize, Vec<String>)> = Vec::new();
325
326 for (module_name, module) in &self.modules {
327 for (idx, re_export) in module.module_re_exports.iter().enumerate() {
328 if re_export.use_wildcard_import && re_export.items.is_empty() {
329 if let Some(source_mod) = self.modules.get(&re_export.source_module) {
331 let mut items = Vec::new();
333 for class in source_mod.class.values() {
334 if !class.name.starts_with('_') {
335 items.push(class.name.to_string());
336 }
337 }
338 for enum_ in source_mod.enum_.values() {
339 if !enum_.name.starts_with('_') {
340 items.push(enum_.name.to_string());
341 }
342 }
343 for func_name in source_mod.function.keys() {
344 if !func_name.starts_with('_') {
345 items.push(func_name.to_string());
346 }
347 }
348 for var_name in source_mod.variables.keys() {
349 if !var_name.starts_with('_') {
350 items.push(var_name.to_string());
351 }
352 }
353 for alias_name in source_mod.type_aliases.keys() {
354 if !alias_name.starts_with('_') {
355 items.push(alias_name.to_string());
356 }
357 }
358 for submod in &source_mod.submodules {
360 if !submod.starts_with('_') {
361 items.push(submod.to_string());
362 }
363 }
364 resolutions.push((module_name.clone(), idx, items));
365 } else {
366 anyhow::bail!(
368 "Cannot resolve wildcard re-export in module '{}': source module '{}' not found. \
369 Wildcard re-exports only work with internal modules.",
370 module_name,
371 re_export.source_module
372 );
373 }
374 }
375 }
376 }
377
378 for (module_name, idx, items) in resolutions {
380 if let Some(module) = self.modules.get_mut(&module_name) {
381 module.module_re_exports[idx].items = items;
382 }
383 }
384
385 Ok(())
386 }
387
388 fn add_methods(&mut self, info: &PyMethodsInfo) -> Result<()> {
389 let struct_id = (info.struct_id)();
390 for module in self.modules.values_mut() {
391 if let Some(entry) = module.class.get_mut(&struct_id) {
392 for attr in info.attrs {
393 entry.attrs.push(MemberDef {
394 name: attr.name,
395 r#type: (attr.r#type)(),
396 doc: attr.doc,
397 default: attr.default.map(|f| f()),
398 deprecated: attr.deprecated.clone(),
399 });
400 }
401 for getter in info.getters {
402 entry
403 .getter_setters
404 .entry(getter.name.to_string())
405 .or_default()
406 .0 = Some(MemberDef {
407 name: getter.name,
408 r#type: (getter.r#type)(),
409 doc: getter.doc,
410 default: getter.default.map(|f| f()),
411 deprecated: getter.deprecated.clone(),
412 });
413 }
414 for setter in info.setters {
415 entry
416 .getter_setters
417 .entry(setter.name.to_string())
418 .or_default()
419 .1 = Some(MemberDef {
420 name: setter.name,
421 r#type: (setter.r#type)(),
422 doc: setter.doc,
423 default: setter.default.map(|f| f()),
424 deprecated: setter.deprecated.clone(),
425 });
426 }
427 for method in info.methods {
428 let entries = entry.methods.entry(method.name.to_string()).or_default();
429
430 let new_method = MethodDef::from(method);
432 if !new_method.is_overload {
433 let non_overload_count = entries.iter().filter(|m| !m.is_overload).count();
434 if non_overload_count > 0 {
435 anyhow::bail!(
436 "Multiple methods with name '{}' in class '{}' found without @overload decorator. \
437 Please add @overload decorator to all variants.",
438 method.name, entry.name
439 );
440 }
441 }
442
443 entries.push(new_method);
444 }
445 return Ok(());
446 } else if let Some(entry) = module.enum_.get_mut(&struct_id) {
447 for attr in info.attrs {
448 entry.attrs.push(MemberDef {
449 name: attr.name,
450 r#type: (attr.r#type)(),
451 doc: attr.doc,
452 default: attr.default.map(|f| f()),
453 deprecated: attr.deprecated.clone(),
454 });
455 }
456 for getter in info.getters {
457 entry.getters.push(MemberDef {
458 name: getter.name,
459 r#type: (getter.r#type)(),
460 doc: getter.doc,
461 default: getter.default.map(|f| f()),
462 deprecated: getter.deprecated.clone(),
463 });
464 }
465 for setter in info.setters {
466 entry.setters.push(MemberDef {
467 name: setter.name,
468 r#type: (setter.r#type)(),
469 doc: setter.doc,
470 default: setter.default.map(|f| f()),
471 deprecated: setter.deprecated.clone(),
472 });
473 }
474 for method in info.methods {
475 let new_method = MethodDef::from(method);
477 if !new_method.is_overload {
478 let non_overload_count = entry
479 .methods
480 .iter()
481 .filter(|m| m.name == method.name && !m.is_overload)
482 .count();
483 if non_overload_count > 0 {
484 anyhow::bail!(
485 "Multiple methods with name '{}' in enum '{}' found without @overload decorator. \
486 Please add @overload decorator to all variants.",
487 method.name, entry.name
488 );
489 }
490 }
491
492 entry.methods.push(new_method);
493 }
494 return Ok(());
495 }
496 }
497 unreachable!("Missing struct_id/enum_id = {:?}", struct_id);
498 }
499
500 fn build(mut self) -> Result<StubInfo> {
501 for info in inventory::iter::<PyClassInfo> {
502 self.add_class(info);
503 }
504 for info in inventory::iter::<PyComplexEnumInfo> {
505 self.add_complex_enum(info);
506 }
507 for info in inventory::iter::<PyEnumInfo> {
508 self.add_enum(info);
509 }
510 for info in inventory::iter::<PyFunctionInfo> {
511 self.add_function(info)?;
512 }
513 for info in inventory::iter::<PyVariableInfo> {
514 self.add_variable(info);
515 }
516 for info in inventory::iter::<TypeAliasInfo> {
517 self.add_type_alias(info);
518 }
519 for info in inventory::iter::<ModuleDocInfo> {
520 self.add_module_doc(info);
521 }
522 let mut methods_infos: Vec<&PyMethodsInfo> = inventory::iter::<PyMethodsInfo>().collect();
524 methods_infos.sort_by_key(|info| (info.file, info.line, info.column));
525 for info in methods_infos {
526 self.add_methods(info)?;
527 }
528 for info in inventory::iter::<ReexportModuleMembers> {
530 self.add_module_export(info);
531 }
532 for info in inventory::iter::<ExportVerbatim> {
533 self.add_verbatim_export(info);
534 }
535 for info in inventory::iter::<ExcludeFromAll> {
536 self.add_exclude(info);
537 }
538 self.register_submodules();
539
540 self.resolve_wildcard_re_exports()?;
542
543 Ok(StubInfo {
544 modules: self.modules,
545 python_root: self.python_root,
546 is_mixed_layout: self.is_mixed_layout,
547 config: self.config,
548 pyproject_dir: None, })
550 }
551}
552
553#[cfg(test)]
554mod tests {
555 use super::*;
556
557 #[test]
558 fn test_register_submodules_creates_empty_parent_modules() {
559 let mut builder = StubInfoBuilder::from_project_root(
560 "test_module".to_string(),
561 "/tmp".into(),
562 false,
563 StubGenConfig::default(),
564 );
565
566 builder.modules.insert(
568 "test_module.sub_mod".to_string(),
569 Module {
570 name: "test_module.sub_mod".to_string(),
571 default_module_name: "test_module".to_string(),
572 ..Default::default()
573 },
574 );
575
576 builder.register_submodules();
577
578 assert!(builder.modules.contains_key("test_module"));
580 let parent_module = &builder.modules["test_module"];
581 assert_eq!(parent_module.name, "test_module");
582 assert!(parent_module.submodules.contains("sub_mod"));
583
584 assert!(builder.modules.contains_key("test_module.sub_mod"));
586 }
587
588 #[test]
589 fn test_register_submodules_with_multiple_levels() {
590 let mut builder = StubInfoBuilder::from_project_root(
591 "root".to_string(),
592 "/tmp".into(),
593 false,
594 StubGenConfig::default(),
595 );
596
597 builder.modules.insert(
599 "root.level1.level2.deep_mod".to_string(),
600 Module {
601 name: "root.level1.level2.deep_mod".to_string(),
602 default_module_name: "root".to_string(),
603 ..Default::default()
604 },
605 );
606
607 builder.register_submodules();
608
609 assert!(builder.modules.contains_key("root"));
611 assert!(builder.modules.contains_key("root.level1"));
612 assert!(builder.modules.contains_key("root.level1.level2"));
613 assert!(builder.modules.contains_key("root.level1.level2.deep_mod"));
614
615 assert!(builder.modules["root"].submodules.contains("level1"));
617 assert!(builder.modules["root.level1"].submodules.contains("level2"));
618 assert!(builder.modules["root.level1.level2"]
619 .submodules
620 .contains("deep_mod"));
621 }
622
623 #[test]
624 fn test_pure_layout_rejects_multiple_modules() {
625 let stub_info = StubInfo {
627 modules: {
628 let mut map = BTreeMap::new();
629 map.insert(
630 "mymodule".to_string(),
631 Module {
632 name: "mymodule".to_string(),
633 default_module_name: "mymodule".to_string(),
634 ..Default::default()
635 },
636 );
637 map.insert(
638 "mymodule.sub".to_string(),
639 Module {
640 name: "mymodule.sub".to_string(),
641 default_module_name: "mymodule".to_string(),
642 ..Default::default()
643 },
644 );
645 map
646 },
647 python_root: PathBuf::from("/tmp"),
648 is_mixed_layout: false,
649 config: StubGenConfig::default(),
650 pyproject_dir: None,
651 };
652
653 let result = stub_info.generate();
654 assert!(result.is_err());
655 let err_msg = result.unwrap_err().to_string();
656 assert!(
657 err_msg.contains("Pure Rust layout does not support multiple modules or submodules")
658 );
659 }
660}