pyo3_stub_gen/generate/
module.rs

1use crate::generate::*;
2use crate::stub_type::ImportRef;
3use itertools::Itertools;
4use std::{
5    any::TypeId,
6    collections::{BTreeMap, BTreeSet},
7    fmt,
8};
9
10/// Type info for a Python (sub-)module. This corresponds to a single `*.pyi` file.
11#[derive(Debug, Clone, PartialEq, Default)]
12pub struct Module {
13    pub doc: String,
14    pub class: BTreeMap<TypeId, ClassDef>,
15    pub enum_: BTreeMap<TypeId, EnumDef>,
16    pub function: BTreeMap<&'static str, Vec<FunctionDef>>,
17    pub variables: BTreeMap<&'static str, VariableDef>,
18    pub name: String,
19    pub default_module_name: String,
20    /// Direct submodules of this module.
21    pub submodules: BTreeSet<String>,
22}
23
24impl Import for Module {
25    fn import(&self) -> HashSet<ImportRef> {
26        let mut imports = HashSet::new();
27        for class in self.class.values() {
28            imports.extend(class.import());
29        }
30        for enum_ in self.enum_.values() {
31            imports.extend(enum_.import());
32        }
33        for function in self.function.values().flatten() {
34            imports.extend(function.import());
35        }
36        imports
37    }
38}
39
40impl fmt::Display for Module {
41    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
42        writeln!(f, "# This file is automatically generated by pyo3_stub_gen")?;
43        writeln!(f, "# ruff: noqa: E501, F401")?;
44        if !self.doc.is_empty() {
45            docstring::write_docstring(f, &self.doc, "")?;
46        }
47        writeln!(f)?;
48        let mut imports = self.import();
49        // Check if any function group needs @overload decorator
50        let any_overloaded = self.function.values().any(|functions| {
51            let has_overload = functions.iter().any(|f| f.is_overload);
52            functions.len() > 1 && has_overload
53        });
54        if any_overloaded {
55            imports.insert("typing".into());
56        }
57
58        // To gather `from submod import A, B, C` style imports
59        let mut type_ref_grouped: BTreeMap<String, Vec<String>> = BTreeMap::new();
60        for import_ref in imports.into_iter().sorted() {
61            match import_ref {
62                ImportRef::Module(module_ref) => {
63                    let name = module_ref.get().unwrap_or(&self.default_module_name);
64                    if name != self.name {
65                        writeln!(f, "import {name}")?;
66                    }
67                }
68                ImportRef::Type(type_ref) => {
69                    let module_name = type_ref.module.get().unwrap_or(&self.default_module_name);
70                    if module_name != self.name {
71                        type_ref_grouped
72                            .entry(module_name.to_string())
73                            .or_default()
74                            .push(type_ref.name);
75                    }
76                }
77            }
78        }
79        for (module_name, type_names) in type_ref_grouped {
80            let mut sorted_type_names = type_names.clone();
81            sorted_type_names.sort();
82            writeln!(
83                f,
84                "from {} import {}",
85                module_name,
86                sorted_type_names.join(", ")
87            )?;
88        }
89        for submod in &self.submodules {
90            writeln!(f, "from . import {submod}")?;
91        }
92        writeln!(f)?;
93
94        for var in self.variables.values() {
95            writeln!(f, "{var}")?;
96        }
97        for class in self.class.values().sorted_by_key(|class| class.name) {
98            write!(f, "{class}")?;
99        }
100        for enum_ in self.enum_.values().sorted_by_key(|class| class.name) {
101            write!(f, "{enum_}")?;
102        }
103        for functions in self.function.values() {
104            // Check if we should add @overload to all functions
105            let has_overload = functions.iter().any(|func| func.is_overload);
106            let should_add_overload = functions.len() > 1 && has_overload;
107
108            // Sort by source location and index for deterministic ordering
109            let mut sorted_functions = functions.clone();
110            sorted_functions.sort_by_key(|func| (func.file, func.line, func.column, func.index));
111            for function in sorted_functions {
112                if should_add_overload {
113                    writeln!(f, "@typing.overload")?;
114                }
115                write!(f, "{function}")?;
116            }
117        }
118        Ok(())
119    }
120}