1use indexmap::IndexMap;
2
3use crate::generate::variant_methods::get_variant_methods;
4use crate::{
5 generate::{
6 docstring, indent, GetterDisplay, Import, MemberDef, MethodDef, Parameter,
7 ParameterDefault, Parameters, SetterDisplay,
8 },
9 stub_type::ImportRef,
10 type_info::*,
11 TypeInfo,
12};
13use std::collections::HashSet;
14use std::{fmt, vec};
15
16#[derive(Debug, Clone, PartialEq)]
18pub struct ClassDef {
19 pub name: &'static str,
20 pub module: Option<&'static str>,
21 pub doc: &'static str,
22 pub attrs: Vec<MemberDef>,
23 pub getter_setters: IndexMap<String, (Option<MemberDef>, Option<MemberDef>)>,
24 pub methods: IndexMap<String, Vec<MethodDef>>,
25 pub bases: Vec<TypeInfo>,
26 pub classes: Vec<ClassDef>,
27 pub match_args: Option<Vec<String>>,
28 pub subclass: bool,
29}
30
31impl Import for ClassDef {
32 fn import(&self) -> HashSet<ImportRef> {
33 let mut import = HashSet::new();
34 if !self.subclass {
35 import.insert("typing".into());
37 }
38 for base in &self.bases {
39 import.extend(base.import.clone());
40 }
41 for attr in &self.attrs {
42 import.extend(attr.import());
43 }
44 for (getter, setter) in self.getter_setters.values() {
45 if let Some(getter) = getter {
46 import.extend(getter.import());
47 }
48 if let Some(setter) = setter {
49 import.extend(setter.import());
50 }
51 }
52 for method in self.methods.values() {
53 if method.len() > 1 {
54 import.insert("typing".into());
56 }
57 for method in method {
58 import.extend(method.import());
59 }
60 }
61 for class in &self.classes {
62 import.extend(class.import());
63 }
64 import
65 }
66}
67
68impl From<&PyComplexEnumInfo> for ClassDef {
69 fn from(info: &PyComplexEnumInfo) -> Self {
70 let enum_info = Self {
74 name: info.pyclass_name,
75 module: info.module,
76 doc: info.doc,
77 getter_setters: IndexMap::new(),
78 methods: IndexMap::new(),
79 classes: info
80 .variants
81 .iter()
82 .map(|v| ClassDef::from_variant(info, v))
83 .collect(),
84 bases: Vec::new(),
85 match_args: None,
86 attrs: Vec::new(),
87 subclass: true, };
89
90 enum_info
91 }
92}
93
94impl ClassDef {
95 fn from_variant(enum_info: &PyComplexEnumInfo, info: &VariantInfo) -> Self {
96 let methods = get_variant_methods(enum_info, info);
97
98 Self {
99 name: info.pyclass_name,
100 module: enum_info.module,
101 doc: info.doc,
102 getter_setters: info
103 .fields
104 .iter()
105 .map(|info| (info.name.to_string(), (Some(MemberDef::from(info)), None)))
106 .collect(),
107 methods,
108 classes: Vec::new(),
109 bases: vec![TypeInfo::unqualified(enum_info.pyclass_name)],
110 match_args: Some(info.fields.iter().map(|f| f.name.to_string()).collect()),
111 attrs: Vec::new(),
112 subclass: false,
113 }
114 }
115}
116
117impl From<&PyClassInfo> for ClassDef {
118 fn from(info: &PyClassInfo) -> Self {
119 let mut getter_setters: IndexMap<String, (Option<MemberDef>, Option<MemberDef>)> = info
122 .getters
123 .iter()
124 .map(|info| (info.name.to_string(), (Some(MemberDef::from(info)), None)))
125 .collect();
126 for setter in info.setters {
127 getter_setters.entry(setter.name.to_string()).or_default().1 = Some(MemberDef {
128 name: setter.name,
129 r#type: (setter.r#type)(),
130 doc: setter.doc,
131 default: setter.default.map(|f| f()),
132 deprecated: setter.deprecated.clone(),
133 });
134 }
135 let mut new = Self {
136 name: info.pyclass_name,
137 module: info.module,
138 doc: info.doc,
139 attrs: Vec::new(),
140 getter_setters,
141 methods: Default::default(),
142 classes: Vec::new(),
143 bases: info.bases.iter().map(|f| f()).collect(),
144 match_args: None,
145 subclass: info.subclass,
146 };
147 if info.has_eq {
148 new.add_eq_method();
149 }
150 if info.has_ord {
151 new.add_ord_methods();
152 }
153 if info.has_hash {
154 new.add_hash_method();
155 }
156 if info.has_str {
157 new.add_str_method();
158 }
159 new
160 }
161}
162impl ClassDef {
163 fn add_eq_method(&mut self) {
164 let method = MethodDef {
165 name: "__eq__",
166 parameters: Parameters {
167 positional_or_keyword: vec![Parameter {
168 name: "other",
169 kind: ParameterKind::PositionalOrKeyword,
170 type_info: TypeInfo::builtin("object"),
171 default: ParameterDefault::None,
172 }],
173 ..Parameters::new()
174 },
175 r#return: TypeInfo::builtin("bool"),
176 doc: "",
177 r#type: MethodType::Instance,
178 is_async: false,
179 deprecated: None,
180 type_ignored: None,
181 is_overload: false,
182 };
183 self.methods
184 .entry("__eq__".to_string())
185 .or_default()
186 .push(method);
187 }
188
189 fn add_ord_methods(&mut self) {
190 let ord_methods = ["__lt__", "__le__", "__gt__", "__ge__"];
191
192 for name in &ord_methods {
193 let method = MethodDef {
194 name,
195 parameters: Parameters {
196 positional_or_keyword: vec![Parameter {
197 name: "other",
198 kind: ParameterKind::PositionalOrKeyword,
199 type_info: TypeInfo::builtin("object"),
200 default: ParameterDefault::None,
201 }],
202 ..Parameters::new()
203 },
204 r#return: TypeInfo::builtin("bool"),
205 doc: "",
206 r#type: MethodType::Instance,
207 is_async: false,
208 deprecated: None,
209 type_ignored: None,
210 is_overload: false,
211 };
212 self.methods
213 .entry(name.to_string())
214 .or_default()
215 .push(method);
216 }
217 }
218
219 fn add_hash_method(&mut self) {
220 let method = MethodDef {
221 name: "__hash__",
222 parameters: Parameters::new(),
223 r#return: TypeInfo::builtin("int"),
224 doc: "",
225 r#type: MethodType::Instance,
226 is_async: false,
227 deprecated: None,
228 type_ignored: None,
229 is_overload: false,
230 };
231 self.methods
232 .entry("__hash__".to_string())
233 .or_default()
234 .push(method);
235 }
236
237 fn add_str_method(&mut self) {
238 let method = MethodDef {
239 name: "__str__",
240 parameters: Parameters::new(),
241 r#return: TypeInfo::builtin("str"),
242 doc: "",
243 r#type: MethodType::Instance,
244 is_async: false,
245 deprecated: None,
246 type_ignored: None,
247 is_overload: false,
248 };
249 self.methods
250 .entry("__str__".to_string())
251 .or_default()
252 .push(method);
253 }
254
255 pub fn resolve_default_modules(&mut self, default_module_name: &str) {
258 for (getter, setter) in self.getter_setters.values_mut() {
260 if let Some(getter) = getter {
261 getter.r#type.resolve_default_module(default_module_name);
262 }
263 if let Some(setter) = setter {
264 setter.r#type.resolve_default_module(default_module_name);
265 }
266 }
267
268 for methods in self.methods.values_mut() {
270 for method in methods {
271 for param in &mut method.parameters.positional_only {
273 param.type_info.resolve_default_module(default_module_name);
274 }
275 for param in &mut method.parameters.positional_or_keyword {
276 param.type_info.resolve_default_module(default_module_name);
277 }
278 for param in &mut method.parameters.keyword_only {
279 param.type_info.resolve_default_module(default_module_name);
280 }
281 if let Some(varargs) = &mut method.parameters.varargs {
282 varargs
283 .type_info
284 .resolve_default_module(default_module_name);
285 }
286 if let Some(varkw) = &mut method.parameters.varkw {
287 varkw.type_info.resolve_default_module(default_module_name);
288 }
289 method.r#return.resolve_default_module(default_module_name);
290 }
291 }
292
293 for base in &mut self.bases {
295 base.resolve_default_module(default_module_name);
296 }
297
298 for attr in &mut self.attrs {
300 attr.r#type.resolve_default_module(default_module_name);
301 }
302
303 for class in &mut self.classes {
305 class.resolve_default_modules(default_module_name);
306 }
307 }
308}
309
310impl fmt::Display for ClassDef {
311 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
312 let bases = self
313 .bases
314 .iter()
315 .map(|i| i.name.clone())
316 .reduce(|acc, path| format!("{acc}, {path}"))
317 .map(|bases| format!("({bases})"))
318 .unwrap_or_default();
319 if !self.subclass {
320 writeln!(f, "@typing.final")?;
321 }
322 writeln!(f, "class {}{}:", self.name, bases)?;
323 let indent = indent();
324 let doc = self.doc.trim();
325 docstring::write_docstring(f, doc, indent)?;
326
327 if let Some(match_args) = &self.match_args {
328 if match_args.is_empty() {
329 writeln!(f, "{indent}__match_args__ = ()")?;
330 } else {
331 let match_args_txt = match_args
332 .iter()
333 .map(|a| format!(r##""{a}""##))
334 .collect::<Vec<_>>()
335 .join(", ");
336 writeln!(f, "{indent}__match_args__ = ({match_args_txt},)")?;
337 }
338 }
339 for attr in &self.attrs {
340 attr.fmt(f)?;
341 }
342 for (getter, setter) in self.getter_setters.values() {
343 if let Some(getter) = getter {
344 write!(
345 f,
346 "{}",
347 GetterDisplay {
348 member: getter,
349 target_module: self.module.unwrap_or(self.name)
350 }
351 )?;
352 }
353 if let Some(setter) = setter {
354 write!(
355 f,
356 "{}",
357 SetterDisplay {
358 member: setter,
359 target_module: self.module.unwrap_or(self.name)
360 }
361 )?;
362 }
363 }
364 for (_method_name, methods) in &self.methods {
365 let has_overload = methods.iter().any(|m| m.is_overload);
367 let should_add_overload = methods.len() > 1 && has_overload;
368
369 for method in methods {
370 if should_add_overload {
371 writeln!(f, "{indent}@typing.overload")?;
372 }
373 method.fmt(f)?;
374 }
375 }
376 for class in &self.classes {
377 let emit = format!("{class}");
378 for line in emit.lines() {
379 writeln!(f, "{indent}{line}")?;
380 }
381 }
382 if self.attrs.is_empty() && self.getter_setters.is_empty() && self.methods.is_empty() {
383 writeln!(f, "{indent}...")?;
384 }
385 writeln!(f)?;
386 Ok(())
387 }
388}
389
390impl ClassDef {
391 pub fn fmt_for_module(&self, target_module: &str, f: &mut fmt::Formatter) -> fmt::Result {
396 let bases = self
398 .bases
399 .iter()
400 .map(|i| i.qualified_for_module(target_module))
401 .reduce(|acc, path| format!("{acc}, {path}"))
402 .map(|bases| format!("({bases})"))
403 .unwrap_or_default();
404
405 if !self.subclass {
406 writeln!(f, "@typing.final")?;
407 }
408 writeln!(f, "class {}{}:", self.name, bases)?;
409
410 let indent = indent();
411 let doc = self.doc.trim();
412 docstring::write_docstring(f, doc, indent)?;
413
414 if let Some(match_args) = &self.match_args {
415 if match_args.is_empty() {
416 writeln!(f, "{indent}__match_args__ = ()")?;
417 } else {
418 let match_args_txt = match_args
419 .iter()
420 .map(|a| format!(r##""{a}""##))
421 .collect::<Vec<_>>()
422 .join(", ");
423 writeln!(f, "{indent}__match_args__ = ({match_args_txt},)")?;
424 }
425 }
426
427 for attr in &self.attrs {
429 attr.fmt_for_module(target_module, f, indent)?;
430 }
431
432 for (getter, setter) in self.getter_setters.values() {
434 if let Some(getter) = getter {
435 write!(
436 f,
437 "{}",
438 GetterDisplay {
439 member: getter,
440 target_module
441 }
442 )?;
443 }
444 if let Some(setter) = setter {
445 write!(
446 f,
447 "{}",
448 SetterDisplay {
449 member: setter,
450 target_module
451 }
452 )?;
453 }
454 }
455
456 for (_method_name, methods) in &self.methods {
458 let has_overload = methods.iter().any(|m| m.is_overload);
459 let should_add_overload = methods.len() > 1 && has_overload;
460
461 for method in methods {
462 if should_add_overload {
463 writeln!(f, "{indent}@typing.overload")?;
464 }
465 method.fmt_for_module(target_module, f, indent)?;
466 }
467 }
468
469 for class in &self.classes {
471 struct FmtAdapter<'a, 'b> {
473 class: &'a ClassDef,
474 target_module: &'b str,
475 }
476 impl<'a, 'b> fmt::Display for FmtAdapter<'a, 'b> {
477 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
478 self.class.fmt_for_module(self.target_module, f)
479 }
480 }
481 let emit = format!(
482 "{}",
483 FmtAdapter {
484 class,
485 target_module
486 }
487 );
488 for line in emit.lines() {
489 writeln!(f, "{indent}{line}")?;
490 }
491 }
492
493 if self.attrs.is_empty() && self.getter_setters.is_empty() && self.methods.is_empty() {
494 writeln!(f, "{indent}...")?;
495 }
496 writeln!(f)?;
497 Ok(())
498 }
499}