1use indexmap::IndexMap;
2
3use crate::generate::variant_methods::get_variant_methods;
4use crate::{generate::*, type_info::*, TypeInfo};
5use std::{fmt, vec};
6
7#[derive(Debug, Clone, PartialEq)]
9pub struct ClassDef {
10 pub name: &'static str,
11 pub doc: &'static str,
12 pub attrs: Vec<MemberDef>,
13 pub getters: Vec<MemberDef>,
14 pub setters: Vec<MemberDef>,
15 pub methods: IndexMap<String, Vec<MethodDef>>,
16 pub bases: Vec<TypeInfo>,
17 pub classes: Vec<ClassDef>,
18 pub match_args: Option<Vec<String>>,
19}
20
21impl Import for ClassDef {
22 fn import(&self) -> HashSet<ModuleRef> {
23 let mut import = HashSet::new();
24 for base in &self.bases {
25 import.extend(base.import.clone());
26 }
27 for attr in &self.attrs {
28 import.extend(attr.import());
29 }
30 for getter in &self.getters {
31 import.extend(getter.import());
32 }
33 for setter in &self.setters {
34 import.extend(setter.import());
35 }
36 for method in self.methods.values() {
37 if method.len() > 1 {
38 import.insert("typing".into());
40 }
41 for method in method {
42 import.extend(method.import());
43 }
44 }
45 for class in &self.classes {
46 import.extend(class.import());
47 }
48 import
49 }
50}
51
52impl From<&PyComplexEnumInfo> for ClassDef {
53 fn from(info: &PyComplexEnumInfo) -> Self {
54 let enum_info = Self {
58 name: info.pyclass_name,
59 doc: info.doc,
60 getters: Vec::new(),
61 setters: Vec::new(),
62 methods: IndexMap::new(),
63 classes: info
64 .variants
65 .iter()
66 .map(|v| ClassDef::from_variant(info, v))
67 .collect(),
68 bases: Vec::new(),
69 match_args: None,
70 attrs: Vec::new(),
71 };
72
73 enum_info
74 }
75}
76
77impl ClassDef {
78 fn from_variant(enum_info: &PyComplexEnumInfo, info: &VariantInfo) -> Self {
79 let methods = get_variant_methods(enum_info, info);
80
81 Self {
82 name: info.pyclass_name,
83 doc: info.doc,
84 getters: info.fields.iter().map(MemberDef::from).collect(),
85 setters: Vec::new(),
86 methods,
87 classes: Vec::new(),
88 bases: vec![TypeInfo::unqualified(enum_info.pyclass_name)],
89 match_args: Some(info.fields.iter().map(|f| f.name.to_string()).collect()),
90 attrs: Vec::new(),
91 }
92 }
93}
94
95impl From<&PyClassInfo> for ClassDef {
96 fn from(info: &PyClassInfo) -> Self {
97 let mut new = Self {
100 name: info.pyclass_name,
101 doc: info.doc,
102 attrs: Vec::new(),
103 setters: info.setters.iter().map(MemberDef::from).collect(),
104 getters: info.getters.iter().map(MemberDef::from).collect(),
105 methods: Default::default(),
106 classes: Vec::new(),
107 bases: info.bases.iter().map(|f| f()).collect(),
108 match_args: None,
109 };
110 if info.has_eq {
111 new.add_eq_method();
112 }
113 if info.has_ord {
114 new.add_ord_methods();
115 }
116 if info.has_hash {
117 new.add_hash_method();
118 }
119 if info.has_str {
120 new.add_str_method();
121 }
122 new
123 }
124}
125impl ClassDef {
126 fn add_eq_method(&mut self) {
127 let method = MethodDef {
128 name: "__eq__",
129 args: vec![Arg {
130 name: "other",
131 r#type: TypeInfo::builtin("object"),
132 signature: None,
133 }],
134 r#return: TypeInfo::builtin("bool"),
135 doc: "",
136 r#type: MethodType::Instance,
137 is_async: false,
138 deprecated: None,
139 type_ignored: None,
140 };
141 self.methods
142 .entry("__eq__".to_string())
143 .or_default()
144 .push(method);
145 }
146
147 fn add_ord_methods(&mut self) {
148 let ord_methods = ["__lt__", "__le__", "__gt__", "__ge__"];
149
150 for name in &ord_methods {
151 let method = MethodDef {
152 name,
153 args: vec![Arg {
154 name: "other",
155 r#type: TypeInfo::builtin("object"),
156 signature: None,
157 }],
158 r#return: TypeInfo::builtin("bool"),
159 doc: "",
160 r#type: MethodType::Instance,
161 is_async: false,
162 deprecated: None,
163 type_ignored: None,
164 };
165 self.methods
166 .entry(name.to_string())
167 .or_default()
168 .push(method);
169 }
170 }
171
172 fn add_hash_method(&mut self) {
173 let method = MethodDef {
174 name: "__hash__",
175 args: vec![],
176 r#return: TypeInfo::builtin("int"),
177 doc: "",
178 r#type: MethodType::Instance,
179 is_async: false,
180 deprecated: None,
181 type_ignored: None,
182 };
183 self.methods
184 .entry("__hash__".to_string())
185 .or_default()
186 .push(method);
187 }
188
189 fn add_str_method(&mut self) {
190 let method = MethodDef {
191 name: "__str__",
192 args: vec![],
193 r#return: TypeInfo::builtin("str"),
194 doc: "",
195 r#type: MethodType::Instance,
196 is_async: false,
197 deprecated: None,
198 type_ignored: None,
199 };
200 self.methods
201 .entry("__str__".to_string())
202 .or_default()
203 .push(method);
204 }
205}
206
207impl fmt::Display for ClassDef {
208 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
209 let bases = self
210 .bases
211 .iter()
212 .map(|i| i.name.clone())
213 .reduce(|acc, path| format!("{acc}, {path}"))
214 .map(|bases| format!("({bases})"))
215 .unwrap_or_default();
216 writeln!(f, "class {}{}:", self.name, bases)?;
217 let indent = indent();
218 let doc = self.doc.trim();
219 docstring::write_docstring(f, doc, indent)?;
220
221 if let Some(match_args) = &self.match_args {
222 let match_args_txt = if match_args.is_empty() {
223 "()".to_string()
224 } else {
225 match_args
226 .iter()
227 .map(|a| format!(r##""{a}""##))
228 .collect::<Vec<_>>()
229 .join(", ")
230 };
231
232 writeln!(f, "{indent}__match_args__ = ({match_args_txt},)")?;
233 }
234 for attr in &self.attrs {
235 attr.fmt(f)?;
236 }
237 for getter in &self.getters {
238 GetterDisplay(getter).fmt(f)?;
239 }
240 for setter in &self.setters {
241 SetterDisplay(setter).fmt(f)?;
242 }
243 for methods in self.methods.values() {
244 let overloaded = methods.len() > 1;
245 for method in methods {
246 if overloaded {
247 writeln!(f, "{indent}@typing.overload")?;
248 }
249 method.fmt(f)?;
250 }
251 }
252 for class in &self.classes {
253 let emit = format!("{class}");
254 for line in emit.lines() {
255 writeln!(f, "{indent}{line}")?;
256 }
257 }
258 if self.attrs.is_empty()
259 && self.getters.is_empty()
260 && self.setters.is_empty()
261 && self.methods.is_empty()
262 {
263 writeln!(f, "{indent}...")?;
264 }
265 writeln!(f)?;
266 Ok(())
267 }
268}