1use crate::generate::*;
2use crate::stub_type::ImportRef;
3use itertools::Itertools;
4use std::{
5 any::TypeId,
6 collections::{BTreeMap, BTreeSet},
7 fmt,
8};
9
10#[derive(Debug, Clone, PartialEq)]
12pub struct ModuleReExport {
13 pub source_module: String,
14 pub items: Vec<String>,
16 pub additional_items: Vec<String>,
19}
20
21#[derive(Debug, Clone, PartialEq, Default)]
23pub struct Module {
24 pub doc: String,
25 pub class: BTreeMap<TypeId, ClassDef>,
26 pub enum_: BTreeMap<TypeId, EnumDef>,
27 pub function: BTreeMap<&'static str, Vec<FunctionDef>>,
28 pub variables: BTreeMap<&'static str, VariableDef>,
29 pub type_aliases: BTreeMap<&'static str, TypeAliasDef>,
30 pub name: String,
31 pub default_module_name: String,
32 pub submodules: BTreeSet<String>,
34 pub module_re_exports: Vec<ModuleReExport>,
36 pub verbatim_all_entries: BTreeSet<String>,
38 pub excluded_all_entries: BTreeSet<String>,
40}
41
42impl Module {
43 pub fn is_empty(&self) -> bool {
49 self.doc.is_empty()
50 && self.class.is_empty()
51 && self.enum_.is_empty()
52 && self.function.is_empty()
53 && self.variables.is_empty()
54 && self.type_aliases.is_empty()
55 && self.submodules.is_empty()
56 && self.module_re_exports.is_empty()
57 && self.verbatim_all_entries.is_empty()
58 }
59
60 pub fn is_init_py_compatible(&self) -> bool {
66 self.class.is_empty()
67 && self.enum_.is_empty()
68 && self.function.is_empty()
69 && self.variables.is_empty()
70 && self.type_aliases.is_empty()
71 }
72
73 pub fn declared_item_names(&self) -> Vec<String> {
77 let mut names = Vec::new();
78
79 if !self.doc.is_empty() {
80 names.push("module_doc".to_string());
81 }
82 for class in self.class.values() {
83 names.push(format!("class {}", class.name));
84 }
85 for enum_def in self.enum_.values() {
86 names.push(format!("enum {}", enum_def.name));
87 }
88 for func_name in self.function.keys() {
89 names.push(format!("function {}", func_name));
90 }
91 for var_name in self.variables.keys() {
92 names.push(format!("variable {}", var_name));
93 }
94 for alias_name in self.type_aliases.keys() {
95 names.push(format!("type_alias {}", alias_name));
96 }
97 for re_export in &self.module_re_exports {
98 names.push(format!("re-export from {}", re_export.source_module));
99 }
100
101 names.sort();
102 names
103 }
104
105 pub fn format_with_config(&self, use_type_statement: bool) -> String {
107 use std::fmt::Write;
108 let mut output = String::new();
109
110 struct ModuleFormatter<'a> {
112 module: &'a Module,
113 use_type_statement: bool,
114 }
115
116 impl<'a> fmt::Display for ModuleFormatter<'a> {
117 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
118 writeln!(f, "# This file is automatically generated by pyo3_stub_gen")?;
120 writeln!(f, "# ruff: noqa: E501, F401, F403, F405")?;
121 if !self.module.doc.is_empty() {
122 docstring::write_docstring(f, &self.module.doc, "")?;
123 }
124 writeln!(f)?;
125
126 let mut imports = self.module.import();
127
128 if !self.use_type_statement && !self.module.type_aliases.is_empty() {
130 imports.insert(ImportRef::Type(crate::stub_type::TypeRef {
131 module: crate::stub_type::ModuleRef::Named("typing".to_string()),
132 name: "TypeAlias".to_string(),
133 }));
134 }
135
136 let any_overloaded = self.module.function.values().any(|functions| {
138 let has_overload = functions.iter().any(|func| func.is_overload);
139 functions.len() > 1 && has_overload
140 });
141 if any_overloaded {
142 imports.insert("typing".into());
143 }
144
145 let mut type_ref_grouped: BTreeMap<String, Vec<String>> = BTreeMap::new();
147 for import_ref in imports.into_iter().sorted() {
148 match import_ref {
149 ImportRef::Module(module_ref) => {
150 let name = module_ref.get().unwrap_or(&self.module.default_module_name);
151 if name != self.module.name && !name.is_empty() {
152 let is_internal_module = if let Some(root) =
153 self.module.default_module_name.split('.').next()
154 {
155 name.starts_with(root)
156 } else {
157 false
158 };
159
160 if is_internal_module && name.contains('.') {
161 let last_dot_pos = name.rfind('.').unwrap();
162 let parent_module = &name[..last_dot_pos];
163 let child_module = &name[last_dot_pos + 1..];
164
165 if !self.module.submodules.contains(child_module) {
166 writeln!(
167 f,
168 "from {} import {}",
169 parent_module, child_module
170 )?;
171 }
172 } else {
173 writeln!(f, "import {name}")?;
174 }
175 }
176 }
177 ImportRef::Type(type_ref) => {
178 let module_name = type_ref
179 .module
180 .get()
181 .unwrap_or(&self.module.default_module_name);
182 if module_name != self.module.name {
183 type_ref_grouped
184 .entry(module_name.to_string())
185 .or_default()
186 .push(type_ref.name);
187 }
188 }
189 }
190 }
191 for (module_name, type_names) in type_ref_grouped {
192 let mut sorted_type_names = type_names.clone();
193 sorted_type_names.sort();
194 writeln!(
195 f,
196 "from {} import {}",
197 module_name,
198 sorted_type_names.join(", ")
199 )?;
200 }
201
202 let mut sorted_re_exports = self.module.module_re_exports.clone();
204 sorted_re_exports.sort_by(|a, b| a.source_module.cmp(&b.source_module));
205 for re_export in &sorted_re_exports {
206 if re_export.items.is_empty() {
207 continue;
208 }
209 let mut sorted_items = re_export.items.clone();
210 sorted_items.sort();
211 writeln!(
212 f,
213 "from {} import {}",
214 re_export.source_module,
215 sorted_items.join(", ")
216 )?;
217 }
218 for submod in &self.module.submodules {
219 writeln!(f, "from . import {submod}")?;
220 }
221
222 self.module.write_all_list(f)?;
224
225 writeln!(f)?;
226
227 for alias in self.module.type_aliases.values() {
229 alias.fmt_with_config(&self.module.name, f, self.use_type_statement)?;
230 writeln!(f)?;
231 }
232
233 for var in self.module.variables.values() {
235 var.fmt_for_module(&self.module.name, f)?;
236 writeln!(f)?;
237 }
238
239 for class in self.module.class.values().sorted_by_key(|class| class.name) {
241 class.fmt_for_module(&self.module.name, f)?;
242 }
243
244 for enum_ in self.module.enum_.values().sorted_by_key(|enum_| enum_.name) {
246 enum_.fmt_for_module(&self.module.name, f)?;
247 }
248
249 for functions in self.module.function.values() {
251 let has_overload = functions.iter().any(|func| func.is_overload);
252 let should_add_overload = functions.len() > 1 && has_overload;
253
254 let mut sorted_functions = functions.clone();
255 sorted_functions
256 .sort_by_key(|func| (func.file, func.line, func.column, func.index));
257 for function in sorted_functions {
258 if should_add_overload {
259 writeln!(f, "@typing.overload")?;
260 }
261 function.fmt_for_module(&self.module.name, f)?;
262 }
263 }
264
265 Ok(())
266 }
267 }
268
269 write!(
270 &mut output,
271 "{}",
272 ModuleFormatter {
273 module: self,
274 use_type_statement
275 }
276 )
277 .unwrap();
278 output
279 }
280
281 fn collect_all_items(&self) -> BTreeSet<String> {
287 let mut all_items: BTreeSet<String> = BTreeSet::new();
288
289 for class in self.class.values() {
291 if !class.name.starts_with('_') {
292 all_items.insert(class.name.to_string());
293 }
294 }
295 for enum_ in self.enum_.values() {
296 if !enum_.name.starts_with('_') {
297 all_items.insert(enum_.name.to_string());
298 }
299 }
300 for func_name in self.function.keys() {
301 if !func_name.starts_with('_') {
302 all_items.insert(func_name.to_string());
303 }
304 }
305 for var_name in self.variables.keys() {
306 if !var_name.starts_with('_') {
307 all_items.insert(var_name.to_string());
308 }
309 }
310 for alias_name in self.type_aliases.keys() {
311 if !alias_name.starts_with('_') {
312 all_items.insert(alias_name.to_string());
313 }
314 }
315 for submod in &self.submodules {
316 if !submod.starts_with('_') {
317 all_items.insert(submod.to_string());
318 }
319 }
320
321 for re_export in &self.module_re_exports {
323 all_items.extend(re_export.items.iter().cloned());
324 }
325
326 all_items.extend(self.verbatim_all_entries.iter().cloned());
328
329 for excluded in &self.excluded_all_entries {
331 all_items.remove(excluded);
332 }
333
334 all_items
335 }
336
337 pub fn format_init_py(&self) -> String {
343 use std::fmt::Write;
344 let mut output = String::new();
345
346 writeln!(
348 output,
349 "# This file is automatically generated by pyo3_stub_gen"
350 )
351 .unwrap();
352 writeln!(output, "# ruff: noqa: F401").unwrap();
353 writeln!(output).unwrap();
354
355 let mut sorted_re_exports = self.module_re_exports.clone();
358 sorted_re_exports.sort_by(|a, b| a.source_module.cmp(&b.source_module));
359 for re_export in &sorted_re_exports {
360 if re_export.items.is_empty() {
361 continue;
362 }
363 let mut sorted_items = re_export.items.clone();
364 sorted_items.sort();
365 writeln!(
366 output,
367 "from {} import {}",
368 re_export.source_module,
369 sorted_items.join(", ")
370 )
371 .unwrap();
372 }
373
374 let mut all_items: BTreeSet<String> = BTreeSet::new();
377 for re_export in &self.module_re_exports {
378 all_items.extend(re_export.items.iter().cloned());
379 }
380 all_items.extend(self.verbatim_all_entries.iter().cloned());
381 for excluded in &self.excluded_all_entries {
382 all_items.remove(excluded);
383 }
384 if all_items.is_empty() {
385 writeln!(output, "__all__ = []").unwrap();
386 } else {
387 writeln!(output, "__all__ = [").unwrap();
388 for item in all_items {
389 writeln!(output, " \"{}\",", item).unwrap();
390 }
391 writeln!(output, "]").unwrap();
392 }
393
394 output
395 }
396
397 fn write_all_list(&self, f: &mut fmt::Formatter) -> fmt::Result {
398 let all_items = self.collect_all_items();
399
400 if all_items.is_empty() {
402 writeln!(f, "__all__ = []")?;
403 } else {
404 writeln!(f, "__all__ = [")?;
405 for item in all_items {
406 writeln!(f, " \"{}\",", item)?;
407 }
408 writeln!(f, "]")?;
409 }
410
411 Ok(())
412 }
413}
414
415impl Import for Module {
416 fn import(&self) -> HashSet<ImportRef> {
417 let mut imports = HashSet::new();
418 for class in self.class.values() {
419 imports.extend(class.import());
420 }
421 for enum_ in self.enum_.values() {
422 imports.extend(enum_.import());
423 }
424 for function in self.function.values().flatten() {
425 imports.extend(function.import());
426 }
427 for variable in self.variables.values() {
428 imports.extend(variable.import());
429 }
430 for type_alias in self.type_aliases.values() {
431 imports.extend(type_alias.import());
432 }
433 imports
434 }
435}
436
437impl fmt::Display for Module {
438 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
439 writeln!(f, "# This file is automatically generated by pyo3_stub_gen")?;
441 writeln!(f, "# ruff: noqa: E501, F401, F403, F405")?;
442 if !self.doc.is_empty() {
443 docstring::write_docstring(f, &self.doc, "")?;
444 }
445 writeln!(f)?;
446
447 let mut imports = self.import();
448 let any_overloaded = self.function.values().any(|functions| {
450 let has_overload = functions.iter().any(|f| f.is_overload);
451 functions.len() > 1 && has_overload
452 });
453 if any_overloaded {
454 imports.insert("typing".into());
455 }
456
457 let mut type_ref_grouped: BTreeMap<String, Vec<String>> = BTreeMap::new();
459 for import_ref in imports.into_iter().sorted() {
460 match import_ref {
461 ImportRef::Module(module_ref) => {
462 let name = module_ref.get().unwrap_or(&self.default_module_name);
463 if name != self.name && !name.is_empty() {
464 let is_internal_module =
467 if let Some(root) = self.default_module_name.split('.').next() {
468 name.starts_with(root)
469 } else {
470 false
471 };
472
473 if is_internal_module && name.contains('.') {
477 let last_dot_pos = name.rfind('.').unwrap();
478 let parent_module = &name[..last_dot_pos];
479 let child_module = &name[last_dot_pos + 1..];
480
481 if !self.submodules.contains(child_module) {
483 writeln!(f, "from {} import {}", parent_module, child_module)?;
484 }
485 } else {
486 writeln!(f, "import {name}")?;
488 }
489 }
490 }
491 ImportRef::Type(type_ref) => {
492 let module_name = type_ref.module.get().unwrap_or(&self.default_module_name);
493 if module_name != self.name {
494 type_ref_grouped
495 .entry(module_name.to_string())
496 .or_default()
497 .push(type_ref.name);
498 }
499 }
500 }
501 }
502 for (module_name, type_names) in type_ref_grouped {
503 let mut sorted_type_names = type_names.clone();
504 sorted_type_names.sort();
505 writeln!(
506 f,
507 "from {} import {}",
508 module_name,
509 sorted_type_names.join(", ")
510 )?;
511 }
512 let mut sorted_re_exports = self.module_re_exports.clone();
514 sorted_re_exports.sort_by(|a, b| a.source_module.cmp(&b.source_module));
515 for re_export in &sorted_re_exports {
516 if re_export.items.is_empty() {
517 continue;
518 }
519 let mut sorted_items = re_export.items.clone();
520 sorted_items.sort();
521 writeln!(
522 f,
523 "from {} import {}",
524 re_export.source_module,
525 sorted_items.join(", ")
526 )?;
527 }
528 for submod in &self.submodules {
529 writeln!(f, "from . import {submod}")?;
530 }
531
532 self.write_all_list(f)?;
534
535 writeln!(f)?;
536
537 for alias in self.type_aliases.values() {
538 alias.fmt_for_module(&self.name, f)?;
539 writeln!(f)?;
540 }
541 for var in self.variables.values() {
542 var.fmt_for_module(&self.name, f)?;
543 writeln!(f)?;
544 }
545 for class in self.class.values().sorted_by_key(|class| class.name) {
546 class.fmt_for_module(&self.name, f)?;
547 }
548 for enum_ in self.enum_.values().sorted_by_key(|class| class.name) {
549 enum_.fmt_for_module(&self.name, f)?;
550 }
551 for functions in self.function.values() {
552 let has_overload = functions.iter().any(|func| func.is_overload);
554 let should_add_overload = functions.len() > 1 && has_overload;
555
556 let mut sorted_functions = functions.clone();
558 sorted_functions.sort_by_key(|func| (func.file, func.line, func.column, func.index));
559 for function in sorted_functions {
560 if should_add_overload {
561 writeln!(f, "@typing.overload")?;
562 }
563 function.fmt_for_module(&self.name, f)?;
564 }
565 }
566 Ok(())
567 }
568}