pyo3_stub_gen_derive/gen_stub/parse_python/
pyfunction.rs

1//! Parse Python function stub syntax and generate PyFunctionInfo
2
3use rustpython_parser::{ast, Parse};
4use syn::{Error, LitStr, Result};
5
6use super::{
7    dedent, extract_args, extract_deprecated_from_decorators, extract_docstring,
8    extract_return_type,
9};
10use crate::gen_stub::pyfunction::PyFunctionInfo;
11
12/// Intermediate representation for Python function stub
13pub struct PythonFunctionStub {
14    pub func_def: ast::StmtFunctionDef,
15    pub imports: Vec<String>,
16    pub is_async: bool,
17}
18
19impl TryFrom<PythonFunctionStub> for PyFunctionInfo {
20    type Error = syn::Error;
21
22    fn try_from(stub: PythonFunctionStub) -> Result<Self> {
23        let func_name = stub.func_def.name.to_string();
24
25        // Extract docstring
26        let doc = extract_docstring(&stub.func_def);
27
28        // Extract arguments
29        let args = extract_args(&stub.func_def.args, &stub.imports)?;
30
31        // Extract return type
32        let return_type = extract_return_type(&stub.func_def.returns, &stub.imports)?;
33
34        // Try to extract deprecated decorator
35        let deprecated = extract_deprecated_from_decorators(&stub.func_def.decorator_list);
36
37        // Note: type_ignored (# type: ignore comments) cannot be extracted from Python AST
38        // as rustpython-parser doesn't preserve comments
39
40        // Construct PyFunctionInfo
41        Ok(PyFunctionInfo {
42            name: func_name,
43            args,
44            r#return: return_type,
45            sig: None,
46            doc,
47            module: None,
48            is_async: stub.is_async,
49            deprecated,
50            type_ignored: None,
51        })
52    }
53}
54
55/// Parse Python stub string and return PyFunctionInfo
56pub fn parse_python_function_stub(input: LitStr) -> Result<PyFunctionInfo> {
57    let stub_content = input.value();
58
59    // Remove common indentation to allow indented Python code in raw strings
60    let dedented_content = dedent(&stub_content);
61
62    // Parse Python code using rustpython-parser
63    let parsed = ast::Suite::parse(&dedented_content, "<stub>")
64        .map_err(|e| Error::new(input.span(), format!("Failed to parse Python stub: {}", e)))?;
65
66    // Extract imports and function definitions
67    let mut imports = Vec::new();
68    let mut function: Option<(ast::StmtFunctionDef, bool)> = None;
69
70    for stmt in parsed {
71        match stmt {
72            ast::Stmt::Import(import_stmt) => {
73                for alias in &import_stmt.names {
74                    imports.push(alias.name.to_string());
75                }
76            }
77            ast::Stmt::ImportFrom(import_from_stmt) => {
78                if let Some(module) = &import_from_stmt.module {
79                    imports.push(module.to_string());
80                }
81            }
82            ast::Stmt::FunctionDef(func_def) => {
83                if function.is_some() {
84                    return Err(Error::new(
85                        input.span(),
86                        "Multiple function definitions found. Only one function is allowed per gen_function_from_python! call",
87                    ));
88                }
89                function = Some((func_def, false));
90            }
91            ast::Stmt::AsyncFunctionDef(func_def) => {
92                if function.is_some() {
93                    return Err(Error::new(
94                        input.span(),
95                        "Multiple function definitions found. Only one function is allowed per gen_function_from_python! call",
96                    ));
97                }
98                // Convert AsyncFunctionDef to FunctionDef for uniform processing
99                let sync_func = ast::StmtFunctionDef {
100                    range: func_def.range,
101                    name: func_def.name,
102                    type_params: func_def.type_params,
103                    args: func_def.args,
104                    body: func_def.body,
105                    decorator_list: func_def.decorator_list,
106                    returns: func_def.returns,
107                    type_comment: func_def.type_comment,
108                };
109                function = Some((sync_func, true));
110            }
111            _ => {
112                // Ignore other statements
113            }
114        }
115    }
116
117    // Check that exactly one function is defined
118    let (func_def, is_async) = function
119        .ok_or_else(|| Error::new(input.span(), "No function definition found in Python stub"))?;
120
121    // Generate PyFunctionInfo using TryFrom
122    let stub = PythonFunctionStub {
123        func_def,
124        imports,
125        is_async,
126    };
127    PyFunctionInfo::try_from(stub)
128}
129
130#[cfg(test)]
131mod test {
132    use super::*;
133    use proc_macro2::TokenStream as TokenStream2;
134    use quote::{quote, ToTokens};
135
136    #[test]
137    fn test_basic_function() -> Result<()> {
138        let stub_str: LitStr = syn::parse2(quote! {
139            r#"
140            def foo(x: int) -> int:
141                """A simple function"""
142            "#
143        })?;
144        let info = parse_python_function_stub(stub_str)?;
145        let out = info.to_token_stream();
146        insta::assert_snapshot!(format_as_value(out), @r###"
147        ::pyo3_stub_gen::type_info::PyFunctionInfo {
148            name: "foo",
149            args: &[
150                ::pyo3_stub_gen::type_info::ArgInfo {
151                    name: "x",
152                    r#type: || ::pyo3_stub_gen::TypeInfo {
153                        name: "int".to_string(),
154                        import: ::std::collections::HashSet::from([]),
155                    },
156                    signature: None,
157                },
158            ],
159            r#return: || ::pyo3_stub_gen::TypeInfo {
160                name: "int".to_string(),
161                import: ::std::collections::HashSet::from([]),
162            },
163            doc: "A simple function",
164            module: None,
165            is_async: false,
166            deprecated: None,
167            type_ignored: None,
168        }
169        "###);
170        Ok(())
171    }
172
173    #[test]
174    fn test_function_with_imports() -> Result<()> {
175        let stub_str: LitStr = syn::parse2(quote! {
176            r#"
177            import typing
178            from collections.abc import Callable
179
180            def process(func: Callable[[str], int]) -> typing.Optional[int]:
181                """Process a callback function"""
182            "#
183        })?;
184        let info = parse_python_function_stub(stub_str)?;
185        let out = info.to_token_stream();
186        insta::assert_snapshot!(format_as_value(out), @r###"
187        ::pyo3_stub_gen::type_info::PyFunctionInfo {
188            name: "process",
189            args: &[
190                ::pyo3_stub_gen::type_info::ArgInfo {
191                    name: "func",
192                    r#type: || ::pyo3_stub_gen::TypeInfo {
193                        name: "Callable[[str], int]".to_string(),
194                        import: ::std::collections::HashSet::from([
195                            "typing".into(),
196                            "collections.abc".into(),
197                        ]),
198                    },
199                    signature: None,
200                },
201            ],
202            r#return: || ::pyo3_stub_gen::TypeInfo {
203                name: "typing.Optional[int]".to_string(),
204                import: ::std::collections::HashSet::from([
205                    "typing".into(),
206                    "collections.abc".into(),
207                ]),
208            },
209            doc: "Process a callback function",
210            module: None,
211            is_async: false,
212            deprecated: None,
213            type_ignored: None,
214        }
215        "###);
216        Ok(())
217    }
218
219    #[test]
220    fn test_complex_types() -> Result<()> {
221        let stub_str: LitStr = syn::parse2(quote! {
222            r#"
223            import collections.abc
224            import typing
225
226            def fn_override_type(cb: collections.abc.Callable[[str], typing.Any]) -> collections.abc.Callable[[str], typing.Any]:
227                """Example function with complex types"""
228            "#
229        })?;
230        let info = parse_python_function_stub(stub_str)?;
231        let out = info.to_token_stream();
232        insta::assert_snapshot!(format_as_value(out), @r###"
233        ::pyo3_stub_gen::type_info::PyFunctionInfo {
234            name: "fn_override_type",
235            args: &[
236                ::pyo3_stub_gen::type_info::ArgInfo {
237                    name: "cb",
238                    r#type: || ::pyo3_stub_gen::TypeInfo {
239                        name: "collections.abc.Callable[[str], typing.Any]".to_string(),
240                        import: ::std::collections::HashSet::from([
241                            "collections.abc".into(),
242                            "typing".into(),
243                        ]),
244                    },
245                    signature: None,
246                },
247            ],
248            r#return: || ::pyo3_stub_gen::TypeInfo {
249                name: "collections.abc.Callable[[str], typing.Any]".to_string(),
250                import: ::std::collections::HashSet::from([
251                    "collections.abc".into(),
252                    "typing".into(),
253                ]),
254            },
255            doc: "Example function with complex types",
256            module: None,
257            is_async: false,
258            deprecated: None,
259            type_ignored: None,
260        }
261        "###);
262        Ok(())
263    }
264
265    #[test]
266    fn test_multiple_args() -> Result<()> {
267        let stub_str: LitStr = syn::parse2(quote! {
268            r#"
269            import typing
270
271            def add(a: int, b: int, c: typing.Optional[int]) -> int: ...
272            "#
273        })?;
274        let info = parse_python_function_stub(stub_str)?;
275        let out = info.to_token_stream();
276        insta::assert_snapshot!(format_as_value(out), @r###"
277        ::pyo3_stub_gen::type_info::PyFunctionInfo {
278            name: "add",
279            args: &[
280                ::pyo3_stub_gen::type_info::ArgInfo {
281                    name: "a",
282                    r#type: || ::pyo3_stub_gen::TypeInfo {
283                        name: "int".to_string(),
284                        import: ::std::collections::HashSet::from(["typing".into()]),
285                    },
286                    signature: None,
287                },
288                ::pyo3_stub_gen::type_info::ArgInfo {
289                    name: "b",
290                    r#type: || ::pyo3_stub_gen::TypeInfo {
291                        name: "int".to_string(),
292                        import: ::std::collections::HashSet::from(["typing".into()]),
293                    },
294                    signature: None,
295                },
296                ::pyo3_stub_gen::type_info::ArgInfo {
297                    name: "c",
298                    r#type: || ::pyo3_stub_gen::TypeInfo {
299                        name: "typing.Optional[int]".to_string(),
300                        import: ::std::collections::HashSet::from(["typing".into()]),
301                    },
302                    signature: None,
303                },
304            ],
305            r#return: || ::pyo3_stub_gen::TypeInfo {
306                name: "int".to_string(),
307                import: ::std::collections::HashSet::from(["typing".into()]),
308            },
309            doc: "",
310            module: None,
311            is_async: false,
312            deprecated: None,
313            type_ignored: None,
314        }
315        "###);
316        Ok(())
317    }
318
319    #[test]
320    fn test_no_return_type() -> Result<()> {
321        let stub_str: LitStr = syn::parse2(quote! {
322            r#"
323            def print_hello(name: str):
324                """Print a greeting"""
325            "#
326        })?;
327        let info = parse_python_function_stub(stub_str)?;
328        let out = info.to_token_stream();
329        insta::assert_snapshot!(format_as_value(out), @r###"
330        ::pyo3_stub_gen::type_info::PyFunctionInfo {
331            name: "print_hello",
332            args: &[
333                ::pyo3_stub_gen::type_info::ArgInfo {
334                    name: "name",
335                    r#type: || ::pyo3_stub_gen::TypeInfo {
336                        name: "str".to_string(),
337                        import: ::std::collections::HashSet::from([]),
338                    },
339                    signature: None,
340                },
341            ],
342            r#return: ::pyo3_stub_gen::type_info::no_return_type_output,
343            doc: "Print a greeting",
344            module: None,
345            is_async: false,
346            deprecated: None,
347            type_ignored: None,
348        }
349        "###);
350        Ok(())
351    }
352
353    #[test]
354    fn test_async_function() -> Result<()> {
355        let stub_str: LitStr = syn::parse2(quote! {
356            r#"
357            async def fetch_data(url: str) -> str:
358                """Fetch data from URL"""
359            "#
360        })?;
361        let info = parse_python_function_stub(stub_str)?;
362        let out = info.to_token_stream();
363        insta::assert_snapshot!(format_as_value(out), @r###"
364        ::pyo3_stub_gen::type_info::PyFunctionInfo {
365            name: "fetch_data",
366            args: &[
367                ::pyo3_stub_gen::type_info::ArgInfo {
368                    name: "url",
369                    r#type: || ::pyo3_stub_gen::TypeInfo {
370                        name: "str".to_string(),
371                        import: ::std::collections::HashSet::from([]),
372                    },
373                    signature: None,
374                },
375            ],
376            r#return: || ::pyo3_stub_gen::TypeInfo {
377                name: "str".to_string(),
378                import: ::std::collections::HashSet::from([]),
379            },
380            doc: "Fetch data from URL",
381            module: None,
382            is_async: true,
383            deprecated: None,
384            type_ignored: None,
385        }
386        "###);
387        Ok(())
388    }
389
390    #[test]
391    fn test_deprecated_decorator() -> Result<()> {
392        let stub_str: LitStr = syn::parse2(quote! {
393            r#"
394            @deprecated
395            def old_function(x: int) -> int:
396                """This function is deprecated"""
397            "#
398        })?;
399        let info = parse_python_function_stub(stub_str)?;
400        let out = info.to_token_stream();
401        insta::assert_snapshot!(format_as_value(out), @r###"
402        ::pyo3_stub_gen::type_info::PyFunctionInfo {
403            name: "old_function",
404            args: &[
405                ::pyo3_stub_gen::type_info::ArgInfo {
406                    name: "x",
407                    r#type: || ::pyo3_stub_gen::TypeInfo {
408                        name: "int".to_string(),
409                        import: ::std::collections::HashSet::from([]),
410                    },
411                    signature: None,
412                },
413            ],
414            r#return: || ::pyo3_stub_gen::TypeInfo {
415                name: "int".to_string(),
416                import: ::std::collections::HashSet::from([]),
417            },
418            doc: "This function is deprecated",
419            module: None,
420            is_async: false,
421            deprecated: Some(::pyo3_stub_gen::type_info::DeprecatedInfo {
422                since: None,
423                note: None,
424            }),
425            type_ignored: None,
426        }
427        "###);
428        Ok(())
429    }
430
431    #[test]
432    fn test_deprecated_with_message() -> Result<()> {
433        let stub_str: LitStr = syn::parse2(quote! {
434            r#"
435            @deprecated("Use new_function instead")
436            def old_function(x: int) -> int:
437                """This function is deprecated"""
438            "#
439        })?;
440        let info = parse_python_function_stub(stub_str)?;
441        let out = info.to_token_stream();
442        insta::assert_snapshot!(format_as_value(out), @r###"
443        ::pyo3_stub_gen::type_info::PyFunctionInfo {
444            name: "old_function",
445            args: &[
446                ::pyo3_stub_gen::type_info::ArgInfo {
447                    name: "x",
448                    r#type: || ::pyo3_stub_gen::TypeInfo {
449                        name: "int".to_string(),
450                        import: ::std::collections::HashSet::from([]),
451                    },
452                    signature: None,
453                },
454            ],
455            r#return: || ::pyo3_stub_gen::TypeInfo {
456                name: "int".to_string(),
457                import: ::std::collections::HashSet::from([]),
458            },
459            doc: "This function is deprecated",
460            module: None,
461            is_async: false,
462            deprecated: Some(::pyo3_stub_gen::type_info::DeprecatedInfo {
463                since: None,
464                note: Some("Use new_function instead"),
465            }),
466            type_ignored: None,
467        }
468        "###);
469        Ok(())
470    }
471
472    fn format_as_value(tt: TokenStream2) -> String {
473        let ttt = quote! { const _: () = #tt; };
474        let formatted = prettyplease::unparse(&syn::parse_file(&ttt.to_string()).unwrap());
475        formatted
476            .trim()
477            .strip_prefix("const _: () = ")
478            .unwrap()
479            .strip_suffix(';')
480            .unwrap()
481            .to_string()
482    }
483}