pyo3_stub_gen/generate/
class.rs

1use indexmap::IndexMap;
2
3use crate::generate::variant_methods::get_variant_methods;
4use crate::{generate::*, type_info::*, TypeInfo};
5use std::{fmt, vec};
6
7/// Definition of a Python class.
8#[derive(Debug, Clone, PartialEq)]
9pub struct ClassDef {
10    pub name: &'static str,
11    pub doc: &'static str,
12    pub attrs: Vec<MemberDef>,
13    pub getters: Vec<MemberDef>,
14    pub setters: Vec<MemberDef>,
15    pub methods: IndexMap<String, Vec<MethodDef>>,
16    pub bases: Vec<TypeInfo>,
17    pub classes: Vec<ClassDef>,
18    pub match_args: Option<Vec<String>>,
19}
20
21impl Import for ClassDef {
22    fn import(&self) -> HashSet<ModuleRef> {
23        let mut import = HashSet::new();
24        for base in &self.bases {
25            import.extend(base.import.clone());
26        }
27        for attr in &self.attrs {
28            import.extend(attr.import());
29        }
30        for getter in &self.getters {
31            import.extend(getter.import());
32        }
33        for setter in &self.setters {
34            import.extend(setter.import());
35        }
36        for method in self.methods.values() {
37            if method.len() > 1 {
38                // for @typing.overload
39                import.insert("typing".into());
40            }
41            for method in method {
42                import.extend(method.import());
43            }
44        }
45        for class in &self.classes {
46            import.extend(class.import());
47        }
48        import
49    }
50}
51
52impl From<&PyComplexEnumInfo> for ClassDef {
53    fn from(info: &PyComplexEnumInfo) -> Self {
54        // Since there are multiple `#[pymethods]` for a single class, we need to merge them.
55        // This is only an initializer. See `StubInfo::gather` for the actual merging.
56
57        let enum_info = Self {
58            name: info.pyclass_name,
59            doc: info.doc,
60            getters: Vec::new(),
61            setters: Vec::new(),
62            methods: IndexMap::new(),
63            classes: info
64                .variants
65                .iter()
66                .map(|v| ClassDef::from_variant(info, v))
67                .collect(),
68            bases: Vec::new(),
69            match_args: None,
70            attrs: Vec::new(),
71        };
72
73        enum_info
74    }
75}
76
77impl ClassDef {
78    fn from_variant(enum_info: &PyComplexEnumInfo, info: &VariantInfo) -> Self {
79        let methods = get_variant_methods(enum_info, info);
80
81        Self {
82            name: info.pyclass_name,
83            doc: info.doc,
84            getters: info.fields.iter().map(MemberDef::from).collect(),
85            setters: Vec::new(),
86            methods,
87            classes: Vec::new(),
88            bases: vec![TypeInfo::unqualified(enum_info.pyclass_name)],
89            match_args: Some(info.fields.iter().map(|f| f.name.to_string()).collect()),
90            attrs: Vec::new(),
91        }
92    }
93}
94
95impl From<&PyClassInfo> for ClassDef {
96    fn from(info: &PyClassInfo) -> Self {
97        // Since there are multiple `#[pymethods]` for a single class, we need to merge them.
98        // This is only an initializer. See `StubInfo::gather` for the actual merging.
99        Self {
100            name: info.pyclass_name,
101            doc: info.doc,
102            attrs: Vec::new(),
103            setters: info.setters.iter().map(MemberDef::from).collect(),
104            getters: info.getters.iter().map(MemberDef::from).collect(),
105            methods: Default::default(),
106            classes: Vec::new(),
107            bases: info.bases.iter().map(|f| f()).collect(),
108            match_args: None,
109        }
110    }
111}
112
113impl fmt::Display for ClassDef {
114    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
115        let bases = self
116            .bases
117            .iter()
118            .map(|i| i.name.clone())
119            .reduce(|acc, path| format!("{acc}, {path}"))
120            .map(|bases| format!("({bases})"))
121            .unwrap_or_default();
122        writeln!(f, "class {}{}:", self.name, bases)?;
123        let indent = indent();
124        let doc = self.doc.trim();
125        docstring::write_docstring(f, doc, indent)?;
126
127        if let Some(match_args) = &self.match_args {
128            let match_args_txt = if match_args.is_empty() {
129                "()".to_string()
130            } else {
131                match_args
132                    .iter()
133                    .map(|a| format!(r##""{a}""##))
134                    .collect::<Vec<_>>()
135                    .join(", ")
136            };
137
138            writeln!(f, "{indent}__match_args__ = ({match_args_txt},)")?;
139        }
140        for attr in &self.attrs {
141            attr.fmt(f)?;
142        }
143        for getter in &self.getters {
144            GetterDisplay(getter).fmt(f)?;
145        }
146        for setter in &self.setters {
147            SetterDisplay(setter).fmt(f)?;
148        }
149        for methods in self.methods.values() {
150            let overloaded = methods.len() > 1;
151            for method in methods {
152                if overloaded {
153                    writeln!(f, "{indent}@typing.overload")?;
154                }
155                method.fmt(f)?;
156            }
157        }
158        for class in &self.classes {
159            let emit = format!("{class}");
160            for line in emit.lines() {
161                writeln!(f, "{indent}{line}")?;
162            }
163        }
164        if self.attrs.is_empty()
165            && self.getters.is_empty()
166            && self.setters.is_empty()
167            && self.methods.is_empty()
168        {
169            writeln!(f, "{indent}...")?;
170        }
171        writeln!(f)?;
172        Ok(())
173    }
174}