1use crate::generate::*;
2use crate::stub_type::ImportRef;
3use itertools::Itertools;
4use std::{
5 any::TypeId,
6 collections::{BTreeMap, BTreeSet},
7 fmt,
8};
9
10#[derive(Debug, Clone, PartialEq)]
12pub struct ModuleReExport {
13 pub source_module: String,
14 pub items: Vec<String>,
15 pub use_wildcard_import: bool,
16}
17
18#[derive(Debug, Clone, PartialEq, Default)]
20pub struct Module {
21 pub doc: String,
22 pub class: BTreeMap<TypeId, ClassDef>,
23 pub enum_: BTreeMap<TypeId, EnumDef>,
24 pub function: BTreeMap<&'static str, Vec<FunctionDef>>,
25 pub variables: BTreeMap<&'static str, VariableDef>,
26 pub type_aliases: BTreeMap<&'static str, TypeAliasDef>,
27 pub name: String,
28 pub default_module_name: String,
29 pub submodules: BTreeSet<String>,
31 pub module_re_exports: Vec<ModuleReExport>,
33 pub verbatim_all_entries: BTreeSet<String>,
35 pub excluded_all_entries: BTreeSet<String>,
37}
38
39impl Module {
40 pub fn format_with_config(&self, use_type_statement: bool) -> String {
42 use std::fmt::Write;
43 let mut output = String::new();
44
45 struct ModuleFormatter<'a> {
47 module: &'a Module,
48 use_type_statement: bool,
49 }
50
51 impl<'a> fmt::Display for ModuleFormatter<'a> {
52 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
53 writeln!(f, "# This file is automatically generated by pyo3_stub_gen")?;
55 writeln!(f, "# ruff: noqa: E501, F401, F403, F405")?;
56 if !self.module.doc.is_empty() {
57 docstring::write_docstring(f, &self.module.doc, "")?;
58 }
59 writeln!(f)?;
60
61 let mut imports = self.module.import();
62
63 if !self.use_type_statement && !self.module.type_aliases.is_empty() {
65 imports.insert(ImportRef::Type(crate::stub_type::TypeRef {
66 module: crate::stub_type::ModuleRef::Named("typing".to_string()),
67 name: "TypeAlias".to_string(),
68 }));
69 }
70
71 let any_overloaded = self.module.function.values().any(|functions| {
73 let has_overload = functions.iter().any(|func| func.is_overload);
74 functions.len() > 1 && has_overload
75 });
76 if any_overloaded {
77 imports.insert("typing".into());
78 }
79
80 let mut type_ref_grouped: BTreeMap<String, Vec<String>> = BTreeMap::new();
82 for import_ref in imports.into_iter().sorted() {
83 match import_ref {
84 ImportRef::Module(module_ref) => {
85 let name = module_ref.get().unwrap_or(&self.module.default_module_name);
86 if name != self.module.name && !name.is_empty() {
87 let is_internal_module = if let Some(root) =
88 self.module.default_module_name.split('.').next()
89 {
90 name.starts_with(root)
91 } else {
92 false
93 };
94
95 if is_internal_module && name.contains('.') {
96 let last_dot_pos = name.rfind('.').unwrap();
97 let parent_module = &name[..last_dot_pos];
98 let child_module = &name[last_dot_pos + 1..];
99
100 if !self.module.submodules.contains(child_module) {
101 writeln!(
102 f,
103 "from {} import {}",
104 parent_module, child_module
105 )?;
106 }
107 } else {
108 writeln!(f, "import {name}")?;
109 }
110 }
111 }
112 ImportRef::Type(type_ref) => {
113 let module_name = type_ref
114 .module
115 .get()
116 .unwrap_or(&self.module.default_module_name);
117 if module_name != self.module.name {
118 type_ref_grouped
119 .entry(module_name.to_string())
120 .or_default()
121 .push(type_ref.name);
122 }
123 }
124 }
125 }
126 for (module_name, type_names) in type_ref_grouped {
127 let mut sorted_type_names = type_names.clone();
128 sorted_type_names.sort();
129 writeln!(
130 f,
131 "from {} import {}",
132 module_name,
133 sorted_type_names.join(", ")
134 )?;
135 }
136
137 let mut sorted_re_exports = self.module.module_re_exports.clone();
139 sorted_re_exports.sort_by(|a, b| a.source_module.cmp(&b.source_module));
140 for re_export in &sorted_re_exports {
141 if re_export.use_wildcard_import {
142 writeln!(f, "from {} import *", re_export.source_module)?;
143 } else {
144 let mut sorted_items = re_export.items.clone();
145 sorted_items.sort();
146 writeln!(
147 f,
148 "from {} import {}",
149 re_export.source_module,
150 sorted_items.join(", ")
151 )?;
152 }
153 }
154 for submod in &self.module.submodules {
155 writeln!(f, "from . import {submod}")?;
156 }
157
158 self.module.write_all_list(f)?;
160
161 writeln!(f)?;
162
163 for alias in self.module.type_aliases.values() {
165 alias.fmt_with_config(&self.module.name, f, self.use_type_statement)?;
166 writeln!(f)?;
167 }
168
169 for var in self.module.variables.values() {
171 var.fmt_for_module(&self.module.name, f)?;
172 writeln!(f)?;
173 }
174
175 for class in self.module.class.values().sorted_by_key(|class| class.name) {
177 class.fmt_for_module(&self.module.name, f)?;
178 }
179
180 for enum_ in self.module.enum_.values().sorted_by_key(|enum_| enum_.name) {
182 enum_.fmt_for_module(&self.module.name, f)?;
183 }
184
185 for functions in self.module.function.values() {
187 let has_overload = functions.iter().any(|func| func.is_overload);
188 let should_add_overload = functions.len() > 1 && has_overload;
189
190 let mut sorted_functions = functions.clone();
191 sorted_functions
192 .sort_by_key(|func| (func.file, func.line, func.column, func.index));
193 for function in sorted_functions {
194 if should_add_overload {
195 writeln!(f, "@typing.overload")?;
196 }
197 function.fmt_for_module(&self.module.name, f)?;
198 }
199 }
200
201 Ok(())
202 }
203 }
204
205 write!(
206 &mut output,
207 "{}",
208 ModuleFormatter {
209 module: self,
210 use_type_statement
211 }
212 )
213 .unwrap();
214 output
215 }
216
217 fn write_all_list(&self, f: &mut fmt::Formatter) -> fmt::Result {
218 let mut all_items: BTreeSet<String> = BTreeSet::new();
219
220 for class in self.class.values() {
222 if !class.name.starts_with('_') {
223 all_items.insert(class.name.to_string());
224 }
225 }
226 for enum_ in self.enum_.values() {
227 if !enum_.name.starts_with('_') {
228 all_items.insert(enum_.name.to_string());
229 }
230 }
231 for func_name in self.function.keys() {
232 if !func_name.starts_with('_') {
233 all_items.insert(func_name.to_string());
234 }
235 }
236 for var_name in self.variables.keys() {
237 if !var_name.starts_with('_') {
238 all_items.insert(var_name.to_string());
239 }
240 }
241 for alias_name in self.type_aliases.keys() {
242 if !alias_name.starts_with('_') {
243 all_items.insert(alias_name.to_string());
244 }
245 }
246 for submod in &self.submodules {
248 if !submod.starts_with('_') {
249 all_items.insert(submod.to_string());
250 }
251 }
252
253 for re_export in &self.module_re_exports {
255 all_items.extend(re_export.items.iter().cloned());
256 }
257
258 all_items.extend(self.verbatim_all_entries.iter().cloned());
260
261 for excluded in &self.excluded_all_entries {
263 all_items.remove(excluded);
264 }
265
266 if all_items.is_empty() {
268 writeln!(f, "__all__ = []")?;
269 } else {
270 writeln!(f, "__all__ = [")?;
271 for item in all_items {
272 writeln!(f, " \"{}\",", item)?;
273 }
274 writeln!(f, "]")?;
275 }
276
277 Ok(())
278 }
279}
280
281impl Import for Module {
282 fn import(&self) -> HashSet<ImportRef> {
283 let mut imports = HashSet::new();
284 for class in self.class.values() {
285 imports.extend(class.import());
286 }
287 for enum_ in self.enum_.values() {
288 imports.extend(enum_.import());
289 }
290 for function in self.function.values().flatten() {
291 imports.extend(function.import());
292 }
293 for variable in self.variables.values() {
294 imports.extend(variable.import());
295 }
296 for type_alias in self.type_aliases.values() {
297 imports.extend(type_alias.import());
298 }
299 imports
300 }
301}
302
303impl fmt::Display for Module {
304 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
305 writeln!(f, "# This file is automatically generated by pyo3_stub_gen")?;
307 writeln!(f, "# ruff: noqa: E501, F401, F403, F405")?;
308 if !self.doc.is_empty() {
309 docstring::write_docstring(f, &self.doc, "")?;
310 }
311 writeln!(f)?;
312
313 let mut imports = self.import();
314 let any_overloaded = self.function.values().any(|functions| {
316 let has_overload = functions.iter().any(|f| f.is_overload);
317 functions.len() > 1 && has_overload
318 });
319 if any_overloaded {
320 imports.insert("typing".into());
321 }
322
323 let mut type_ref_grouped: BTreeMap<String, Vec<String>> = BTreeMap::new();
325 for import_ref in imports.into_iter().sorted() {
326 match import_ref {
327 ImportRef::Module(module_ref) => {
328 let name = module_ref.get().unwrap_or(&self.default_module_name);
329 if name != self.name && !name.is_empty() {
330 let is_internal_module =
333 if let Some(root) = self.default_module_name.split('.').next() {
334 name.starts_with(root)
335 } else {
336 false
337 };
338
339 if is_internal_module && name.contains('.') {
343 let last_dot_pos = name.rfind('.').unwrap();
344 let parent_module = &name[..last_dot_pos];
345 let child_module = &name[last_dot_pos + 1..];
346
347 if !self.submodules.contains(child_module) {
349 writeln!(f, "from {} import {}", parent_module, child_module)?;
350 }
351 } else {
352 writeln!(f, "import {name}")?;
354 }
355 }
356 }
357 ImportRef::Type(type_ref) => {
358 let module_name = type_ref.module.get().unwrap_or(&self.default_module_name);
359 if module_name != self.name {
360 type_ref_grouped
361 .entry(module_name.to_string())
362 .or_default()
363 .push(type_ref.name);
364 }
365 }
366 }
367 }
368 for (module_name, type_names) in type_ref_grouped {
369 let mut sorted_type_names = type_names.clone();
370 sorted_type_names.sort();
371 writeln!(
372 f,
373 "from {} import {}",
374 module_name,
375 sorted_type_names.join(", ")
376 )?;
377 }
378 let mut sorted_re_exports = self.module_re_exports.clone();
380 sorted_re_exports.sort_by(|a, b| a.source_module.cmp(&b.source_module));
381 for re_export in &sorted_re_exports {
382 if re_export.use_wildcard_import {
383 writeln!(f, "from {} import *", re_export.source_module)?;
385 } else {
386 let mut sorted_items = re_export.items.clone();
388 sorted_items.sort();
389 writeln!(
390 f,
391 "from {} import {}",
392 re_export.source_module,
393 sorted_items.join(", ")
394 )?;
395 }
396 }
397 for submod in &self.submodules {
398 writeln!(f, "from . import {submod}")?;
399 }
400
401 self.write_all_list(f)?;
403
404 writeln!(f)?;
405
406 for alias in self.type_aliases.values() {
407 alias.fmt_for_module(&self.name, f)?;
408 writeln!(f)?;
409 }
410 for var in self.variables.values() {
411 var.fmt_for_module(&self.name, f)?;
412 writeln!(f)?;
413 }
414 for class in self.class.values().sorted_by_key(|class| class.name) {
415 class.fmt_for_module(&self.name, f)?;
416 }
417 for enum_ in self.enum_.values().sorted_by_key(|class| class.name) {
418 enum_.fmt_for_module(&self.name, f)?;
419 }
420 for functions in self.function.values() {
421 let has_overload = functions.iter().any(|func| func.is_overload);
423 let should_add_overload = functions.len() > 1 && has_overload;
424
425 let mut sorted_functions = functions.clone();
427 sorted_functions.sort_by_key(|func| (func.file, func.line, func.column, func.index));
428 for function in sorted_functions {
429 if should_add_overload {
430 writeln!(f, "@typing.overload")?;
431 }
432 function.fmt_for_module(&self.name, f)?;
433 }
434 }
435 Ok(())
436 }
437}