pyo3_stub_gen_derive/gen_stub/parse_python/
pyfunction.rs

1//! Parse Python function stub syntax and generate PyFunctionInfo
2
3use rustpython_parser::{ast, Parse};
4use syn::{parse::Parse as SynParse, parse::ParseStream, Error, LitStr, Result};
5
6use super::{
7    build_parameters_from_ast, dedent, extract_deprecated_from_decorators, extract_docstring,
8    extract_return_type, has_overload_decorator,
9};
10use crate::gen_stub::pyfunction::PyFunctionInfo;
11
12/// Input for gen_function_from_python! macro
13pub struct GenFunctionFromPythonInput {
14    module: Option<String>,
15    python_stub: LitStr,
16}
17
18impl SynParse for GenFunctionFromPythonInput {
19    fn parse(input: ParseStream) -> Result<Self> {
20        // Check if first token is an identifier (for module parameter)
21        if input.peek(syn::Ident) {
22            let key: syn::Ident = input.parse()?;
23            if key == "module" {
24                let _: syn::token::Eq = input.parse()?;
25                let value: LitStr = input.parse()?;
26                let _: syn::token::Comma = input.parse()?;
27                let python_stub: LitStr = input.parse()?;
28                return Ok(Self {
29                    module: Some(value.value()),
30                    python_stub,
31                });
32            } else {
33                return Err(Error::new(
34                    key.span(),
35                    format!(
36                        "Unknown parameter: {}. Expected 'module' or a string literal",
37                        key
38                    ),
39                ));
40            }
41        }
42
43        // No module parameter, just parse the string literal
44        let python_stub: LitStr = input.parse()?;
45        Ok(Self {
46            module: None,
47            python_stub,
48        })
49    }
50}
51
52/// Intermediate representation for Python function stub
53pub struct PythonFunctionStub {
54    pub func_def: ast::StmtFunctionDef,
55    pub imports: Vec<String>,
56    pub is_async: bool,
57    pub is_overload: bool,
58}
59
60impl TryFrom<PythonFunctionStub> for PyFunctionInfo {
61    type Error = syn::Error;
62
63    fn try_from(stub: PythonFunctionStub) -> Result<Self> {
64        let func_name = stub.func_def.name.to_string();
65
66        // Extract docstring
67        let doc = extract_docstring(&stub.func_def);
68
69        // Build Parameters directly from Python AST with proper kind classification
70        let parameters = build_parameters_from_ast(&stub.func_def.args, &stub.imports)?;
71
72        // Extract return type
73        let return_type = extract_return_type(&stub.func_def.returns, &stub.imports)?;
74
75        // Try to extract deprecated decorator
76        let deprecated = extract_deprecated_from_decorators(&stub.func_def.decorator_list);
77
78        // Note: type_ignored (# type: ignore comments) cannot be extracted from Python AST
79        // as rustpython-parser doesn't preserve comments
80
81        // Construct PyFunctionInfo
82        Ok(PyFunctionInfo {
83            name: func_name,
84            parameters, // Use pre-built Parameters from Python AST
85            r#return: return_type,
86            doc,
87            module: None,
88            is_async: stub.is_async,
89            deprecated,
90            type_ignored: None,
91            is_overload: stub.is_overload,
92            index: 0, // Will be set by caller when generating multiple overloads
93        })
94    }
95}
96
97/// Parse Python stub string and return PyFunctionInfo
98pub fn parse_python_function_stub(input: LitStr) -> Result<PyFunctionInfo> {
99    let stub_content = input.value();
100
101    // Remove common indentation to allow indented Python code in raw strings
102    let dedented_content = dedent(&stub_content);
103
104    // Parse Python code using rustpython-parser
105    let parsed = ast::Suite::parse(&dedented_content, "<stub>")
106        .map_err(|e| Error::new(input.span(), format!("Failed to parse Python stub: {}", e)))?;
107
108    // Extract imports and function definitions
109    let mut imports = Vec::new();
110    let mut function: Option<(ast::StmtFunctionDef, bool)> = None;
111
112    for stmt in parsed {
113        match stmt {
114            ast::Stmt::Import(import_stmt) => {
115                for alias in &import_stmt.names {
116                    imports.push(alias.name.to_string());
117                }
118            }
119            ast::Stmt::ImportFrom(import_from_stmt) => {
120                if let Some(module) = &import_from_stmt.module {
121                    imports.push(module.to_string());
122                }
123            }
124            ast::Stmt::FunctionDef(func_def) => {
125                if function.is_some() {
126                    return Err(Error::new(
127                        input.span(),
128                        "Multiple function definitions found. Only one function is allowed per gen_function_from_python! call",
129                    ));
130                }
131                function = Some((func_def, false));
132            }
133            ast::Stmt::AsyncFunctionDef(func_def) => {
134                if function.is_some() {
135                    return Err(Error::new(
136                        input.span(),
137                        "Multiple function definitions found. Only one function is allowed per gen_function_from_python! call",
138                    ));
139                }
140                // Convert AsyncFunctionDef to FunctionDef for uniform processing
141                let sync_func = ast::StmtFunctionDef {
142                    range: func_def.range,
143                    name: func_def.name,
144                    type_params: func_def.type_params,
145                    args: func_def.args,
146                    body: func_def.body,
147                    decorator_list: func_def.decorator_list,
148                    returns: func_def.returns,
149                    type_comment: func_def.type_comment,
150                };
151                function = Some((sync_func, true));
152            }
153            _ => {
154                // Ignore other statements
155            }
156        }
157    }
158
159    // Check that exactly one function is defined
160    let (func_def, is_async) = function
161        .ok_or_else(|| Error::new(input.span(), "No function definition found in Python stub"))?;
162
163    // Check if function has @overload decorator
164    let is_overload = has_overload_decorator(&func_def.decorator_list);
165
166    // Generate PyFunctionInfo using TryFrom
167    let stub = PythonFunctionStub {
168        func_def,
169        imports,
170        is_async,
171        is_overload,
172    };
173    PyFunctionInfo::try_from(stub)
174}
175
176/// Parse multiple overload function definitions from Python stub string
177/// Used for the `python_overload` parameter
178pub fn parse_python_overload_stubs(
179    input: LitStr,
180    expected_function_name: &str,
181) -> Result<Vec<PyFunctionInfo>> {
182    let stub_content = input.value();
183    let dedented_content = dedent(&stub_content);
184
185    // Parse Python code using rustpython-parser
186    let parsed = ast::Suite::parse(&dedented_content, "<stub>")
187        .map_err(|e| Error::new(input.span(), format!("Failed to parse Python stub: {}", e)))?;
188
189    // Extract imports and function definitions
190    let mut imports = Vec::new();
191    let mut functions: Vec<(ast::StmtFunctionDef, bool)> = Vec::new();
192
193    for stmt in parsed {
194        match stmt {
195            ast::Stmt::Import(import_stmt) => {
196                for alias in &import_stmt.names {
197                    imports.push(alias.name.to_string());
198                }
199            }
200            ast::Stmt::ImportFrom(import_from_stmt) => {
201                if let Some(module) = &import_from_stmt.module {
202                    imports.push(module.to_string());
203                }
204            }
205            ast::Stmt::FunctionDef(func_def) => {
206                functions.push((func_def, false));
207            }
208            ast::Stmt::AsyncFunctionDef(func_def) => {
209                // Convert AsyncFunctionDef to FunctionDef for uniform processing
210                let sync_func = ast::StmtFunctionDef {
211                    range: func_def.range,
212                    name: func_def.name,
213                    type_params: func_def.type_params,
214                    args: func_def.args,
215                    body: func_def.body,
216                    decorator_list: func_def.decorator_list,
217                    returns: func_def.returns,
218                    type_comment: func_def.type_comment,
219                };
220                functions.push((sync_func, true));
221            }
222            _ => {
223                // Ignore other statements
224            }
225        }
226    }
227
228    // Check that at least one function is defined
229    if functions.is_empty() {
230        return Err(Error::new(
231            input.span(),
232            "No function definition found in python_overload parameter",
233        ));
234    }
235
236    // Validate all functions
237    let mut result = Vec::new();
238    for (func_def, is_async) in functions {
239        let func_name = func_def.name.to_string();
240
241        // Validate: function name must match expected name
242        if func_name != expected_function_name {
243            return Err(Error::new(
244                input.span(),
245                format!(
246                    "Function name '{}' in python_overload does not match Rust function name '{}'. Please ensure all overload function names match the Rust function name.",
247                    func_name, expected_function_name
248                ),
249            ));
250        }
251
252        // Validate: all functions must have @overload decorator
253        let is_overload = has_overload_decorator(&func_def.decorator_list);
254        if !is_overload {
255            return Err(Error::new(
256                input.span(),
257                format!(
258                    "Function '{}' in python_overload must have @overload decorator",
259                    func_name
260                ),
261            ));
262        }
263
264        // Create PyFunctionInfo
265        let stub = PythonFunctionStub {
266            func_def,
267            imports: imports.clone(),
268            is_async,
269            is_overload,
270        };
271        result.push(PyFunctionInfo::try_from(stub)?);
272    }
273
274    Ok(result)
275}
276
277/// Parse gen_function_from_python! input with optional module parameter
278pub fn parse_gen_function_from_python_input(
279    input: GenFunctionFromPythonInput,
280) -> Result<PyFunctionInfo> {
281    let mut info = parse_python_function_stub(input.python_stub)?;
282
283    // Set module if provided
284    if let Some(module) = input.module {
285        info.module = Some(module);
286    }
287
288    Ok(info)
289}
290
291#[cfg(test)]
292mod test {
293    use super::*;
294    use proc_macro2::TokenStream as TokenStream2;
295    use quote::{quote, ToTokens};
296
297    #[test]
298    fn test_basic_function() -> Result<()> {
299        let stub_str: LitStr = syn::parse2(quote! {
300            r#"
301            def foo(x: int) -> int:
302                """A simple function"""
303            "#
304        })?;
305        let info = parse_python_function_stub(stub_str)?;
306        let out = info.to_token_stream();
307        insta::assert_snapshot!(format_as_value(out), @r###"
308        ::pyo3_stub_gen::type_info::PyFunctionInfo {
309            name: "foo",
310            parameters: &[
311                ::pyo3_stub_gen::type_info::ParameterInfo {
312                    name: "x",
313                    kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
314                    type_info: || ::pyo3_stub_gen::TypeInfo {
315                        name: "int".to_string(),
316                        import: ::std::collections::HashSet::from([]),
317                    },
318                    default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
319                },
320            ],
321            r#return: || ::pyo3_stub_gen::TypeInfo {
322                name: "int".to_string(),
323                import: ::std::collections::HashSet::from([]),
324            },
325            doc: "A simple function",
326            module: None,
327            is_async: false,
328            deprecated: None,
329            type_ignored: None,
330            is_overload: false,
331            file: file!(),
332            line: line!(),
333            column: column!(),
334            index: 0usize,
335        }
336        "###);
337        Ok(())
338    }
339
340    #[test]
341    fn test_function_with_imports() -> Result<()> {
342        let stub_str: LitStr = syn::parse2(quote! {
343            r#"
344            import typing
345            from collections.abc import Callable
346
347            def process(func: Callable[[str], int]) -> typing.Optional[int]:
348                """Process a callback function"""
349            "#
350        })?;
351        let info = parse_python_function_stub(stub_str)?;
352        let out = info.to_token_stream();
353        insta::assert_snapshot!(format_as_value(out), @r###"
354        ::pyo3_stub_gen::type_info::PyFunctionInfo {
355            name: "process",
356            parameters: &[
357                ::pyo3_stub_gen::type_info::ParameterInfo {
358                    name: "func",
359                    kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
360                    type_info: || ::pyo3_stub_gen::TypeInfo {
361                        name: "Callable[[str], int]".to_string(),
362                        import: ::std::collections::HashSet::from([
363                            "typing".into(),
364                            "collections.abc".into(),
365                        ]),
366                    },
367                    default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
368                },
369            ],
370            r#return: || ::pyo3_stub_gen::TypeInfo {
371                name: "typing.Optional[int]".to_string(),
372                import: ::std::collections::HashSet::from([
373                    "typing".into(),
374                    "collections.abc".into(),
375                ]),
376            },
377            doc: "Process a callback function",
378            module: None,
379            is_async: false,
380            deprecated: None,
381            type_ignored: None,
382            is_overload: false,
383            file: file!(),
384            line: line!(),
385            column: column!(),
386            index: 0usize,
387        }
388        "###);
389        Ok(())
390    }
391
392    #[test]
393    fn test_complex_types() -> Result<()> {
394        let stub_str: LitStr = syn::parse2(quote! {
395            r#"
396            import collections.abc
397            import typing
398
399            def fn_override_type(cb: collections.abc.Callable[[str], typing.Any]) -> collections.abc.Callable[[str], typing.Any]:
400                """Example function with complex types"""
401            "#
402        })?;
403        let info = parse_python_function_stub(stub_str)?;
404        let out = info.to_token_stream();
405        insta::assert_snapshot!(format_as_value(out), @r###"
406        ::pyo3_stub_gen::type_info::PyFunctionInfo {
407            name: "fn_override_type",
408            parameters: &[
409                ::pyo3_stub_gen::type_info::ParameterInfo {
410                    name: "cb",
411                    kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
412                    type_info: || ::pyo3_stub_gen::TypeInfo {
413                        name: "collections.abc.Callable[[str], typing.Any]".to_string(),
414                        import: ::std::collections::HashSet::from([
415                            "collections.abc".into(),
416                            "typing".into(),
417                        ]),
418                    },
419                    default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
420                },
421            ],
422            r#return: || ::pyo3_stub_gen::TypeInfo {
423                name: "collections.abc.Callable[[str], typing.Any]".to_string(),
424                import: ::std::collections::HashSet::from([
425                    "collections.abc".into(),
426                    "typing".into(),
427                ]),
428            },
429            doc: "Example function with complex types",
430            module: None,
431            is_async: false,
432            deprecated: None,
433            type_ignored: None,
434            is_overload: false,
435            file: file!(),
436            line: line!(),
437            column: column!(),
438            index: 0usize,
439        }
440        "###);
441        Ok(())
442    }
443
444    #[test]
445    fn test_multiple_args() -> Result<()> {
446        let stub_str: LitStr = syn::parse2(quote! {
447            r#"
448            import typing
449
450            def add(a: int, b: int, c: typing.Optional[int]) -> int: ...
451            "#
452        })?;
453        let info = parse_python_function_stub(stub_str)?;
454        let out = info.to_token_stream();
455        insta::assert_snapshot!(format_as_value(out), @r###"
456        ::pyo3_stub_gen::type_info::PyFunctionInfo {
457            name: "add",
458            parameters: &[
459                ::pyo3_stub_gen::type_info::ParameterInfo {
460                    name: "a",
461                    kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
462                    type_info: || ::pyo3_stub_gen::TypeInfo {
463                        name: "int".to_string(),
464                        import: ::std::collections::HashSet::from(["typing".into()]),
465                    },
466                    default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
467                },
468                ::pyo3_stub_gen::type_info::ParameterInfo {
469                    name: "b",
470                    kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
471                    type_info: || ::pyo3_stub_gen::TypeInfo {
472                        name: "int".to_string(),
473                        import: ::std::collections::HashSet::from(["typing".into()]),
474                    },
475                    default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
476                },
477                ::pyo3_stub_gen::type_info::ParameterInfo {
478                    name: "c",
479                    kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
480                    type_info: || ::pyo3_stub_gen::TypeInfo {
481                        name: "typing.Optional[int]".to_string(),
482                        import: ::std::collections::HashSet::from(["typing".into()]),
483                    },
484                    default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
485                },
486            ],
487            r#return: || ::pyo3_stub_gen::TypeInfo {
488                name: "int".to_string(),
489                import: ::std::collections::HashSet::from(["typing".into()]),
490            },
491            doc: "",
492            module: None,
493            is_async: false,
494            deprecated: None,
495            type_ignored: None,
496            is_overload: false,
497            file: file!(),
498            line: line!(),
499            column: column!(),
500            index: 0usize,
501        }
502        "###);
503        Ok(())
504    }
505
506    #[test]
507    fn test_no_return_type() -> Result<()> {
508        let stub_str: LitStr = syn::parse2(quote! {
509            r#"
510            def print_hello(name: str):
511                """Print a greeting"""
512            "#
513        })?;
514        let info = parse_python_function_stub(stub_str)?;
515        let out = info.to_token_stream();
516        insta::assert_snapshot!(format_as_value(out), @r###"
517        ::pyo3_stub_gen::type_info::PyFunctionInfo {
518            name: "print_hello",
519            parameters: &[
520                ::pyo3_stub_gen::type_info::ParameterInfo {
521                    name: "name",
522                    kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
523                    type_info: || ::pyo3_stub_gen::TypeInfo {
524                        name: "str".to_string(),
525                        import: ::std::collections::HashSet::from([]),
526                    },
527                    default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
528                },
529            ],
530            r#return: ::pyo3_stub_gen::type_info::no_return_type_output,
531            doc: "Print a greeting",
532            module: None,
533            is_async: false,
534            deprecated: None,
535            type_ignored: None,
536            is_overload: false,
537            file: file!(),
538            line: line!(),
539            column: column!(),
540            index: 0usize,
541        }
542        "###);
543        Ok(())
544    }
545
546    #[test]
547    fn test_async_function() -> Result<()> {
548        let stub_str: LitStr = syn::parse2(quote! {
549            r#"
550            async def fetch_data(url: str) -> str:
551                """Fetch data from URL"""
552            "#
553        })?;
554        let info = parse_python_function_stub(stub_str)?;
555        let out = info.to_token_stream();
556        insta::assert_snapshot!(format_as_value(out), @r###"
557        ::pyo3_stub_gen::type_info::PyFunctionInfo {
558            name: "fetch_data",
559            parameters: &[
560                ::pyo3_stub_gen::type_info::ParameterInfo {
561                    name: "url",
562                    kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
563                    type_info: || ::pyo3_stub_gen::TypeInfo {
564                        name: "str".to_string(),
565                        import: ::std::collections::HashSet::from([]),
566                    },
567                    default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
568                },
569            ],
570            r#return: || ::pyo3_stub_gen::TypeInfo {
571                name: "str".to_string(),
572                import: ::std::collections::HashSet::from([]),
573            },
574            doc: "Fetch data from URL",
575            module: None,
576            is_async: true,
577            deprecated: None,
578            type_ignored: None,
579            is_overload: false,
580            file: file!(),
581            line: line!(),
582            column: column!(),
583            index: 0usize,
584        }
585        "###);
586        Ok(())
587    }
588
589    #[test]
590    fn test_deprecated_decorator() -> Result<()> {
591        let stub_str: LitStr = syn::parse2(quote! {
592            r#"
593            @deprecated
594            def old_function(x: int) -> int:
595                """This function is deprecated"""
596            "#
597        })?;
598        let info = parse_python_function_stub(stub_str)?;
599        let out = info.to_token_stream();
600        insta::assert_snapshot!(format_as_value(out), @r###"
601        ::pyo3_stub_gen::type_info::PyFunctionInfo {
602            name: "old_function",
603            parameters: &[
604                ::pyo3_stub_gen::type_info::ParameterInfo {
605                    name: "x",
606                    kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
607                    type_info: || ::pyo3_stub_gen::TypeInfo {
608                        name: "int".to_string(),
609                        import: ::std::collections::HashSet::from([]),
610                    },
611                    default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
612                },
613            ],
614            r#return: || ::pyo3_stub_gen::TypeInfo {
615                name: "int".to_string(),
616                import: ::std::collections::HashSet::from([]),
617            },
618            doc: "This function is deprecated",
619            module: None,
620            is_async: false,
621            deprecated: Some(::pyo3_stub_gen::type_info::DeprecatedInfo {
622                since: None,
623                note: None,
624            }),
625            type_ignored: None,
626            is_overload: false,
627            file: file!(),
628            line: line!(),
629            column: column!(),
630            index: 0usize,
631        }
632        "###);
633        Ok(())
634    }
635
636    #[test]
637    fn test_deprecated_with_message() -> Result<()> {
638        let stub_str: LitStr = syn::parse2(quote! {
639            r#"
640            @deprecated("Use new_function instead")
641            def old_function(x: int) -> int:
642                """This function is deprecated"""
643            "#
644        })?;
645        let info = parse_python_function_stub(stub_str)?;
646        let out = info.to_token_stream();
647        insta::assert_snapshot!(format_as_value(out), @r###"
648        ::pyo3_stub_gen::type_info::PyFunctionInfo {
649            name: "old_function",
650            parameters: &[
651                ::pyo3_stub_gen::type_info::ParameterInfo {
652                    name: "x",
653                    kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
654                    type_info: || ::pyo3_stub_gen::TypeInfo {
655                        name: "int".to_string(),
656                        import: ::std::collections::HashSet::from([]),
657                    },
658                    default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
659                },
660            ],
661            r#return: || ::pyo3_stub_gen::TypeInfo {
662                name: "int".to_string(),
663                import: ::std::collections::HashSet::from([]),
664            },
665            doc: "This function is deprecated",
666            module: None,
667            is_async: false,
668            deprecated: Some(::pyo3_stub_gen::type_info::DeprecatedInfo {
669                since: None,
670                note: Some("Use new_function instead"),
671            }),
672            type_ignored: None,
673            is_overload: false,
674            file: file!(),
675            line: line!(),
676            column: column!(),
677            index: 0usize,
678        }
679        "###);
680        Ok(())
681    }
682
683    #[test]
684    fn test_rust_type_marker() -> Result<()> {
685        let stub_str: LitStr = syn::parse2(quote! {
686            r#"
687            def process_data(x: pyo3_stub_gen.RustType["MyRustType"]) -> pyo3_stub_gen.RustType["MyRustType"]:
688                """Process data using Rust type marker"""
689            "#
690        })?;
691        let info = parse_python_function_stub(stub_str)?;
692        let out = info.to_token_stream();
693        insta::assert_snapshot!(format_as_value(out), @r###"
694        ::pyo3_stub_gen::type_info::PyFunctionInfo {
695            name: "process_data",
696            parameters: &[
697                ::pyo3_stub_gen::type_info::ParameterInfo {
698                    name: "x",
699                    kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
700                    type_info: <MyRustType as ::pyo3_stub_gen::PyStubType>::type_input,
701                    default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
702                },
703            ],
704            r#return: <MyRustType as pyo3_stub_gen::PyStubType>::type_output,
705            doc: "Process data using Rust type marker",
706            module: None,
707            is_async: false,
708            deprecated: None,
709            type_ignored: None,
710            is_overload: false,
711            file: file!(),
712            line: line!(),
713            column: column!(),
714            index: 0usize,
715        }
716        "###);
717        Ok(())
718    }
719
720    #[test]
721    fn test_rust_type_marker_with_path() -> Result<()> {
722        let stub_str: LitStr = syn::parse2(quote! {
723            r#"
724            def process(x: pyo3_stub_gen.RustType["crate::MyType"]) -> pyo3_stub_gen.RustType["Vec<String>"]:
725                """Test with type paths"""
726            "#
727        })?;
728        let info = parse_python_function_stub(stub_str)?;
729        let out = info.to_token_stream();
730        insta::assert_snapshot!(format_as_value(out), @r###"
731        ::pyo3_stub_gen::type_info::PyFunctionInfo {
732            name: "process",
733            parameters: &[
734                ::pyo3_stub_gen::type_info::ParameterInfo {
735                    name: "x",
736                    kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
737                    type_info: <crate::MyType as ::pyo3_stub_gen::PyStubType>::type_input,
738                    default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
739                },
740            ],
741            r#return: <Vec<String> as pyo3_stub_gen::PyStubType>::type_output,
742            doc: "Test with type paths",
743            module: None,
744            is_async: false,
745            deprecated: None,
746            type_ignored: None,
747            is_overload: false,
748            file: file!(),
749            line: line!(),
750            column: column!(),
751            index: 0usize,
752        }
753        "###);
754        Ok(())
755    }
756
757    #[test]
758    fn test_keyword_only_args() -> Result<()> {
759        let stub_str: LitStr = syn::parse2(quote! {
760            r#"
761            import typing
762
763            def configure(name: str, *, dtype: str, ndim: int, jagged: bool = False) -> None:
764                """Test keyword-only parameters"""
765            "#
766        })?;
767        let info = parse_python_function_stub(stub_str)?;
768        let out = info.to_token_stream();
769        insta::assert_snapshot!(format_as_value(out), @r###"
770        ::pyo3_stub_gen::type_info::PyFunctionInfo {
771            name: "configure",
772            parameters: &[
773                ::pyo3_stub_gen::type_info::ParameterInfo {
774                    name: "name",
775                    kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
776                    type_info: || ::pyo3_stub_gen::TypeInfo {
777                        name: "str".to_string(),
778                        import: ::std::collections::HashSet::from(["typing".into()]),
779                    },
780                    default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
781                },
782                ::pyo3_stub_gen::type_info::ParameterInfo {
783                    name: "dtype",
784                    kind: ::pyo3_stub_gen::type_info::ParameterKind::KeywordOnly,
785                    type_info: || ::pyo3_stub_gen::TypeInfo {
786                        name: "str".to_string(),
787                        import: ::std::collections::HashSet::from(["typing".into()]),
788                    },
789                    default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
790                },
791                ::pyo3_stub_gen::type_info::ParameterInfo {
792                    name: "ndim",
793                    kind: ::pyo3_stub_gen::type_info::ParameterKind::KeywordOnly,
794                    type_info: || ::pyo3_stub_gen::TypeInfo {
795                        name: "int".to_string(),
796                        import: ::std::collections::HashSet::from(["typing".into()]),
797                    },
798                    default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
799                },
800                ::pyo3_stub_gen::type_info::ParameterInfo {
801                    name: "jagged",
802                    kind: ::pyo3_stub_gen::type_info::ParameterKind::KeywordOnly,
803                    type_info: || ::pyo3_stub_gen::TypeInfo {
804                        name: "bool".to_string(),
805                        import: ::std::collections::HashSet::from(["typing".into()]),
806                    },
807                    default: ::pyo3_stub_gen::type_info::ParameterDefault::Expr({
808                        fn _fmt() -> String {
809                            "False".to_string()
810                        }
811                        _fmt
812                    }),
813                },
814            ],
815            r#return: || ::pyo3_stub_gen::TypeInfo {
816                name: "None".to_string(),
817                import: ::std::collections::HashSet::from(["typing".into()]),
818            },
819            doc: "Test keyword-only parameters",
820            module: None,
821            is_async: false,
822            deprecated: None,
823            type_ignored: None,
824            is_overload: false,
825            file: file!(),
826            line: line!(),
827            column: column!(),
828            index: 0usize,
829        }
830        "###);
831        Ok(())
832    }
833
834    #[test]
835    fn test_positional_only_args() -> Result<()> {
836        let stub_str: LitStr = syn::parse2(quote! {
837            r#"
838            def func(x: int, y: int, /, z: int) -> int:
839                """Test positional-only parameters"""
840            "#
841        })?;
842        let info = parse_python_function_stub(stub_str)?;
843        let out = info.to_token_stream();
844        insta::assert_snapshot!(format_as_value(out), @r###"
845        ::pyo3_stub_gen::type_info::PyFunctionInfo {
846            name: "func",
847            parameters: &[
848                ::pyo3_stub_gen::type_info::ParameterInfo {
849                    name: "x",
850                    kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOnly,
851                    type_info: || ::pyo3_stub_gen::TypeInfo {
852                        name: "int".to_string(),
853                        import: ::std::collections::HashSet::from([]),
854                    },
855                    default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
856                },
857                ::pyo3_stub_gen::type_info::ParameterInfo {
858                    name: "y",
859                    kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOnly,
860                    type_info: || ::pyo3_stub_gen::TypeInfo {
861                        name: "int".to_string(),
862                        import: ::std::collections::HashSet::from([]),
863                    },
864                    default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
865                },
866                ::pyo3_stub_gen::type_info::ParameterInfo {
867                    name: "z",
868                    kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
869                    type_info: || ::pyo3_stub_gen::TypeInfo {
870                        name: "int".to_string(),
871                        import: ::std::collections::HashSet::from([]),
872                    },
873                    default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
874                },
875            ],
876            r#return: || ::pyo3_stub_gen::TypeInfo {
877                name: "int".to_string(),
878                import: ::std::collections::HashSet::from([]),
879            },
880            doc: "Test positional-only parameters",
881            module: None,
882            is_async: false,
883            deprecated: None,
884            type_ignored: None,
885            is_overload: false,
886            file: file!(),
887            line: line!(),
888            column: column!(),
889            index: 0usize,
890        }
891        "###);
892        Ok(())
893    }
894
895    #[test]
896    fn test_single_overload() -> Result<()> {
897        let stub_str: LitStr = syn::parse2(quote! {
898            r#"
899            @overload
900            def foo(x: int) -> int:
901                """Integer overload"""
902            "#
903        })?;
904        let info = parse_python_function_stub(stub_str)?;
905        let out = info.to_token_stream();
906        insta::assert_snapshot!(format_as_value(out), @r###"
907        ::pyo3_stub_gen::type_info::PyFunctionInfo {
908            name: "foo",
909            parameters: &[
910                ::pyo3_stub_gen::type_info::ParameterInfo {
911                    name: "x",
912                    kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
913                    type_info: || ::pyo3_stub_gen::TypeInfo {
914                        name: "int".to_string(),
915                        import: ::std::collections::HashSet::from([]),
916                    },
917                    default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
918                },
919            ],
920            r#return: || ::pyo3_stub_gen::TypeInfo {
921                name: "int".to_string(),
922                import: ::std::collections::HashSet::from([]),
923            },
924            doc: "Integer overload",
925            module: None,
926            is_async: false,
927            deprecated: None,
928            type_ignored: None,
929            is_overload: true,
930            file: file!(),
931            line: line!(),
932            column: column!(),
933            index: 0usize,
934        }
935        "###);
936        Ok(())
937    }
938
939    #[test]
940    fn test_multiple_overloads() -> Result<()> {
941        let stub_str: LitStr = syn::parse2(quote! {
942            r#"
943            @overload
944            def foo(x: int) -> int:
945                """Integer overload"""
946
947            @overload
948            def foo(x: float) -> float:
949                """Float overload"""
950            "#
951        })?;
952        let infos = parse_python_overload_stubs(stub_str, "foo")?;
953        assert_eq!(infos.len(), 2);
954
955        let out1 = infos[0].to_token_stream();
956        insta::assert_snapshot!(format_as_value(out1), @r###"
957        ::pyo3_stub_gen::type_info::PyFunctionInfo {
958            name: "foo",
959            parameters: &[
960                ::pyo3_stub_gen::type_info::ParameterInfo {
961                    name: "x",
962                    kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
963                    type_info: || ::pyo3_stub_gen::TypeInfo {
964                        name: "int".to_string(),
965                        import: ::std::collections::HashSet::from([]),
966                    },
967                    default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
968                },
969            ],
970            r#return: || ::pyo3_stub_gen::TypeInfo {
971                name: "int".to_string(),
972                import: ::std::collections::HashSet::from([]),
973            },
974            doc: "Integer overload",
975            module: None,
976            is_async: false,
977            deprecated: None,
978            type_ignored: None,
979            is_overload: true,
980            file: file!(),
981            line: line!(),
982            column: column!(),
983            index: 0usize,
984        }
985        "###);
986
987        let out2 = infos[1].to_token_stream();
988        insta::assert_snapshot!(format_as_value(out2), @r###"
989        ::pyo3_stub_gen::type_info::PyFunctionInfo {
990            name: "foo",
991            parameters: &[
992                ::pyo3_stub_gen::type_info::ParameterInfo {
993                    name: "x",
994                    kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
995                    type_info: || ::pyo3_stub_gen::TypeInfo {
996                        name: "float".to_string(),
997                        import: ::std::collections::HashSet::from([]),
998                    },
999                    default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
1000                },
1001            ],
1002            r#return: || ::pyo3_stub_gen::TypeInfo {
1003                name: "float".to_string(),
1004                import: ::std::collections::HashSet::from([]),
1005            },
1006            doc: "Float overload",
1007            module: None,
1008            is_async: false,
1009            deprecated: None,
1010            type_ignored: None,
1011            is_overload: true,
1012            file: file!(),
1013            line: line!(),
1014            column: column!(),
1015            index: 0usize,
1016        }
1017        "###);
1018        Ok(())
1019    }
1020
1021    #[test]
1022    fn test_overload_with_literal_types() -> Result<()> {
1023        let stub_str: LitStr = syn::parse2(quote! {
1024            r#"
1025            import typing
1026            @overload
1027            def as_tuple(xs: list[int], *, tuple_out: typing.Literal[True]) -> tuple[int, ...]:
1028                """Return as tuple"""
1029            "#
1030        })?;
1031        let info = parse_python_function_stub(stub_str)?;
1032        let out = info.to_token_stream();
1033        insta::assert_snapshot!(format_as_value(out), @r###"
1034        ::pyo3_stub_gen::type_info::PyFunctionInfo {
1035            name: "as_tuple",
1036            parameters: &[
1037                ::pyo3_stub_gen::type_info::ParameterInfo {
1038                    name: "xs",
1039                    kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
1040                    type_info: || ::pyo3_stub_gen::TypeInfo {
1041                        name: "list[int]".to_string(),
1042                        import: ::std::collections::HashSet::from(["typing".into()]),
1043                    },
1044                    default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
1045                },
1046                ::pyo3_stub_gen::type_info::ParameterInfo {
1047                    name: "tuple_out",
1048                    kind: ::pyo3_stub_gen::type_info::ParameterKind::KeywordOnly,
1049                    type_info: || ::pyo3_stub_gen::TypeInfo {
1050                        name: "typing.Literal[True]".to_string(),
1051                        import: ::std::collections::HashSet::from(["typing".into()]),
1052                    },
1053                    default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
1054                },
1055            ],
1056            r#return: || ::pyo3_stub_gen::TypeInfo {
1057                name: "tuple[int, ...]".to_string(),
1058                import: ::std::collections::HashSet::from(["typing".into()]),
1059            },
1060            doc: "Return as tuple",
1061            module: None,
1062            is_async: false,
1063            deprecated: None,
1064            type_ignored: None,
1065            is_overload: true,
1066            file: file!(),
1067            line: line!(),
1068            column: column!(),
1069            index: 0usize,
1070        }
1071        "###);
1072        Ok(())
1073    }
1074
1075    fn format_as_value(tt: TokenStream2) -> String {
1076        let ttt = quote! { const _: () = #tt; };
1077        let formatted = prettyplease::unparse(&syn::parse_file(&ttt.to_string()).unwrap());
1078        formatted
1079            .trim()
1080            .strip_prefix("const _: () = ")
1081            .unwrap()
1082            .strip_suffix(';')
1083            .unwrap()
1084            .to_string()
1085    }
1086}