1use indexmap::IndexSet;
4use rustpython_parser::{ast, Parse};
5use syn::{Error, LitStr, Result, Type};
6
7use super::pyfunction::PythonFunctionStub;
8use super::{
9 dedent, extract_deprecated_from_decorators, extract_docstring, extract_return_type,
10 type_annotation_to_type_override,
11};
12use crate::gen_stub::{
13 arg::ArgInfo, method::MethodInfo, method::MethodType, pymethods::PyMethodsInfo,
14 util::TypeOrOverride,
15};
16
17pub struct PythonMethodStub {
19 pub func_stub: PythonFunctionStub,
20 pub method_type: MethodType,
21}
22
23impl TryFrom<PythonMethodStub> for MethodInfo {
24 type Error = syn::Error;
25
26 fn try_from(stub: PythonMethodStub) -> Result<Self> {
27 let func_name = stub.func_stub.func_def.name.to_string();
28
29 let doc = extract_docstring(&stub.func_stub.func_def);
31
32 let args = extract_args_for_method(
34 &stub.func_stub.func_def.args,
35 &stub.func_stub.imports,
36 stub.method_type,
37 )?;
38
39 let return_type =
41 extract_return_type(&stub.func_stub.func_def.returns, &stub.func_stub.imports)?;
42
43 let deprecated =
45 extract_deprecated_from_decorators(&stub.func_stub.func_def.decorator_list);
46
47 Ok(MethodInfo {
49 name: func_name,
50 args,
51 sig: None,
52 r#return: return_type,
53 doc,
54 r#type: stub.method_type,
55 is_async: stub.func_stub.is_async,
56 deprecated,
57 type_ignored: None,
58 })
59 }
60}
61
62pub struct PythonClassStub {
64 pub class_def: ast::StmtClassDef,
65 pub imports: Vec<String>,
66}
67
68impl PythonClassStub {
69 pub fn new(input: &LitStr) -> Result<Self> {
71 let stub_content = input.value();
72
73 let dedented_content = dedent(&stub_content);
75
76 let parsed = ast::Suite::parse(&dedented_content, "<stub>")
78 .map_err(|e| Error::new(input.span(), format!("Failed to parse Python stub: {}", e)))?;
79
80 let mut imports = Vec::new();
82 let mut class_def: Option<ast::StmtClassDef> = None;
83
84 for stmt in parsed {
85 match stmt {
86 ast::Stmt::Import(import_stmt) => {
87 for alias in &import_stmt.names {
88 imports.push(alias.name.to_string());
89 }
90 }
91 ast::Stmt::ImportFrom(import_from_stmt) => {
92 if let Some(module) = &import_from_stmt.module {
93 imports.push(module.to_string());
94 }
95 }
96 ast::Stmt::ClassDef(cls_def) => {
97 if class_def.is_some() {
98 return Err(Error::new(
99 input.span(),
100 "Multiple class definitions found. Only one class is allowed per gen_methods_from_python! call",
101 ));
102 }
103 class_def = Some(cls_def);
104 }
105 _ => {
106 }
108 }
109 }
110
111 let class_def = class_def
113 .ok_or_else(|| Error::new(input.span(), "No class definition found in Python stub"))?;
114
115 Ok(Self { class_def, imports })
116 }
117}
118
119impl TryFrom<PythonClassStub> for PyMethodsInfo {
120 type Error = syn::Error;
121
122 fn try_from(stub: PythonClassStub) -> Result<Self> {
123 let class_name = stub.class_def.name.to_string();
124 let mut methods = Vec::new();
125
126 for stmt in &stub.class_def.body {
128 match stmt {
129 ast::Stmt::FunctionDef(func_def) => {
130 let method_type = determine_method_type(func_def, &func_def.args);
132
133 let func_stub = PythonFunctionStub {
135 func_def: func_def.clone(),
136 imports: stub.imports.clone(),
137 is_async: false,
138 };
139
140 let method_stub = PythonMethodStub {
142 func_stub,
143 method_type,
144 };
145 let method = MethodInfo::try_from(method_stub)?;
146 methods.push(method);
147 }
148 ast::Stmt::AsyncFunctionDef(func_def) => {
149 let sync_func = ast::StmtFunctionDef {
151 range: func_def.range,
152 name: func_def.name.clone(),
153 type_params: func_def.type_params.clone(),
154 args: func_def.args.clone(),
155 body: func_def.body.clone(),
156 decorator_list: func_def.decorator_list.clone(),
157 returns: func_def.returns.clone(),
158 type_comment: func_def.type_comment.clone(),
159 };
160
161 let method_type = determine_method_type(&sync_func, &sync_func.args);
163
164 let func_stub = PythonFunctionStub {
166 func_def: sync_func,
167 imports: stub.imports.clone(),
168 is_async: true,
169 };
170
171 let method_stub = PythonMethodStub {
173 func_stub,
174 method_type,
175 };
176 let method = MethodInfo::try_from(method_stub)?;
177 methods.push(method);
178 }
179 _ => {
180 }
182 }
183 }
184
185 if methods.is_empty() {
186 return Err(Error::new(
187 proc_macro2::Span::call_site(),
188 "No method definitions found in class body",
189 ));
190 }
191
192 let struct_id: Type = syn::parse_str(&class_name).map_err(|e| {
194 Error::new(
195 proc_macro2::Span::call_site(),
196 format!("Failed to parse class name '{}': {}", class_name, e),
197 )
198 })?;
199
200 Ok(PyMethodsInfo {
201 struct_id,
202 attrs: Vec::new(),
203 getters: Vec::new(),
204 setters: Vec::new(),
205 methods,
206 })
207 }
208}
209
210pub fn parse_python_methods_stub(input: &LitStr) -> Result<PyMethodsInfo> {
212 let stub = PythonClassStub::new(input)?;
213 PyMethodsInfo::try_from(stub).map_err(|e| Error::new(input.span(), format!("{}", e)))
214}
215
216fn determine_method_type(func_def: &ast::StmtFunctionDef, args: &ast::Arguments) -> MethodType {
218 for decorator in &func_def.decorator_list {
220 if let ast::Expr::Name(name) = decorator {
221 match name.id.as_str() {
222 "staticmethod" => return MethodType::Static,
223 "classmethod" => return MethodType::Class,
224 _ => {}
225 }
226 }
227 }
228
229 if func_def.name.as_str() == "__new__" {
231 return MethodType::New;
232 }
233
234 if let Some(first_arg) = args.args.first() {
236 let arg_name = first_arg.def.arg.as_str();
237 if arg_name == "self" {
238 return MethodType::Instance;
239 } else if arg_name == "cls" {
240 return MethodType::Class;
241 }
242 }
243
244 MethodType::Instance
246}
247
248fn extract_args_for_method(
250 args: &ast::Arguments,
251 imports: &[String],
252 method_type: MethodType,
253) -> Result<Vec<ArgInfo>> {
254 let mut arg_infos = Vec::new();
255
256 let dummy_type: Type = syn::parse_str("()").unwrap();
258
259 for (idx, arg) in args.args.iter().enumerate() {
261 let arg_name = arg.def.arg.to_string();
262
263 if idx == 0
265 && ((method_type == MethodType::Instance && arg_name == "self")
266 || (method_type == MethodType::Class && arg_name == "cls")
267 || (method_type == MethodType::New && arg_name == "cls"))
268 {
269 continue;
270 }
271
272 let type_override = if let Some(annotation) = &arg.def.annotation {
273 type_annotation_to_type_override(annotation, imports, dummy_type.clone())?
274 } else {
275 TypeOrOverride::OverrideType {
277 r#type: dummy_type.clone(),
278 type_repr: "typing.Any".to_string(),
279 imports: IndexSet::from(["typing".to_string()]),
280 }
281 };
282
283 arg_infos.push(ArgInfo {
284 name: arg_name,
285 r#type: type_override,
286 });
287 }
288
289 Ok(arg_infos)
290}
291
292#[cfg(test)]
293mod test {
294 use super::*;
295 use proc_macro2::TokenStream as TokenStream2;
296 use quote::{quote, ToTokens};
297
298 #[test]
299 fn test_single_method_class() -> Result<()> {
300 let stub_str: LitStr = syn::parse2(quote! {
301 r#"
302 class Incrementer:
303 def increment(self, x: int) -> int:
304 """Increment by one"""
305 "#
306 })?;
307 let py_methods_info = parse_python_methods_stub(&stub_str)?;
308 assert_eq!(py_methods_info.methods.len(), 1);
309
310 let out = py_methods_info.methods[0].to_token_stream();
311 insta::assert_snapshot!(format_as_value(out), @r###"
312 ::pyo3_stub_gen::type_info::MethodInfo {
313 name: "increment",
314 args: &[
315 ::pyo3_stub_gen::type_info::ArgInfo {
316 name: "x",
317 r#type: || ::pyo3_stub_gen::TypeInfo {
318 name: "int".to_string(),
319 import: ::std::collections::HashSet::from([]),
320 },
321 signature: None,
322 },
323 ],
324 r#return: || ::pyo3_stub_gen::TypeInfo {
325 name: "int".to_string(),
326 import: ::std::collections::HashSet::from([]),
327 },
328 doc: "Increment by one",
329 r#type: ::pyo3_stub_gen::type_info::MethodType::Instance,
330 is_async: false,
331 deprecated: None,
332 type_ignored: None,
333 }
334 "###);
335 Ok(())
336 }
337
338 #[test]
339 fn test_multiple_methods_class() -> Result<()> {
340 let stub_str: LitStr = syn::parse2(quote! {
341 r#"
342 class Incrementer:
343 def increment_1(self, x: int) -> int:
344 """First method"""
345
346 def increment_2(self, x: float) -> float:
347 """Second method"""
348 "#
349 })?;
350 let py_methods_info = parse_python_methods_stub(&stub_str)?;
351 assert_eq!(py_methods_info.methods.len(), 2);
352
353 assert_eq!(py_methods_info.methods[0].name, "increment_1");
354 assert_eq!(py_methods_info.methods[1].name, "increment_2");
355 Ok(())
356 }
357
358 #[test]
359 fn test_static_method_in_class() -> Result<()> {
360 let stub_str: LitStr = syn::parse2(quote! {
361 r#"
362 class MyClass:
363 @staticmethod
364 def create(name: str) -> str:
365 """Create something"""
366 "#
367 })?;
368 let py_methods_info = parse_python_methods_stub(&stub_str)?;
369 assert_eq!(py_methods_info.methods.len(), 1);
370
371 let out = py_methods_info.methods[0].to_token_stream();
372 insta::assert_snapshot!(format_as_value(out), @r###"
373 ::pyo3_stub_gen::type_info::MethodInfo {
374 name: "create",
375 args: &[
376 ::pyo3_stub_gen::type_info::ArgInfo {
377 name: "name",
378 r#type: || ::pyo3_stub_gen::TypeInfo {
379 name: "str".to_string(),
380 import: ::std::collections::HashSet::from([]),
381 },
382 signature: None,
383 },
384 ],
385 r#return: || ::pyo3_stub_gen::TypeInfo {
386 name: "str".to_string(),
387 import: ::std::collections::HashSet::from([]),
388 },
389 doc: "Create something",
390 r#type: ::pyo3_stub_gen::type_info::MethodType::Static,
391 is_async: false,
392 deprecated: None,
393 type_ignored: None,
394 }
395 "###);
396 Ok(())
397 }
398
399 #[test]
400 fn test_class_method_in_class() -> Result<()> {
401 let stub_str: LitStr = syn::parse2(quote! {
402 r#"
403 class MyClass:
404 @classmethod
405 def from_string(cls, s: str) -> int:
406 """Create from string"""
407 "#
408 })?;
409 let py_methods_info = parse_python_methods_stub(&stub_str)?;
410 assert_eq!(py_methods_info.methods.len(), 1);
411
412 let out = py_methods_info.methods[0].to_token_stream();
413 insta::assert_snapshot!(format_as_value(out), @r###"
414 ::pyo3_stub_gen::type_info::MethodInfo {
415 name: "from_string",
416 args: &[
417 ::pyo3_stub_gen::type_info::ArgInfo {
418 name: "s",
419 r#type: || ::pyo3_stub_gen::TypeInfo {
420 name: "str".to_string(),
421 import: ::std::collections::HashSet::from([]),
422 },
423 signature: None,
424 },
425 ],
426 r#return: || ::pyo3_stub_gen::TypeInfo {
427 name: "int".to_string(),
428 import: ::std::collections::HashSet::from([]),
429 },
430 doc: "Create from string",
431 r#type: ::pyo3_stub_gen::type_info::MethodType::Class,
432 is_async: false,
433 deprecated: None,
434 type_ignored: None,
435 }
436 "###);
437 Ok(())
438 }
439
440 #[test]
441 fn test_new_method_in_class() -> Result<()> {
442 let stub_str: LitStr = syn::parse2(quote! {
443 r#"
444 class MyClass:
445 def __new__(cls) -> object:
446 """Constructor"""
447 "#
448 })?;
449 let py_methods_info = parse_python_methods_stub(&stub_str)?;
450 assert_eq!(py_methods_info.methods.len(), 1);
451
452 let out = py_methods_info.methods[0].to_token_stream();
453 insta::assert_snapshot!(format_as_value(out), @r###"
454 ::pyo3_stub_gen::type_info::MethodInfo {
455 name: "__new__",
456 args: &[],
457 r#return: || ::pyo3_stub_gen::TypeInfo {
458 name: "object".to_string(),
459 import: ::std::collections::HashSet::from([]),
460 },
461 doc: "Constructor",
462 r#type: ::pyo3_stub_gen::type_info::MethodType::New,
463 is_async: false,
464 deprecated: None,
465 type_ignored: None,
466 }
467 "###);
468 Ok(())
469 }
470
471 #[test]
472 fn test_method_with_imports_in_class() -> Result<()> {
473 let stub_str: LitStr = syn::parse2(quote! {
474 r#"
475 import typing
476 from collections.abc import Callable
477
478 class MyClass:
479 def process(self, func: Callable[[str], int]) -> typing.Optional[int]:
480 """Process a callback"""
481 "#
482 })?;
483 let py_methods_info = parse_python_methods_stub(&stub_str)?;
484 assert_eq!(py_methods_info.methods.len(), 1);
485
486 let out = py_methods_info.methods[0].to_token_stream();
487 insta::assert_snapshot!(format_as_value(out), @r###"
488 ::pyo3_stub_gen::type_info::MethodInfo {
489 name: "process",
490 args: &[
491 ::pyo3_stub_gen::type_info::ArgInfo {
492 name: "func",
493 r#type: || ::pyo3_stub_gen::TypeInfo {
494 name: "Callable[[str], int]".to_string(),
495 import: ::std::collections::HashSet::from([
496 "typing".into(),
497 "collections.abc".into(),
498 ]),
499 },
500 signature: None,
501 },
502 ],
503 r#return: || ::pyo3_stub_gen::TypeInfo {
504 name: "typing.Optional[int]".to_string(),
505 import: ::std::collections::HashSet::from([
506 "typing".into(),
507 "collections.abc".into(),
508 ]),
509 },
510 doc: "Process a callback",
511 r#type: ::pyo3_stub_gen::type_info::MethodType::Instance,
512 is_async: false,
513 deprecated: None,
514 type_ignored: None,
515 }
516 "###);
517 Ok(())
518 }
519
520 #[test]
521 fn test_async_method_in_class() -> Result<()> {
522 let stub_str: LitStr = syn::parse2(quote! {
523 r#"
524 class MyClass:
525 async def fetch_data(self, url: str) -> str:
526 """Fetch data asynchronously"""
527 "#
528 })?;
529 let py_methods_info = parse_python_methods_stub(&stub_str)?;
530 assert_eq!(py_methods_info.methods.len(), 1);
531
532 let out = py_methods_info.methods[0].to_token_stream();
533 insta::assert_snapshot!(format_as_value(out), @r###"
534 ::pyo3_stub_gen::type_info::MethodInfo {
535 name: "fetch_data",
536 args: &[
537 ::pyo3_stub_gen::type_info::ArgInfo {
538 name: "url",
539 r#type: || ::pyo3_stub_gen::TypeInfo {
540 name: "str".to_string(),
541 import: ::std::collections::HashSet::from([]),
542 },
543 signature: None,
544 },
545 ],
546 r#return: || ::pyo3_stub_gen::TypeInfo {
547 name: "str".to_string(),
548 import: ::std::collections::HashSet::from([]),
549 },
550 doc: "Fetch data asynchronously",
551 r#type: ::pyo3_stub_gen::type_info::MethodType::Instance,
552 is_async: true,
553 deprecated: None,
554 type_ignored: None,
555 }
556 "###);
557 Ok(())
558 }
559
560 fn format_as_value(tt: TokenStream2) -> String {
561 let ttt = quote! { const _: () = #tt; };
562 let formatted = prettyplease::unparse(&syn::parse_file(&ttt.to_string()).unwrap());
563 formatted
564 .trim()
565 .strip_prefix("const _: () = ")
566 .unwrap()
567 .strip_suffix(';')
568 .unwrap()
569 .to_string()
570 }
571}