pyo3_stub_gen/generate/
module.rs

1use crate::generate::*;
2use itertools::Itertools;
3use std::{
4    any::TypeId,
5    collections::{BTreeMap, BTreeSet},
6    fmt,
7};
8
9/// Type info for a Python (sub-)module. This corresponds to a single `*.pyi` file.
10#[derive(Debug, Clone, PartialEq, Default)]
11pub struct Module {
12    pub class: BTreeMap<TypeId, ClassDef>,
13    pub enum_: BTreeMap<TypeId, EnumDef>,
14    pub function: BTreeMap<&'static str, Vec<FunctionDef>>,
15    pub error: BTreeMap<&'static str, ErrorDef>,
16    pub variables: BTreeMap<&'static str, VariableDef>,
17    pub name: String,
18    pub default_module_name: String,
19    /// Direct submodules of this module.
20    pub submodules: BTreeSet<String>,
21}
22
23impl Import for Module {
24    fn import(&self) -> HashSet<ModuleRef> {
25        let mut imports = HashSet::new();
26        for class in self.class.values() {
27            imports.extend(class.import());
28        }
29        for function in self.function.values().flatten() {
30            imports.extend(function.import());
31        }
32        imports
33    }
34}
35
36impl fmt::Display for Module {
37    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
38        writeln!(f, "# This file is automatically generated by pyo3_stub_gen")?;
39        writeln!(f, "# ruff: noqa: E501, F401")?;
40        writeln!(f)?;
41        let mut imports = self.import();
42        let any_overloaded = self.function.values().any(|functions| functions.len() > 1);
43        if any_overloaded {
44            imports.insert(ModuleRef::Named("typing".to_string()));
45        }
46
47        for import in imports.into_iter().sorted() {
48            let name = import.get().unwrap_or(&self.default_module_name);
49            if name != self.name {
50                writeln!(f, "import {name}")?;
51            }
52        }
53        for submod in &self.submodules {
54            writeln!(f, "from . import {submod}")?;
55        }
56        if !self.enum_.is_empty() {
57            writeln!(f, "from enum import Enum")?;
58        }
59        writeln!(f)?;
60
61        for var in self.variables.values() {
62            writeln!(f, "{var}")?;
63        }
64        for class in self.class.values().sorted_by_key(|class| class.name) {
65            write!(f, "{class}")?;
66        }
67        for enum_ in self.enum_.values().sorted_by_key(|class| class.name) {
68            write!(f, "{enum_}")?;
69        }
70        for functions in self.function.values() {
71            let overloaded = functions.len() > 1;
72            for function in functions {
73                if overloaded {
74                    writeln!(f, "@typing.overload")?;
75                }
76                write!(f, "{function}")?;
77            }
78        }
79        for error in self.error.values() {
80            writeln!(f, "{error}")?;
81        }
82        Ok(())
83    }
84}