pyo3_stub_gen/generate/
module.rs

1use crate::generate::*;
2use itertools::Itertools;
3use std::{
4    any::TypeId,
5    collections::{BTreeMap, BTreeSet},
6    fmt,
7};
8
9/// Type info for a Python (sub-)module. This corresponds to a single `*.pyi` file.
10#[derive(Debug, Clone, PartialEq, Default)]
11pub struct Module {
12    pub class: BTreeMap<TypeId, ClassDef>,
13    pub enum_: BTreeMap<TypeId, EnumDef>,
14    pub function: BTreeMap<&'static str, Vec<FunctionDef>>,
15    pub variables: BTreeMap<&'static str, VariableDef>,
16    pub name: String,
17    pub default_module_name: String,
18    /// Direct submodules of this module.
19    pub submodules: BTreeSet<String>,
20}
21
22impl Import for Module {
23    fn import(&self) -> HashSet<ModuleRef> {
24        let mut imports = HashSet::new();
25        for class in self.class.values() {
26            imports.extend(class.import());
27        }
28        for function in self.function.values().flatten() {
29            imports.extend(function.import());
30        }
31        imports
32    }
33}
34
35impl fmt::Display for Module {
36    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
37        writeln!(f, "# This file is automatically generated by pyo3_stub_gen")?;
38        writeln!(f, "# ruff: noqa: E501, F401")?;
39        writeln!(f)?;
40        let mut imports = self.import();
41        let any_overloaded = self.function.values().any(|functions| functions.len() > 1);
42        if any_overloaded {
43            imports.insert(ModuleRef::Named("typing".to_string()));
44        }
45
46        for import in imports.into_iter().sorted() {
47            let name = import.get().unwrap_or(&self.default_module_name);
48            if name != self.name {
49                writeln!(f, "import {name}")?;
50            }
51        }
52        for submod in &self.submodules {
53            writeln!(f, "from . import {submod}")?;
54        }
55        if !self.enum_.is_empty() {
56            writeln!(f, "from enum import Enum")?;
57        }
58        writeln!(f)?;
59
60        for var in self.variables.values() {
61            writeln!(f, "{var}")?;
62        }
63        for class in self.class.values().sorted_by_key(|class| class.name) {
64            write!(f, "{class}")?;
65        }
66        for enum_ in self.enum_.values().sorted_by_key(|class| class.name) {
67            write!(f, "{enum_}")?;
68        }
69        for functions in self.function.values() {
70            let overloaded = functions.len() > 1;
71            for function in functions {
72                if overloaded {
73                    writeln!(f, "@typing.overload")?;
74                }
75                write!(f, "{function}")?;
76            }
77        }
78        Ok(())
79    }
80}