pyo3_stub_gen/generate/
class.rs

1use indexmap::IndexMap;
2
3use crate::generate::variant_methods::get_variant_methods;
4use crate::{
5    generate::{
6        docstring, indent, GetterDisplay, Import, MemberDef, MethodDef, Parameter,
7        ParameterDefault, Parameters, SetterDisplay,
8    },
9    stub_type::ImportRef,
10    type_info::*,
11    TypeInfo,
12};
13use std::collections::HashSet;
14use std::{fmt, vec};
15
16/// Definition of a Python class.
17#[derive(Debug, Clone, PartialEq)]
18pub struct ClassDef {
19    pub name: &'static str,
20    pub doc: &'static str,
21    pub attrs: Vec<MemberDef>,
22    pub getter_setters: IndexMap<String, (Option<MemberDef>, Option<MemberDef>)>,
23    pub methods: IndexMap<String, Vec<MethodDef>>,
24    pub bases: Vec<TypeInfo>,
25    pub classes: Vec<ClassDef>,
26    pub match_args: Option<Vec<String>>,
27    pub subclass: bool,
28}
29
30impl Import for ClassDef {
31    fn import(&self) -> HashSet<ImportRef> {
32        let mut import = HashSet::new();
33        if !self.subclass {
34            // for @typing.final
35            import.insert("typing".into());
36        }
37        for base in &self.bases {
38            import.extend(base.import.clone());
39        }
40        for attr in &self.attrs {
41            import.extend(attr.import());
42        }
43        for (getter, setter) in self.getter_setters.values() {
44            if let Some(getter) = getter {
45                import.extend(getter.import());
46            }
47            if let Some(setter) = setter {
48                import.extend(setter.import());
49            }
50        }
51        for method in self.methods.values() {
52            if method.len() > 1 {
53                // for @typing.overload
54                import.insert("typing".into());
55            }
56            for method in method {
57                import.extend(method.import());
58            }
59        }
60        for class in &self.classes {
61            import.extend(class.import());
62        }
63        import
64    }
65}
66
67impl From<&PyComplexEnumInfo> for ClassDef {
68    fn from(info: &PyComplexEnumInfo) -> Self {
69        // Since there are multiple `#[pymethods]` for a single class, we need to merge them.
70        // This is only an initializer. See `StubInfo::gather` for the actual merging.
71
72        let enum_info = Self {
73            name: info.pyclass_name,
74            doc: info.doc,
75            getter_setters: IndexMap::new(),
76            methods: IndexMap::new(),
77            classes: info
78                .variants
79                .iter()
80                .map(|v| ClassDef::from_variant(info, v))
81                .collect(),
82            bases: Vec::new(),
83            match_args: None,
84            attrs: Vec::new(),
85            subclass: true, // Complex enums can be subclassed by their variants
86        };
87
88        enum_info
89    }
90}
91
92impl ClassDef {
93    fn from_variant(enum_info: &PyComplexEnumInfo, info: &VariantInfo) -> Self {
94        let methods = get_variant_methods(enum_info, info);
95
96        Self {
97            name: info.pyclass_name,
98            doc: info.doc,
99            getter_setters: info
100                .fields
101                .iter()
102                .map(|info| (info.name.to_string(), (Some(MemberDef::from(info)), None)))
103                .collect(),
104            methods,
105            classes: Vec::new(),
106            bases: vec![TypeInfo::unqualified(enum_info.pyclass_name)],
107            match_args: Some(info.fields.iter().map(|f| f.name.to_string()).collect()),
108            attrs: Vec::new(),
109            subclass: false,
110        }
111    }
112}
113
114impl From<&PyClassInfo> for ClassDef {
115    fn from(info: &PyClassInfo) -> Self {
116        // Since there are multiple `#[pymethods]` for a single class, we need to merge them.
117        // This is only an initializer. See `StubInfo::gather` for the actual merging.
118        let mut getter_setters: IndexMap<String, (Option<MemberDef>, Option<MemberDef>)> = info
119            .getters
120            .iter()
121            .map(|info| (info.name.to_string(), (Some(MemberDef::from(info)), None)))
122            .collect();
123        for setter in info.setters {
124            getter_setters.entry(setter.name.to_string()).or_default().1 = Some(MemberDef {
125                name: setter.name,
126                r#type: (setter.r#type)(),
127                doc: setter.doc,
128                default: setter.default.map(|f| f()),
129                deprecated: setter.deprecated.clone(),
130            });
131        }
132        let mut new = Self {
133            name: info.pyclass_name,
134            doc: info.doc,
135            attrs: Vec::new(),
136            getter_setters,
137            methods: Default::default(),
138            classes: Vec::new(),
139            bases: info.bases.iter().map(|f| f()).collect(),
140            match_args: None,
141            subclass: info.subclass,
142        };
143        if info.has_eq {
144            new.add_eq_method();
145        }
146        if info.has_ord {
147            new.add_ord_methods();
148        }
149        if info.has_hash {
150            new.add_hash_method();
151        }
152        if info.has_str {
153            new.add_str_method();
154        }
155        new
156    }
157}
158impl ClassDef {
159    fn add_eq_method(&mut self) {
160        let method = MethodDef {
161            name: "__eq__",
162            parameters: Parameters {
163                positional_or_keyword: vec![Parameter {
164                    name: "other",
165                    kind: ParameterKind::PositionalOrKeyword,
166                    type_info: TypeInfo::builtin("object"),
167                    default: ParameterDefault::None,
168                }],
169                ..Parameters::new()
170            },
171            r#return: TypeInfo::builtin("bool"),
172            doc: "",
173            r#type: MethodType::Instance,
174            is_async: false,
175            deprecated: None,
176            type_ignored: None,
177            is_overload: false,
178        };
179        self.methods
180            .entry("__eq__".to_string())
181            .or_default()
182            .push(method);
183    }
184
185    fn add_ord_methods(&mut self) {
186        let ord_methods = ["__lt__", "__le__", "__gt__", "__ge__"];
187
188        for name in &ord_methods {
189            let method = MethodDef {
190                name,
191                parameters: Parameters {
192                    positional_or_keyword: vec![Parameter {
193                        name: "other",
194                        kind: ParameterKind::PositionalOrKeyword,
195                        type_info: TypeInfo::builtin("object"),
196                        default: ParameterDefault::None,
197                    }],
198                    ..Parameters::new()
199                },
200                r#return: TypeInfo::builtin("bool"),
201                doc: "",
202                r#type: MethodType::Instance,
203                is_async: false,
204                deprecated: None,
205                type_ignored: None,
206                is_overload: false,
207            };
208            self.methods
209                .entry(name.to_string())
210                .or_default()
211                .push(method);
212        }
213    }
214
215    fn add_hash_method(&mut self) {
216        let method = MethodDef {
217            name: "__hash__",
218            parameters: Parameters::new(),
219            r#return: TypeInfo::builtin("int"),
220            doc: "",
221            r#type: MethodType::Instance,
222            is_async: false,
223            deprecated: None,
224            type_ignored: None,
225            is_overload: false,
226        };
227        self.methods
228            .entry("__hash__".to_string())
229            .or_default()
230            .push(method);
231    }
232
233    fn add_str_method(&mut self) {
234        let method = MethodDef {
235            name: "__str__",
236            parameters: Parameters::new(),
237            r#return: TypeInfo::builtin("str"),
238            doc: "",
239            r#type: MethodType::Instance,
240            is_async: false,
241            deprecated: None,
242            type_ignored: None,
243            is_overload: false,
244        };
245        self.methods
246            .entry("__str__".to_string())
247            .or_default()
248            .push(method);
249    }
250}
251
252impl fmt::Display for ClassDef {
253    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
254        let bases = self
255            .bases
256            .iter()
257            .map(|i| i.name.clone())
258            .reduce(|acc, path| format!("{acc}, {path}"))
259            .map(|bases| format!("({bases})"))
260            .unwrap_or_default();
261        if !self.subclass {
262            writeln!(f, "@typing.final")?;
263        }
264        writeln!(f, "class {}{}:", self.name, bases)?;
265        let indent = indent();
266        let doc = self.doc.trim();
267        docstring::write_docstring(f, doc, indent)?;
268
269        if let Some(match_args) = &self.match_args {
270            if match_args.is_empty() {
271                writeln!(f, "{indent}__match_args__ = ()")?;
272            } else {
273                let match_args_txt = match_args
274                    .iter()
275                    .map(|a| format!(r##""{a}""##))
276                    .collect::<Vec<_>>()
277                    .join(", ");
278                writeln!(f, "{indent}__match_args__ = ({match_args_txt},)")?;
279            }
280        }
281        for attr in &self.attrs {
282            attr.fmt(f)?;
283        }
284        for (getter, setter) in self.getter_setters.values() {
285            if let Some(getter) = getter {
286                GetterDisplay(getter).fmt(f)?;
287            }
288            if let Some(setter) = setter {
289                SetterDisplay(setter).fmt(f)?;
290            }
291        }
292        for (_method_name, methods) in &self.methods {
293            // Check if we should add @overload to all methods
294            let has_overload = methods.iter().any(|m| m.is_overload);
295            let should_add_overload = methods.len() > 1 && has_overload;
296
297            for method in methods {
298                if should_add_overload {
299                    writeln!(f, "{indent}@typing.overload")?;
300                }
301                method.fmt(f)?;
302            }
303        }
304        for class in &self.classes {
305            let emit = format!("{class}");
306            for line in emit.lines() {
307                writeln!(f, "{indent}{line}")?;
308            }
309        }
310        if self.attrs.is_empty() && self.getter_setters.is_empty() && self.methods.is_empty() {
311            writeln!(f, "{indent}...")?;
312        }
313        writeln!(f)?;
314        Ok(())
315    }
316}