pyo3_stub_gen/generate/
module.rs

1use crate::generate::*;
2use crate::stub_type::ImportRef;
3use itertools::Itertools;
4use std::{
5    any::TypeId,
6    collections::{BTreeMap, BTreeSet},
7    fmt,
8};
9
10/// Type info for a Python (sub-)module. This corresponds to a single `*.pyi` file.
11#[derive(Debug, Clone, PartialEq, Default)]
12pub struct Module {
13    pub doc: String,
14    pub class: BTreeMap<TypeId, ClassDef>,
15    pub enum_: BTreeMap<TypeId, EnumDef>,
16    pub function: BTreeMap<&'static str, Vec<FunctionDef>>,
17    pub variables: BTreeMap<&'static str, VariableDef>,
18    pub name: String,
19    pub default_module_name: String,
20    /// Direct submodules of this module.
21    pub submodules: BTreeSet<String>,
22}
23
24impl Import for Module {
25    fn import(&self) -> HashSet<ImportRef> {
26        let mut imports = HashSet::new();
27        for class in self.class.values() {
28            imports.extend(class.import());
29        }
30        for function in self.function.values().flatten() {
31            imports.extend(function.import());
32        }
33        imports
34    }
35}
36
37impl fmt::Display for Module {
38    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
39        writeln!(f, "# This file is automatically generated by pyo3_stub_gen")?;
40        writeln!(f, "# ruff: noqa: E501, F401")?;
41        if !self.doc.is_empty() {
42            docstring::write_docstring(f, &self.doc, "")?;
43        }
44        writeln!(f)?;
45        let mut imports = self.import();
46        let any_overloaded = self.function.values().any(|functions| functions.len() > 1);
47        if any_overloaded {
48            imports.insert("typing".into());
49        }
50
51        // To gather `from submod import A, B, C` style imports
52        let mut type_ref_grouped: BTreeMap<String, Vec<String>> = BTreeMap::new();
53        for import_ref in imports.into_iter().sorted() {
54            match import_ref {
55                ImportRef::Module(module_ref) => {
56                    let name = module_ref.get().unwrap_or(&self.default_module_name);
57                    if name != self.name {
58                        writeln!(f, "import {name}")?;
59                    }
60                }
61                ImportRef::Type(type_ref) => {
62                    let module_name = type_ref.module.get().unwrap_or(&self.default_module_name);
63                    if module_name != self.name {
64                        type_ref_grouped
65                            .entry(module_name.to_string())
66                            .or_default()
67                            .push(type_ref.name);
68                    }
69                }
70            }
71        }
72        for (module_name, type_names) in type_ref_grouped {
73            let mut sorted_type_names = type_names.clone();
74            sorted_type_names.sort();
75            writeln!(
76                f,
77                "from {} import {}",
78                module_name,
79                sorted_type_names.join(", ")
80            )?;
81        }
82        for submod in &self.submodules {
83            writeln!(f, "from . import {submod}")?;
84        }
85        if !self.enum_.is_empty() {
86            writeln!(f, "from enum import Enum")?;
87        }
88        writeln!(f)?;
89
90        for var in self.variables.values() {
91            writeln!(f, "{var}")?;
92        }
93        for class in self.class.values().sorted_by_key(|class| class.name) {
94            write!(f, "{class}")?;
95        }
96        for enum_ in self.enum_.values().sorted_by_key(|class| class.name) {
97            write!(f, "{enum_}")?;
98        }
99        for functions in self.function.values() {
100            let overloaded = functions.len() > 1;
101            for function in functions {
102                if overloaded {
103                    writeln!(f, "@typing.overload")?;
104                }
105                write!(f, "{function}")?;
106            }
107        }
108        Ok(())
109    }
110}