pyo3_stub_gen_derive/gen_stub/parse_python/
pymethods.rs

1//! Parse Python class method stub syntax and generate MethodInfo
2
3use indexmap::IndexSet;
4use rustpython_parser::{ast, Parse};
5use syn::{Error, LitStr, Result, Type};
6
7use super::pyfunction::PythonFunctionStub;
8use super::{
9    dedent, extract_deprecated_from_decorators, extract_docstring, extract_return_type,
10    type_annotation_to_type_override,
11};
12use crate::gen_stub::{
13    arg::ArgInfo, method::MethodInfo, method::MethodType, pymethods::PyMethodsInfo,
14    util::TypeOrOverride,
15};
16
17/// Intermediate representation for Python method stub
18pub struct PythonMethodStub {
19    pub func_stub: PythonFunctionStub,
20    pub method_type: MethodType,
21}
22
23impl TryFrom<PythonMethodStub> for MethodInfo {
24    type Error = syn::Error;
25
26    fn try_from(stub: PythonMethodStub) -> Result<Self> {
27        let func_name = stub.func_stub.func_def.name.to_string();
28
29        // Extract docstring
30        let doc = extract_docstring(&stub.func_stub.func_def);
31
32        // Extract arguments based on method type
33        let args = extract_args_for_method(
34            &stub.func_stub.func_def.args,
35            &stub.func_stub.imports,
36            stub.method_type,
37        )?;
38
39        // Extract return type
40        let return_type =
41            extract_return_type(&stub.func_stub.func_def.returns, &stub.func_stub.imports)?;
42
43        // Try to extract deprecated decorator
44        let deprecated =
45            extract_deprecated_from_decorators(&stub.func_stub.func_def.decorator_list);
46
47        // Construct MethodInfo
48        Ok(MethodInfo {
49            name: func_name,
50            args,
51            sig: None,
52            r#return: return_type,
53            doc,
54            r#type: stub.method_type,
55            is_async: stub.func_stub.is_async,
56            deprecated,
57            type_ignored: None,
58        })
59    }
60}
61
62/// Intermediate representation for Python class stub (for methods)
63pub struct PythonClassStub {
64    pub class_def: ast::StmtClassDef,
65    pub imports: Vec<String>,
66}
67
68impl PythonClassStub {
69    /// Parse Python class definition from a literal string
70    pub fn new(input: &LitStr) -> Result<Self> {
71        let stub_content = input.value();
72
73        // Remove common indentation to allow indented Python code in raw strings
74        let dedented_content = dedent(&stub_content);
75
76        // Parse Python code using rustpython-parser
77        let parsed = ast::Suite::parse(&dedented_content, "<stub>")
78            .map_err(|e| Error::new(input.span(), format!("Failed to parse Python stub: {}", e)))?;
79
80        // Extract imports and class definition
81        let mut imports = Vec::new();
82        let mut class_def: Option<ast::StmtClassDef> = None;
83
84        for stmt in parsed {
85            match stmt {
86                ast::Stmt::Import(import_stmt) => {
87                    for alias in &import_stmt.names {
88                        imports.push(alias.name.to_string());
89                    }
90                }
91                ast::Stmt::ImportFrom(import_from_stmt) => {
92                    if let Some(module) = &import_from_stmt.module {
93                        imports.push(module.to_string());
94                    }
95                }
96                ast::Stmt::ClassDef(cls_def) => {
97                    if class_def.is_some() {
98                        return Err(Error::new(
99                            input.span(),
100                            "Multiple class definitions found. Only one class is allowed per gen_methods_from_python! call",
101                        ));
102                    }
103                    class_def = Some(cls_def);
104                }
105                _ => {
106                    // Ignore other statements
107                }
108            }
109        }
110
111        // Check that exactly one class is defined
112        let class_def = class_def
113            .ok_or_else(|| Error::new(input.span(), "No class definition found in Python stub"))?;
114
115        Ok(Self { class_def, imports })
116    }
117}
118
119impl TryFrom<PythonClassStub> for PyMethodsInfo {
120    type Error = syn::Error;
121
122    fn try_from(stub: PythonClassStub) -> Result<Self> {
123        let class_name = stub.class_def.name.to_string();
124        let mut methods = Vec::new();
125
126        // Extract methods from class body
127        for stmt in &stub.class_def.body {
128            match stmt {
129                ast::Stmt::FunctionDef(func_def) => {
130                    // Determine method type
131                    let method_type = determine_method_type(func_def, &func_def.args);
132
133                    // Create PythonFunctionStub
134                    let func_stub = PythonFunctionStub {
135                        func_def: func_def.clone(),
136                        imports: stub.imports.clone(),
137                        is_async: false,
138                    };
139
140                    // Create PythonMethodStub and convert to MethodInfo
141                    let method_stub = PythonMethodStub {
142                        func_stub,
143                        method_type,
144                    };
145                    let method = MethodInfo::try_from(method_stub)?;
146                    methods.push(method);
147                }
148                ast::Stmt::AsyncFunctionDef(func_def) => {
149                    // Convert AsyncFunctionDef to FunctionDef for uniform processing
150                    let sync_func = ast::StmtFunctionDef {
151                        range: func_def.range,
152                        name: func_def.name.clone(),
153                        type_params: func_def.type_params.clone(),
154                        args: func_def.args.clone(),
155                        body: func_def.body.clone(),
156                        decorator_list: func_def.decorator_list.clone(),
157                        returns: func_def.returns.clone(),
158                        type_comment: func_def.type_comment.clone(),
159                    };
160
161                    // Determine method type
162                    let method_type = determine_method_type(&sync_func, &sync_func.args);
163
164                    // Create PythonFunctionStub
165                    let func_stub = PythonFunctionStub {
166                        func_def: sync_func,
167                        imports: stub.imports.clone(),
168                        is_async: true,
169                    };
170
171                    // Create PythonMethodStub and convert to MethodInfo
172                    let method_stub = PythonMethodStub {
173                        func_stub,
174                        method_type,
175                    };
176                    let method = MethodInfo::try_from(method_stub)?;
177                    methods.push(method);
178                }
179                _ => {
180                    // Ignore other statements (e.g., docstrings, pass)
181                }
182            }
183        }
184
185        if methods.is_empty() {
186            return Err(Error::new(
187                proc_macro2::Span::call_site(),
188                "No method definitions found in class body",
189            ));
190        }
191
192        // Parse class name as Type
193        let struct_id: Type = syn::parse_str(&class_name).map_err(|e| {
194            Error::new(
195                proc_macro2::Span::call_site(),
196                format!("Failed to parse class name '{}': {}", class_name, e),
197            )
198        })?;
199
200        Ok(PyMethodsInfo {
201            struct_id,
202            attrs: Vec::new(),
203            getters: Vec::new(),
204            setters: Vec::new(),
205            methods,
206        })
207    }
208}
209
210/// Parse Python class definition and return PyMethodsInfo
211pub fn parse_python_methods_stub(input: &LitStr) -> Result<PyMethodsInfo> {
212    let stub = PythonClassStub::new(input)?;
213    PyMethodsInfo::try_from(stub).map_err(|e| Error::new(input.span(), format!("{}", e)))
214}
215
216/// Determine method type from decorators and arguments
217fn determine_method_type(func_def: &ast::StmtFunctionDef, args: &ast::Arguments) -> MethodType {
218    // Check for @staticmethod decorator
219    for decorator in &func_def.decorator_list {
220        if let ast::Expr::Name(name) = decorator {
221            match name.id.as_str() {
222                "staticmethod" => return MethodType::Static,
223                "classmethod" => return MethodType::Class,
224                _ => {}
225            }
226        }
227    }
228
229    // Check if it's __new__ (constructor)
230    if func_def.name.as_str() == "__new__" {
231        return MethodType::New;
232    }
233
234    // Check first argument to determine if it's instance/class method
235    if let Some(first_arg) = args.args.first() {
236        let arg_name = first_arg.def.arg.as_str();
237        if arg_name == "self" {
238            return MethodType::Instance;
239        } else if arg_name == "cls" {
240            return MethodType::Class;
241        }
242    }
243
244    // Default to instance method
245    MethodType::Instance
246}
247
248/// Extract arguments for method (handling self/cls)
249fn extract_args_for_method(
250    args: &ast::Arguments,
251    imports: &[String],
252    method_type: MethodType,
253) -> Result<Vec<ArgInfo>> {
254    let mut arg_infos = Vec::new();
255
256    // Dummy type for TypeOrOverride (not used in ToTokens for OverrideType)
257    let dummy_type: Type = syn::parse_str("()").unwrap();
258
259    // Process positional arguments
260    for (idx, arg) in args.args.iter().enumerate() {
261        let arg_name = arg.def.arg.to_string();
262
263        // Skip 'self' or 'cls' for instance/class/new methods (first argument only)
264        if idx == 0
265            && ((method_type == MethodType::Instance && arg_name == "self")
266                || (method_type == MethodType::Class && arg_name == "cls")
267                || (method_type == MethodType::New && arg_name == "cls"))
268        {
269            continue;
270        }
271
272        let type_override = if let Some(annotation) = &arg.def.annotation {
273            type_annotation_to_type_override(annotation, imports, dummy_type.clone())?
274        } else {
275            // No type annotation - use Any
276            TypeOrOverride::OverrideType {
277                r#type: dummy_type.clone(),
278                type_repr: "typing.Any".to_string(),
279                imports: IndexSet::from(["typing".to_string()]),
280            }
281        };
282
283        arg_infos.push(ArgInfo {
284            name: arg_name,
285            r#type: type_override,
286        });
287    }
288
289    Ok(arg_infos)
290}
291
292#[cfg(test)]
293mod test {
294    use super::*;
295    use proc_macro2::TokenStream as TokenStream2;
296    use quote::{quote, ToTokens};
297
298    #[test]
299    fn test_single_method_class() -> Result<()> {
300        let stub_str: LitStr = syn::parse2(quote! {
301            r#"
302            class Incrementer:
303                def increment(self, x: int) -> int:
304                    """Increment by one"""
305            "#
306        })?;
307        let py_methods_info = parse_python_methods_stub(&stub_str)?;
308        assert_eq!(py_methods_info.methods.len(), 1);
309
310        let out = py_methods_info.methods[0].to_token_stream();
311        insta::assert_snapshot!(format_as_value(out), @r###"
312        ::pyo3_stub_gen::type_info::MethodInfo {
313            name: "increment",
314            args: &[
315                ::pyo3_stub_gen::type_info::ArgInfo {
316                    name: "x",
317                    r#type: || ::pyo3_stub_gen::TypeInfo {
318                        name: "int".to_string(),
319                        import: ::std::collections::HashSet::from([]),
320                    },
321                    signature: None,
322                },
323            ],
324            r#return: || ::pyo3_stub_gen::TypeInfo {
325                name: "int".to_string(),
326                import: ::std::collections::HashSet::from([]),
327            },
328            doc: "Increment by one",
329            r#type: ::pyo3_stub_gen::type_info::MethodType::Instance,
330            is_async: false,
331            deprecated: None,
332            type_ignored: None,
333        }
334        "###);
335        Ok(())
336    }
337
338    #[test]
339    fn test_multiple_methods_class() -> Result<()> {
340        let stub_str: LitStr = syn::parse2(quote! {
341            r#"
342            class Incrementer:
343                def increment_1(self, x: int) -> int:
344                    """First method"""
345
346                def increment_2(self, x: float) -> float:
347                    """Second method"""
348            "#
349        })?;
350        let py_methods_info = parse_python_methods_stub(&stub_str)?;
351        assert_eq!(py_methods_info.methods.len(), 2);
352
353        assert_eq!(py_methods_info.methods[0].name, "increment_1");
354        assert_eq!(py_methods_info.methods[1].name, "increment_2");
355        Ok(())
356    }
357
358    #[test]
359    fn test_static_method_in_class() -> Result<()> {
360        let stub_str: LitStr = syn::parse2(quote! {
361            r#"
362            class MyClass:
363                @staticmethod
364                def create(name: str) -> str:
365                    """Create something"""
366            "#
367        })?;
368        let py_methods_info = parse_python_methods_stub(&stub_str)?;
369        assert_eq!(py_methods_info.methods.len(), 1);
370
371        let out = py_methods_info.methods[0].to_token_stream();
372        insta::assert_snapshot!(format_as_value(out), @r###"
373        ::pyo3_stub_gen::type_info::MethodInfo {
374            name: "create",
375            args: &[
376                ::pyo3_stub_gen::type_info::ArgInfo {
377                    name: "name",
378                    r#type: || ::pyo3_stub_gen::TypeInfo {
379                        name: "str".to_string(),
380                        import: ::std::collections::HashSet::from([]),
381                    },
382                    signature: None,
383                },
384            ],
385            r#return: || ::pyo3_stub_gen::TypeInfo {
386                name: "str".to_string(),
387                import: ::std::collections::HashSet::from([]),
388            },
389            doc: "Create something",
390            r#type: ::pyo3_stub_gen::type_info::MethodType::Static,
391            is_async: false,
392            deprecated: None,
393            type_ignored: None,
394        }
395        "###);
396        Ok(())
397    }
398
399    #[test]
400    fn test_class_method_in_class() -> Result<()> {
401        let stub_str: LitStr = syn::parse2(quote! {
402            r#"
403            class MyClass:
404                @classmethod
405                def from_string(cls, s: str) -> int:
406                    """Create from string"""
407            "#
408        })?;
409        let py_methods_info = parse_python_methods_stub(&stub_str)?;
410        assert_eq!(py_methods_info.methods.len(), 1);
411
412        let out = py_methods_info.methods[0].to_token_stream();
413        insta::assert_snapshot!(format_as_value(out), @r###"
414        ::pyo3_stub_gen::type_info::MethodInfo {
415            name: "from_string",
416            args: &[
417                ::pyo3_stub_gen::type_info::ArgInfo {
418                    name: "s",
419                    r#type: || ::pyo3_stub_gen::TypeInfo {
420                        name: "str".to_string(),
421                        import: ::std::collections::HashSet::from([]),
422                    },
423                    signature: None,
424                },
425            ],
426            r#return: || ::pyo3_stub_gen::TypeInfo {
427                name: "int".to_string(),
428                import: ::std::collections::HashSet::from([]),
429            },
430            doc: "Create from string",
431            r#type: ::pyo3_stub_gen::type_info::MethodType::Class,
432            is_async: false,
433            deprecated: None,
434            type_ignored: None,
435        }
436        "###);
437        Ok(())
438    }
439
440    #[test]
441    fn test_new_method_in_class() -> Result<()> {
442        let stub_str: LitStr = syn::parse2(quote! {
443            r#"
444            class MyClass:
445                def __new__(cls) -> object:
446                    """Constructor"""
447            "#
448        })?;
449        let py_methods_info = parse_python_methods_stub(&stub_str)?;
450        assert_eq!(py_methods_info.methods.len(), 1);
451
452        let out = py_methods_info.methods[0].to_token_stream();
453        insta::assert_snapshot!(format_as_value(out), @r###"
454        ::pyo3_stub_gen::type_info::MethodInfo {
455            name: "__new__",
456            args: &[],
457            r#return: || ::pyo3_stub_gen::TypeInfo {
458                name: "object".to_string(),
459                import: ::std::collections::HashSet::from([]),
460            },
461            doc: "Constructor",
462            r#type: ::pyo3_stub_gen::type_info::MethodType::New,
463            is_async: false,
464            deprecated: None,
465            type_ignored: None,
466        }
467        "###);
468        Ok(())
469    }
470
471    #[test]
472    fn test_method_with_imports_in_class() -> Result<()> {
473        let stub_str: LitStr = syn::parse2(quote! {
474            r#"
475            import typing
476            from collections.abc import Callable
477
478            class MyClass:
479                def process(self, func: Callable[[str], int]) -> typing.Optional[int]:
480                    """Process a callback"""
481            "#
482        })?;
483        let py_methods_info = parse_python_methods_stub(&stub_str)?;
484        assert_eq!(py_methods_info.methods.len(), 1);
485
486        let out = py_methods_info.methods[0].to_token_stream();
487        insta::assert_snapshot!(format_as_value(out), @r###"
488        ::pyo3_stub_gen::type_info::MethodInfo {
489            name: "process",
490            args: &[
491                ::pyo3_stub_gen::type_info::ArgInfo {
492                    name: "func",
493                    r#type: || ::pyo3_stub_gen::TypeInfo {
494                        name: "Callable[[str], int]".to_string(),
495                        import: ::std::collections::HashSet::from([
496                            "typing".into(),
497                            "collections.abc".into(),
498                        ]),
499                    },
500                    signature: None,
501                },
502            ],
503            r#return: || ::pyo3_stub_gen::TypeInfo {
504                name: "typing.Optional[int]".to_string(),
505                import: ::std::collections::HashSet::from([
506                    "typing".into(),
507                    "collections.abc".into(),
508                ]),
509            },
510            doc: "Process a callback",
511            r#type: ::pyo3_stub_gen::type_info::MethodType::Instance,
512            is_async: false,
513            deprecated: None,
514            type_ignored: None,
515        }
516        "###);
517        Ok(())
518    }
519
520    #[test]
521    fn test_async_method_in_class() -> Result<()> {
522        let stub_str: LitStr = syn::parse2(quote! {
523            r#"
524            class MyClass:
525                async def fetch_data(self, url: str) -> str:
526                    """Fetch data asynchronously"""
527            "#
528        })?;
529        let py_methods_info = parse_python_methods_stub(&stub_str)?;
530        assert_eq!(py_methods_info.methods.len(), 1);
531
532        let out = py_methods_info.methods[0].to_token_stream();
533        insta::assert_snapshot!(format_as_value(out), @r###"
534        ::pyo3_stub_gen::type_info::MethodInfo {
535            name: "fetch_data",
536            args: &[
537                ::pyo3_stub_gen::type_info::ArgInfo {
538                    name: "url",
539                    r#type: || ::pyo3_stub_gen::TypeInfo {
540                        name: "str".to_string(),
541                        import: ::std::collections::HashSet::from([]),
542                    },
543                    signature: None,
544                },
545            ],
546            r#return: || ::pyo3_stub_gen::TypeInfo {
547                name: "str".to_string(),
548                import: ::std::collections::HashSet::from([]),
549            },
550            doc: "Fetch data asynchronously",
551            r#type: ::pyo3_stub_gen::type_info::MethodType::Instance,
552            is_async: true,
553            deprecated: None,
554            type_ignored: None,
555        }
556        "###);
557        Ok(())
558    }
559
560    fn format_as_value(tt: TokenStream2) -> String {
561        let ttt = quote! { const _: () = #tt; };
562        let formatted = prettyplease::unparse(&syn::parse_file(&ttt.to_string()).unwrap());
563        formatted
564            .trim()
565            .strip_prefix("const _: () = ")
566            .unwrap()
567            .strip_suffix(';')
568            .unwrap()
569            .to_string()
570    }
571}