1use indexmap::IndexMap;
2
3use crate::generate::variant_methods::get_variant_methods;
4use crate::{
5 generate::{
6 docstring, indent, GetterDisplay, Import, MemberDef, MethodDef, Parameter,
7 ParameterDefault, Parameters, SetterDisplay,
8 },
9 stub_type::ImportRef,
10 type_info::*,
11 TypeInfo,
12};
13use std::collections::HashSet;
14use std::{fmt, vec};
15
16#[derive(Debug, Clone, PartialEq)]
18pub struct ClassDef {
19 pub name: &'static str,
20 pub doc: &'static str,
21 pub attrs: Vec<MemberDef>,
22 pub getter_setters: IndexMap<String, (Option<MemberDef>, Option<MemberDef>)>,
23 pub methods: IndexMap<String, Vec<MethodDef>>,
24 pub bases: Vec<TypeInfo>,
25 pub classes: Vec<ClassDef>,
26 pub match_args: Option<Vec<String>>,
27 pub subclass: bool,
28}
29
30impl Import for ClassDef {
31 fn import(&self) -> HashSet<ImportRef> {
32 let mut import = HashSet::new();
33 if !self.subclass {
34 import.insert("typing".into());
36 }
37 for base in &self.bases {
38 import.extend(base.import.clone());
39 }
40 for attr in &self.attrs {
41 import.extend(attr.import());
42 }
43 for (getter, setter) in self.getter_setters.values() {
44 if let Some(getter) = getter {
45 import.extend(getter.import());
46 }
47 if let Some(setter) = setter {
48 import.extend(setter.import());
49 }
50 }
51 for method in self.methods.values() {
52 if method.len() > 1 {
53 import.insert("typing".into());
55 }
56 for method in method {
57 import.extend(method.import());
58 }
59 }
60 for class in &self.classes {
61 import.extend(class.import());
62 }
63 import
64 }
65}
66
67impl From<&PyComplexEnumInfo> for ClassDef {
68 fn from(info: &PyComplexEnumInfo) -> Self {
69 let enum_info = Self {
73 name: info.pyclass_name,
74 doc: info.doc,
75 getter_setters: IndexMap::new(),
76 methods: IndexMap::new(),
77 classes: info
78 .variants
79 .iter()
80 .map(|v| ClassDef::from_variant(info, v))
81 .collect(),
82 bases: Vec::new(),
83 match_args: None,
84 attrs: Vec::new(),
85 subclass: true, };
87
88 enum_info
89 }
90}
91
92impl ClassDef {
93 fn from_variant(enum_info: &PyComplexEnumInfo, info: &VariantInfo) -> Self {
94 let methods = get_variant_methods(enum_info, info);
95
96 Self {
97 name: info.pyclass_name,
98 doc: info.doc,
99 getter_setters: info
100 .fields
101 .iter()
102 .map(|info| (info.name.to_string(), (Some(MemberDef::from(info)), None)))
103 .collect(),
104 methods,
105 classes: Vec::new(),
106 bases: vec![TypeInfo::unqualified(enum_info.pyclass_name)],
107 match_args: Some(info.fields.iter().map(|f| f.name.to_string()).collect()),
108 attrs: Vec::new(),
109 subclass: false,
110 }
111 }
112}
113
114impl From<&PyClassInfo> for ClassDef {
115 fn from(info: &PyClassInfo) -> Self {
116 let mut getter_setters: IndexMap<String, (Option<MemberDef>, Option<MemberDef>)> = info
119 .getters
120 .iter()
121 .map(|info| (info.name.to_string(), (Some(MemberDef::from(info)), None)))
122 .collect();
123 for setter in info.setters {
124 getter_setters.entry(setter.name.to_string()).or_default().1 = Some(MemberDef {
125 name: setter.name,
126 r#type: (setter.r#type)(),
127 doc: setter.doc,
128 default: setter.default.map(|f| f()),
129 deprecated: setter.deprecated.clone(),
130 });
131 }
132 let mut new = Self {
133 name: info.pyclass_name,
134 doc: info.doc,
135 attrs: Vec::new(),
136 getter_setters,
137 methods: Default::default(),
138 classes: Vec::new(),
139 bases: info.bases.iter().map(|f| f()).collect(),
140 match_args: None,
141 subclass: info.subclass,
142 };
143 if info.has_eq {
144 new.add_eq_method();
145 }
146 if info.has_ord {
147 new.add_ord_methods();
148 }
149 if info.has_hash {
150 new.add_hash_method();
151 }
152 if info.has_str {
153 new.add_str_method();
154 }
155 new
156 }
157}
158impl ClassDef {
159 fn add_eq_method(&mut self) {
160 let method = MethodDef {
161 name: "__eq__",
162 parameters: Parameters {
163 positional_or_keyword: vec![Parameter {
164 name: "other",
165 kind: ParameterKind::PositionalOrKeyword,
166 type_info: TypeInfo::builtin("object"),
167 default: ParameterDefault::None,
168 }],
169 ..Parameters::new()
170 },
171 r#return: TypeInfo::builtin("bool"),
172 doc: "",
173 r#type: MethodType::Instance,
174 is_async: false,
175 deprecated: None,
176 type_ignored: None,
177 is_overload: false,
178 };
179 self.methods
180 .entry("__eq__".to_string())
181 .or_default()
182 .push(method);
183 }
184
185 fn add_ord_methods(&mut self) {
186 let ord_methods = ["__lt__", "__le__", "__gt__", "__ge__"];
187
188 for name in &ord_methods {
189 let method = MethodDef {
190 name,
191 parameters: Parameters {
192 positional_or_keyword: vec![Parameter {
193 name: "other",
194 kind: ParameterKind::PositionalOrKeyword,
195 type_info: TypeInfo::builtin("object"),
196 default: ParameterDefault::None,
197 }],
198 ..Parameters::new()
199 },
200 r#return: TypeInfo::builtin("bool"),
201 doc: "",
202 r#type: MethodType::Instance,
203 is_async: false,
204 deprecated: None,
205 type_ignored: None,
206 is_overload: false,
207 };
208 self.methods
209 .entry(name.to_string())
210 .or_default()
211 .push(method);
212 }
213 }
214
215 fn add_hash_method(&mut self) {
216 let method = MethodDef {
217 name: "__hash__",
218 parameters: Parameters::new(),
219 r#return: TypeInfo::builtin("int"),
220 doc: "",
221 r#type: MethodType::Instance,
222 is_async: false,
223 deprecated: None,
224 type_ignored: None,
225 is_overload: false,
226 };
227 self.methods
228 .entry("__hash__".to_string())
229 .or_default()
230 .push(method);
231 }
232
233 fn add_str_method(&mut self) {
234 let method = MethodDef {
235 name: "__str__",
236 parameters: Parameters::new(),
237 r#return: TypeInfo::builtin("str"),
238 doc: "",
239 r#type: MethodType::Instance,
240 is_async: false,
241 deprecated: None,
242 type_ignored: None,
243 is_overload: false,
244 };
245 self.methods
246 .entry("__str__".to_string())
247 .or_default()
248 .push(method);
249 }
250}
251
252impl fmt::Display for ClassDef {
253 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
254 let bases = self
255 .bases
256 .iter()
257 .map(|i| i.name.clone())
258 .reduce(|acc, path| format!("{acc}, {path}"))
259 .map(|bases| format!("({bases})"))
260 .unwrap_or_default();
261 if !self.subclass {
262 writeln!(f, "@typing.final")?;
263 }
264 writeln!(f, "class {}{}:", self.name, bases)?;
265 let indent = indent();
266 let doc = self.doc.trim();
267 docstring::write_docstring(f, doc, indent)?;
268
269 if let Some(match_args) = &self.match_args {
270 if match_args.is_empty() {
271 writeln!(f, "{indent}__match_args__ = ()")?;
272 } else {
273 let match_args_txt = match_args
274 .iter()
275 .map(|a| format!(r##""{a}""##))
276 .collect::<Vec<_>>()
277 .join(", ");
278 writeln!(f, "{indent}__match_args__ = ({match_args_txt},)")?;
279 }
280 }
281 for attr in &self.attrs {
282 attr.fmt(f)?;
283 }
284 for (getter, setter) in self.getter_setters.values() {
285 if let Some(getter) = getter {
286 GetterDisplay(getter).fmt(f)?;
287 }
288 if let Some(setter) = setter {
289 SetterDisplay(setter).fmt(f)?;
290 }
291 }
292 for (_method_name, methods) in &self.methods {
293 let has_overload = methods.iter().any(|m| m.is_overload);
295 let should_add_overload = methods.len() > 1 && has_overload;
296
297 for method in methods {
298 if should_add_overload {
299 writeln!(f, "{indent}@typing.overload")?;
300 }
301 method.fmt(f)?;
302 }
303 }
304 for class in &self.classes {
305 let emit = format!("{class}");
306 for line in emit.lines() {
307 writeln!(f, "{indent}{line}")?;
308 }
309 }
310 if self.attrs.is_empty() && self.getter_setters.is_empty() && self.methods.is_empty() {
311 writeln!(f, "{indent}...")?;
312 }
313 writeln!(f)?;
314 Ok(())
315 }
316}