pyo3_stub_gen/generate/
module.rs1use crate::generate::*;
2use crate::stub_type::ImportRef;
3use itertools::Itertools;
4use std::{
5 any::TypeId,
6 collections::{BTreeMap, BTreeSet},
7 fmt,
8};
9
10#[derive(Debug, Clone, PartialEq, Default)]
12pub struct Module {
13 pub doc: String,
14 pub class: BTreeMap<TypeId, ClassDef>,
15 pub enum_: BTreeMap<TypeId, EnumDef>,
16 pub function: BTreeMap<&'static str, Vec<FunctionDef>>,
17 pub variables: BTreeMap<&'static str, VariableDef>,
18 pub name: String,
19 pub default_module_name: String,
20 pub submodules: BTreeSet<String>,
22}
23
24impl Import for Module {
25 fn import(&self) -> HashSet<ImportRef> {
26 let mut imports = HashSet::new();
27 for class in self.class.values() {
28 imports.extend(class.import());
29 }
30 for enum_ in self.enum_.values() {
31 imports.extend(enum_.import());
32 }
33 for function in self.function.values().flatten() {
34 imports.extend(function.import());
35 }
36 imports
37 }
38}
39
40impl fmt::Display for Module {
41 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
42 writeln!(f, "# This file is automatically generated by pyo3_stub_gen")?;
43 writeln!(f, "# ruff: noqa: E501, F401")?;
44 if !self.doc.is_empty() {
45 docstring::write_docstring(f, &self.doc, "")?;
46 }
47 writeln!(f)?;
48 let mut imports = self.import();
49 let any_overloaded = self.function.values().any(|functions| {
51 let has_overload = functions.iter().any(|f| f.is_overload);
52 functions.len() > 1 && has_overload
53 });
54 if any_overloaded {
55 imports.insert("typing".into());
56 }
57
58 let mut type_ref_grouped: BTreeMap<String, Vec<String>> = BTreeMap::new();
60 for import_ref in imports.into_iter().sorted() {
61 match import_ref {
62 ImportRef::Module(module_ref) => {
63 let name = module_ref.get().unwrap_or(&self.default_module_name);
64 if name != self.name {
65 writeln!(f, "import {name}")?;
66 }
67 }
68 ImportRef::Type(type_ref) => {
69 let module_name = type_ref.module.get().unwrap_or(&self.default_module_name);
70 if module_name != self.name {
71 type_ref_grouped
72 .entry(module_name.to_string())
73 .or_default()
74 .push(type_ref.name);
75 }
76 }
77 }
78 }
79 for (module_name, type_names) in type_ref_grouped {
80 let mut sorted_type_names = type_names.clone();
81 sorted_type_names.sort();
82 writeln!(
83 f,
84 "from {} import {}",
85 module_name,
86 sorted_type_names.join(", ")
87 )?;
88 }
89 for submod in &self.submodules {
90 writeln!(f, "from . import {submod}")?;
91 }
92 writeln!(f)?;
93
94 for var in self.variables.values() {
95 writeln!(f, "{var}")?;
96 }
97 for class in self.class.values().sorted_by_key(|class| class.name) {
98 write!(f, "{class}")?;
99 }
100 for enum_ in self.enum_.values().sorted_by_key(|class| class.name) {
101 write!(f, "{enum_}")?;
102 }
103 for functions in self.function.values() {
104 let has_overload = functions.iter().any(|func| func.is_overload);
106 let should_add_overload = functions.len() > 1 && has_overload;
107
108 let mut sorted_functions = functions.clone();
110 sorted_functions.sort_by_key(|func| (func.file, func.line, func.column, func.index));
111 for function in sorted_functions {
112 if should_add_overload {
113 writeln!(f, "@typing.overload")?;
114 }
115 write!(f, "{function}")?;
116 }
117 }
118 Ok(())
119 }
120}