pyo3_stub_gen/generate/
module.rs

1use crate::generate::*;
2use crate::stub_type::ImportRef;
3use itertools::Itertools;
4use std::{
5    any::TypeId,
6    collections::{BTreeMap, BTreeSet},
7    fmt,
8};
9
10/// Type info for a Python (sub-)module. This corresponds to a single `*.pyi` file.
11#[derive(Debug, Clone, PartialEq, Default)]
12pub struct Module {
13    pub doc: String,
14    pub class: BTreeMap<TypeId, ClassDef>,
15    pub enum_: BTreeMap<TypeId, EnumDef>,
16    pub function: BTreeMap<&'static str, Vec<FunctionDef>>,
17    pub variables: BTreeMap<&'static str, VariableDef>,
18    pub name: String,
19    pub default_module_name: String,
20    /// Direct submodules of this module.
21    pub submodules: BTreeSet<String>,
22}
23
24impl Import for Module {
25    fn import(&self) -> HashSet<ImportRef> {
26        let mut imports = HashSet::new();
27        for class in self.class.values() {
28            imports.extend(class.import());
29        }
30        for enum_ in self.enum_.values() {
31            imports.extend(enum_.import());
32        }
33        for function in self.function.values().flatten() {
34            imports.extend(function.import());
35        }
36        imports
37    }
38}
39
40impl fmt::Display for Module {
41    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
42        writeln!(f, "# This file is automatically generated by pyo3_stub_gen")?;
43        writeln!(f, "# ruff: noqa: E501, F401")?;
44        if !self.doc.is_empty() {
45            docstring::write_docstring(f, &self.doc, "")?;
46        }
47        writeln!(f)?;
48        let mut imports = self.import();
49        let any_overloaded = self.function.values().any(|functions| functions.len() > 1);
50        if any_overloaded {
51            imports.insert("typing".into());
52        }
53
54        // To gather `from submod import A, B, C` style imports
55        let mut type_ref_grouped: BTreeMap<String, Vec<String>> = BTreeMap::new();
56        for import_ref in imports.into_iter().sorted() {
57            match import_ref {
58                ImportRef::Module(module_ref) => {
59                    let name = module_ref.get().unwrap_or(&self.default_module_name);
60                    if name != self.name {
61                        writeln!(f, "import {name}")?;
62                    }
63                }
64                ImportRef::Type(type_ref) => {
65                    let module_name = type_ref.module.get().unwrap_or(&self.default_module_name);
66                    if module_name != self.name {
67                        type_ref_grouped
68                            .entry(module_name.to_string())
69                            .or_default()
70                            .push(type_ref.name);
71                    }
72                }
73            }
74        }
75        for (module_name, type_names) in type_ref_grouped {
76            let mut sorted_type_names = type_names.clone();
77            sorted_type_names.sort();
78            writeln!(
79                f,
80                "from {} import {}",
81                module_name,
82                sorted_type_names.join(", ")
83            )?;
84        }
85        for submod in &self.submodules {
86            writeln!(f, "from . import {submod}")?;
87        }
88        writeln!(f)?;
89
90        for var in self.variables.values() {
91            writeln!(f, "{var}")?;
92        }
93        for class in self.class.values().sorted_by_key(|class| class.name) {
94            write!(f, "{class}")?;
95        }
96        for enum_ in self.enum_.values().sorted_by_key(|class| class.name) {
97            write!(f, "{enum_}")?;
98        }
99        for functions in self.function.values() {
100            let overloaded = functions.len() > 1;
101            for function in functions {
102                if overloaded {
103                    writeln!(f, "@typing.overload")?;
104                }
105                write!(f, "{function}")?;
106            }
107        }
108        Ok(())
109    }
110}