pyo3_stub_gen/generate/
class.rs

1use indexmap::IndexMap;
2
3use crate::generate::docstring::normalize_docstring;
4use crate::generate::variant_methods::get_variant_methods;
5use crate::{
6    generate::{
7        docstring, indent, GetterDisplay, Import, MemberDef, MethodDef, Parameter,
8        ParameterDefault, Parameters, SetterDisplay,
9    },
10    stub_type::ImportRef,
11    type_info::*,
12    TypeInfo,
13};
14use std::collections::HashSet;
15use std::{fmt, vec};
16
17/// Definition of a Python class.
18#[derive(Debug, Clone, PartialEq)]
19pub struct ClassDef {
20    pub name: &'static str,
21    pub module: Option<&'static str>,
22    pub doc: &'static str,
23    pub attrs: Vec<MemberDef>,
24    pub getter_setters: IndexMap<String, (Option<MemberDef>, Option<MemberDef>)>,
25    pub methods: IndexMap<String, Vec<MethodDef>>,
26    pub bases: Vec<TypeInfo>,
27    pub classes: Vec<ClassDef>,
28    pub match_args: Option<Vec<String>>,
29    pub subclass: bool,
30}
31
32impl Import for ClassDef {
33    fn import(&self) -> HashSet<ImportRef> {
34        let mut import = HashSet::new();
35        if !self.subclass {
36            // for @typing.final
37            import.insert("typing".into());
38        }
39        for base in &self.bases {
40            import.extend(base.import.clone());
41        }
42        for attr in &self.attrs {
43            import.extend(attr.import());
44        }
45        for (getter, setter) in self.getter_setters.values() {
46            if let Some(getter) = getter {
47                import.extend(getter.import());
48            }
49            if let Some(setter) = setter {
50                import.extend(setter.import());
51            }
52        }
53        for method in self.methods.values() {
54            if method.len() > 1 {
55                // for @typing.overload
56                import.insert("typing".into());
57            }
58            for method in method {
59                import.extend(method.import());
60            }
61        }
62        for class in &self.classes {
63            import.extend(class.import());
64        }
65        import
66    }
67}
68
69impl From<&PyComplexEnumInfo> for ClassDef {
70    fn from(info: &PyComplexEnumInfo) -> Self {
71        // Since there are multiple `#[pymethods]` for a single class, we need to merge them.
72        // This is only an initializer. See `StubInfo::gather` for the actual merging.
73
74        let doc = if info.doc.is_empty() {
75            ""
76        } else {
77            Box::leak(normalize_docstring(info.doc).into_boxed_str())
78        };
79
80        let enum_info = Self {
81            name: info.pyclass_name,
82            module: info.module,
83            doc,
84            getter_setters: IndexMap::new(),
85            methods: IndexMap::new(),
86            classes: info
87                .variants
88                .iter()
89                .map(|v| ClassDef::from_variant(info, v))
90                .collect(),
91            bases: Vec::new(),
92            match_args: None,
93            attrs: Vec::new(),
94            subclass: true, // Complex enums can be subclassed by their variants
95        };
96
97        enum_info
98    }
99}
100
101impl ClassDef {
102    fn from_variant(enum_info: &PyComplexEnumInfo, info: &VariantInfo) -> Self {
103        let methods = get_variant_methods(enum_info, info);
104
105        let doc = if info.doc.is_empty() {
106            ""
107        } else {
108            Box::leak(normalize_docstring(info.doc).into_boxed_str())
109        };
110
111        Self {
112            name: info.pyclass_name,
113            module: enum_info.module,
114            doc,
115            getter_setters: info
116                .fields
117                .iter()
118                .map(|info| (info.name.to_string(), (Some(MemberDef::from(info)), None)))
119                .collect(),
120            methods,
121            classes: Vec::new(),
122            bases: vec![TypeInfo::unqualified(enum_info.pyclass_name)],
123            match_args: Some(info.fields.iter().map(|f| f.name.to_string()).collect()),
124            attrs: Vec::new(),
125            subclass: false,
126        }
127    }
128}
129
130impl From<&PyClassInfo> for ClassDef {
131    fn from(info: &PyClassInfo) -> Self {
132        // Since there are multiple `#[pymethods]` for a single class, we need to merge them.
133        // This is only an initializer. See `StubInfo::gather` for the actual merging.
134        let doc = if info.doc.is_empty() {
135            ""
136        } else {
137            Box::leak(normalize_docstring(info.doc).into_boxed_str())
138        };
139
140        let mut getter_setters: IndexMap<String, (Option<MemberDef>, Option<MemberDef>)> = info
141            .getters
142            .iter()
143            .map(|info| (info.name.to_string(), (Some(MemberDef::from(info)), None)))
144            .collect();
145        for setter in info.setters {
146            let setter_doc = if setter.doc.is_empty() {
147                ""
148            } else {
149                Box::leak(normalize_docstring(setter.doc).into_boxed_str())
150            };
151
152            getter_setters.entry(setter.name.to_string()).or_default().1 = Some(MemberDef {
153                name: setter.name,
154                r#type: (setter.r#type)(),
155                doc: setter_doc,
156                default: setter.default.map(|f| f()),
157                deprecated: setter.deprecated.clone(),
158            });
159        }
160        let mut new = Self {
161            name: info.pyclass_name,
162            module: info.module,
163            doc,
164            attrs: Vec::new(),
165            getter_setters,
166            methods: Default::default(),
167            classes: Vec::new(),
168            bases: info.bases.iter().map(|f| f()).collect(),
169            match_args: None,
170            subclass: info.subclass,
171        };
172        if info.has_eq {
173            new.add_eq_method();
174        }
175        if info.has_ord {
176            new.add_ord_methods();
177        }
178        if info.has_hash {
179            new.add_hash_method();
180        }
181        if info.has_str {
182            new.add_str_method();
183        }
184        new
185    }
186}
187impl ClassDef {
188    fn add_eq_method(&mut self) {
189        let method = MethodDef {
190            name: "__eq__",
191            parameters: Parameters {
192                positional_or_keyword: vec![Parameter {
193                    name: "other",
194                    kind: ParameterKind::PositionalOrKeyword,
195                    type_info: TypeInfo::builtin("object"),
196                    default: ParameterDefault::None,
197                }],
198                ..Parameters::new()
199            },
200            r#return: TypeInfo::builtin("bool"),
201            doc: "",
202            r#type: MethodType::Instance,
203            is_async: false,
204            deprecated: None,
205            type_ignored: None,
206            is_overload: false,
207        };
208        self.methods
209            .entry("__eq__".to_string())
210            .or_default()
211            .push(method);
212    }
213
214    fn add_ord_methods(&mut self) {
215        let ord_methods = ["__lt__", "__le__", "__gt__", "__ge__"];
216
217        for name in &ord_methods {
218            let method = MethodDef {
219                name,
220                parameters: Parameters {
221                    positional_or_keyword: vec![Parameter {
222                        name: "other",
223                        kind: ParameterKind::PositionalOrKeyword,
224                        type_info: TypeInfo::builtin("object"),
225                        default: ParameterDefault::None,
226                    }],
227                    ..Parameters::new()
228                },
229                r#return: TypeInfo::builtin("bool"),
230                doc: "",
231                r#type: MethodType::Instance,
232                is_async: false,
233                deprecated: None,
234                type_ignored: None,
235                is_overload: false,
236            };
237            self.methods
238                .entry(name.to_string())
239                .or_default()
240                .push(method);
241        }
242    }
243
244    fn add_hash_method(&mut self) {
245        let method = MethodDef {
246            name: "__hash__",
247            parameters: Parameters::new(),
248            r#return: TypeInfo::builtin("int"),
249            doc: "",
250            r#type: MethodType::Instance,
251            is_async: false,
252            deprecated: None,
253            type_ignored: None,
254            is_overload: false,
255        };
256        self.methods
257            .entry("__hash__".to_string())
258            .or_default()
259            .push(method);
260    }
261
262    fn add_str_method(&mut self) {
263        let method = MethodDef {
264            name: "__str__",
265            parameters: Parameters::new(),
266            r#return: TypeInfo::builtin("str"),
267            doc: "",
268            r#type: MethodType::Instance,
269            is_async: false,
270            deprecated: None,
271            type_ignored: None,
272            is_overload: false,
273        };
274        self.methods
275            .entry("__str__".to_string())
276            .or_default()
277            .push(method);
278    }
279
280    /// Resolve all ModuleRef::Default to actual module name.
281    /// Called after construction, before formatting.
282    pub fn resolve_default_modules(&mut self, default_module_name: &str) {
283        // Resolve in getter/setter types
284        for (getter, setter) in self.getter_setters.values_mut() {
285            if let Some(getter) = getter {
286                getter.r#type.resolve_default_module(default_module_name);
287            }
288            if let Some(setter) = setter {
289                setter.r#type.resolve_default_module(default_module_name);
290            }
291        }
292
293        // Resolve in method parameter and return types
294        for methods in self.methods.values_mut() {
295            for method in methods {
296                // Resolve all parameter types
297                for param in &mut method.parameters.positional_only {
298                    param.type_info.resolve_default_module(default_module_name);
299                }
300                for param in &mut method.parameters.positional_or_keyword {
301                    param.type_info.resolve_default_module(default_module_name);
302                }
303                for param in &mut method.parameters.keyword_only {
304                    param.type_info.resolve_default_module(default_module_name);
305                }
306                if let Some(varargs) = &mut method.parameters.varargs {
307                    varargs
308                        .type_info
309                        .resolve_default_module(default_module_name);
310                }
311                if let Some(varkw) = &mut method.parameters.varkw {
312                    varkw.type_info.resolve_default_module(default_module_name);
313                }
314                method.r#return.resolve_default_module(default_module_name);
315            }
316        }
317
318        // Resolve in base classes
319        for base in &mut self.bases {
320            base.resolve_default_module(default_module_name);
321        }
322
323        // Resolve in class attributes
324        for attr in &mut self.attrs {
325            attr.r#type.resolve_default_module(default_module_name);
326        }
327
328        // Recursively resolve in nested classes
329        for class in &mut self.classes {
330            class.resolve_default_modules(default_module_name);
331        }
332    }
333}
334
335impl fmt::Display for ClassDef {
336    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
337        let bases = self
338            .bases
339            .iter()
340            .map(|i| i.name.clone())
341            .reduce(|acc, path| format!("{acc}, {path}"))
342            .map(|bases| format!("({bases})"))
343            .unwrap_or_default();
344        if !self.subclass {
345            writeln!(f, "@typing.final")?;
346        }
347        writeln!(f, "class {}{}:", self.name, bases)?;
348        let indent = indent();
349        let doc = self.doc.trim();
350        docstring::write_docstring(f, doc, indent)?;
351
352        if let Some(match_args) = &self.match_args {
353            if match_args.is_empty() {
354                writeln!(f, "{indent}__match_args__ = ()")?;
355            } else {
356                let match_args_txt = match_args
357                    .iter()
358                    .map(|a| format!(r##""{a}""##))
359                    .collect::<Vec<_>>()
360                    .join(", ");
361                writeln!(f, "{indent}__match_args__ = ({match_args_txt},)")?;
362            }
363        }
364        for attr in &self.attrs {
365            attr.fmt(f)?;
366        }
367        for (getter, setter) in self.getter_setters.values() {
368            if let Some(getter) = getter {
369                write!(
370                    f,
371                    "{}",
372                    GetterDisplay {
373                        member: getter,
374                        target_module: self.module.unwrap_or(self.name)
375                    }
376                )?;
377            }
378            if let Some(setter) = setter {
379                write!(
380                    f,
381                    "{}",
382                    SetterDisplay {
383                        member: setter,
384                        target_module: self.module.unwrap_or(self.name)
385                    }
386                )?;
387            }
388        }
389        for (_method_name, methods) in &self.methods {
390            // Check if we should add @overload to all methods
391            let has_overload = methods.iter().any(|m| m.is_overload);
392            let should_add_overload = methods.len() > 1 && has_overload;
393
394            for method in methods {
395                if should_add_overload {
396                    writeln!(f, "{indent}@typing.overload")?;
397                }
398                method.fmt(f)?;
399            }
400        }
401        for class in &self.classes {
402            let emit = format!("{class}");
403            for line in emit.lines() {
404                writeln!(f, "{indent}{line}")?;
405            }
406        }
407        if self.attrs.is_empty() && self.getter_setters.is_empty() && self.methods.is_empty() {
408            writeln!(f, "{indent}...")?;
409        }
410        writeln!(f)?;
411        Ok(())
412    }
413}
414
415impl ClassDef {
416    /// Format class with module-qualified type names
417    ///
418    /// This method uses the target module context to qualify type identifiers
419    /// within compound type expressions based on their source modules.
420    pub fn fmt_for_module(&self, target_module: &str, f: &mut fmt::Formatter) -> fmt::Result {
421        // Qualify base classes
422        let bases = self
423            .bases
424            .iter()
425            .map(|i| i.qualified_for_module(target_module))
426            .reduce(|acc, path| format!("{acc}, {path}"))
427            .map(|bases| format!("({bases})"))
428            .unwrap_or_default();
429
430        if !self.subclass {
431            writeln!(f, "@typing.final")?;
432        }
433        writeln!(f, "class {}{}:", self.name, bases)?;
434
435        let indent = indent();
436        let doc = self.doc.trim();
437        docstring::write_docstring(f, doc, indent)?;
438
439        if let Some(match_args) = &self.match_args {
440            if match_args.is_empty() {
441                writeln!(f, "{indent}__match_args__ = ()")?;
442            } else {
443                let match_args_txt = match_args
444                    .iter()
445                    .map(|a| format!(r##""{a}""##))
446                    .collect::<Vec<_>>()
447                    .join(", ");
448                writeln!(f, "{indent}__match_args__ = ({match_args_txt},)")?;
449            }
450        }
451
452        // Format attributes with qualified types
453        for attr in &self.attrs {
454            attr.fmt_for_module(target_module, f, indent)?;
455        }
456
457        // Format properties with qualified types
458        for (getter, setter) in self.getter_setters.values() {
459            if let Some(getter) = getter {
460                write!(
461                    f,
462                    "{}",
463                    GetterDisplay {
464                        member: getter,
465                        target_module
466                    }
467                )?;
468            }
469            if let Some(setter) = setter {
470                write!(
471                    f,
472                    "{}",
473                    SetterDisplay {
474                        member: setter,
475                        target_module
476                    }
477                )?;
478            }
479        }
480
481        // Format methods with qualified types
482        for (_method_name, methods) in &self.methods {
483            let has_overload = methods.iter().any(|m| m.is_overload);
484            let should_add_overload = methods.len() > 1 && has_overload;
485
486            for method in methods {
487                if should_add_overload {
488                    writeln!(f, "{indent}@typing.overload")?;
489                }
490                method.fmt_for_module(target_module, f, indent)?;
491            }
492        }
493
494        // Format nested classes recursively
495        for class in &self.classes {
496            // Create a temporary formatter to capture nested class output
497            struct FmtAdapter<'a, 'b> {
498                class: &'a ClassDef,
499                target_module: &'b str,
500            }
501            impl<'a, 'b> fmt::Display for FmtAdapter<'a, 'b> {
502                fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
503                    self.class.fmt_for_module(self.target_module, f)
504                }
505            }
506            let emit = format!(
507                "{}",
508                FmtAdapter {
509                    class,
510                    target_module
511                }
512            );
513            for line in emit.lines() {
514                writeln!(f, "{indent}{line}")?;
515            }
516        }
517
518        if self.attrs.is_empty() && self.getter_setters.is_empty() && self.methods.is_empty() {
519            writeln!(f, "{indent}...")?;
520        }
521        writeln!(f)?;
522        Ok(())
523    }
524}