1use indexmap::IndexMap;
2
3use crate::generate::docstring::normalize_docstring;
4use crate::generate::variant_methods::get_variant_methods;
5use crate::{
6 generate::{
7 docstring, indent, GetterDisplay, Import, MemberDef, MethodDef, Parameter,
8 ParameterDefault, Parameters, SetterDisplay,
9 },
10 stub_type::ImportRef,
11 type_info::*,
12 TypeInfo,
13};
14use std::collections::HashSet;
15use std::{fmt, vec};
16
17#[derive(Debug, Clone, PartialEq)]
19pub struct ClassDef {
20 pub name: &'static str,
21 pub module: Option<&'static str>,
22 pub doc: &'static str,
23 pub attrs: Vec<MemberDef>,
24 pub getter_setters: IndexMap<String, (Option<MemberDef>, Option<MemberDef>)>,
25 pub methods: IndexMap<String, Vec<MethodDef>>,
26 pub bases: Vec<TypeInfo>,
27 pub classes: Vec<ClassDef>,
28 pub match_args: Option<Vec<String>>,
29 pub subclass: bool,
30}
31
32impl Import for ClassDef {
33 fn import(&self) -> HashSet<ImportRef> {
34 let mut import = HashSet::new();
35 if !self.subclass {
36 import.insert("typing".into());
38 }
39 for base in &self.bases {
40 import.extend(base.import.clone());
41 }
42 for attr in &self.attrs {
43 import.extend(attr.import());
44 }
45 for (getter, setter) in self.getter_setters.values() {
46 if let Some(getter) = getter {
47 import.extend(getter.import());
48 }
49 if let Some(setter) = setter {
50 import.extend(setter.import());
51 }
52 }
53 for method in self.methods.values() {
54 if method.len() > 1 {
55 import.insert("typing".into());
57 }
58 for method in method {
59 import.extend(method.import());
60 }
61 }
62 for class in &self.classes {
63 import.extend(class.import());
64 }
65 import
66 }
67}
68
69impl From<&PyComplexEnumInfo> for ClassDef {
70 fn from(info: &PyComplexEnumInfo) -> Self {
71 let doc = if info.doc.is_empty() {
75 ""
76 } else {
77 Box::leak(normalize_docstring(info.doc).into_boxed_str())
78 };
79
80 let enum_info = Self {
81 name: info.pyclass_name,
82 module: info.module,
83 doc,
84 getter_setters: IndexMap::new(),
85 methods: IndexMap::new(),
86 classes: info
87 .variants
88 .iter()
89 .map(|v| ClassDef::from_variant(info, v))
90 .collect(),
91 bases: Vec::new(),
92 match_args: None,
93 attrs: Vec::new(),
94 subclass: true, };
96
97 enum_info
98 }
99}
100
101impl ClassDef {
102 fn from_variant(enum_info: &PyComplexEnumInfo, info: &VariantInfo) -> Self {
103 let methods = get_variant_methods(enum_info, info);
104
105 let doc = if info.doc.is_empty() {
106 ""
107 } else {
108 Box::leak(normalize_docstring(info.doc).into_boxed_str())
109 };
110
111 Self {
112 name: info.pyclass_name,
113 module: enum_info.module,
114 doc,
115 getter_setters: info
116 .fields
117 .iter()
118 .map(|info| (info.name.to_string(), (Some(MemberDef::from(info)), None)))
119 .collect(),
120 methods,
121 classes: Vec::new(),
122 bases: vec![TypeInfo::unqualified(enum_info.pyclass_name)],
123 match_args: Some(info.fields.iter().map(|f| f.name.to_string()).collect()),
124 attrs: Vec::new(),
125 subclass: false,
126 }
127 }
128}
129
130impl From<&PyClassInfo> for ClassDef {
131 fn from(info: &PyClassInfo) -> Self {
132 let doc = if info.doc.is_empty() {
135 ""
136 } else {
137 Box::leak(normalize_docstring(info.doc).into_boxed_str())
138 };
139
140 let mut getter_setters: IndexMap<String, (Option<MemberDef>, Option<MemberDef>)> = info
141 .getters
142 .iter()
143 .map(|info| (info.name.to_string(), (Some(MemberDef::from(info)), None)))
144 .collect();
145 for setter in info.setters {
146 let setter_doc = if setter.doc.is_empty() {
147 ""
148 } else {
149 Box::leak(normalize_docstring(setter.doc).into_boxed_str())
150 };
151
152 getter_setters.entry(setter.name.to_string()).or_default().1 = Some(MemberDef {
153 name: setter.name,
154 r#type: (setter.r#type)(),
155 doc: setter_doc,
156 default: setter.default.map(|f| f()),
157 deprecated: setter.deprecated.clone(),
158 });
159 }
160 let mut new = Self {
161 name: info.pyclass_name,
162 module: info.module,
163 doc,
164 attrs: Vec::new(),
165 getter_setters,
166 methods: Default::default(),
167 classes: Vec::new(),
168 bases: info.bases.iter().map(|f| f()).collect(),
169 match_args: None,
170 subclass: info.subclass,
171 };
172 if info.has_eq {
173 new.add_eq_method();
174 }
175 if info.has_ord {
176 new.add_ord_methods();
177 }
178 if info.has_hash {
179 new.add_hash_method();
180 }
181 if info.has_str {
182 new.add_str_method();
183 }
184 new
185 }
186}
187impl ClassDef {
188 fn add_eq_method(&mut self) {
189 let method = MethodDef {
190 name: "__eq__",
191 parameters: Parameters {
192 positional_or_keyword: vec![Parameter {
193 name: "other",
194 kind: ParameterKind::PositionalOrKeyword,
195 type_info: TypeInfo::builtin("object"),
196 default: ParameterDefault::None,
197 }],
198 ..Parameters::new()
199 },
200 r#return: TypeInfo::builtin("bool"),
201 doc: "",
202 r#type: MethodType::Instance,
203 is_async: false,
204 deprecated: None,
205 type_ignored: None,
206 is_overload: false,
207 };
208 self.methods
209 .entry("__eq__".to_string())
210 .or_default()
211 .push(method);
212 }
213
214 fn add_ord_methods(&mut self) {
215 let ord_methods = ["__lt__", "__le__", "__gt__", "__ge__"];
216
217 for name in &ord_methods {
218 let method = MethodDef {
219 name,
220 parameters: Parameters {
221 positional_or_keyword: vec![Parameter {
222 name: "other",
223 kind: ParameterKind::PositionalOrKeyword,
224 type_info: TypeInfo::builtin("object"),
225 default: ParameterDefault::None,
226 }],
227 ..Parameters::new()
228 },
229 r#return: TypeInfo::builtin("bool"),
230 doc: "",
231 r#type: MethodType::Instance,
232 is_async: false,
233 deprecated: None,
234 type_ignored: None,
235 is_overload: false,
236 };
237 self.methods
238 .entry(name.to_string())
239 .or_default()
240 .push(method);
241 }
242 }
243
244 fn add_hash_method(&mut self) {
245 let method = MethodDef {
246 name: "__hash__",
247 parameters: Parameters::new(),
248 r#return: TypeInfo::builtin("int"),
249 doc: "",
250 r#type: MethodType::Instance,
251 is_async: false,
252 deprecated: None,
253 type_ignored: None,
254 is_overload: false,
255 };
256 self.methods
257 .entry("__hash__".to_string())
258 .or_default()
259 .push(method);
260 }
261
262 fn add_str_method(&mut self) {
263 let method = MethodDef {
264 name: "__str__",
265 parameters: Parameters::new(),
266 r#return: TypeInfo::builtin("str"),
267 doc: "",
268 r#type: MethodType::Instance,
269 is_async: false,
270 deprecated: None,
271 type_ignored: None,
272 is_overload: false,
273 };
274 self.methods
275 .entry("__str__".to_string())
276 .or_default()
277 .push(method);
278 }
279
280 pub fn resolve_default_modules(&mut self, default_module_name: &str) {
283 for (getter, setter) in self.getter_setters.values_mut() {
285 if let Some(getter) = getter {
286 getter.r#type.resolve_default_module(default_module_name);
287 }
288 if let Some(setter) = setter {
289 setter.r#type.resolve_default_module(default_module_name);
290 }
291 }
292
293 for methods in self.methods.values_mut() {
295 for method in methods {
296 for param in &mut method.parameters.positional_only {
298 param.type_info.resolve_default_module(default_module_name);
299 }
300 for param in &mut method.parameters.positional_or_keyword {
301 param.type_info.resolve_default_module(default_module_name);
302 }
303 for param in &mut method.parameters.keyword_only {
304 param.type_info.resolve_default_module(default_module_name);
305 }
306 if let Some(varargs) = &mut method.parameters.varargs {
307 varargs
308 .type_info
309 .resolve_default_module(default_module_name);
310 }
311 if let Some(varkw) = &mut method.parameters.varkw {
312 varkw.type_info.resolve_default_module(default_module_name);
313 }
314 method.r#return.resolve_default_module(default_module_name);
315 }
316 }
317
318 for base in &mut self.bases {
320 base.resolve_default_module(default_module_name);
321 }
322
323 for attr in &mut self.attrs {
325 attr.r#type.resolve_default_module(default_module_name);
326 }
327
328 for class in &mut self.classes {
330 class.resolve_default_modules(default_module_name);
331 }
332 }
333}
334
335impl fmt::Display for ClassDef {
336 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
337 let bases = self
338 .bases
339 .iter()
340 .map(|i| i.name.clone())
341 .reduce(|acc, path| format!("{acc}, {path}"))
342 .map(|bases| format!("({bases})"))
343 .unwrap_or_default();
344 if !self.subclass {
345 writeln!(f, "@typing.final")?;
346 }
347 writeln!(f, "class {}{}:", self.name, bases)?;
348 let indent = indent();
349 let doc = self.doc.trim();
350 docstring::write_docstring(f, doc, indent)?;
351
352 if let Some(match_args) = &self.match_args {
353 if match_args.is_empty() {
354 writeln!(f, "{indent}__match_args__ = ()")?;
355 } else {
356 let match_args_txt = match_args
357 .iter()
358 .map(|a| format!(r##""{a}""##))
359 .collect::<Vec<_>>()
360 .join(", ");
361 writeln!(f, "{indent}__match_args__ = ({match_args_txt},)")?;
362 }
363 }
364 for attr in &self.attrs {
365 attr.fmt(f)?;
366 }
367 for (getter, setter) in self.getter_setters.values() {
368 if let Some(getter) = getter {
369 write!(
370 f,
371 "{}",
372 GetterDisplay {
373 member: getter,
374 target_module: self.module.unwrap_or(self.name)
375 }
376 )?;
377 }
378 if let Some(setter) = setter {
379 write!(
380 f,
381 "{}",
382 SetterDisplay {
383 member: setter,
384 target_module: self.module.unwrap_or(self.name)
385 }
386 )?;
387 }
388 }
389 for (_method_name, methods) in &self.methods {
390 let has_overload = methods.iter().any(|m| m.is_overload);
392 let should_add_overload = methods.len() > 1 && has_overload;
393
394 for method in methods {
395 if should_add_overload {
396 writeln!(f, "{indent}@typing.overload")?;
397 }
398 method.fmt(f)?;
399 }
400 }
401 for class in &self.classes {
402 let emit = format!("{class}");
403 for line in emit.lines() {
404 writeln!(f, "{indent}{line}")?;
405 }
406 }
407 if self.attrs.is_empty() && self.getter_setters.is_empty() && self.methods.is_empty() {
408 writeln!(f, "{indent}...")?;
409 }
410 writeln!(f)?;
411 Ok(())
412 }
413}
414
415impl ClassDef {
416 pub fn fmt_for_module(&self, target_module: &str, f: &mut fmt::Formatter) -> fmt::Result {
421 let bases = self
423 .bases
424 .iter()
425 .map(|i| i.qualified_for_module(target_module))
426 .reduce(|acc, path| format!("{acc}, {path}"))
427 .map(|bases| format!("({bases})"))
428 .unwrap_or_default();
429
430 if !self.subclass {
431 writeln!(f, "@typing.final")?;
432 }
433 writeln!(f, "class {}{}:", self.name, bases)?;
434
435 let indent = indent();
436 let doc = self.doc.trim();
437 docstring::write_docstring(f, doc, indent)?;
438
439 if let Some(match_args) = &self.match_args {
440 if match_args.is_empty() {
441 writeln!(f, "{indent}__match_args__ = ()")?;
442 } else {
443 let match_args_txt = match_args
444 .iter()
445 .map(|a| format!(r##""{a}""##))
446 .collect::<Vec<_>>()
447 .join(", ");
448 writeln!(f, "{indent}__match_args__ = ({match_args_txt},)")?;
449 }
450 }
451
452 for attr in &self.attrs {
454 attr.fmt_for_module(target_module, f, indent)?;
455 }
456
457 for (getter, setter) in self.getter_setters.values() {
459 if let Some(getter) = getter {
460 write!(
461 f,
462 "{}",
463 GetterDisplay {
464 member: getter,
465 target_module
466 }
467 )?;
468 }
469 if let Some(setter) = setter {
470 write!(
471 f,
472 "{}",
473 SetterDisplay {
474 member: setter,
475 target_module
476 }
477 )?;
478 }
479 }
480
481 for (_method_name, methods) in &self.methods {
483 let has_overload = methods.iter().any(|m| m.is_overload);
484 let should_add_overload = methods.len() > 1 && has_overload;
485
486 for method in methods {
487 if should_add_overload {
488 writeln!(f, "{indent}@typing.overload")?;
489 }
490 method.fmt_for_module(target_module, f, indent)?;
491 }
492 }
493
494 for class in &self.classes {
496 struct FmtAdapter<'a, 'b> {
498 class: &'a ClassDef,
499 target_module: &'b str,
500 }
501 impl<'a, 'b> fmt::Display for FmtAdapter<'a, 'b> {
502 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
503 self.class.fmt_for_module(self.target_module, f)
504 }
505 }
506 let emit = format!(
507 "{}",
508 FmtAdapter {
509 class,
510 target_module
511 }
512 );
513 for line in emit.lines() {
514 writeln!(f, "{indent}{line}")?;
515 }
516 }
517
518 if self.attrs.is_empty() && self.getter_setters.is_empty() && self.methods.is_empty() {
519 writeln!(f, "{indent}...")?;
520 }
521 writeln!(f)?;
522 Ok(())
523 }
524}