pyo3_stub_gen/generate/
class.rs

1use indexmap::IndexMap;
2
3use crate::generate::variant_methods::get_variant_methods;
4use crate::{generate::*, type_info::*, TypeInfo};
5use std::{fmt, vec};
6
7/// Definition of a Python class.
8#[derive(Debug, Clone, PartialEq)]
9pub struct ClassDef {
10    pub name: &'static str,
11    pub doc: &'static str,
12    pub attrs: Vec<MemberDef>,
13    pub getters: Vec<MemberDef>,
14    pub setters: Vec<MemberDef>,
15    pub methods: IndexMap<String, Vec<MethodDef>>,
16    pub bases: Vec<TypeInfo>,
17    pub classes: Vec<ClassDef>,
18    pub match_args: Option<Vec<String>>,
19}
20
21impl Import for ClassDef {
22    fn import(&self) -> HashSet<ModuleRef> {
23        let mut import = HashSet::new();
24        for base in &self.bases {
25            import.extend(base.import.clone());
26        }
27        for attr in &self.attrs {
28            import.extend(attr.import());
29        }
30        for getter in &self.getters {
31            import.extend(getter.import());
32        }
33        for setter in &self.setters {
34            import.extend(setter.import());
35        }
36        for method in self.methods.values() {
37            if method.len() > 1 {
38                // for @typing.overload
39                import.insert("typing".into());
40            }
41            for method in method {
42                import.extend(method.import());
43            }
44        }
45        for class in &self.classes {
46            import.extend(class.import());
47        }
48        import
49    }
50}
51
52impl From<&PyComplexEnumInfo> for ClassDef {
53    fn from(info: &PyComplexEnumInfo) -> Self {
54        // Since there are multiple `#[pymethods]` for a single class, we need to merge them.
55        // This is only an initializer. See `StubInfo::gather` for the actual merging.
56
57        let enum_info = Self {
58            name: info.pyclass_name,
59            doc: info.doc,
60            getters: Vec::new(),
61            setters: Vec::new(),
62            methods: IndexMap::new(),
63            classes: info
64                .variants
65                .iter()
66                .map(|v| ClassDef::from_variant(info, v))
67                .collect(),
68            bases: Vec::new(),
69            match_args: None,
70            attrs: Vec::new(),
71        };
72
73        enum_info
74    }
75}
76
77impl ClassDef {
78    fn from_variant(enum_info: &PyComplexEnumInfo, info: &VariantInfo) -> Self {
79        let methods = get_variant_methods(enum_info, info);
80
81        Self {
82            name: info.pyclass_name,
83            doc: info.doc,
84            getters: info.fields.iter().map(MemberDef::from).collect(),
85            setters: Vec::new(),
86            methods,
87            classes: Vec::new(),
88            bases: vec![TypeInfo::unqualified(enum_info.pyclass_name)],
89            match_args: Some(info.fields.iter().map(|f| f.name.to_string()).collect()),
90            attrs: Vec::new(),
91        }
92    }
93}
94
95impl From<&PyClassInfo> for ClassDef {
96    fn from(info: &PyClassInfo) -> Self {
97        // Since there are multiple `#[pymethods]` for a single class, we need to merge them.
98        // This is only an initializer. See `StubInfo::gather` for the actual merging.
99        let mut new = Self {
100            name: info.pyclass_name,
101            doc: info.doc,
102            attrs: Vec::new(),
103            setters: info.setters.iter().map(MemberDef::from).collect(),
104            getters: info.getters.iter().map(MemberDef::from).collect(),
105            methods: Default::default(),
106            classes: Vec::new(),
107            bases: info.bases.iter().map(|f| f()).collect(),
108            match_args: None,
109        };
110        if info.has_eq {
111            new.add_eq_method();
112        }
113        if info.has_ord {
114            new.add_ord_methods();
115        }
116        if info.has_hash {
117            new.add_hash_method();
118        }
119        if info.has_str {
120            new.add_str_method();
121        }
122        new
123    }
124}
125impl ClassDef {
126    fn add_eq_method(&mut self) {
127        let method = MethodDef {
128            name: "__eq__",
129            args: vec![Arg {
130                name: "other",
131                r#type: TypeInfo::builtin("object"),
132                signature: None,
133            }],
134            r#return: TypeInfo::builtin("bool"),
135            doc: "",
136            r#type: MethodType::Instance,
137            is_async: false,
138            deprecated: None,
139            type_ignored: None,
140        };
141        self.methods
142            .entry("__eq__".to_string())
143            .or_default()
144            .push(method);
145    }
146
147    fn add_ord_methods(&mut self) {
148        let ord_methods = ["__lt__", "__le__", "__gt__", "__ge__"];
149
150        for name in &ord_methods {
151            let method = MethodDef {
152                name,
153                args: vec![Arg {
154                    name: "other",
155                    r#type: TypeInfo::builtin("object"),
156                    signature: None,
157                }],
158                r#return: TypeInfo::builtin("bool"),
159                doc: "",
160                r#type: MethodType::Instance,
161                is_async: false,
162                deprecated: None,
163                type_ignored: None,
164            };
165            self.methods
166                .entry(name.to_string())
167                .or_default()
168                .push(method);
169        }
170    }
171
172    fn add_hash_method(&mut self) {
173        let method = MethodDef {
174            name: "__hash__",
175            args: vec![],
176            r#return: TypeInfo::builtin("int"),
177            doc: "",
178            r#type: MethodType::Instance,
179            is_async: false,
180            deprecated: None,
181            type_ignored: None,
182        };
183        self.methods
184            .entry("__hash__".to_string())
185            .or_default()
186            .push(method);
187    }
188
189    fn add_str_method(&mut self) {
190        let method = MethodDef {
191            name: "__str__",
192            args: vec![],
193            r#return: TypeInfo::builtin("str"),
194            doc: "",
195            r#type: MethodType::Instance,
196            is_async: false,
197            deprecated: None,
198            type_ignored: None,
199        };
200        self.methods
201            .entry("__str__".to_string())
202            .or_default()
203            .push(method);
204    }
205}
206
207impl fmt::Display for ClassDef {
208    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
209        let bases = self
210            .bases
211            .iter()
212            .map(|i| i.name.clone())
213            .reduce(|acc, path| format!("{acc}, {path}"))
214            .map(|bases| format!("({bases})"))
215            .unwrap_or_default();
216        writeln!(f, "class {}{}:", self.name, bases)?;
217        let indent = indent();
218        let doc = self.doc.trim();
219        docstring::write_docstring(f, doc, indent)?;
220
221        if let Some(match_args) = &self.match_args {
222            let match_args_txt = if match_args.is_empty() {
223                "()".to_string()
224            } else {
225                match_args
226                    .iter()
227                    .map(|a| format!(r##""{a}""##))
228                    .collect::<Vec<_>>()
229                    .join(", ")
230            };
231
232            writeln!(f, "{indent}__match_args__ = ({match_args_txt},)")?;
233        }
234        for attr in &self.attrs {
235            attr.fmt(f)?;
236        }
237        for getter in &self.getters {
238            GetterDisplay(getter).fmt(f)?;
239        }
240        for setter in &self.setters {
241            SetterDisplay(setter).fmt(f)?;
242        }
243        for methods in self.methods.values() {
244            let overloaded = methods.len() > 1;
245            for method in methods {
246                if overloaded {
247                    writeln!(f, "{indent}@typing.overload")?;
248                }
249                method.fmt(f)?;
250            }
251        }
252        for class in &self.classes {
253            let emit = format!("{class}");
254            for line in emit.lines() {
255                writeln!(f, "{indent}{line}")?;
256            }
257        }
258        if self.attrs.is_empty()
259            && self.getters.is_empty()
260            && self.setters.is_empty()
261            && self.methods.is_empty()
262        {
263            writeln!(f, "{indent}...")?;
264        }
265        writeln!(f)?;
266        Ok(())
267    }
268}