1use crate::{
2 generate::*,
3 pyproject::{PyProject, StubGenConfig},
4 type_info::*,
5};
6use anyhow::{Context, Result};
7use std::{
8 collections::{BTreeMap, BTreeSet},
9 fs,
10 path::*,
11};
12
13#[derive(Debug, Clone, PartialEq)]
14pub struct StubInfo {
15 pub modules: BTreeMap<String, Module>,
16 pub python_root: PathBuf,
17 pub is_mixed_layout: bool,
19 pub config: StubGenConfig,
21}
22
23impl StubInfo {
24 pub fn from_pyproject_toml(path: impl AsRef<Path>) -> Result<Self> {
27 let pyproject = PyProject::parse_toml(path)?;
28 let config = pyproject.stub_gen_config();
29 StubInfoBuilder::from_pyproject_toml(pyproject, config).build()
30 }
31
32 pub fn from_project_root(
36 default_module_name: String,
37 project_root: PathBuf,
38 is_mixed_layout: bool,
39 config: StubGenConfig,
40 ) -> Result<Self> {
41 StubInfoBuilder::from_project_root(
42 default_module_name,
43 project_root,
44 is_mixed_layout,
45 config,
46 )
47 .build()
48 }
49
50 pub fn generate(&self) -> Result<()> {
51 if !self.is_mixed_layout && self.modules.len() > 1 {
53 let module_names: Vec<_> = self.modules.keys().collect();
54 anyhow::bail!(
55 "Pure Rust layout does not support multiple modules or submodules. Found {} modules: {}. \
56 Please use mixed Python/Rust layout (add `python-source` to [tool.maturin] in pyproject.toml) \
57 if you need multiple modules or submodules.",
58 self.modules.len(),
59 module_names.iter().map(|s| format!("'{}'", s)).collect::<Vec<_>>().join(", ")
60 );
61 }
62
63 for (name, module) in self.modules.iter() {
64 let normalized_name = name.replace("-", "_");
66 let path = normalized_name.replace(".", "/");
67
68 let dest = if self.is_mixed_layout {
70 self.python_root.join(&path).join("__init__.pyi")
72 } else {
73 let package_name = normalized_name
75 .split('.')
76 .next()
77 .filter(|s| !s.is_empty())
78 .ok_or_else(|| {
79 anyhow::anyhow!(
80 "Module name is empty after normalization: original name was `{name}`"
81 )
82 })?;
83 self.python_root.join(format!("{package_name}.pyi"))
84 };
85
86 let dir = dest.parent().context("Cannot get parent directory")?;
87 if !dir.exists() {
88 fs::create_dir_all(dir)?;
89 }
90
91 let content = module.format_with_config(self.config.use_type_statement);
92 fs::write(&dest, content)?;
93 log::info!(
94 "Generate stub file of a module `{name}` at {dest}",
95 dest = dest.display()
96 );
97 }
98 Ok(())
99 }
100}
101
102struct StubInfoBuilder {
103 modules: BTreeMap<String, Module>,
104 default_module_name: String,
105 python_root: PathBuf,
106 is_mixed_layout: bool,
107 config: StubGenConfig,
108}
109
110impl StubInfoBuilder {
111 fn from_pyproject_toml(pyproject: PyProject, config: StubGenConfig) -> Self {
112 let is_mixed_layout = pyproject.python_source().is_some();
113 let python_root = pyproject
114 .python_source()
115 .unwrap_or(PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").unwrap()));
116
117 Self {
118 modules: BTreeMap::new(),
119 default_module_name: pyproject.module_name().to_string(),
120 python_root,
121 is_mixed_layout,
122 config,
123 }
124 }
125
126 fn from_project_root(
127 default_module_name: String,
128 project_root: PathBuf,
129 is_mixed_layout: bool,
130 config: StubGenConfig,
131 ) -> Self {
132 Self {
133 modules: BTreeMap::new(),
134 default_module_name,
135 python_root: project_root,
136 is_mixed_layout,
137 config,
138 }
139 }
140
141 fn get_module(&mut self, name: Option<&str>) -> &mut Module {
142 let name = name.unwrap_or(&self.default_module_name).to_string();
143 let module = self.modules.entry(name.clone()).or_default();
144 module.name = name;
145 module.default_module_name = self.default_module_name.clone();
146 module
147 }
148
149 fn register_submodules(&mut self) {
150 let mut all_parent_child_pairs: Vec<(String, String)> = Vec::new();
151
152 for module in self.modules.keys() {
154 let path = module.split('.').collect::<Vec<_>>();
155
156 for i in 1..path.len() {
158 let parent = path[..i].join(".");
159 let child = path[i].to_string();
160 all_parent_child_pairs.push((parent, child));
161 }
162 }
163
164 let mut parent_to_children: BTreeMap<String, BTreeSet<String>> = BTreeMap::new();
166 for (parent, child) in all_parent_child_pairs {
167 parent_to_children.entry(parent).or_default().insert(child);
168 }
169
170 for (parent, children) in parent_to_children {
172 let module = self.modules.entry(parent.clone()).or_default();
173 module.name = parent;
174 module.default_module_name = self.default_module_name.clone();
175 module.submodules.extend(children);
176 }
177 }
178
179 fn add_class(&mut self, info: &PyClassInfo) {
180 let mut class_def = ClassDef::from(info);
181 class_def.resolve_default_modules(&self.default_module_name);
182 self.get_module(info.module)
183 .class
184 .insert((info.struct_id)(), class_def);
185 }
186
187 fn add_complex_enum(&mut self, info: &PyComplexEnumInfo) {
188 let mut class_def = ClassDef::from(info);
189 class_def.resolve_default_modules(&self.default_module_name);
190 self.get_module(info.module)
191 .class
192 .insert((info.enum_id)(), class_def);
193 }
194
195 fn add_enum(&mut self, info: &PyEnumInfo) {
196 self.get_module(info.module)
197 .enum_
198 .insert((info.enum_id)(), EnumDef::from(info));
199 }
200
201 fn add_function(&mut self, info: &PyFunctionInfo) -> Result<()> {
202 let default_module_name = self.default_module_name.clone();
204
205 let target = self
206 .get_module(info.module)
207 .function
208 .entry(info.name)
209 .or_default();
210
211 let mut new_func = FunctionDef::from(info);
213 new_func.resolve_default_modules(&default_module_name);
214
215 if !new_func.is_overload {
216 let non_overload_count = target.iter().filter(|f| !f.is_overload).count();
217 if non_overload_count > 0 {
218 anyhow::bail!(
219 "Multiple functions with name '{}' found without @overload decorator. \
220 Please add @overload decorator to all variants.",
221 info.name
222 );
223 }
224 }
225
226 target.push(new_func);
227 Ok(())
228 }
229
230 fn add_variable(&mut self, info: &PyVariableInfo) {
231 self.get_module(Some(info.module))
232 .variables
233 .insert(info.name, VariableDef::from(info));
234 }
235
236 fn add_type_alias(&mut self, info: &TypeAliasInfo) {
237 self.get_module(Some(info.module))
238 .type_aliases
239 .insert(info.name, TypeAliasDef::from(info));
240 }
241
242 fn add_module_doc(&mut self, info: &ModuleDocInfo) {
243 self.get_module(Some(info.module)).doc = (info.doc)();
244 }
245
246 fn add_module_export(&mut self, info: &ReexportModuleMembers) {
247 let use_wildcard = info.items.is_none();
248 let items = info
249 .items
250 .map(|items| items.iter().map(|s| s.to_string()).collect())
251 .unwrap_or_default();
252
253 self.get_module(Some(info.target_module))
254 .module_re_exports
255 .push(ModuleReExport {
256 source_module: info.source_module.to_string(),
257 items,
258 use_wildcard_import: use_wildcard,
259 });
260 }
261
262 fn add_verbatim_export(&mut self, info: &ExportVerbatim) {
263 self.get_module(Some(info.target_module))
264 .verbatim_all_entries
265 .insert(info.name.to_string());
266 }
267
268 fn add_exclude(&mut self, info: &ExcludeFromAll) {
269 self.get_module(Some(info.target_module))
270 .excluded_all_entries
271 .insert(info.name.to_string());
272 }
273
274 fn resolve_wildcard_re_exports(&mut self) -> Result<()> {
275 let mut resolutions: Vec<(String, usize, Vec<String>)> = Vec::new();
277
278 for (module_name, module) in &self.modules {
279 for (idx, re_export) in module.module_re_exports.iter().enumerate() {
280 if re_export.use_wildcard_import && re_export.items.is_empty() {
281 if let Some(source_mod) = self.modules.get(&re_export.source_module) {
283 let mut items = Vec::new();
285 for class in source_mod.class.values() {
286 if !class.name.starts_with('_') {
287 items.push(class.name.to_string());
288 }
289 }
290 for enum_ in source_mod.enum_.values() {
291 if !enum_.name.starts_with('_') {
292 items.push(enum_.name.to_string());
293 }
294 }
295 for func_name in source_mod.function.keys() {
296 if !func_name.starts_with('_') {
297 items.push(func_name.to_string());
298 }
299 }
300 for var_name in source_mod.variables.keys() {
301 if !var_name.starts_with('_') {
302 items.push(var_name.to_string());
303 }
304 }
305 for alias_name in source_mod.type_aliases.keys() {
306 if !alias_name.starts_with('_') {
307 items.push(alias_name.to_string());
308 }
309 }
310 for submod in &source_mod.submodules {
312 if !submod.starts_with('_') {
313 items.push(submod.to_string());
314 }
315 }
316 resolutions.push((module_name.clone(), idx, items));
317 } else {
318 anyhow::bail!(
320 "Cannot resolve wildcard re-export in module '{}': source module '{}' not found. \
321 Wildcard re-exports only work with internal modules.",
322 module_name,
323 re_export.source_module
324 );
325 }
326 }
327 }
328 }
329
330 for (module_name, idx, items) in resolutions {
332 if let Some(module) = self.modules.get_mut(&module_name) {
333 module.module_re_exports[idx].items = items;
334 }
335 }
336
337 Ok(())
338 }
339
340 fn add_methods(&mut self, info: &PyMethodsInfo) -> Result<()> {
341 let struct_id = (info.struct_id)();
342 for module in self.modules.values_mut() {
343 if let Some(entry) = module.class.get_mut(&struct_id) {
344 for attr in info.attrs {
345 entry.attrs.push(MemberDef {
346 name: attr.name,
347 r#type: (attr.r#type)(),
348 doc: attr.doc,
349 default: attr.default.map(|f| f()),
350 deprecated: attr.deprecated.clone(),
351 });
352 }
353 for getter in info.getters {
354 entry
355 .getter_setters
356 .entry(getter.name.to_string())
357 .or_default()
358 .0 = Some(MemberDef {
359 name: getter.name,
360 r#type: (getter.r#type)(),
361 doc: getter.doc,
362 default: getter.default.map(|f| f()),
363 deprecated: getter.deprecated.clone(),
364 });
365 }
366 for setter in info.setters {
367 entry
368 .getter_setters
369 .entry(setter.name.to_string())
370 .or_default()
371 .1 = Some(MemberDef {
372 name: setter.name,
373 r#type: (setter.r#type)(),
374 doc: setter.doc,
375 default: setter.default.map(|f| f()),
376 deprecated: setter.deprecated.clone(),
377 });
378 }
379 for method in info.methods {
380 let entries = entry.methods.entry(method.name.to_string()).or_default();
381
382 let new_method = MethodDef::from(method);
384 if !new_method.is_overload {
385 let non_overload_count = entries.iter().filter(|m| !m.is_overload).count();
386 if non_overload_count > 0 {
387 anyhow::bail!(
388 "Multiple methods with name '{}' in class '{}' found without @overload decorator. \
389 Please add @overload decorator to all variants.",
390 method.name, entry.name
391 );
392 }
393 }
394
395 entries.push(new_method);
396 }
397 return Ok(());
398 } else if let Some(entry) = module.enum_.get_mut(&struct_id) {
399 for attr in info.attrs {
400 entry.attrs.push(MemberDef {
401 name: attr.name,
402 r#type: (attr.r#type)(),
403 doc: attr.doc,
404 default: attr.default.map(|f| f()),
405 deprecated: attr.deprecated.clone(),
406 });
407 }
408 for getter in info.getters {
409 entry.getters.push(MemberDef {
410 name: getter.name,
411 r#type: (getter.r#type)(),
412 doc: getter.doc,
413 default: getter.default.map(|f| f()),
414 deprecated: getter.deprecated.clone(),
415 });
416 }
417 for setter in info.setters {
418 entry.setters.push(MemberDef {
419 name: setter.name,
420 r#type: (setter.r#type)(),
421 doc: setter.doc,
422 default: setter.default.map(|f| f()),
423 deprecated: setter.deprecated.clone(),
424 });
425 }
426 for method in info.methods {
427 let new_method = MethodDef::from(method);
429 if !new_method.is_overload {
430 let non_overload_count = entry
431 .methods
432 .iter()
433 .filter(|m| m.name == method.name && !m.is_overload)
434 .count();
435 if non_overload_count > 0 {
436 anyhow::bail!(
437 "Multiple methods with name '{}' in enum '{}' found without @overload decorator. \
438 Please add @overload decorator to all variants.",
439 method.name, entry.name
440 );
441 }
442 }
443
444 entry.methods.push(new_method);
445 }
446 return Ok(());
447 }
448 }
449 unreachable!("Missing struct_id/enum_id = {:?}", struct_id);
450 }
451
452 fn build(mut self) -> Result<StubInfo> {
453 for info in inventory::iter::<PyClassInfo> {
454 self.add_class(info);
455 }
456 for info in inventory::iter::<PyComplexEnumInfo> {
457 self.add_complex_enum(info);
458 }
459 for info in inventory::iter::<PyEnumInfo> {
460 self.add_enum(info);
461 }
462 for info in inventory::iter::<PyFunctionInfo> {
463 self.add_function(info)?;
464 }
465 for info in inventory::iter::<PyVariableInfo> {
466 self.add_variable(info);
467 }
468 for info in inventory::iter::<TypeAliasInfo> {
469 self.add_type_alias(info);
470 }
471 for info in inventory::iter::<ModuleDocInfo> {
472 self.add_module_doc(info);
473 }
474 let mut methods_infos: Vec<&PyMethodsInfo> = inventory::iter::<PyMethodsInfo>().collect();
476 methods_infos.sort_by_key(|info| (info.file, info.line, info.column));
477 for info in methods_infos {
478 self.add_methods(info)?;
479 }
480 for info in inventory::iter::<ReexportModuleMembers> {
482 self.add_module_export(info);
483 }
484 for info in inventory::iter::<ExportVerbatim> {
485 self.add_verbatim_export(info);
486 }
487 for info in inventory::iter::<ExcludeFromAll> {
488 self.add_exclude(info);
489 }
490 self.register_submodules();
491
492 self.resolve_wildcard_re_exports()?;
494
495 Ok(StubInfo {
496 modules: self.modules,
497 python_root: self.python_root,
498 is_mixed_layout: self.is_mixed_layout,
499 config: self.config,
500 })
501 }
502}
503
504#[cfg(test)]
505mod tests {
506 use super::*;
507
508 #[test]
509 fn test_register_submodules_creates_empty_parent_modules() {
510 let mut builder = StubInfoBuilder::from_project_root(
511 "test_module".to_string(),
512 "/tmp".into(),
513 false,
514 StubGenConfig::default(),
515 );
516
517 builder.modules.insert(
519 "test_module.sub_mod".to_string(),
520 Module {
521 name: "test_module.sub_mod".to_string(),
522 default_module_name: "test_module".to_string(),
523 ..Default::default()
524 },
525 );
526
527 builder.register_submodules();
528
529 assert!(builder.modules.contains_key("test_module"));
531 let parent_module = &builder.modules["test_module"];
532 assert_eq!(parent_module.name, "test_module");
533 assert!(parent_module.submodules.contains("sub_mod"));
534
535 assert!(builder.modules.contains_key("test_module.sub_mod"));
537 }
538
539 #[test]
540 fn test_register_submodules_with_multiple_levels() {
541 let mut builder = StubInfoBuilder::from_project_root(
542 "root".to_string(),
543 "/tmp".into(),
544 false,
545 StubGenConfig::default(),
546 );
547
548 builder.modules.insert(
550 "root.level1.level2.deep_mod".to_string(),
551 Module {
552 name: "root.level1.level2.deep_mod".to_string(),
553 default_module_name: "root".to_string(),
554 ..Default::default()
555 },
556 );
557
558 builder.register_submodules();
559
560 assert!(builder.modules.contains_key("root"));
562 assert!(builder.modules.contains_key("root.level1"));
563 assert!(builder.modules.contains_key("root.level1.level2"));
564 assert!(builder.modules.contains_key("root.level1.level2.deep_mod"));
565
566 assert!(builder.modules["root"].submodules.contains("level1"));
568 assert!(builder.modules["root.level1"].submodules.contains("level2"));
569 assert!(builder.modules["root.level1.level2"]
570 .submodules
571 .contains("deep_mod"));
572 }
573
574 #[test]
575 fn test_pure_layout_rejects_multiple_modules() {
576 let stub_info = StubInfo {
578 modules: {
579 let mut map = BTreeMap::new();
580 map.insert(
581 "mymodule".to_string(),
582 Module {
583 name: "mymodule".to_string(),
584 default_module_name: "mymodule".to_string(),
585 ..Default::default()
586 },
587 );
588 map.insert(
589 "mymodule.sub".to_string(),
590 Module {
591 name: "mymodule.sub".to_string(),
592 default_module_name: "mymodule".to_string(),
593 ..Default::default()
594 },
595 );
596 map
597 },
598 python_root: PathBuf::from("/tmp"),
599 is_mixed_layout: false,
600 config: StubGenConfig::default(),
601 };
602
603 let result = stub_info.generate();
604 assert!(result.is_err());
605 let err_msg = result.unwrap_err().to_string();
606 assert!(
607 err_msg.contains("Pure Rust layout does not support multiple modules or submodules")
608 );
609 }
610}