pyo3_stub_gen/generate/
module.rs1use crate::generate::*;
2use crate::stub_type::ImportRef;
3use itertools::Itertools;
4use std::{
5 any::TypeId,
6 collections::{BTreeMap, BTreeSet},
7 fmt,
8};
9
10#[derive(Debug, Clone, PartialEq, Default)]
12pub struct Module {
13 pub doc: String,
14 pub class: BTreeMap<TypeId, ClassDef>,
15 pub enum_: BTreeMap<TypeId, EnumDef>,
16 pub function: BTreeMap<&'static str, Vec<FunctionDef>>,
17 pub variables: BTreeMap<&'static str, VariableDef>,
18 pub name: String,
19 pub default_module_name: String,
20 pub submodules: BTreeSet<String>,
22}
23
24impl Import for Module {
25 fn import(&self) -> HashSet<ImportRef> {
26 let mut imports = HashSet::new();
27 for class in self.class.values() {
28 imports.extend(class.import());
29 }
30 for enum_ in self.enum_.values() {
31 imports.extend(enum_.import());
32 }
33 for function in self.function.values().flatten() {
34 imports.extend(function.import());
35 }
36 imports
37 }
38}
39
40impl fmt::Display for Module {
41 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
42 writeln!(f, "# This file is automatically generated by pyo3_stub_gen")?;
43 writeln!(f, "# ruff: noqa: E501, F401")?;
44 if !self.doc.is_empty() {
45 docstring::write_docstring(f, &self.doc, "")?;
46 }
47 writeln!(f)?;
48 let mut imports = self.import();
49 let any_overloaded = self.function.values().any(|functions| functions.len() > 1);
50 if any_overloaded {
51 imports.insert("typing".into());
52 }
53
54 let mut type_ref_grouped: BTreeMap<String, Vec<String>> = BTreeMap::new();
56 for import_ref in imports.into_iter().sorted() {
57 match import_ref {
58 ImportRef::Module(module_ref) => {
59 let name = module_ref.get().unwrap_or(&self.default_module_name);
60 if name != self.name {
61 writeln!(f, "import {name}")?;
62 }
63 }
64 ImportRef::Type(type_ref) => {
65 let module_name = type_ref.module.get().unwrap_or(&self.default_module_name);
66 if module_name != self.name {
67 type_ref_grouped
68 .entry(module_name.to_string())
69 .or_default()
70 .push(type_ref.name);
71 }
72 }
73 }
74 }
75 for (module_name, type_names) in type_ref_grouped {
76 let mut sorted_type_names = type_names.clone();
77 sorted_type_names.sort();
78 writeln!(
79 f,
80 "from {} import {}",
81 module_name,
82 sorted_type_names.join(", ")
83 )?;
84 }
85 for submod in &self.submodules {
86 writeln!(f, "from . import {submod}")?;
87 }
88 writeln!(f)?;
89
90 for var in self.variables.values() {
91 writeln!(f, "{var}")?;
92 }
93 for class in self.class.values().sorted_by_key(|class| class.name) {
94 write!(f, "{class}")?;
95 }
96 for enum_ in self.enum_.values().sorted_by_key(|class| class.name) {
97 write!(f, "{enum_}")?;
98 }
99 for functions in self.function.values() {
100 let overloaded = functions.len() > 1;
101 for function in functions {
102 if overloaded {
103 writeln!(f, "@typing.overload")?;
104 }
105 write!(f, "{function}")?;
106 }
107 }
108 Ok(())
109 }
110}