pyo3_stub_gen/generate/
module.rs1use crate::generate::*;
2use itertools::Itertools;
3use std::{
4 any::TypeId,
5 collections::{BTreeMap, BTreeSet},
6 fmt,
7};
8
9#[derive(Debug, Clone, PartialEq, Default)]
11pub struct Module {
12 pub class: BTreeMap<TypeId, ClassDef>,
13 pub enum_: BTreeMap<TypeId, EnumDef>,
14 pub function: BTreeMap<&'static str, Vec<FunctionDef>>,
15 pub error: BTreeMap<&'static str, ErrorDef>,
16 pub variables: BTreeMap<&'static str, VariableDef>,
17 pub name: String,
18 pub default_module_name: String,
19 pub submodules: BTreeSet<String>,
21}
22
23impl Import for Module {
24 fn import(&self) -> HashSet<ModuleRef> {
25 let mut imports = HashSet::new();
26 for class in self.class.values() {
27 imports.extend(class.import());
28 }
29 for function in self.function.values().flatten() {
30 imports.extend(function.import());
31 }
32 imports
33 }
34}
35
36impl fmt::Display for Module {
37 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
38 writeln!(f, "# This file is automatically generated by pyo3_stub_gen")?;
39 writeln!(f, "# ruff: noqa: E501, F401")?;
40 writeln!(f)?;
41 let mut imports = self.import();
42 let any_overloaded = self.function.values().any(|functions| functions.len() > 1);
43 if any_overloaded {
44 imports.insert(ModuleRef::Named("typing".to_string()));
45 }
46
47 for import in imports.into_iter().sorted() {
48 let name = import.get().unwrap_or(&self.default_module_name);
49 if name != self.name {
50 writeln!(f, "import {name}")?;
51 }
52 }
53 for submod in &self.submodules {
54 writeln!(f, "from . import {submod}")?;
55 }
56 if !self.enum_.is_empty() {
57 writeln!(f, "from enum import Enum")?;
58 }
59 writeln!(f)?;
60
61 for var in self.variables.values() {
62 writeln!(f, "{var}")?;
63 }
64 for class in self.class.values().sorted_by_key(|class| class.name) {
65 write!(f, "{class}")?;
66 }
67 for enum_ in self.enum_.values().sorted_by_key(|class| class.name) {
68 write!(f, "{enum_}")?;
69 }
70 for functions in self.function.values() {
71 let overloaded = functions.len() > 1;
72 for function in functions {
73 if overloaded {
74 writeln!(f, "@typing.overload")?;
75 }
76 write!(f, "{function}")?;
77 }
78 }
79 for error in self.error.values() {
80 writeln!(f, "{error}")?;
81 }
82 Ok(())
83 }
84}