pyo3_stub_gen/generate/
module.rs

1use crate::generate::*;
2use itertools::Itertools;
3use std::{
4    any::TypeId,
5    collections::{BTreeMap, BTreeSet},
6    fmt,
7};
8
9/// Type info for a Python (sub-)module. This corresponds to a single `*.pyi` file.
10#[derive(Debug, Clone, PartialEq, Default)]
11pub struct Module {
12    pub class: BTreeMap<TypeId, ClassDef>,
13    pub enum_: BTreeMap<TypeId, EnumDef>,
14    pub function: BTreeMap<&'static str, FunctionDef>,
15    pub error: BTreeMap<&'static str, ErrorDef>,
16    pub variables: BTreeMap<&'static str, VariableDef>,
17    pub name: String,
18    pub default_module_name: String,
19    /// Direct submodules of this module.
20    pub submodules: BTreeSet<String>,
21}
22
23impl Import for Module {
24    fn import(&self) -> HashSet<ModuleRef> {
25        let mut imports = HashSet::new();
26        for class in self.class.values() {
27            imports.extend(class.import());
28        }
29        for function in self.function.values() {
30            imports.extend(function.import());
31        }
32        imports
33    }
34}
35
36impl fmt::Display for Module {
37    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
38        writeln!(f, "# This file is automatically generated by pyo3_stub_gen")?;
39        writeln!(f, "# ruff: noqa: E501, F401")?;
40        writeln!(f)?;
41        for import in self.import().into_iter().sorted() {
42            let name = import.get().unwrap_or(&self.default_module_name);
43            if name != self.name {
44                writeln!(f, "import {}", name)?;
45            }
46        }
47        for submod in &self.submodules {
48            writeln!(f, "from . import {}", submod)?;
49        }
50        if !self.enum_.is_empty() {
51            writeln!(f, "from enum import Enum")?;
52        }
53        writeln!(f)?;
54
55        for var in self.variables.values() {
56            writeln!(f, "{}", var)?;
57        }
58        for class in self.class.values().sorted_by_key(|class| class.name) {
59            write!(f, "{}", class)?;
60        }
61        for enum_ in self.enum_.values().sorted_by_key(|class| class.name) {
62            write!(f, "{}", enum_)?;
63        }
64        for function in self.function.values() {
65            write!(f, "{}", function)?;
66        }
67        for error in self.error.values() {
68            writeln!(f, "{}", error)?;
69        }
70        Ok(())
71    }
72}