1use crate::{generate::*, pyproject::PyProject, type_info::*};
2use anyhow::{Context, Result};
3use std::{
4 collections::{BTreeMap, BTreeSet},
5 fs,
6 io::Write,
7 path::*,
8};
9
10#[derive(Debug, Clone, PartialEq)]
11pub struct StubInfo {
12 pub modules: BTreeMap<String, Module>,
13 pub python_root: PathBuf,
14 pub is_mixed_layout: bool,
16}
17
18impl StubInfo {
19 pub fn from_pyproject_toml(path: impl AsRef<Path>) -> Result<Self> {
22 let pyproject = PyProject::parse_toml(path)?;
23 StubInfoBuilder::from_pyproject_toml(pyproject).build()
24 }
25
26 pub fn from_project_root(
30 default_module_name: String,
31 project_root: PathBuf,
32 is_mixed_layout: bool,
33 ) -> Result<Self> {
34 StubInfoBuilder::from_project_root(default_module_name, project_root, is_mixed_layout)
35 .build()
36 }
37
38 pub fn generate(&self) -> Result<()> {
39 if !self.is_mixed_layout && self.modules.len() > 1 {
41 let module_names: Vec<_> = self.modules.keys().collect();
42 anyhow::bail!(
43 "Pure Rust layout does not support multiple modules or submodules. Found {} modules: {}. \
44 Please use mixed Python/Rust layout (add `python-source` to [tool.maturin] in pyproject.toml) \
45 if you need multiple modules or submodules.",
46 self.modules.len(),
47 module_names.iter().map(|s| format!("'{}'", s)).collect::<Vec<_>>().join(", ")
48 );
49 }
50
51 for (name, module) in self.modules.iter() {
52 let normalized_name = name.replace("-", "_");
54 let path = normalized_name.replace(".", "/");
55
56 let dest = if self.is_mixed_layout {
58 self.python_root.join(&path).join("__init__.pyi")
60 } else {
61 let package_name = normalized_name
63 .split('.')
64 .next()
65 .filter(|s| !s.is_empty())
66 .ok_or_else(|| {
67 anyhow::anyhow!(
68 "Module name is empty after normalization: original name was `{name}`"
69 )
70 })?;
71 self.python_root.join(format!("{package_name}.pyi"))
72 };
73
74 let dir = dest.parent().context("Cannot get parent directory")?;
75 if !dir.exists() {
76 fs::create_dir_all(dir)?;
77 }
78
79 let mut f = fs::File::create(&dest)?;
80 write!(f, "{module}")?;
81 log::info!(
82 "Generate stub file of a module `{name}` at {dest}",
83 dest = dest.display()
84 );
85 }
86 Ok(())
87 }
88}
89
90struct StubInfoBuilder {
91 modules: BTreeMap<String, Module>,
92 default_module_name: String,
93 python_root: PathBuf,
94 is_mixed_layout: bool,
95}
96
97impl StubInfoBuilder {
98 fn from_pyproject_toml(pyproject: PyProject) -> Self {
99 let is_mixed_layout = pyproject.python_source().is_some();
100 let python_root = pyproject
101 .python_source()
102 .unwrap_or(PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").unwrap()));
103
104 Self {
105 modules: BTreeMap::new(),
106 default_module_name: pyproject.module_name().to_string(),
107 python_root,
108 is_mixed_layout,
109 }
110 }
111
112 fn from_project_root(
113 default_module_name: String,
114 project_root: PathBuf,
115 is_mixed_layout: bool,
116 ) -> Self {
117 Self {
118 modules: BTreeMap::new(),
119 default_module_name,
120 python_root: project_root,
121 is_mixed_layout,
122 }
123 }
124
125 fn get_module(&mut self, name: Option<&str>) -> &mut Module {
126 let name = name.unwrap_or(&self.default_module_name).to_string();
127 let module = self.modules.entry(name.clone()).or_default();
128 module.name = name;
129 module.default_module_name = self.default_module_name.clone();
130 module
131 }
132
133 fn register_submodules(&mut self) {
134 let mut all_parent_child_pairs: Vec<(String, String)> = Vec::new();
135
136 for module in self.modules.keys() {
138 let path = module.split('.').collect::<Vec<_>>();
139
140 for i in 1..path.len() {
142 let parent = path[..i].join(".");
143 let child = path[i].to_string();
144 all_parent_child_pairs.push((parent, child));
145 }
146 }
147
148 let mut parent_to_children: BTreeMap<String, BTreeSet<String>> = BTreeMap::new();
150 for (parent, child) in all_parent_child_pairs {
151 parent_to_children.entry(parent).or_default().insert(child);
152 }
153
154 for (parent, children) in parent_to_children {
156 let module = self.modules.entry(parent.clone()).or_default();
157 module.name = parent;
158 module.default_module_name = self.default_module_name.clone();
159 module.submodules.extend(children);
160 }
161 }
162
163 fn add_class(&mut self, info: &PyClassInfo) {
164 self.get_module(info.module)
165 .class
166 .insert((info.struct_id)(), ClassDef::from(info));
167 }
168
169 fn add_complex_enum(&mut self, info: &PyComplexEnumInfo) {
170 self.get_module(info.module)
171 .class
172 .insert((info.enum_id)(), ClassDef::from(info));
173 }
174
175 fn add_enum(&mut self, info: &PyEnumInfo) {
176 self.get_module(info.module)
177 .enum_
178 .insert((info.enum_id)(), EnumDef::from(info));
179 }
180
181 fn add_function(&mut self, info: &PyFunctionInfo) -> Result<()> {
182 let target = self
183 .get_module(info.module)
184 .function
185 .entry(info.name)
186 .or_default();
187
188 let new_func = FunctionDef::from(info);
190 if !new_func.is_overload {
191 let non_overload_count = target.iter().filter(|f| !f.is_overload).count();
192 if non_overload_count > 0 {
193 anyhow::bail!(
194 "Multiple functions with name '{}' found without @overload decorator. \
195 Please add @overload decorator to all variants.",
196 info.name
197 );
198 }
199 }
200
201 target.push(new_func);
202 Ok(())
203 }
204
205 fn add_variable(&mut self, info: &PyVariableInfo) {
206 self.get_module(Some(info.module))
207 .variables
208 .insert(info.name, VariableDef::from(info));
209 }
210
211 fn add_module_doc(&mut self, info: &ModuleDocInfo) {
212 self.get_module(Some(info.module)).doc = (info.doc)();
213 }
214
215 fn add_methods(&mut self, info: &PyMethodsInfo) -> Result<()> {
216 let struct_id = (info.struct_id)();
217 for module in self.modules.values_mut() {
218 if let Some(entry) = module.class.get_mut(&struct_id) {
219 for attr in info.attrs {
220 entry.attrs.push(MemberDef {
221 name: attr.name,
222 r#type: (attr.r#type)(),
223 doc: attr.doc,
224 default: attr.default.map(|f| f()),
225 deprecated: attr.deprecated.clone(),
226 });
227 }
228 for getter in info.getters {
229 entry
230 .getter_setters
231 .entry(getter.name.to_string())
232 .or_default()
233 .0 = Some(MemberDef {
234 name: getter.name,
235 r#type: (getter.r#type)(),
236 doc: getter.doc,
237 default: getter.default.map(|f| f()),
238 deprecated: getter.deprecated.clone(),
239 });
240 }
241 for setter in info.setters {
242 entry
243 .getter_setters
244 .entry(setter.name.to_string())
245 .or_default()
246 .1 = Some(MemberDef {
247 name: setter.name,
248 r#type: (setter.r#type)(),
249 doc: setter.doc,
250 default: setter.default.map(|f| f()),
251 deprecated: setter.deprecated.clone(),
252 });
253 }
254 for method in info.methods {
255 let entries = entry.methods.entry(method.name.to_string()).or_default();
256
257 let new_method = MethodDef::from(method);
259 if !new_method.is_overload {
260 let non_overload_count = entries.iter().filter(|m| !m.is_overload).count();
261 if non_overload_count > 0 {
262 anyhow::bail!(
263 "Multiple methods with name '{}' in class '{}' found without @overload decorator. \
264 Please add @overload decorator to all variants.",
265 method.name, entry.name
266 );
267 }
268 }
269
270 entries.push(new_method);
271 }
272 return Ok(());
273 } else if let Some(entry) = module.enum_.get_mut(&struct_id) {
274 for attr in info.attrs {
275 entry.attrs.push(MemberDef {
276 name: attr.name,
277 r#type: (attr.r#type)(),
278 doc: attr.doc,
279 default: attr.default.map(|f| f()),
280 deprecated: attr.deprecated.clone(),
281 });
282 }
283 for getter in info.getters {
284 entry.getters.push(MemberDef {
285 name: getter.name,
286 r#type: (getter.r#type)(),
287 doc: getter.doc,
288 default: getter.default.map(|f| f()),
289 deprecated: getter.deprecated.clone(),
290 });
291 }
292 for setter in info.setters {
293 entry.setters.push(MemberDef {
294 name: setter.name,
295 r#type: (setter.r#type)(),
296 doc: setter.doc,
297 default: setter.default.map(|f| f()),
298 deprecated: setter.deprecated.clone(),
299 });
300 }
301 for method in info.methods {
302 let new_method = MethodDef::from(method);
304 if !new_method.is_overload {
305 let non_overload_count = entry
306 .methods
307 .iter()
308 .filter(|m| m.name == method.name && !m.is_overload)
309 .count();
310 if non_overload_count > 0 {
311 anyhow::bail!(
312 "Multiple methods with name '{}' in enum '{}' found without @overload decorator. \
313 Please add @overload decorator to all variants.",
314 method.name, entry.name
315 );
316 }
317 }
318
319 entry.methods.push(new_method);
320 }
321 return Ok(());
322 }
323 }
324 unreachable!("Missing struct_id/enum_id = {:?}", struct_id);
325 }
326
327 fn build(mut self) -> Result<StubInfo> {
328 for info in inventory::iter::<PyClassInfo> {
329 self.add_class(info);
330 }
331 for info in inventory::iter::<PyComplexEnumInfo> {
332 self.add_complex_enum(info);
333 }
334 for info in inventory::iter::<PyEnumInfo> {
335 self.add_enum(info);
336 }
337 for info in inventory::iter::<PyFunctionInfo> {
338 self.add_function(info)?;
339 }
340 for info in inventory::iter::<PyVariableInfo> {
341 self.add_variable(info);
342 }
343 for info in inventory::iter::<ModuleDocInfo> {
344 self.add_module_doc(info);
345 }
346 let mut methods_infos: Vec<&PyMethodsInfo> = inventory::iter::<PyMethodsInfo>().collect();
348 methods_infos.sort_by_key(|info| (info.file, info.line, info.column));
349 for info in methods_infos {
350 self.add_methods(info)?;
351 }
352 self.register_submodules();
353 Ok(StubInfo {
354 modules: self.modules,
355 python_root: self.python_root,
356 is_mixed_layout: self.is_mixed_layout,
357 })
358 }
359}
360
361#[cfg(test)]
362mod tests {
363 use super::*;
364
365 #[test]
366 fn test_register_submodules_creates_empty_parent_modules() {
367 let mut builder =
368 StubInfoBuilder::from_project_root("test_module".to_string(), "/tmp".into(), false);
369
370 builder.modules.insert(
372 "test_module.sub_mod".to_string(),
373 Module {
374 name: "test_module.sub_mod".to_string(),
375 default_module_name: "test_module".to_string(),
376 ..Default::default()
377 },
378 );
379
380 builder.register_submodules();
381
382 assert!(builder.modules.contains_key("test_module"));
384 let parent_module = &builder.modules["test_module"];
385 assert_eq!(parent_module.name, "test_module");
386 assert!(parent_module.submodules.contains("sub_mod"));
387
388 assert!(builder.modules.contains_key("test_module.sub_mod"));
390 }
391
392 #[test]
393 fn test_register_submodules_with_multiple_levels() {
394 let mut builder =
395 StubInfoBuilder::from_project_root("root".to_string(), "/tmp".into(), false);
396
397 builder.modules.insert(
399 "root.level1.level2.deep_mod".to_string(),
400 Module {
401 name: "root.level1.level2.deep_mod".to_string(),
402 default_module_name: "root".to_string(),
403 ..Default::default()
404 },
405 );
406
407 builder.register_submodules();
408
409 assert!(builder.modules.contains_key("root"));
411 assert!(builder.modules.contains_key("root.level1"));
412 assert!(builder.modules.contains_key("root.level1.level2"));
413 assert!(builder.modules.contains_key("root.level1.level2.deep_mod"));
414
415 assert!(builder.modules["root"].submodules.contains("level1"));
417 assert!(builder.modules["root.level1"].submodules.contains("level2"));
418 assert!(builder.modules["root.level1.level2"]
419 .submodules
420 .contains("deep_mod"));
421 }
422
423 #[test]
424 fn test_pure_layout_rejects_multiple_modules() {
425 let stub_info = StubInfo {
427 modules: {
428 let mut map = BTreeMap::new();
429 map.insert(
430 "mymodule".to_string(),
431 Module {
432 name: "mymodule".to_string(),
433 default_module_name: "mymodule".to_string(),
434 ..Default::default()
435 },
436 );
437 map.insert(
438 "mymodule.sub".to_string(),
439 Module {
440 name: "mymodule.sub".to_string(),
441 default_module_name: "mymodule".to_string(),
442 ..Default::default()
443 },
444 );
445 map
446 },
447 python_root: PathBuf::from("/tmp"),
448 is_mixed_layout: false,
449 };
450
451 let result = stub_info.generate();
452 assert!(result.is_err());
453 let err_msg = result.unwrap_err().to_string();
454 assert!(
455 err_msg.contains("Pure Rust layout does not support multiple modules or submodules")
456 );
457 }
458}