pyo3_stub_gen/generate/
module.rs1use crate::generate::*;
2use crate::stub_type::ImportRef;
3use itertools::Itertools;
4use std::{
5 any::TypeId,
6 collections::{BTreeMap, BTreeSet},
7 fmt,
8};
9
10#[derive(Debug, Clone, PartialEq, Default)]
12pub struct Module {
13 pub doc: String,
14 pub class: BTreeMap<TypeId, ClassDef>,
15 pub enum_: BTreeMap<TypeId, EnumDef>,
16 pub function: BTreeMap<&'static str, Vec<FunctionDef>>,
17 pub variables: BTreeMap<&'static str, VariableDef>,
18 pub name: String,
19 pub default_module_name: String,
20 pub submodules: BTreeSet<String>,
22}
23
24impl Import for Module {
25 fn import(&self) -> HashSet<ImportRef> {
26 let mut imports = HashSet::new();
27 for class in self.class.values() {
28 imports.extend(class.import());
29 }
30 for function in self.function.values().flatten() {
31 imports.extend(function.import());
32 }
33 imports
34 }
35}
36
37impl fmt::Display for Module {
38 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
39 writeln!(f, "# This file is automatically generated by pyo3_stub_gen")?;
40 writeln!(f, "# ruff: noqa: E501, F401")?;
41 if !self.doc.is_empty() {
42 docstring::write_docstring(f, &self.doc, "")?;
43 }
44 writeln!(f)?;
45 let mut imports = self.import();
46 let any_overloaded = self.function.values().any(|functions| functions.len() > 1);
47 if any_overloaded {
48 imports.insert("typing".into());
49 }
50
51 let mut type_ref_grouped: BTreeMap<String, Vec<String>> = BTreeMap::new();
53 for import_ref in imports.into_iter().sorted() {
54 match import_ref {
55 ImportRef::Module(module_ref) => {
56 let name = module_ref.get().unwrap_or(&self.default_module_name);
57 if name != self.name {
58 writeln!(f, "import {name}")?;
59 }
60 }
61 ImportRef::Type(type_ref) => {
62 let module_name = type_ref.module.get().unwrap_or(&self.default_module_name);
63 if module_name != self.name {
64 type_ref_grouped
65 .entry(module_name.to_string())
66 .or_default()
67 .push(type_ref.name);
68 }
69 }
70 }
71 }
72 for (module_name, type_names) in type_ref_grouped {
73 let mut sorted_type_names = type_names.clone();
74 sorted_type_names.sort();
75 writeln!(
76 f,
77 "from {} import {}",
78 module_name,
79 sorted_type_names.join(", ")
80 )?;
81 }
82 for submod in &self.submodules {
83 writeln!(f, "from . import {submod}")?;
84 }
85 if !self.enum_.is_empty() {
86 writeln!(f, "from enum import Enum")?;
87 }
88 writeln!(f)?;
89
90 for var in self.variables.values() {
91 writeln!(f, "{var}")?;
92 }
93 for class in self.class.values().sorted_by_key(|class| class.name) {
94 write!(f, "{class}")?;
95 }
96 for enum_ in self.enum_.values().sorted_by_key(|class| class.name) {
97 write!(f, "{enum_}")?;
98 }
99 for functions in self.function.values() {
100 let overloaded = functions.len() > 1;
101 for function in functions {
102 if overloaded {
103 writeln!(f, "@typing.overload")?;
104 }
105 write!(f, "{function}")?;
106 }
107 }
108 Ok(())
109 }
110}