pyo3_stub_gen/generate/
class.rs

1use indexmap::IndexMap;
2
3use crate::generate::variant_methods::get_variant_methods;
4use crate::{generate::*, type_info::*, TypeInfo};
5use std::{fmt, vec};
6
7/// Definition of a Python class.
8#[derive(Debug, Clone, PartialEq)]
9pub struct ClassDef {
10    pub name: &'static str,
11    pub doc: &'static str,
12    pub attrs: Vec<MemberDef>,
13    pub getter_setters: IndexMap<String, (Option<MemberDef>, Option<MemberDef>)>,
14    pub methods: IndexMap<String, Vec<MethodDef>>,
15    pub bases: Vec<TypeInfo>,
16    pub classes: Vec<ClassDef>,
17    pub match_args: Option<Vec<String>>,
18    pub subclass: bool,
19}
20
21impl Import for ClassDef {
22    fn import(&self) -> HashSet<ImportRef> {
23        let mut import = HashSet::new();
24        if !self.subclass {
25            // for @typing.final
26            import.insert("typing".into());
27        }
28        for base in &self.bases {
29            import.extend(base.import.clone());
30        }
31        for attr in &self.attrs {
32            import.extend(attr.import());
33        }
34        for (getter, setter) in self.getter_setters.values() {
35            if let Some(getter) = getter {
36                import.extend(getter.import());
37            }
38            if let Some(setter) = setter {
39                import.extend(setter.import());
40            }
41        }
42        for method in self.methods.values() {
43            if method.len() > 1 {
44                // for @typing.overload
45                import.insert("typing".into());
46            }
47            for method in method {
48                import.extend(method.import());
49            }
50        }
51        for class in &self.classes {
52            import.extend(class.import());
53        }
54        import
55    }
56}
57
58impl From<&PyComplexEnumInfo> for ClassDef {
59    fn from(info: &PyComplexEnumInfo) -> Self {
60        // Since there are multiple `#[pymethods]` for a single class, we need to merge them.
61        // This is only an initializer. See `StubInfo::gather` for the actual merging.
62
63        let enum_info = Self {
64            name: info.pyclass_name,
65            doc: info.doc,
66            getter_setters: IndexMap::new(),
67            methods: IndexMap::new(),
68            classes: info
69                .variants
70                .iter()
71                .map(|v| ClassDef::from_variant(info, v))
72                .collect(),
73            bases: Vec::new(),
74            match_args: None,
75            attrs: Vec::new(),
76            subclass: true, // Complex enums can be subclassed by their variants
77        };
78
79        enum_info
80    }
81}
82
83impl ClassDef {
84    fn from_variant(enum_info: &PyComplexEnumInfo, info: &VariantInfo) -> Self {
85        let methods = get_variant_methods(enum_info, info);
86
87        Self {
88            name: info.pyclass_name,
89            doc: info.doc,
90            getter_setters: info
91                .fields
92                .iter()
93                .map(|info| (info.name.to_string(), (Some(MemberDef::from(info)), None)))
94                .collect(),
95            methods,
96            classes: Vec::new(),
97            bases: vec![TypeInfo::unqualified(enum_info.pyclass_name)],
98            match_args: Some(info.fields.iter().map(|f| f.name.to_string()).collect()),
99            attrs: Vec::new(),
100            subclass: false,
101        }
102    }
103}
104
105impl From<&PyClassInfo> for ClassDef {
106    fn from(info: &PyClassInfo) -> Self {
107        // Since there are multiple `#[pymethods]` for a single class, we need to merge them.
108        // This is only an initializer. See `StubInfo::gather` for the actual merging.
109        let mut getter_setters: IndexMap<String, (Option<MemberDef>, Option<MemberDef>)> = info
110            .getters
111            .iter()
112            .map(|info| (info.name.to_string(), (Some(MemberDef::from(info)), None)))
113            .collect();
114        for setter in info.setters {
115            getter_setters.entry(setter.name.to_string()).or_default().1 = Some(MemberDef {
116                name: setter.name,
117                r#type: (setter.r#type)(),
118                doc: setter.doc,
119                default: setter.default.map(|f| f()),
120                deprecated: setter.deprecated.clone(),
121            });
122        }
123        let mut new = Self {
124            name: info.pyclass_name,
125            doc: info.doc,
126            attrs: Vec::new(),
127            getter_setters,
128            methods: Default::default(),
129            classes: Vec::new(),
130            bases: info.bases.iter().map(|f| f()).collect(),
131            match_args: None,
132            subclass: info.subclass,
133        };
134        if info.has_eq {
135            new.add_eq_method();
136        }
137        if info.has_ord {
138            new.add_ord_methods();
139        }
140        if info.has_hash {
141            new.add_hash_method();
142        }
143        if info.has_str {
144            new.add_str_method();
145        }
146        new
147    }
148}
149impl ClassDef {
150    fn add_eq_method(&mut self) {
151        let method = MethodDef {
152            name: "__eq__",
153            args: vec![Arg {
154                name: "other",
155                r#type: TypeInfo::builtin("object"),
156                signature: None,
157            }],
158            r#return: TypeInfo::builtin("bool"),
159            doc: "",
160            r#type: MethodType::Instance,
161            is_async: false,
162            deprecated: None,
163            type_ignored: None,
164        };
165        self.methods
166            .entry("__eq__".to_string())
167            .or_default()
168            .push(method);
169    }
170
171    fn add_ord_methods(&mut self) {
172        let ord_methods = ["__lt__", "__le__", "__gt__", "__ge__"];
173
174        for name in &ord_methods {
175            let method = MethodDef {
176                name,
177                args: vec![Arg {
178                    name: "other",
179                    r#type: TypeInfo::builtin("object"),
180                    signature: None,
181                }],
182                r#return: TypeInfo::builtin("bool"),
183                doc: "",
184                r#type: MethodType::Instance,
185                is_async: false,
186                deprecated: None,
187                type_ignored: None,
188            };
189            self.methods
190                .entry(name.to_string())
191                .or_default()
192                .push(method);
193        }
194    }
195
196    fn add_hash_method(&mut self) {
197        let method = MethodDef {
198            name: "__hash__",
199            args: vec![],
200            r#return: TypeInfo::builtin("int"),
201            doc: "",
202            r#type: MethodType::Instance,
203            is_async: false,
204            deprecated: None,
205            type_ignored: None,
206        };
207        self.methods
208            .entry("__hash__".to_string())
209            .or_default()
210            .push(method);
211    }
212
213    fn add_str_method(&mut self) {
214        let method = MethodDef {
215            name: "__str__",
216            args: vec![],
217            r#return: TypeInfo::builtin("str"),
218            doc: "",
219            r#type: MethodType::Instance,
220            is_async: false,
221            deprecated: None,
222            type_ignored: None,
223        };
224        self.methods
225            .entry("__str__".to_string())
226            .or_default()
227            .push(method);
228    }
229}
230
231impl fmt::Display for ClassDef {
232    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
233        let bases = self
234            .bases
235            .iter()
236            .map(|i| i.name.clone())
237            .reduce(|acc, path| format!("{acc}, {path}"))
238            .map(|bases| format!("({bases})"))
239            .unwrap_or_default();
240        if !self.subclass {
241            writeln!(f, "@typing.final")?;
242        }
243        writeln!(f, "class {}{}:", self.name, bases)?;
244        let indent = indent();
245        let doc = self.doc.trim();
246        docstring::write_docstring(f, doc, indent)?;
247
248        if let Some(match_args) = &self.match_args {
249            if match_args.is_empty() {
250                writeln!(f, "{indent}__match_args__ = ()")?;
251            } else {
252                let match_args_txt = match_args
253                    .iter()
254                    .map(|a| format!(r##""{a}""##))
255                    .collect::<Vec<_>>()
256                    .join(", ");
257                writeln!(f, "{indent}__match_args__ = ({match_args_txt},)")?;
258            }
259        }
260        for attr in &self.attrs {
261            attr.fmt(f)?;
262        }
263        for (getter, setter) in self.getter_setters.values() {
264            if let Some(getter) = getter {
265                GetterDisplay(getter).fmt(f)?;
266            }
267            if let Some(setter) = setter {
268                SetterDisplay(setter).fmt(f)?;
269            }
270        }
271        for methods in self.methods.values() {
272            let overloaded = methods.len() > 1;
273            for method in methods {
274                if overloaded {
275                    writeln!(f, "{indent}@typing.overload")?;
276                }
277                method.fmt(f)?;
278            }
279        }
280        for class in &self.classes {
281            let emit = format!("{class}");
282            for line in emit.lines() {
283                writeln!(f, "{indent}{line}")?;
284            }
285        }
286        if self.attrs.is_empty() && self.getter_setters.is_empty() && self.methods.is_empty() {
287            writeln!(f, "{indent}...")?;
288        }
289        writeln!(f)?;
290        Ok(())
291    }
292}