1use indexmap::IndexMap;
2
3use crate::generate::variant_methods::get_variant_methods;
4use crate::{generate::*, type_info::*, TypeInfo};
5use std::{fmt, vec};
6
7#[derive(Debug, Clone, PartialEq)]
9pub struct ClassDef {
10 pub name: &'static str,
11 pub doc: &'static str,
12 pub attrs: Vec<MemberDef>,
13 pub getter_setters: IndexMap<String, (Option<MemberDef>, Option<MemberDef>)>,
14 pub methods: IndexMap<String, Vec<MethodDef>>,
15 pub bases: Vec<TypeInfo>,
16 pub classes: Vec<ClassDef>,
17 pub match_args: Option<Vec<String>>,
18 pub subclass: bool,
19}
20
21impl Import for ClassDef {
22 fn import(&self) -> HashSet<ImportRef> {
23 let mut import = HashSet::new();
24 if !self.subclass {
25 import.insert("typing".into());
27 }
28 for base in &self.bases {
29 import.extend(base.import.clone());
30 }
31 for attr in &self.attrs {
32 import.extend(attr.import());
33 }
34 for (getter, setter) in self.getter_setters.values() {
35 if let Some(getter) = getter {
36 import.extend(getter.import());
37 }
38 if let Some(setter) = setter {
39 import.extend(setter.import());
40 }
41 }
42 for method in self.methods.values() {
43 if method.len() > 1 {
44 import.insert("typing".into());
46 }
47 for method in method {
48 import.extend(method.import());
49 }
50 }
51 for class in &self.classes {
52 import.extend(class.import());
53 }
54 import
55 }
56}
57
58impl From<&PyComplexEnumInfo> for ClassDef {
59 fn from(info: &PyComplexEnumInfo) -> Self {
60 let enum_info = Self {
64 name: info.pyclass_name,
65 doc: info.doc,
66 getter_setters: IndexMap::new(),
67 methods: IndexMap::new(),
68 classes: info
69 .variants
70 .iter()
71 .map(|v| ClassDef::from_variant(info, v))
72 .collect(),
73 bases: Vec::new(),
74 match_args: None,
75 attrs: Vec::new(),
76 subclass: true, };
78
79 enum_info
80 }
81}
82
83impl ClassDef {
84 fn from_variant(enum_info: &PyComplexEnumInfo, info: &VariantInfo) -> Self {
85 let methods = get_variant_methods(enum_info, info);
86
87 Self {
88 name: info.pyclass_name,
89 doc: info.doc,
90 getter_setters: info
91 .fields
92 .iter()
93 .map(|info| (info.name.to_string(), (Some(MemberDef::from(info)), None)))
94 .collect(),
95 methods,
96 classes: Vec::new(),
97 bases: vec![TypeInfo::unqualified(enum_info.pyclass_name)],
98 match_args: Some(info.fields.iter().map(|f| f.name.to_string()).collect()),
99 attrs: Vec::new(),
100 subclass: false,
101 }
102 }
103}
104
105impl From<&PyClassInfo> for ClassDef {
106 fn from(info: &PyClassInfo) -> Self {
107 let mut getter_setters: IndexMap<String, (Option<MemberDef>, Option<MemberDef>)> = info
110 .getters
111 .iter()
112 .map(|info| (info.name.to_string(), (Some(MemberDef::from(info)), None)))
113 .collect();
114 for setter in info.setters {
115 getter_setters.entry(setter.name.to_string()).or_default().1 = Some(MemberDef {
116 name: setter.name,
117 r#type: (setter.r#type)(),
118 doc: setter.doc,
119 default: setter.default.map(|f| f()),
120 deprecated: setter.deprecated.clone(),
121 });
122 }
123 let mut new = Self {
124 name: info.pyclass_name,
125 doc: info.doc,
126 attrs: Vec::new(),
127 getter_setters,
128 methods: Default::default(),
129 classes: Vec::new(),
130 bases: info.bases.iter().map(|f| f()).collect(),
131 match_args: None,
132 subclass: info.subclass,
133 };
134 if info.has_eq {
135 new.add_eq_method();
136 }
137 if info.has_ord {
138 new.add_ord_methods();
139 }
140 if info.has_hash {
141 new.add_hash_method();
142 }
143 if info.has_str {
144 new.add_str_method();
145 }
146 new
147 }
148}
149impl ClassDef {
150 fn add_eq_method(&mut self) {
151 let method = MethodDef {
152 name: "__eq__",
153 args: vec![Arg {
154 name: "other",
155 r#type: TypeInfo::builtin("object"),
156 signature: None,
157 }],
158 r#return: TypeInfo::builtin("bool"),
159 doc: "",
160 r#type: MethodType::Instance,
161 is_async: false,
162 deprecated: None,
163 type_ignored: None,
164 };
165 self.methods
166 .entry("__eq__".to_string())
167 .or_default()
168 .push(method);
169 }
170
171 fn add_ord_methods(&mut self) {
172 let ord_methods = ["__lt__", "__le__", "__gt__", "__ge__"];
173
174 for name in &ord_methods {
175 let method = MethodDef {
176 name,
177 args: vec![Arg {
178 name: "other",
179 r#type: TypeInfo::builtin("object"),
180 signature: None,
181 }],
182 r#return: TypeInfo::builtin("bool"),
183 doc: "",
184 r#type: MethodType::Instance,
185 is_async: false,
186 deprecated: None,
187 type_ignored: None,
188 };
189 self.methods
190 .entry(name.to_string())
191 .or_default()
192 .push(method);
193 }
194 }
195
196 fn add_hash_method(&mut self) {
197 let method = MethodDef {
198 name: "__hash__",
199 args: vec![],
200 r#return: TypeInfo::builtin("int"),
201 doc: "",
202 r#type: MethodType::Instance,
203 is_async: false,
204 deprecated: None,
205 type_ignored: None,
206 };
207 self.methods
208 .entry("__hash__".to_string())
209 .or_default()
210 .push(method);
211 }
212
213 fn add_str_method(&mut self) {
214 let method = MethodDef {
215 name: "__str__",
216 args: vec![],
217 r#return: TypeInfo::builtin("str"),
218 doc: "",
219 r#type: MethodType::Instance,
220 is_async: false,
221 deprecated: None,
222 type_ignored: None,
223 };
224 self.methods
225 .entry("__str__".to_string())
226 .or_default()
227 .push(method);
228 }
229}
230
231impl fmt::Display for ClassDef {
232 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
233 let bases = self
234 .bases
235 .iter()
236 .map(|i| i.name.clone())
237 .reduce(|acc, path| format!("{acc}, {path}"))
238 .map(|bases| format!("({bases})"))
239 .unwrap_or_default();
240 if !self.subclass {
241 writeln!(f, "@typing.final")?;
242 }
243 writeln!(f, "class {}{}:", self.name, bases)?;
244 let indent = indent();
245 let doc = self.doc.trim();
246 docstring::write_docstring(f, doc, indent)?;
247
248 if let Some(match_args) = &self.match_args {
249 if match_args.is_empty() {
250 writeln!(f, "{indent}__match_args__ = ()")?;
251 } else {
252 let match_args_txt = match_args
253 .iter()
254 .map(|a| format!(r##""{a}""##))
255 .collect::<Vec<_>>()
256 .join(", ");
257 writeln!(f, "{indent}__match_args__ = ({match_args_txt},)")?;
258 }
259 }
260 for attr in &self.attrs {
261 attr.fmt(f)?;
262 }
263 for (getter, setter) in self.getter_setters.values() {
264 if let Some(getter) = getter {
265 GetterDisplay(getter).fmt(f)?;
266 }
267 if let Some(setter) = setter {
268 SetterDisplay(setter).fmt(f)?;
269 }
270 }
271 for methods in self.methods.values() {
272 let overloaded = methods.len() > 1;
273 for method in methods {
274 if overloaded {
275 writeln!(f, "{indent}@typing.overload")?;
276 }
277 method.fmt(f)?;
278 }
279 }
280 for class in &self.classes {
281 let emit = format!("{class}");
282 for line in emit.lines() {
283 writeln!(f, "{indent}{line}")?;
284 }
285 }
286 if self.attrs.is_empty() && self.getter_setters.is_empty() && self.methods.is_empty() {
287 writeln!(f, "{indent}...")?;
288 }
289 writeln!(f)?;
290 Ok(())
291 }
292}