pyo3_stub_gen/generate/
class.rs

1use indexmap::IndexMap;
2
3use crate::generate::variant_methods::get_variant_methods;
4use crate::{
5    generate::{
6        docstring, indent, GetterDisplay, Import, MemberDef, MethodDef, Parameter,
7        ParameterDefault, Parameters, SetterDisplay,
8    },
9    stub_type::ImportRef,
10    type_info::*,
11    TypeInfo,
12};
13use std::collections::HashSet;
14use std::{fmt, vec};
15
16/// Definition of a Python class.
17#[derive(Debug, Clone, PartialEq)]
18pub struct ClassDef {
19    pub name: &'static str,
20    pub module: Option<&'static str>,
21    pub doc: &'static str,
22    pub attrs: Vec<MemberDef>,
23    pub getter_setters: IndexMap<String, (Option<MemberDef>, Option<MemberDef>)>,
24    pub methods: IndexMap<String, Vec<MethodDef>>,
25    pub bases: Vec<TypeInfo>,
26    pub classes: Vec<ClassDef>,
27    pub match_args: Option<Vec<String>>,
28    pub subclass: bool,
29}
30
31impl Import for ClassDef {
32    fn import(&self) -> HashSet<ImportRef> {
33        let mut import = HashSet::new();
34        if !self.subclass {
35            // for @typing.final
36            import.insert("typing".into());
37        }
38        for base in &self.bases {
39            import.extend(base.import.clone());
40        }
41        for attr in &self.attrs {
42            import.extend(attr.import());
43        }
44        for (getter, setter) in self.getter_setters.values() {
45            if let Some(getter) = getter {
46                import.extend(getter.import());
47            }
48            if let Some(setter) = setter {
49                import.extend(setter.import());
50            }
51        }
52        for method in self.methods.values() {
53            if method.len() > 1 {
54                // for @typing.overload
55                import.insert("typing".into());
56            }
57            for method in method {
58                import.extend(method.import());
59            }
60        }
61        for class in &self.classes {
62            import.extend(class.import());
63        }
64        import
65    }
66}
67
68impl From<&PyComplexEnumInfo> for ClassDef {
69    fn from(info: &PyComplexEnumInfo) -> Self {
70        // Since there are multiple `#[pymethods]` for a single class, we need to merge them.
71        // This is only an initializer. See `StubInfo::gather` for the actual merging.
72
73        let enum_info = Self {
74            name: info.pyclass_name,
75            module: info.module,
76            doc: info.doc,
77            getter_setters: IndexMap::new(),
78            methods: IndexMap::new(),
79            classes: info
80                .variants
81                .iter()
82                .map(|v| ClassDef::from_variant(info, v))
83                .collect(),
84            bases: Vec::new(),
85            match_args: None,
86            attrs: Vec::new(),
87            subclass: true, // Complex enums can be subclassed by their variants
88        };
89
90        enum_info
91    }
92}
93
94impl ClassDef {
95    fn from_variant(enum_info: &PyComplexEnumInfo, info: &VariantInfo) -> Self {
96        let methods = get_variant_methods(enum_info, info);
97
98        Self {
99            name: info.pyclass_name,
100            module: enum_info.module,
101            doc: info.doc,
102            getter_setters: info
103                .fields
104                .iter()
105                .map(|info| (info.name.to_string(), (Some(MemberDef::from(info)), None)))
106                .collect(),
107            methods,
108            classes: Vec::new(),
109            bases: vec![TypeInfo::unqualified(enum_info.pyclass_name)],
110            match_args: Some(info.fields.iter().map(|f| f.name.to_string()).collect()),
111            attrs: Vec::new(),
112            subclass: false,
113        }
114    }
115}
116
117impl From<&PyClassInfo> for ClassDef {
118    fn from(info: &PyClassInfo) -> Self {
119        // Since there are multiple `#[pymethods]` for a single class, we need to merge them.
120        // This is only an initializer. See `StubInfo::gather` for the actual merging.
121        let mut getter_setters: IndexMap<String, (Option<MemberDef>, Option<MemberDef>)> = info
122            .getters
123            .iter()
124            .map(|info| (info.name.to_string(), (Some(MemberDef::from(info)), None)))
125            .collect();
126        for setter in info.setters {
127            getter_setters.entry(setter.name.to_string()).or_default().1 = Some(MemberDef {
128                name: setter.name,
129                r#type: (setter.r#type)(),
130                doc: setter.doc,
131                default: setter.default.map(|f| f()),
132                deprecated: setter.deprecated.clone(),
133            });
134        }
135        let mut new = Self {
136            name: info.pyclass_name,
137            module: info.module,
138            doc: info.doc,
139            attrs: Vec::new(),
140            getter_setters,
141            methods: Default::default(),
142            classes: Vec::new(),
143            bases: info.bases.iter().map(|f| f()).collect(),
144            match_args: None,
145            subclass: info.subclass,
146        };
147        if info.has_eq {
148            new.add_eq_method();
149        }
150        if info.has_ord {
151            new.add_ord_methods();
152        }
153        if info.has_hash {
154            new.add_hash_method();
155        }
156        if info.has_str {
157            new.add_str_method();
158        }
159        new
160    }
161}
162impl ClassDef {
163    fn add_eq_method(&mut self) {
164        let method = MethodDef {
165            name: "__eq__",
166            parameters: Parameters {
167                positional_or_keyword: vec![Parameter {
168                    name: "other",
169                    kind: ParameterKind::PositionalOrKeyword,
170                    type_info: TypeInfo::builtin("object"),
171                    default: ParameterDefault::None,
172                }],
173                ..Parameters::new()
174            },
175            r#return: TypeInfo::builtin("bool"),
176            doc: "",
177            r#type: MethodType::Instance,
178            is_async: false,
179            deprecated: None,
180            type_ignored: None,
181            is_overload: false,
182        };
183        self.methods
184            .entry("__eq__".to_string())
185            .or_default()
186            .push(method);
187    }
188
189    fn add_ord_methods(&mut self) {
190        let ord_methods = ["__lt__", "__le__", "__gt__", "__ge__"];
191
192        for name in &ord_methods {
193            let method = MethodDef {
194                name,
195                parameters: Parameters {
196                    positional_or_keyword: vec![Parameter {
197                        name: "other",
198                        kind: ParameterKind::PositionalOrKeyword,
199                        type_info: TypeInfo::builtin("object"),
200                        default: ParameterDefault::None,
201                    }],
202                    ..Parameters::new()
203                },
204                r#return: TypeInfo::builtin("bool"),
205                doc: "",
206                r#type: MethodType::Instance,
207                is_async: false,
208                deprecated: None,
209                type_ignored: None,
210                is_overload: false,
211            };
212            self.methods
213                .entry(name.to_string())
214                .or_default()
215                .push(method);
216        }
217    }
218
219    fn add_hash_method(&mut self) {
220        let method = MethodDef {
221            name: "__hash__",
222            parameters: Parameters::new(),
223            r#return: TypeInfo::builtin("int"),
224            doc: "",
225            r#type: MethodType::Instance,
226            is_async: false,
227            deprecated: None,
228            type_ignored: None,
229            is_overload: false,
230        };
231        self.methods
232            .entry("__hash__".to_string())
233            .or_default()
234            .push(method);
235    }
236
237    fn add_str_method(&mut self) {
238        let method = MethodDef {
239            name: "__str__",
240            parameters: Parameters::new(),
241            r#return: TypeInfo::builtin("str"),
242            doc: "",
243            r#type: MethodType::Instance,
244            is_async: false,
245            deprecated: None,
246            type_ignored: None,
247            is_overload: false,
248        };
249        self.methods
250            .entry("__str__".to_string())
251            .or_default()
252            .push(method);
253    }
254
255    /// Resolve all ModuleRef::Default to actual module name.
256    /// Called after construction, before formatting.
257    pub fn resolve_default_modules(&mut self, default_module_name: &str) {
258        // Resolve in getter/setter types
259        for (getter, setter) in self.getter_setters.values_mut() {
260            if let Some(getter) = getter {
261                getter.r#type.resolve_default_module(default_module_name);
262            }
263            if let Some(setter) = setter {
264                setter.r#type.resolve_default_module(default_module_name);
265            }
266        }
267
268        // Resolve in method parameter and return types
269        for methods in self.methods.values_mut() {
270            for method in methods {
271                // Resolve all parameter types
272                for param in &mut method.parameters.positional_only {
273                    param.type_info.resolve_default_module(default_module_name);
274                }
275                for param in &mut method.parameters.positional_or_keyword {
276                    param.type_info.resolve_default_module(default_module_name);
277                }
278                for param in &mut method.parameters.keyword_only {
279                    param.type_info.resolve_default_module(default_module_name);
280                }
281                if let Some(varargs) = &mut method.parameters.varargs {
282                    varargs
283                        .type_info
284                        .resolve_default_module(default_module_name);
285                }
286                if let Some(varkw) = &mut method.parameters.varkw {
287                    varkw.type_info.resolve_default_module(default_module_name);
288                }
289                method.r#return.resolve_default_module(default_module_name);
290            }
291        }
292
293        // Resolve in base classes
294        for base in &mut self.bases {
295            base.resolve_default_module(default_module_name);
296        }
297
298        // Resolve in class attributes
299        for attr in &mut self.attrs {
300            attr.r#type.resolve_default_module(default_module_name);
301        }
302
303        // Recursively resolve in nested classes
304        for class in &mut self.classes {
305            class.resolve_default_modules(default_module_name);
306        }
307    }
308}
309
310impl fmt::Display for ClassDef {
311    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
312        let bases = self
313            .bases
314            .iter()
315            .map(|i| i.name.clone())
316            .reduce(|acc, path| format!("{acc}, {path}"))
317            .map(|bases| format!("({bases})"))
318            .unwrap_or_default();
319        if !self.subclass {
320            writeln!(f, "@typing.final")?;
321        }
322        writeln!(f, "class {}{}:", self.name, bases)?;
323        let indent = indent();
324        let doc = self.doc.trim();
325        docstring::write_docstring(f, doc, indent)?;
326
327        if let Some(match_args) = &self.match_args {
328            if match_args.is_empty() {
329                writeln!(f, "{indent}__match_args__ = ()")?;
330            } else {
331                let match_args_txt = match_args
332                    .iter()
333                    .map(|a| format!(r##""{a}""##))
334                    .collect::<Vec<_>>()
335                    .join(", ");
336                writeln!(f, "{indent}__match_args__ = ({match_args_txt},)")?;
337            }
338        }
339        for attr in &self.attrs {
340            attr.fmt(f)?;
341        }
342        for (getter, setter) in self.getter_setters.values() {
343            if let Some(getter) = getter {
344                write!(
345                    f,
346                    "{}",
347                    GetterDisplay {
348                        member: getter,
349                        target_module: self.module.unwrap_or(self.name)
350                    }
351                )?;
352            }
353            if let Some(setter) = setter {
354                write!(
355                    f,
356                    "{}",
357                    SetterDisplay {
358                        member: setter,
359                        target_module: self.module.unwrap_or(self.name)
360                    }
361                )?;
362            }
363        }
364        for (_method_name, methods) in &self.methods {
365            // Check if we should add @overload to all methods
366            let has_overload = methods.iter().any(|m| m.is_overload);
367            let should_add_overload = methods.len() > 1 && has_overload;
368
369            for method in methods {
370                if should_add_overload {
371                    writeln!(f, "{indent}@typing.overload")?;
372                }
373                method.fmt(f)?;
374            }
375        }
376        for class in &self.classes {
377            let emit = format!("{class}");
378            for line in emit.lines() {
379                writeln!(f, "{indent}{line}")?;
380            }
381        }
382        if self.attrs.is_empty() && self.getter_setters.is_empty() && self.methods.is_empty() {
383            writeln!(f, "{indent}...")?;
384        }
385        writeln!(f)?;
386        Ok(())
387    }
388}
389
390impl ClassDef {
391    /// Format class with module-qualified type names
392    ///
393    /// This method uses the target module context to qualify type identifiers
394    /// within compound type expressions based on their source modules.
395    pub fn fmt_for_module(&self, target_module: &str, f: &mut fmt::Formatter) -> fmt::Result {
396        // Qualify base classes
397        let bases = self
398            .bases
399            .iter()
400            .map(|i| i.qualified_for_module(target_module))
401            .reduce(|acc, path| format!("{acc}, {path}"))
402            .map(|bases| format!("({bases})"))
403            .unwrap_or_default();
404
405        if !self.subclass {
406            writeln!(f, "@typing.final")?;
407        }
408        writeln!(f, "class {}{}:", self.name, bases)?;
409
410        let indent = indent();
411        let doc = self.doc.trim();
412        docstring::write_docstring(f, doc, indent)?;
413
414        if let Some(match_args) = &self.match_args {
415            if match_args.is_empty() {
416                writeln!(f, "{indent}__match_args__ = ()")?;
417            } else {
418                let match_args_txt = match_args
419                    .iter()
420                    .map(|a| format!(r##""{a}""##))
421                    .collect::<Vec<_>>()
422                    .join(", ");
423                writeln!(f, "{indent}__match_args__ = ({match_args_txt},)")?;
424            }
425        }
426
427        // Format attributes with qualified types
428        for attr in &self.attrs {
429            attr.fmt_for_module(target_module, f, indent)?;
430        }
431
432        // Format properties with qualified types
433        for (getter, setter) in self.getter_setters.values() {
434            if let Some(getter) = getter {
435                write!(
436                    f,
437                    "{}",
438                    GetterDisplay {
439                        member: getter,
440                        target_module
441                    }
442                )?;
443            }
444            if let Some(setter) = setter {
445                write!(
446                    f,
447                    "{}",
448                    SetterDisplay {
449                        member: setter,
450                        target_module
451                    }
452                )?;
453            }
454        }
455
456        // Format methods with qualified types
457        for (_method_name, methods) in &self.methods {
458            let has_overload = methods.iter().any(|m| m.is_overload);
459            let should_add_overload = methods.len() > 1 && has_overload;
460
461            for method in methods {
462                if should_add_overload {
463                    writeln!(f, "{indent}@typing.overload")?;
464                }
465                method.fmt_for_module(target_module, f, indent)?;
466            }
467        }
468
469        // Format nested classes recursively
470        for class in &self.classes {
471            // Create a temporary formatter to capture nested class output
472            struct FmtAdapter<'a, 'b> {
473                class: &'a ClassDef,
474                target_module: &'b str,
475            }
476            impl<'a, 'b> fmt::Display for FmtAdapter<'a, 'b> {
477                fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
478                    self.class.fmt_for_module(self.target_module, f)
479                }
480            }
481            let emit = format!(
482                "{}",
483                FmtAdapter {
484                    class,
485                    target_module
486                }
487            );
488            for line in emit.lines() {
489                writeln!(f, "{indent}{line}")?;
490            }
491        }
492
493        if self.attrs.is_empty() && self.getter_setters.is_empty() && self.methods.is_empty() {
494            writeln!(f, "{indent}...")?;
495        }
496        writeln!(f)?;
497        Ok(())
498    }
499}