pyo3_stub_gen/generate/
class.rs

1use crate::{generate::*, type_info::*, TypeInfo};
2use std::fmt;
3
4/// Definition of a Python class.
5#[derive(Debug, Clone, PartialEq)]
6pub struct ClassDef {
7    pub name: &'static str,
8    pub doc: &'static str,
9    pub attrs: Vec<MemberDef>,
10    pub getters: Vec<MemberDef>,
11    pub setters: Vec<MemberDef>,
12    pub methods: Vec<MethodDef>,
13    pub bases: Vec<TypeInfo>,
14}
15
16impl Import for ClassDef {
17    fn import(&self) -> HashSet<ModuleRef> {
18        let mut import = HashSet::new();
19        for base in &self.bases {
20            import.extend(base.import.clone());
21        }
22        for attr in &self.attrs {
23            import.extend(attr.import());
24        }
25        for getter in &self.getters {
26            import.extend(getter.import());
27        }
28        for setter in &self.setters {
29            import.extend(setter.import());
30        }
31        for method in &self.methods {
32            import.extend(method.import());
33        }
34        import
35    }
36}
37
38impl From<&PyClassInfo> for ClassDef {
39    fn from(info: &PyClassInfo) -> Self {
40        // Since there are multiple `#[pymethods]` for a single class, we need to merge them.
41        // This is only an initializer. See `StubInfo::gather` for the actual merging.
42        Self {
43            name: info.pyclass_name,
44            doc: info.doc,
45            attrs: Vec::new(),
46            setters: info.setters.iter().map(MemberDef::from).collect(),
47            getters: info.getters.iter().map(MemberDef::from).collect(),
48            methods: Vec::new(),
49            bases: info.bases.iter().map(|f| f()).collect(),
50        }
51    }
52}
53
54impl fmt::Display for ClassDef {
55    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
56        let bases = self
57            .bases
58            .iter()
59            .map(|i| i.name.clone())
60            .reduce(|acc, path| format!("{acc}, {path}"))
61            .map(|bases| format!("({bases})"))
62            .unwrap_or_default();
63        writeln!(f, "class {}{}:", self.name, bases)?;
64        let indent = indent();
65        let doc = self.doc.trim();
66        docstring::write_docstring(f, doc, indent)?;
67        for attr in &self.attrs {
68            attr.fmt(f)?;
69        }
70        for getter in &self.getters {
71            GetterDisplay(getter).fmt(f)?;
72        }
73        for setter in &self.setters {
74            SetterDisplay(setter).fmt(f)?;
75        }
76        for method in &self.methods {
77            method.fmt(f)?;
78        }
79        if self.attrs.is_empty()
80            && self.getters.is_empty()
81            && self.setters.is_empty()
82            && self.methods.is_empty()
83        {
84            writeln!(f, "{indent}...")?;
85        }
86        writeln!(f)?;
87        Ok(())
88    }
89}