1use rustpython_parser::{ast, Parse};
4use syn::{Error, LitStr, Result, Type};
5
6use super::pyfunction::PythonFunctionStub;
7use super::{
8 build_parameters_from_ast, dedent, extract_deprecated_from_decorators, extract_docstring,
9 extract_return_type, has_overload_decorator,
10};
11use crate::gen_stub::{method::MethodInfo, method::MethodType, pymethods::PyMethodsInfo};
12
13pub struct PythonMethodStub {
15 pub func_stub: PythonFunctionStub,
16 pub method_type: MethodType,
17}
18
19impl TryFrom<PythonMethodStub> for MethodInfo {
20 type Error = syn::Error;
21
22 fn try_from(stub: PythonMethodStub) -> Result<Self> {
23 let func_name = stub.func_stub.func_def.name.to_string();
24
25 let doc = extract_docstring(&stub.func_stub.func_def);
27
28 let parameters =
30 build_parameters_from_ast(&stub.func_stub.func_def.args, &stub.func_stub.imports)?;
31
32 let return_type =
38 extract_return_type(&stub.func_stub.func_def.returns, &stub.func_stub.imports)?;
39
40 let deprecated =
42 extract_deprecated_from_decorators(&stub.func_stub.func_def.decorator_list);
43
44 Ok(MethodInfo {
46 name: func_name,
47 parameters,
48 r#return: return_type,
49 doc,
50 r#type: stub.method_type,
51 is_async: stub.func_stub.is_async,
52 deprecated,
53 type_ignored: None,
54 is_overload: stub.func_stub.is_overload,
55 })
56 }
57}
58
59pub struct PythonClassStub {
61 pub class_def: ast::StmtClassDef,
62 pub imports: Vec<String>,
63}
64
65impl PythonClassStub {
66 pub fn new(input: &LitStr) -> Result<Self> {
68 let stub_content = input.value();
69
70 let dedented_content = dedent(&stub_content);
72
73 let parsed = ast::Suite::parse(&dedented_content, "<stub>")
75 .map_err(|e| Error::new(input.span(), format!("Failed to parse Python stub: {}", e)))?;
76
77 let mut imports = Vec::new();
79 let mut class_def: Option<ast::StmtClassDef> = None;
80
81 for stmt in parsed {
82 match stmt {
83 ast::Stmt::Import(import_stmt) => {
84 for alias in &import_stmt.names {
85 imports.push(alias.name.to_string());
86 }
87 }
88 ast::Stmt::ImportFrom(import_from_stmt) => {
89 if let Some(module) = &import_from_stmt.module {
90 imports.push(module.to_string());
91 }
92 }
93 ast::Stmt::ClassDef(cls_def) => {
94 if class_def.is_some() {
95 return Err(Error::new(
96 input.span(),
97 "Multiple class definitions found. Only one class is allowed per gen_methods_from_python! call",
98 ));
99 }
100 class_def = Some(cls_def);
101 }
102 _ => {
103 }
105 }
106 }
107
108 let class_def = class_def
110 .ok_or_else(|| Error::new(input.span(), "No class definition found in Python stub"))?;
111
112 Ok(Self { class_def, imports })
113 }
114}
115
116impl TryFrom<PythonClassStub> for PyMethodsInfo {
117 type Error = syn::Error;
118
119 fn try_from(stub: PythonClassStub) -> Result<Self> {
120 let class_name = stub.class_def.name.to_string();
121 let mut methods = Vec::new();
122
123 for stmt in &stub.class_def.body {
125 match stmt {
126 ast::Stmt::FunctionDef(func_def) => {
127 let method_type = determine_method_type(func_def, &func_def.args);
129
130 let is_overload = has_overload_decorator(&func_def.decorator_list);
132
133 let func_stub = PythonFunctionStub {
135 func_def: func_def.clone(),
136 imports: stub.imports.clone(),
137 is_async: false,
138 is_overload,
139 };
140
141 let method_stub = PythonMethodStub {
143 func_stub,
144 method_type,
145 };
146 let method = MethodInfo::try_from(method_stub)?;
147 methods.push(method);
148 }
149 ast::Stmt::AsyncFunctionDef(func_def) => {
150 let is_overload = has_overload_decorator(&func_def.decorator_list);
152
153 let sync_func = ast::StmtFunctionDef {
155 range: func_def.range,
156 name: func_def.name.clone(),
157 type_params: func_def.type_params.clone(),
158 args: func_def.args.clone(),
159 body: func_def.body.clone(),
160 decorator_list: func_def.decorator_list.clone(),
161 returns: func_def.returns.clone(),
162 type_comment: func_def.type_comment.clone(),
163 };
164
165 let method_type = determine_method_type(&sync_func, &sync_func.args);
167
168 let func_stub = PythonFunctionStub {
170 func_def: sync_func,
171 imports: stub.imports.clone(),
172 is_async: true,
173 is_overload,
174 };
175
176 let method_stub = PythonMethodStub {
178 func_stub,
179 method_type,
180 };
181 let method = MethodInfo::try_from(method_stub)?;
182 methods.push(method);
183 }
184 _ => {
185 }
187 }
188 }
189
190 if methods.is_empty() {
191 return Err(Error::new(
192 proc_macro2::Span::call_site(),
193 "No method definitions found in class body",
194 ));
195 }
196
197 let struct_id: Type = syn::parse_str(&class_name).map_err(|e| {
199 Error::new(
200 proc_macro2::Span::call_site(),
201 format!("Failed to parse class name '{}': {}", class_name, e),
202 )
203 })?;
204
205 Ok(PyMethodsInfo {
206 struct_id,
207 attrs: Vec::new(),
208 getters: Vec::new(),
209 setters: Vec::new(),
210 methods,
211 })
212 }
213}
214
215pub fn parse_python_methods_stub(input: &LitStr) -> Result<PyMethodsInfo> {
217 let stub = PythonClassStub::new(input)?;
218 PyMethodsInfo::try_from(stub).map_err(|e| Error::new(input.span(), format!("{}", e)))
219}
220
221fn determine_method_type(func_def: &ast::StmtFunctionDef, args: &ast::Arguments) -> MethodType {
223 for decorator in &func_def.decorator_list {
225 if let ast::Expr::Name(name) = decorator {
226 match name.id.as_str() {
227 "staticmethod" => return MethodType::Static,
228 "classmethod" => return MethodType::Class,
229 _ => {}
230 }
231 }
232 }
233
234 if func_def.name.as_str() == "__new__" {
236 return MethodType::New;
237 }
238
239 if let Some(first_arg) = args.args.first() {
241 let arg_name = first_arg.def.arg.as_str();
242 if arg_name == "self" {
243 return MethodType::Instance;
244 } else if arg_name == "cls" {
245 return MethodType::Class;
246 }
247 }
248
249 MethodType::Instance
251}
252
253#[cfg(test)]
254mod test {
255 use super::*;
256 use proc_macro2::TokenStream as TokenStream2;
257 use quote::{quote, ToTokens};
258
259 #[test]
260 fn test_single_method_class() -> Result<()> {
261 let stub_str: LitStr = syn::parse2(quote! {
262 r#"
263 class Incrementer:
264 def increment(self, x: int) -> int:
265 """Increment by one"""
266 "#
267 })?;
268 let py_methods_info = parse_python_methods_stub(&stub_str)?;
269 assert_eq!(py_methods_info.methods.len(), 1);
270
271 let out = py_methods_info.methods[0].to_token_stream();
272 insta::assert_snapshot!(format_as_value(out), @r###"
273 ::pyo3_stub_gen::type_info::MethodInfo {
274 name: "increment",
275 parameters: &[
276 ::pyo3_stub_gen::type_info::ParameterInfo {
277 name: "x",
278 kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
279 type_info: || ::pyo3_stub_gen::TypeInfo {
280 name: "int".to_string(),
281 import: ::std::collections::HashSet::from([]),
282 },
283 default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
284 },
285 ],
286 r#return: || ::pyo3_stub_gen::TypeInfo {
287 name: "int".to_string(),
288 import: ::std::collections::HashSet::from([]),
289 },
290 doc: "Increment by one",
291 r#type: ::pyo3_stub_gen::type_info::MethodType::Instance,
292 is_async: false,
293 deprecated: None,
294 type_ignored: None,
295 is_overload: false,
296 }
297 "###);
298 Ok(())
299 }
300
301 #[test]
302 fn test_multiple_methods_class() -> Result<()> {
303 let stub_str: LitStr = syn::parse2(quote! {
304 r#"
305 class Incrementer:
306 def increment_1(self, x: int) -> int:
307 """First method"""
308
309 def increment_2(self, x: float) -> float:
310 """Second method"""
311 "#
312 })?;
313 let py_methods_info = parse_python_methods_stub(&stub_str)?;
314 assert_eq!(py_methods_info.methods.len(), 2);
315
316 assert_eq!(py_methods_info.methods[0].name, "increment_1");
317 assert_eq!(py_methods_info.methods[1].name, "increment_2");
318 Ok(())
319 }
320
321 #[test]
322 fn test_static_method_in_class() -> Result<()> {
323 let stub_str: LitStr = syn::parse2(quote! {
324 r#"
325 class MyClass:
326 @staticmethod
327 def create(name: str) -> str:
328 """Create something"""
329 "#
330 })?;
331 let py_methods_info = parse_python_methods_stub(&stub_str)?;
332 assert_eq!(py_methods_info.methods.len(), 1);
333
334 let out = py_methods_info.methods[0].to_token_stream();
335 insta::assert_snapshot!(format_as_value(out), @r###"
336 ::pyo3_stub_gen::type_info::MethodInfo {
337 name: "create",
338 parameters: &[
339 ::pyo3_stub_gen::type_info::ParameterInfo {
340 name: "name",
341 kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
342 type_info: || ::pyo3_stub_gen::TypeInfo {
343 name: "str".to_string(),
344 import: ::std::collections::HashSet::from([]),
345 },
346 default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
347 },
348 ],
349 r#return: || ::pyo3_stub_gen::TypeInfo {
350 name: "str".to_string(),
351 import: ::std::collections::HashSet::from([]),
352 },
353 doc: "Create something",
354 r#type: ::pyo3_stub_gen::type_info::MethodType::Static,
355 is_async: false,
356 deprecated: None,
357 type_ignored: None,
358 is_overload: false,
359 }
360 "###);
361 Ok(())
362 }
363
364 #[test]
365 fn test_class_method_in_class() -> Result<()> {
366 let stub_str: LitStr = syn::parse2(quote! {
367 r#"
368 class MyClass:
369 @classmethod
370 def from_string(cls, s: str) -> int:
371 """Create from string"""
372 "#
373 })?;
374 let py_methods_info = parse_python_methods_stub(&stub_str)?;
375 assert_eq!(py_methods_info.methods.len(), 1);
376
377 let out = py_methods_info.methods[0].to_token_stream();
378 insta::assert_snapshot!(format_as_value(out), @r###"
379 ::pyo3_stub_gen::type_info::MethodInfo {
380 name: "from_string",
381 parameters: &[
382 ::pyo3_stub_gen::type_info::ParameterInfo {
383 name: "s",
384 kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
385 type_info: || ::pyo3_stub_gen::TypeInfo {
386 name: "str".to_string(),
387 import: ::std::collections::HashSet::from([]),
388 },
389 default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
390 },
391 ],
392 r#return: || ::pyo3_stub_gen::TypeInfo {
393 name: "int".to_string(),
394 import: ::std::collections::HashSet::from([]),
395 },
396 doc: "Create from string",
397 r#type: ::pyo3_stub_gen::type_info::MethodType::Class,
398 is_async: false,
399 deprecated: None,
400 type_ignored: None,
401 is_overload: false,
402 }
403 "###);
404 Ok(())
405 }
406
407 #[test]
408 fn test_new_method_in_class() -> Result<()> {
409 let stub_str: LitStr = syn::parse2(quote! {
410 r#"
411 class MyClass:
412 def __new__(cls) -> object:
413 """Constructor"""
414 "#
415 })?;
416 let py_methods_info = parse_python_methods_stub(&stub_str)?;
417 assert_eq!(py_methods_info.methods.len(), 1);
418
419 let out = py_methods_info.methods[0].to_token_stream();
420 insta::assert_snapshot!(format_as_value(out), @r###"
421 ::pyo3_stub_gen::type_info::MethodInfo {
422 name: "__new__",
423 parameters: &[],
424 r#return: || ::pyo3_stub_gen::TypeInfo {
425 name: "object".to_string(),
426 import: ::std::collections::HashSet::from([]),
427 },
428 doc: "Constructor",
429 r#type: ::pyo3_stub_gen::type_info::MethodType::New,
430 is_async: false,
431 deprecated: None,
432 type_ignored: None,
433 is_overload: false,
434 }
435 "###);
436 Ok(())
437 }
438
439 #[test]
440 fn test_method_with_imports_in_class() -> Result<()> {
441 let stub_str: LitStr = syn::parse2(quote! {
442 r#"
443 import typing
444 from collections.abc import Callable
445
446 class MyClass:
447 def process(self, func: Callable[[str], int]) -> typing.Optional[int]:
448 """Process a callback"""
449 "#
450 })?;
451 let py_methods_info = parse_python_methods_stub(&stub_str)?;
452 assert_eq!(py_methods_info.methods.len(), 1);
453
454 let out = py_methods_info.methods[0].to_token_stream();
455 insta::assert_snapshot!(format_as_value(out), @r###"
456 ::pyo3_stub_gen::type_info::MethodInfo {
457 name: "process",
458 parameters: &[
459 ::pyo3_stub_gen::type_info::ParameterInfo {
460 name: "func",
461 kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
462 type_info: || ::pyo3_stub_gen::TypeInfo {
463 name: "Callable[[str], int]".to_string(),
464 import: ::std::collections::HashSet::from([
465 "typing".into(),
466 "collections.abc".into(),
467 ]),
468 },
469 default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
470 },
471 ],
472 r#return: || ::pyo3_stub_gen::TypeInfo {
473 name: "typing.Optional[int]".to_string(),
474 import: ::std::collections::HashSet::from([
475 "typing".into(),
476 "collections.abc".into(),
477 ]),
478 },
479 doc: "Process a callback",
480 r#type: ::pyo3_stub_gen::type_info::MethodType::Instance,
481 is_async: false,
482 deprecated: None,
483 type_ignored: None,
484 is_overload: false,
485 }
486 "###);
487 Ok(())
488 }
489
490 #[test]
491 fn test_async_method_in_class() -> Result<()> {
492 let stub_str: LitStr = syn::parse2(quote! {
493 r#"
494 class MyClass:
495 async def fetch_data(self, url: str) -> str:
496 """Fetch data asynchronously"""
497 "#
498 })?;
499 let py_methods_info = parse_python_methods_stub(&stub_str)?;
500 assert_eq!(py_methods_info.methods.len(), 1);
501
502 let out = py_methods_info.methods[0].to_token_stream();
503 insta::assert_snapshot!(format_as_value(out), @r###"
504 ::pyo3_stub_gen::type_info::MethodInfo {
505 name: "fetch_data",
506 parameters: &[
507 ::pyo3_stub_gen::type_info::ParameterInfo {
508 name: "url",
509 kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
510 type_info: || ::pyo3_stub_gen::TypeInfo {
511 name: "str".to_string(),
512 import: ::std::collections::HashSet::from([]),
513 },
514 default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
515 },
516 ],
517 r#return: || ::pyo3_stub_gen::TypeInfo {
518 name: "str".to_string(),
519 import: ::std::collections::HashSet::from([]),
520 },
521 doc: "Fetch data asynchronously",
522 r#type: ::pyo3_stub_gen::type_info::MethodType::Instance,
523 is_async: true,
524 deprecated: None,
525 type_ignored: None,
526 is_overload: false,
527 }
528 "###);
529 Ok(())
530 }
531
532 #[test]
533 fn test_rust_type_marker_in_method() -> Result<()> {
534 let stub_str: LitStr = syn::parse2(quote! {
535 r#"
536 class PyProblem:
537 def __iadd__(self, other: pyo3_stub_gen.RustType["SomeRustType"]) -> pyo3_stub_gen.RustType["PyProblem"]:
538 """In-place addition using Rust type marker"""
539 "#
540 })?;
541 let py_methods_info = parse_python_methods_stub(&stub_str)?;
542 assert_eq!(py_methods_info.methods.len(), 1);
543
544 let out = py_methods_info.methods[0].to_token_stream();
545 insta::assert_snapshot!(format_as_value(out), @r###"
546 ::pyo3_stub_gen::type_info::MethodInfo {
547 name: "__iadd__",
548 parameters: &[
549 ::pyo3_stub_gen::type_info::ParameterInfo {
550 name: "other",
551 kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
552 type_info: <SomeRustType as ::pyo3_stub_gen::PyStubType>::type_input,
553 default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
554 },
555 ],
556 r#return: <PyProblem as pyo3_stub_gen::PyStubType>::type_output,
557 doc: "In-place addition using Rust type marker",
558 r#type: ::pyo3_stub_gen::type_info::MethodType::Instance,
559 is_async: false,
560 deprecated: None,
561 type_ignored: None,
562 is_overload: false,
563 }
564 "###);
565 Ok(())
566 }
567
568 #[test]
569 fn test_keyword_only_params_with_defaults() -> Result<()> {
570 let stub_str: LitStr = syn::parse2(quote! {
571 r#"
572 import builtins
573 import typing
574
575 class Placeholder:
576 def configure(
577 self,
578 name: builtins.str,
579 *,
580 dtype: builtins.str,
581 ndim: builtins.int,
582 shape: typing.Optional[builtins.str],
583 jagged: builtins.bool = False,
584 latex: typing.Optional[builtins.str] = None,
585 description: typing.Optional[builtins.str] = None,
586 ) -> pyo3_stub_gen.RustType["Placeholder"]:
587 """
588 Configure placeholder with keyword-only parameters.
589
590 This demonstrates keyword-only parameters (after *) which should be
591 preserved in the generated stub file.
592 """
593 "#
594 })?;
595 let py_methods_info = parse_python_methods_stub(&stub_str)?;
596 assert_eq!(py_methods_info.methods.len(), 1);
597
598 let out = py_methods_info.to_token_stream();
599 insta::assert_snapshot!(format_as_value(out), @r###"
600 ::pyo3_stub_gen::type_info::PyMethodsInfo {
601 struct_id: std::any::TypeId::of::<Placeholder>,
602 attrs: &[],
603 getters: &[],
604 setters: &[],
605 methods: &[
606 ::pyo3_stub_gen::type_info::MethodInfo {
607 name: "configure",
608 parameters: &[
609 ::pyo3_stub_gen::type_info::ParameterInfo {
610 name: "name",
611 kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
612 type_info: || ::pyo3_stub_gen::TypeInfo {
613 name: "builtins.str".to_string(),
614 import: ::std::collections::HashSet::from([
615 "builtins".into(),
616 "typing".into(),
617 ]),
618 },
619 default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
620 },
621 ::pyo3_stub_gen::type_info::ParameterInfo {
622 name: "dtype",
623 kind: ::pyo3_stub_gen::type_info::ParameterKind::KeywordOnly,
624 type_info: || ::pyo3_stub_gen::TypeInfo {
625 name: "builtins.str".to_string(),
626 import: ::std::collections::HashSet::from([
627 "builtins".into(),
628 "typing".into(),
629 ]),
630 },
631 default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
632 },
633 ::pyo3_stub_gen::type_info::ParameterInfo {
634 name: "ndim",
635 kind: ::pyo3_stub_gen::type_info::ParameterKind::KeywordOnly,
636 type_info: || ::pyo3_stub_gen::TypeInfo {
637 name: "builtins.int".to_string(),
638 import: ::std::collections::HashSet::from([
639 "builtins".into(),
640 "typing".into(),
641 ]),
642 },
643 default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
644 },
645 ::pyo3_stub_gen::type_info::ParameterInfo {
646 name: "shape",
647 kind: ::pyo3_stub_gen::type_info::ParameterKind::KeywordOnly,
648 type_info: || ::pyo3_stub_gen::TypeInfo {
649 name: "typing.Optional[builtins.str]".to_string(),
650 import: ::std::collections::HashSet::from([
651 "builtins".into(),
652 "typing".into(),
653 ]),
654 },
655 default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
656 },
657 ::pyo3_stub_gen::type_info::ParameterInfo {
658 name: "jagged",
659 kind: ::pyo3_stub_gen::type_info::ParameterKind::KeywordOnly,
660 type_info: || ::pyo3_stub_gen::TypeInfo {
661 name: "builtins.bool".to_string(),
662 import: ::std::collections::HashSet::from([
663 "builtins".into(),
664 "typing".into(),
665 ]),
666 },
667 default: ::pyo3_stub_gen::type_info::ParameterDefault::Expr({
668 fn _fmt() -> String {
669 "False".to_string()
670 }
671 _fmt
672 }),
673 },
674 ::pyo3_stub_gen::type_info::ParameterInfo {
675 name: "latex",
676 kind: ::pyo3_stub_gen::type_info::ParameterKind::KeywordOnly,
677 type_info: || ::pyo3_stub_gen::TypeInfo {
678 name: "typing.Optional[builtins.str]".to_string(),
679 import: ::std::collections::HashSet::from([
680 "builtins".into(),
681 "typing".into(),
682 ]),
683 },
684 default: ::pyo3_stub_gen::type_info::ParameterDefault::Expr({
685 fn _fmt() -> String {
686 "None".to_string()
687 }
688 _fmt
689 }),
690 },
691 ::pyo3_stub_gen::type_info::ParameterInfo {
692 name: "description",
693 kind: ::pyo3_stub_gen::type_info::ParameterKind::KeywordOnly,
694 type_info: || ::pyo3_stub_gen::TypeInfo {
695 name: "typing.Optional[builtins.str]".to_string(),
696 import: ::std::collections::HashSet::from([
697 "builtins".into(),
698 "typing".into(),
699 ]),
700 },
701 default: ::pyo3_stub_gen::type_info::ParameterDefault::Expr({
702 fn _fmt() -> String {
703 "None".to_string()
704 }
705 _fmt
706 }),
707 },
708 ],
709 r#return: <Placeholder as pyo3_stub_gen::PyStubType>::type_output,
710 doc: "\n Configure placeholder with keyword-only parameters.\n\n This demonstrates keyword-only parameters (after *) which should be\n preserved in the generated stub file.\n ",
711 r#type: ::pyo3_stub_gen::type_info::MethodType::Instance,
712 is_async: false,
713 deprecated: None,
714 type_ignored: None,
715 is_overload: false,
716 },
717 ],
718 file: file!(),
719 line: line!(),
720 column: column!(),
721 }
722 "###);
723 Ok(())
724 }
725
726 fn format_as_value(tt: TokenStream2) -> String {
727 let ttt = quote! { const _: () = #tt; };
728 let formatted = prettyplease::unparse(&syn::parse_file(&ttt.to_string()).unwrap());
729 formatted
730 .trim()
731 .strip_prefix("const _: () = ")
732 .unwrap()
733 .strip_suffix(';')
734 .unwrap()
735 .to_string()
736 }
737}