pyo3_stub_gen_derive/gen_stub/parse_python/
pyfunction.rs

1//! Parse Python function stub syntax and generate PyFunctionInfo
2
3use rustpython_parser::{ast, Parse};
4use syn::{parse::Parse as SynParse, parse::ParseStream, Error, LitStr, Result};
5
6use super::{
7    build_parameters_from_ast, dedent, extract_deprecated_from_decorators, extract_docstring,
8    extract_return_type, has_overload_decorator,
9};
10use crate::gen_stub::pyfunction::PyFunctionInfo;
11
12/// Input for gen_function_from_python! macro
13pub struct GenFunctionFromPythonInput {
14    module: Option<String>,
15    python_stub: LitStr,
16}
17
18impl SynParse for GenFunctionFromPythonInput {
19    fn parse(input: ParseStream) -> Result<Self> {
20        // Check if first token is an identifier (for module parameter)
21        if input.peek(syn::Ident) {
22            let key: syn::Ident = input.parse()?;
23            if key == "module" {
24                let _: syn::token::Eq = input.parse()?;
25                let value: LitStr = input.parse()?;
26                let _: syn::token::Comma = input.parse()?;
27                let python_stub: LitStr = input.parse()?;
28                return Ok(Self {
29                    module: Some(value.value()),
30                    python_stub,
31                });
32            } else {
33                return Err(Error::new(
34                    key.span(),
35                    format!(
36                        "Unknown parameter: {}. Expected 'module' or a string literal",
37                        key
38                    ),
39                ));
40            }
41        }
42
43        // No module parameter, just parse the string literal
44        let python_stub: LitStr = input.parse()?;
45        Ok(Self {
46            module: None,
47            python_stub,
48        })
49    }
50}
51
52/// Intermediate representation for Python function stub
53pub struct PythonFunctionStub {
54    pub func_def: ast::StmtFunctionDef,
55    pub imports: Vec<String>,
56    pub is_async: bool,
57    pub is_overload: bool,
58}
59
60impl TryFrom<PythonFunctionStub> for PyFunctionInfo {
61    type Error = syn::Error;
62
63    fn try_from(stub: PythonFunctionStub) -> Result<Self> {
64        let func_name = stub.func_def.name.to_string();
65
66        // Extract docstring
67        let doc = extract_docstring(&stub.func_def);
68
69        // Build Parameters directly from Python AST with proper kind classification
70        let parameters = build_parameters_from_ast(&stub.func_def.args, &stub.imports)?;
71
72        // Extract return type
73        let return_type = extract_return_type(&stub.func_def.returns, &stub.imports)?;
74
75        // Try to extract deprecated decorator
76        let deprecated = extract_deprecated_from_decorators(&stub.func_def.decorator_list);
77
78        // Note: type_ignored (# type: ignore comments) cannot be extracted from Python AST
79        // as rustpython-parser doesn't preserve comments
80
81        // Construct PyFunctionInfo
82        Ok(PyFunctionInfo {
83            name: func_name,
84            parameters, // Use pre-built Parameters from Python AST
85            r#return: return_type,
86            doc,
87            module: None,
88            is_async: stub.is_async,
89            deprecated,
90            type_ignored: None,
91            is_overload: stub.is_overload,
92            index: 0, // Will be set by caller when generating multiple overloads
93        })
94    }
95}
96
97/// Parse Python stub string and return PyFunctionInfo
98pub fn parse_python_function_stub(input: LitStr) -> Result<PyFunctionInfo> {
99    let stub_content = input.value();
100
101    // Remove common indentation to allow indented Python code in raw strings
102    let dedented_content = dedent(&stub_content);
103
104    // Parse Python code using rustpython-parser
105    let parsed = ast::Suite::parse(&dedented_content, "<stub>")
106        .map_err(|e| Error::new(input.span(), format!("Failed to parse Python stub: {}", e)))?;
107
108    // Extract imports and function definitions
109    let mut imports = Vec::new();
110    let mut function: Option<(ast::StmtFunctionDef, bool)> = None;
111
112    for stmt in parsed {
113        match stmt {
114            ast::Stmt::Import(import_stmt) => {
115                for alias in &import_stmt.names {
116                    imports.push(alias.name.to_string());
117                }
118            }
119            ast::Stmt::ImportFrom(import_from_stmt) => {
120                if let Some(module) = &import_from_stmt.module {
121                    imports.push(module.to_string());
122                }
123            }
124            ast::Stmt::FunctionDef(func_def) => {
125                if function.is_some() {
126                    return Err(Error::new(
127                        input.span(),
128                        "Multiple function definitions found. Only one function is allowed per gen_function_from_python! call",
129                    ));
130                }
131                function = Some((func_def, false));
132            }
133            ast::Stmt::AsyncFunctionDef(func_def) => {
134                if function.is_some() {
135                    return Err(Error::new(
136                        input.span(),
137                        "Multiple function definitions found. Only one function is allowed per gen_function_from_python! call",
138                    ));
139                }
140                // Convert AsyncFunctionDef to FunctionDef for uniform processing
141                let sync_func = ast::StmtFunctionDef {
142                    range: func_def.range,
143                    name: func_def.name,
144                    type_params: func_def.type_params,
145                    args: func_def.args,
146                    body: func_def.body,
147                    decorator_list: func_def.decorator_list,
148                    returns: func_def.returns,
149                    type_comment: func_def.type_comment,
150                };
151                function = Some((sync_func, true));
152            }
153            _ => {
154                // Ignore other statements
155            }
156        }
157    }
158
159    // Check that exactly one function is defined
160    let (func_def, is_async) = function
161        .ok_or_else(|| Error::new(input.span(), "No function definition found in Python stub"))?;
162
163    // Check if function has @overload decorator
164    let is_overload = has_overload_decorator(&func_def.decorator_list);
165
166    // Generate PyFunctionInfo using TryFrom
167    let stub = PythonFunctionStub {
168        func_def,
169        imports,
170        is_async,
171        is_overload,
172    };
173    PyFunctionInfo::try_from(stub)
174}
175
176/// Parse multiple overload function definitions from Python stub string
177/// Used for the `python_overload` parameter
178pub fn parse_python_overload_stubs(
179    input: LitStr,
180    expected_function_name: &str,
181) -> Result<Vec<PyFunctionInfo>> {
182    let stub_content = input.value();
183    let dedented_content = dedent(&stub_content);
184
185    // Parse Python code using rustpython-parser
186    let parsed = ast::Suite::parse(&dedented_content, "<stub>")
187        .map_err(|e| Error::new(input.span(), format!("Failed to parse Python stub: {}", e)))?;
188
189    // Extract imports and function definitions
190    let mut imports = Vec::new();
191    let mut functions: Vec<(ast::StmtFunctionDef, bool)> = Vec::new();
192
193    for stmt in parsed {
194        match stmt {
195            ast::Stmt::Import(import_stmt) => {
196                for alias in &import_stmt.names {
197                    imports.push(alias.name.to_string());
198                }
199            }
200            ast::Stmt::ImportFrom(import_from_stmt) => {
201                if let Some(module) = &import_from_stmt.module {
202                    imports.push(module.to_string());
203                }
204            }
205            ast::Stmt::FunctionDef(func_def) => {
206                functions.push((func_def, false));
207            }
208            ast::Stmt::AsyncFunctionDef(func_def) => {
209                // Convert AsyncFunctionDef to FunctionDef for uniform processing
210                let sync_func = ast::StmtFunctionDef {
211                    range: func_def.range,
212                    name: func_def.name,
213                    type_params: func_def.type_params,
214                    args: func_def.args,
215                    body: func_def.body,
216                    decorator_list: func_def.decorator_list,
217                    returns: func_def.returns,
218                    type_comment: func_def.type_comment,
219                };
220                functions.push((sync_func, true));
221            }
222            _ => {
223                // Ignore other statements
224            }
225        }
226    }
227
228    // Check that at least one function is defined
229    if functions.is_empty() {
230        return Err(Error::new(
231            input.span(),
232            "No function definition found in python_overload parameter",
233        ));
234    }
235
236    // Validate all functions
237    let mut result = Vec::new();
238    for (func_def, is_async) in functions {
239        let func_name = func_def.name.to_string();
240
241        // Validate: function name must match expected name
242        if func_name != expected_function_name {
243            return Err(Error::new(
244                input.span(),
245                format!(
246                    "Function name '{}' in python_overload does not match Rust function name '{}'. Please ensure all overload function names match the Rust function name.",
247                    func_name, expected_function_name
248                ),
249            ));
250        }
251
252        // Validate: all functions must have @overload decorator
253        let is_overload = has_overload_decorator(&func_def.decorator_list);
254        if !is_overload {
255            return Err(Error::new(
256                input.span(),
257                format!(
258                    "Function '{}' in python_overload must have @overload decorator",
259                    func_name
260                ),
261            ));
262        }
263
264        // Create PyFunctionInfo
265        let stub = PythonFunctionStub {
266            func_def,
267            imports: imports.clone(),
268            is_async,
269            is_overload,
270        };
271        result.push(PyFunctionInfo::try_from(stub)?);
272    }
273
274    Ok(result)
275}
276
277/// Parse gen_function_from_python! input with optional module parameter
278pub fn parse_gen_function_from_python_input(
279    input: GenFunctionFromPythonInput,
280) -> Result<PyFunctionInfo> {
281    let mut info = parse_python_function_stub(input.python_stub)?;
282
283    // Set module if provided
284    if let Some(module) = input.module {
285        info.module = Some(module);
286    }
287
288    Ok(info)
289}
290
291#[cfg(test)]
292mod test {
293    use super::*;
294    use proc_macro2::TokenStream as TokenStream2;
295    use quote::{quote, ToTokens};
296
297    #[test]
298    fn test_basic_function() -> Result<()> {
299        let stub_str: LitStr = syn::parse2(quote! {
300            r#"
301            def foo(x: int) -> int:
302                """A simple function"""
303            "#
304        })?;
305        let info = parse_python_function_stub(stub_str)?;
306        let out = info.to_token_stream();
307        insta::assert_snapshot!(format_as_value(out), @r#"
308        ::pyo3_stub_gen::type_info::PyFunctionInfo {
309            name: "foo",
310            parameters: &[
311                ::pyo3_stub_gen::type_info::ParameterInfo {
312                    name: "x",
313                    kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
314                    type_info: || ::pyo3_stub_gen::TypeInfo {
315                        name: "int".to_string(),
316                        source_module: None,
317                        import: ::std::collections::HashSet::from([]),
318                        type_refs: ::std::collections::HashMap::new(),
319                    },
320                    default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
321                },
322            ],
323            r#return: || ::pyo3_stub_gen::TypeInfo {
324                name: "int".to_string(),
325                source_module: None,
326                import: ::std::collections::HashSet::from([]),
327                type_refs: ::std::collections::HashMap::new(),
328            },
329            doc: "A simple function",
330            module: None,
331            is_async: false,
332            deprecated: None,
333            type_ignored: None,
334            is_overload: false,
335            file: file!(),
336            line: line!(),
337            column: column!(),
338            index: 0usize,
339        }
340        "#);
341        Ok(())
342    }
343
344    #[test]
345    fn test_function_with_imports() -> Result<()> {
346        let stub_str: LitStr = syn::parse2(quote! {
347            r#"
348            import typing
349            from collections.abc import Callable
350
351            def process(func: Callable[[str], int]) -> typing.Optional[int]:
352                """Process a callback function"""
353            "#
354        })?;
355        let info = parse_python_function_stub(stub_str)?;
356        let out = info.to_token_stream();
357        insta::assert_snapshot!(format_as_value(out), @r#"
358        ::pyo3_stub_gen::type_info::PyFunctionInfo {
359            name: "process",
360            parameters: &[
361                ::pyo3_stub_gen::type_info::ParameterInfo {
362                    name: "func",
363                    kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
364                    type_info: || ::pyo3_stub_gen::TypeInfo {
365                        name: "Callable[[str], int]".to_string(),
366                        source_module: None,
367                        import: ::std::collections::HashSet::from([
368                            "typing".into(),
369                            "collections.abc".into(),
370                        ]),
371                        type_refs: ::std::collections::HashMap::new(),
372                    },
373                    default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
374                },
375            ],
376            r#return: || ::pyo3_stub_gen::TypeInfo {
377                name: "typing.Optional[int]".to_string(),
378                source_module: None,
379                import: ::std::collections::HashSet::from([
380                    "typing".into(),
381                    "collections.abc".into(),
382                ]),
383                type_refs: ::std::collections::HashMap::new(),
384            },
385            doc: "Process a callback function",
386            module: None,
387            is_async: false,
388            deprecated: None,
389            type_ignored: None,
390            is_overload: false,
391            file: file!(),
392            line: line!(),
393            column: column!(),
394            index: 0usize,
395        }
396        "#);
397        Ok(())
398    }
399
400    #[test]
401    fn test_complex_types() -> Result<()> {
402        let stub_str: LitStr = syn::parse2(quote! {
403            r#"
404            import collections.abc
405            import typing
406
407            def fn_override_type(cb: collections.abc.Callable[[str], typing.Any]) -> collections.abc.Callable[[str], typing.Any]:
408                """Example function with complex types"""
409            "#
410        })?;
411        let info = parse_python_function_stub(stub_str)?;
412        let out = info.to_token_stream();
413        insta::assert_snapshot!(format_as_value(out), @r#"
414        ::pyo3_stub_gen::type_info::PyFunctionInfo {
415            name: "fn_override_type",
416            parameters: &[
417                ::pyo3_stub_gen::type_info::ParameterInfo {
418                    name: "cb",
419                    kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
420                    type_info: || ::pyo3_stub_gen::TypeInfo {
421                        name: "collections.abc.Callable[[str], typing.Any]".to_string(),
422                        source_module: None,
423                        import: ::std::collections::HashSet::from([
424                            "collections.abc".into(),
425                            "typing".into(),
426                        ]),
427                        type_refs: ::std::collections::HashMap::new(),
428                    },
429                    default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
430                },
431            ],
432            r#return: || ::pyo3_stub_gen::TypeInfo {
433                name: "collections.abc.Callable[[str], typing.Any]".to_string(),
434                source_module: None,
435                import: ::std::collections::HashSet::from([
436                    "collections.abc".into(),
437                    "typing".into(),
438                ]),
439                type_refs: ::std::collections::HashMap::new(),
440            },
441            doc: "Example function with complex types",
442            module: None,
443            is_async: false,
444            deprecated: None,
445            type_ignored: None,
446            is_overload: false,
447            file: file!(),
448            line: line!(),
449            column: column!(),
450            index: 0usize,
451        }
452        "#);
453        Ok(())
454    }
455
456    #[test]
457    fn test_multiple_args() -> Result<()> {
458        let stub_str: LitStr = syn::parse2(quote! {
459            r#"
460            import typing
461
462            def add(a: int, b: int, c: typing.Optional[int]) -> int: ...
463            "#
464        })?;
465        let info = parse_python_function_stub(stub_str)?;
466        let out = info.to_token_stream();
467        insta::assert_snapshot!(format_as_value(out), @r#"
468        ::pyo3_stub_gen::type_info::PyFunctionInfo {
469            name: "add",
470            parameters: &[
471                ::pyo3_stub_gen::type_info::ParameterInfo {
472                    name: "a",
473                    kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
474                    type_info: || ::pyo3_stub_gen::TypeInfo {
475                        name: "int".to_string(),
476                        source_module: None,
477                        import: ::std::collections::HashSet::from(["typing".into()]),
478                        type_refs: ::std::collections::HashMap::new(),
479                    },
480                    default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
481                },
482                ::pyo3_stub_gen::type_info::ParameterInfo {
483                    name: "b",
484                    kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
485                    type_info: || ::pyo3_stub_gen::TypeInfo {
486                        name: "int".to_string(),
487                        source_module: None,
488                        import: ::std::collections::HashSet::from(["typing".into()]),
489                        type_refs: ::std::collections::HashMap::new(),
490                    },
491                    default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
492                },
493                ::pyo3_stub_gen::type_info::ParameterInfo {
494                    name: "c",
495                    kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
496                    type_info: || ::pyo3_stub_gen::TypeInfo {
497                        name: "typing.Optional[int]".to_string(),
498                        source_module: None,
499                        import: ::std::collections::HashSet::from(["typing".into()]),
500                        type_refs: ::std::collections::HashMap::new(),
501                    },
502                    default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
503                },
504            ],
505            r#return: || ::pyo3_stub_gen::TypeInfo {
506                name: "int".to_string(),
507                source_module: None,
508                import: ::std::collections::HashSet::from(["typing".into()]),
509                type_refs: ::std::collections::HashMap::new(),
510            },
511            doc: "",
512            module: None,
513            is_async: false,
514            deprecated: None,
515            type_ignored: None,
516            is_overload: false,
517            file: file!(),
518            line: line!(),
519            column: column!(),
520            index: 0usize,
521        }
522        "#);
523        Ok(())
524    }
525
526    #[test]
527    fn test_no_return_type() -> Result<()> {
528        let stub_str: LitStr = syn::parse2(quote! {
529            r#"
530            def print_hello(name: str):
531                """Print a greeting"""
532            "#
533        })?;
534        let info = parse_python_function_stub(stub_str)?;
535        let out = info.to_token_stream();
536        insta::assert_snapshot!(format_as_value(out), @r#"
537        ::pyo3_stub_gen::type_info::PyFunctionInfo {
538            name: "print_hello",
539            parameters: &[
540                ::pyo3_stub_gen::type_info::ParameterInfo {
541                    name: "name",
542                    kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
543                    type_info: || ::pyo3_stub_gen::TypeInfo {
544                        name: "str".to_string(),
545                        source_module: None,
546                        import: ::std::collections::HashSet::from([]),
547                        type_refs: ::std::collections::HashMap::new(),
548                    },
549                    default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
550                },
551            ],
552            r#return: ::pyo3_stub_gen::type_info::no_return_type_output,
553            doc: "Print a greeting",
554            module: None,
555            is_async: false,
556            deprecated: None,
557            type_ignored: None,
558            is_overload: false,
559            file: file!(),
560            line: line!(),
561            column: column!(),
562            index: 0usize,
563        }
564        "#);
565        Ok(())
566    }
567
568    #[test]
569    fn test_async_function() -> Result<()> {
570        let stub_str: LitStr = syn::parse2(quote! {
571            r#"
572            async def fetch_data(url: str) -> str:
573                """Fetch data from URL"""
574            "#
575        })?;
576        let info = parse_python_function_stub(stub_str)?;
577        let out = info.to_token_stream();
578        insta::assert_snapshot!(format_as_value(out), @r#"
579        ::pyo3_stub_gen::type_info::PyFunctionInfo {
580            name: "fetch_data",
581            parameters: &[
582                ::pyo3_stub_gen::type_info::ParameterInfo {
583                    name: "url",
584                    kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
585                    type_info: || ::pyo3_stub_gen::TypeInfo {
586                        name: "str".to_string(),
587                        source_module: None,
588                        import: ::std::collections::HashSet::from([]),
589                        type_refs: ::std::collections::HashMap::new(),
590                    },
591                    default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
592                },
593            ],
594            r#return: || ::pyo3_stub_gen::TypeInfo {
595                name: "str".to_string(),
596                source_module: None,
597                import: ::std::collections::HashSet::from([]),
598                type_refs: ::std::collections::HashMap::new(),
599            },
600            doc: "Fetch data from URL",
601            module: None,
602            is_async: true,
603            deprecated: None,
604            type_ignored: None,
605            is_overload: false,
606            file: file!(),
607            line: line!(),
608            column: column!(),
609            index: 0usize,
610        }
611        "#);
612        Ok(())
613    }
614
615    #[test]
616    fn test_deprecated_decorator() -> Result<()> {
617        let stub_str: LitStr = syn::parse2(quote! {
618            r#"
619            @deprecated
620            def old_function(x: int) -> int:
621                """This function is deprecated"""
622            "#
623        })?;
624        let info = parse_python_function_stub(stub_str)?;
625        let out = info.to_token_stream();
626        insta::assert_snapshot!(format_as_value(out), @r#"
627        ::pyo3_stub_gen::type_info::PyFunctionInfo {
628            name: "old_function",
629            parameters: &[
630                ::pyo3_stub_gen::type_info::ParameterInfo {
631                    name: "x",
632                    kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
633                    type_info: || ::pyo3_stub_gen::TypeInfo {
634                        name: "int".to_string(),
635                        source_module: None,
636                        import: ::std::collections::HashSet::from([]),
637                        type_refs: ::std::collections::HashMap::new(),
638                    },
639                    default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
640                },
641            ],
642            r#return: || ::pyo3_stub_gen::TypeInfo {
643                name: "int".to_string(),
644                source_module: None,
645                import: ::std::collections::HashSet::from([]),
646                type_refs: ::std::collections::HashMap::new(),
647            },
648            doc: "This function is deprecated",
649            module: None,
650            is_async: false,
651            deprecated: Some(::pyo3_stub_gen::type_info::DeprecatedInfo {
652                since: None,
653                note: None,
654            }),
655            type_ignored: None,
656            is_overload: false,
657            file: file!(),
658            line: line!(),
659            column: column!(),
660            index: 0usize,
661        }
662        "#);
663        Ok(())
664    }
665
666    #[test]
667    fn test_deprecated_with_message() -> Result<()> {
668        let stub_str: LitStr = syn::parse2(quote! {
669            r#"
670            @deprecated("Use new_function instead")
671            def old_function(x: int) -> int:
672                """This function is deprecated"""
673            "#
674        })?;
675        let info = parse_python_function_stub(stub_str)?;
676        let out = info.to_token_stream();
677        insta::assert_snapshot!(format_as_value(out), @r#"
678        ::pyo3_stub_gen::type_info::PyFunctionInfo {
679            name: "old_function",
680            parameters: &[
681                ::pyo3_stub_gen::type_info::ParameterInfo {
682                    name: "x",
683                    kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
684                    type_info: || ::pyo3_stub_gen::TypeInfo {
685                        name: "int".to_string(),
686                        source_module: None,
687                        import: ::std::collections::HashSet::from([]),
688                        type_refs: ::std::collections::HashMap::new(),
689                    },
690                    default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
691                },
692            ],
693            r#return: || ::pyo3_stub_gen::TypeInfo {
694                name: "int".to_string(),
695                source_module: None,
696                import: ::std::collections::HashSet::from([]),
697                type_refs: ::std::collections::HashMap::new(),
698            },
699            doc: "This function is deprecated",
700            module: None,
701            is_async: false,
702            deprecated: Some(::pyo3_stub_gen::type_info::DeprecatedInfo {
703                since: None,
704                note: Some("Use new_function instead"),
705            }),
706            type_ignored: None,
707            is_overload: false,
708            file: file!(),
709            line: line!(),
710            column: column!(),
711            index: 0usize,
712        }
713        "#);
714        Ok(())
715    }
716
717    #[test]
718    fn test_rust_type_marker() -> Result<()> {
719        let stub_str: LitStr = syn::parse2(quote! {
720            r#"
721            def process_data(x: pyo3_stub_gen.RustType["MyRustType"]) -> pyo3_stub_gen.RustType["MyRustType"]:
722                """Process data using Rust type marker"""
723            "#
724        })?;
725        let info = parse_python_function_stub(stub_str)?;
726        let out = info.to_token_stream();
727        insta::assert_snapshot!(format_as_value(out), @r###"
728        ::pyo3_stub_gen::type_info::PyFunctionInfo {
729            name: "process_data",
730            parameters: &[
731                ::pyo3_stub_gen::type_info::ParameterInfo {
732                    name: "x",
733                    kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
734                    type_info: <MyRustType as ::pyo3_stub_gen::PyStubType>::type_input,
735                    default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
736                },
737            ],
738            r#return: <MyRustType as pyo3_stub_gen::PyStubType>::type_output,
739            doc: "Process data using Rust type marker",
740            module: None,
741            is_async: false,
742            deprecated: None,
743            type_ignored: None,
744            is_overload: false,
745            file: file!(),
746            line: line!(),
747            column: column!(),
748            index: 0usize,
749        }
750        "###);
751        Ok(())
752    }
753
754    #[test]
755    fn test_rust_type_marker_with_path() -> Result<()> {
756        let stub_str: LitStr = syn::parse2(quote! {
757            r#"
758            def process(x: pyo3_stub_gen.RustType["crate::MyType"]) -> pyo3_stub_gen.RustType["Vec<String>"]:
759                """Test with type paths"""
760            "#
761        })?;
762        let info = parse_python_function_stub(stub_str)?;
763        let out = info.to_token_stream();
764        insta::assert_snapshot!(format_as_value(out), @r###"
765        ::pyo3_stub_gen::type_info::PyFunctionInfo {
766            name: "process",
767            parameters: &[
768                ::pyo3_stub_gen::type_info::ParameterInfo {
769                    name: "x",
770                    kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
771                    type_info: <crate::MyType as ::pyo3_stub_gen::PyStubType>::type_input,
772                    default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
773                },
774            ],
775            r#return: <Vec<String> as pyo3_stub_gen::PyStubType>::type_output,
776            doc: "Test with type paths",
777            module: None,
778            is_async: false,
779            deprecated: None,
780            type_ignored: None,
781            is_overload: false,
782            file: file!(),
783            line: line!(),
784            column: column!(),
785            index: 0usize,
786        }
787        "###);
788        Ok(())
789    }
790
791    #[test]
792    fn test_keyword_only_args() -> Result<()> {
793        let stub_str: LitStr = syn::parse2(quote! {
794            r#"
795            import typing
796
797            def configure(name: str, *, dtype: str, ndim: int, jagged: bool = False) -> None:
798                """Test keyword-only parameters"""
799            "#
800        })?;
801        let info = parse_python_function_stub(stub_str)?;
802        let out = info.to_token_stream();
803        insta::assert_snapshot!(format_as_value(out), @r#"
804        ::pyo3_stub_gen::type_info::PyFunctionInfo {
805            name: "configure",
806            parameters: &[
807                ::pyo3_stub_gen::type_info::ParameterInfo {
808                    name: "name",
809                    kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
810                    type_info: || ::pyo3_stub_gen::TypeInfo {
811                        name: "str".to_string(),
812                        source_module: None,
813                        import: ::std::collections::HashSet::from(["typing".into()]),
814                        type_refs: ::std::collections::HashMap::new(),
815                    },
816                    default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
817                },
818                ::pyo3_stub_gen::type_info::ParameterInfo {
819                    name: "dtype",
820                    kind: ::pyo3_stub_gen::type_info::ParameterKind::KeywordOnly,
821                    type_info: || ::pyo3_stub_gen::TypeInfo {
822                        name: "str".to_string(),
823                        source_module: None,
824                        import: ::std::collections::HashSet::from(["typing".into()]),
825                        type_refs: ::std::collections::HashMap::new(),
826                    },
827                    default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
828                },
829                ::pyo3_stub_gen::type_info::ParameterInfo {
830                    name: "ndim",
831                    kind: ::pyo3_stub_gen::type_info::ParameterKind::KeywordOnly,
832                    type_info: || ::pyo3_stub_gen::TypeInfo {
833                        name: "int".to_string(),
834                        source_module: None,
835                        import: ::std::collections::HashSet::from(["typing".into()]),
836                        type_refs: ::std::collections::HashMap::new(),
837                    },
838                    default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
839                },
840                ::pyo3_stub_gen::type_info::ParameterInfo {
841                    name: "jagged",
842                    kind: ::pyo3_stub_gen::type_info::ParameterKind::KeywordOnly,
843                    type_info: || ::pyo3_stub_gen::TypeInfo {
844                        name: "bool".to_string(),
845                        source_module: None,
846                        import: ::std::collections::HashSet::from(["typing".into()]),
847                        type_refs: ::std::collections::HashMap::new(),
848                    },
849                    default: ::pyo3_stub_gen::type_info::ParameterDefault::Expr({
850                        fn _fmt() -> String {
851                            "False".to_string()
852                        }
853                        _fmt
854                    }),
855                },
856            ],
857            r#return: || ::pyo3_stub_gen::TypeInfo {
858                name: "None".to_string(),
859                source_module: None,
860                import: ::std::collections::HashSet::from(["typing".into()]),
861                type_refs: ::std::collections::HashMap::new(),
862            },
863            doc: "Test keyword-only parameters",
864            module: None,
865            is_async: false,
866            deprecated: None,
867            type_ignored: None,
868            is_overload: false,
869            file: file!(),
870            line: line!(),
871            column: column!(),
872            index: 0usize,
873        }
874        "#);
875        Ok(())
876    }
877
878    #[test]
879    fn test_positional_only_args() -> Result<()> {
880        let stub_str: LitStr = syn::parse2(quote! {
881            r#"
882            def func(x: int, y: int, /, z: int) -> int:
883                """Test positional-only parameters"""
884            "#
885        })?;
886        let info = parse_python_function_stub(stub_str)?;
887        let out = info.to_token_stream();
888        insta::assert_snapshot!(format_as_value(out), @r#"
889        ::pyo3_stub_gen::type_info::PyFunctionInfo {
890            name: "func",
891            parameters: &[
892                ::pyo3_stub_gen::type_info::ParameterInfo {
893                    name: "x",
894                    kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOnly,
895                    type_info: || ::pyo3_stub_gen::TypeInfo {
896                        name: "int".to_string(),
897                        source_module: None,
898                        import: ::std::collections::HashSet::from([]),
899                        type_refs: ::std::collections::HashMap::new(),
900                    },
901                    default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
902                },
903                ::pyo3_stub_gen::type_info::ParameterInfo {
904                    name: "y",
905                    kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOnly,
906                    type_info: || ::pyo3_stub_gen::TypeInfo {
907                        name: "int".to_string(),
908                        source_module: None,
909                        import: ::std::collections::HashSet::from([]),
910                        type_refs: ::std::collections::HashMap::new(),
911                    },
912                    default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
913                },
914                ::pyo3_stub_gen::type_info::ParameterInfo {
915                    name: "z",
916                    kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
917                    type_info: || ::pyo3_stub_gen::TypeInfo {
918                        name: "int".to_string(),
919                        source_module: None,
920                        import: ::std::collections::HashSet::from([]),
921                        type_refs: ::std::collections::HashMap::new(),
922                    },
923                    default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
924                },
925            ],
926            r#return: || ::pyo3_stub_gen::TypeInfo {
927                name: "int".to_string(),
928                source_module: None,
929                import: ::std::collections::HashSet::from([]),
930                type_refs: ::std::collections::HashMap::new(),
931            },
932            doc: "Test positional-only parameters",
933            module: None,
934            is_async: false,
935            deprecated: None,
936            type_ignored: None,
937            is_overload: false,
938            file: file!(),
939            line: line!(),
940            column: column!(),
941            index: 0usize,
942        }
943        "#);
944        Ok(())
945    }
946
947    #[test]
948    fn test_single_overload() -> Result<()> {
949        let stub_str: LitStr = syn::parse2(quote! {
950            r#"
951            @overload
952            def foo(x: int) -> int:
953                """Integer overload"""
954            "#
955        })?;
956        let info = parse_python_function_stub(stub_str)?;
957        let out = info.to_token_stream();
958        insta::assert_snapshot!(format_as_value(out), @r#"
959        ::pyo3_stub_gen::type_info::PyFunctionInfo {
960            name: "foo",
961            parameters: &[
962                ::pyo3_stub_gen::type_info::ParameterInfo {
963                    name: "x",
964                    kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
965                    type_info: || ::pyo3_stub_gen::TypeInfo {
966                        name: "int".to_string(),
967                        source_module: None,
968                        import: ::std::collections::HashSet::from([]),
969                        type_refs: ::std::collections::HashMap::new(),
970                    },
971                    default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
972                },
973            ],
974            r#return: || ::pyo3_stub_gen::TypeInfo {
975                name: "int".to_string(),
976                source_module: None,
977                import: ::std::collections::HashSet::from([]),
978                type_refs: ::std::collections::HashMap::new(),
979            },
980            doc: "Integer overload",
981            module: None,
982            is_async: false,
983            deprecated: None,
984            type_ignored: None,
985            is_overload: true,
986            file: file!(),
987            line: line!(),
988            column: column!(),
989            index: 0usize,
990        }
991        "#);
992        Ok(())
993    }
994
995    #[test]
996    fn test_multiple_overloads() -> Result<()> {
997        let stub_str: LitStr = syn::parse2(quote! {
998            r#"
999            @overload
1000            def foo(x: int) -> int:
1001                """Integer overload"""
1002
1003            @overload
1004            def foo(x: float) -> float:
1005                """Float overload"""
1006            "#
1007        })?;
1008        let infos = parse_python_overload_stubs(stub_str, "foo")?;
1009        assert_eq!(infos.len(), 2);
1010
1011        let out1 = infos[0].to_token_stream();
1012        insta::assert_snapshot!(format_as_value(out1), @r#"
1013        ::pyo3_stub_gen::type_info::PyFunctionInfo {
1014            name: "foo",
1015            parameters: &[
1016                ::pyo3_stub_gen::type_info::ParameterInfo {
1017                    name: "x",
1018                    kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
1019                    type_info: || ::pyo3_stub_gen::TypeInfo {
1020                        name: "int".to_string(),
1021                        source_module: None,
1022                        import: ::std::collections::HashSet::from([]),
1023                        type_refs: ::std::collections::HashMap::new(),
1024                    },
1025                    default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
1026                },
1027            ],
1028            r#return: || ::pyo3_stub_gen::TypeInfo {
1029                name: "int".to_string(),
1030                source_module: None,
1031                import: ::std::collections::HashSet::from([]),
1032                type_refs: ::std::collections::HashMap::new(),
1033            },
1034            doc: "Integer overload",
1035            module: None,
1036            is_async: false,
1037            deprecated: None,
1038            type_ignored: None,
1039            is_overload: true,
1040            file: file!(),
1041            line: line!(),
1042            column: column!(),
1043            index: 0usize,
1044        }
1045        "#);
1046
1047        let out2 = infos[1].to_token_stream();
1048        insta::assert_snapshot!(format_as_value(out2), @r#"
1049        ::pyo3_stub_gen::type_info::PyFunctionInfo {
1050            name: "foo",
1051            parameters: &[
1052                ::pyo3_stub_gen::type_info::ParameterInfo {
1053                    name: "x",
1054                    kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
1055                    type_info: || ::pyo3_stub_gen::TypeInfo {
1056                        name: "float".to_string(),
1057                        source_module: None,
1058                        import: ::std::collections::HashSet::from([]),
1059                        type_refs: ::std::collections::HashMap::new(),
1060                    },
1061                    default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
1062                },
1063            ],
1064            r#return: || ::pyo3_stub_gen::TypeInfo {
1065                name: "float".to_string(),
1066                source_module: None,
1067                import: ::std::collections::HashSet::from([]),
1068                type_refs: ::std::collections::HashMap::new(),
1069            },
1070            doc: "Float overload",
1071            module: None,
1072            is_async: false,
1073            deprecated: None,
1074            type_ignored: None,
1075            is_overload: true,
1076            file: file!(),
1077            line: line!(),
1078            column: column!(),
1079            index: 0usize,
1080        }
1081        "#);
1082        Ok(())
1083    }
1084
1085    #[test]
1086    fn test_overload_with_literal_types() -> Result<()> {
1087        let stub_str: LitStr = syn::parse2(quote! {
1088            r#"
1089            import typing
1090            @overload
1091            def as_tuple(xs: list[int], *, tuple_out: typing.Literal[True]) -> tuple[int, ...]:
1092                """Return as tuple"""
1093            "#
1094        })?;
1095        let info = parse_python_function_stub(stub_str)?;
1096        let out = info.to_token_stream();
1097        insta::assert_snapshot!(format_as_value(out), @r#"
1098        ::pyo3_stub_gen::type_info::PyFunctionInfo {
1099            name: "as_tuple",
1100            parameters: &[
1101                ::pyo3_stub_gen::type_info::ParameterInfo {
1102                    name: "xs",
1103                    kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
1104                    type_info: || ::pyo3_stub_gen::TypeInfo {
1105                        name: "list[int]".to_string(),
1106                        source_module: None,
1107                        import: ::std::collections::HashSet::from(["typing".into()]),
1108                        type_refs: ::std::collections::HashMap::new(),
1109                    },
1110                    default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
1111                },
1112                ::pyo3_stub_gen::type_info::ParameterInfo {
1113                    name: "tuple_out",
1114                    kind: ::pyo3_stub_gen::type_info::ParameterKind::KeywordOnly,
1115                    type_info: || ::pyo3_stub_gen::TypeInfo {
1116                        name: "typing.Literal[True]".to_string(),
1117                        source_module: None,
1118                        import: ::std::collections::HashSet::from(["typing".into()]),
1119                        type_refs: ::std::collections::HashMap::new(),
1120                    },
1121                    default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
1122                },
1123            ],
1124            r#return: || ::pyo3_stub_gen::TypeInfo {
1125                name: "tuple[int, ...]".to_string(),
1126                source_module: None,
1127                import: ::std::collections::HashSet::from(["typing".into()]),
1128                type_refs: ::std::collections::HashMap::new(),
1129            },
1130            doc: "Return as tuple",
1131            module: None,
1132            is_async: false,
1133            deprecated: None,
1134            type_ignored: None,
1135            is_overload: true,
1136            file: file!(),
1137            line: line!(),
1138            column: column!(),
1139            index: 0usize,
1140        }
1141        "#);
1142        Ok(())
1143    }
1144
1145    fn format_as_value(tt: TokenStream2) -> String {
1146        let ttt = quote! { const _: () = #tt; };
1147        let formatted = prettyplease::unparse(&syn::parse_file(&ttt.to_string()).unwrap());
1148        formatted
1149            .trim()
1150            .strip_prefix("const _: () = ")
1151            .unwrap()
1152            .strip_suffix(';')
1153            .unwrap()
1154            .to_string()
1155    }
1156}