1use rustpython_parser::{ast, Parse};
4use syn::{parse::Parse as SynParse, parse::ParseStream, Error, LitStr, Result};
5
6use super::{
7 build_parameters_from_ast, dedent, extract_deprecated_from_decorators, extract_docstring,
8 extract_return_type, has_overload_decorator,
9};
10use crate::gen_stub::pyfunction::PyFunctionInfo;
11
12pub struct GenFunctionFromPythonInput {
14 module: Option<String>,
15 python_stub: LitStr,
16}
17
18impl SynParse for GenFunctionFromPythonInput {
19 fn parse(input: ParseStream) -> Result<Self> {
20 if input.peek(syn::Ident) {
22 let key: syn::Ident = input.parse()?;
23 if key == "module" {
24 let _: syn::token::Eq = input.parse()?;
25 let value: LitStr = input.parse()?;
26 let _: syn::token::Comma = input.parse()?;
27 let python_stub: LitStr = input.parse()?;
28 return Ok(Self {
29 module: Some(value.value()),
30 python_stub,
31 });
32 } else {
33 return Err(Error::new(
34 key.span(),
35 format!(
36 "Unknown parameter: {}. Expected 'module' or a string literal",
37 key
38 ),
39 ));
40 }
41 }
42
43 let python_stub: LitStr = input.parse()?;
45 Ok(Self {
46 module: None,
47 python_stub,
48 })
49 }
50}
51
52pub struct PythonFunctionStub {
54 pub func_def: ast::StmtFunctionDef,
55 pub imports: Vec<String>,
56 pub is_async: bool,
57 pub is_overload: bool,
58}
59
60impl TryFrom<PythonFunctionStub> for PyFunctionInfo {
61 type Error = syn::Error;
62
63 fn try_from(stub: PythonFunctionStub) -> Result<Self> {
64 let func_name = stub.func_def.name.to_string();
65
66 let doc = extract_docstring(&stub.func_def);
68
69 let parameters = build_parameters_from_ast(&stub.func_def.args, &stub.imports)?;
71
72 let return_type = extract_return_type(&stub.func_def.returns, &stub.imports)?;
74
75 let deprecated = extract_deprecated_from_decorators(&stub.func_def.decorator_list);
77
78 Ok(PyFunctionInfo {
83 name: func_name,
84 parameters, r#return: return_type,
86 doc,
87 module: None,
88 is_async: stub.is_async,
89 deprecated,
90 type_ignored: None,
91 is_overload: stub.is_overload,
92 index: 0, })
94 }
95}
96
97pub fn parse_python_function_stub(input: LitStr) -> Result<PyFunctionInfo> {
99 let stub_content = input.value();
100
101 let dedented_content = dedent(&stub_content);
103
104 let parsed = ast::Suite::parse(&dedented_content, "<stub>")
106 .map_err(|e| Error::new(input.span(), format!("Failed to parse Python stub: {}", e)))?;
107
108 let mut imports = Vec::new();
110 let mut function: Option<(ast::StmtFunctionDef, bool)> = None;
111
112 for stmt in parsed {
113 match stmt {
114 ast::Stmt::Import(import_stmt) => {
115 for alias in &import_stmt.names {
116 imports.push(alias.name.to_string());
117 }
118 }
119 ast::Stmt::ImportFrom(import_from_stmt) => {
120 if let Some(module) = &import_from_stmt.module {
121 imports.push(module.to_string());
122 }
123 }
124 ast::Stmt::FunctionDef(func_def) => {
125 if function.is_some() {
126 return Err(Error::new(
127 input.span(),
128 "Multiple function definitions found. Only one function is allowed per gen_function_from_python! call",
129 ));
130 }
131 function = Some((func_def, false));
132 }
133 ast::Stmt::AsyncFunctionDef(func_def) => {
134 if function.is_some() {
135 return Err(Error::new(
136 input.span(),
137 "Multiple function definitions found. Only one function is allowed per gen_function_from_python! call",
138 ));
139 }
140 let sync_func = ast::StmtFunctionDef {
142 range: func_def.range,
143 name: func_def.name,
144 type_params: func_def.type_params,
145 args: func_def.args,
146 body: func_def.body,
147 decorator_list: func_def.decorator_list,
148 returns: func_def.returns,
149 type_comment: func_def.type_comment,
150 };
151 function = Some((sync_func, true));
152 }
153 _ => {
154 }
156 }
157 }
158
159 let (func_def, is_async) = function
161 .ok_or_else(|| Error::new(input.span(), "No function definition found in Python stub"))?;
162
163 let is_overload = has_overload_decorator(&func_def.decorator_list);
165
166 let stub = PythonFunctionStub {
168 func_def,
169 imports,
170 is_async,
171 is_overload,
172 };
173 PyFunctionInfo::try_from(stub)
174}
175
176pub fn parse_python_overload_stubs(
179 input: LitStr,
180 expected_function_name: &str,
181) -> Result<Vec<PyFunctionInfo>> {
182 let stub_content = input.value();
183 let dedented_content = dedent(&stub_content);
184
185 let parsed = ast::Suite::parse(&dedented_content, "<stub>")
187 .map_err(|e| Error::new(input.span(), format!("Failed to parse Python stub: {}", e)))?;
188
189 let mut imports = Vec::new();
191 let mut functions: Vec<(ast::StmtFunctionDef, bool)> = Vec::new();
192
193 for stmt in parsed {
194 match stmt {
195 ast::Stmt::Import(import_stmt) => {
196 for alias in &import_stmt.names {
197 imports.push(alias.name.to_string());
198 }
199 }
200 ast::Stmt::ImportFrom(import_from_stmt) => {
201 if let Some(module) = &import_from_stmt.module {
202 imports.push(module.to_string());
203 }
204 }
205 ast::Stmt::FunctionDef(func_def) => {
206 functions.push((func_def, false));
207 }
208 ast::Stmt::AsyncFunctionDef(func_def) => {
209 let sync_func = ast::StmtFunctionDef {
211 range: func_def.range,
212 name: func_def.name,
213 type_params: func_def.type_params,
214 args: func_def.args,
215 body: func_def.body,
216 decorator_list: func_def.decorator_list,
217 returns: func_def.returns,
218 type_comment: func_def.type_comment,
219 };
220 functions.push((sync_func, true));
221 }
222 _ => {
223 }
225 }
226 }
227
228 if functions.is_empty() {
230 return Err(Error::new(
231 input.span(),
232 "No function definition found in python_overload parameter",
233 ));
234 }
235
236 let mut result = Vec::new();
238 for (func_def, is_async) in functions {
239 let func_name = func_def.name.to_string();
240
241 if func_name != expected_function_name {
243 return Err(Error::new(
244 input.span(),
245 format!(
246 "Function name '{}' in python_overload does not match Rust function name '{}'. Please ensure all overload function names match the Rust function name.",
247 func_name, expected_function_name
248 ),
249 ));
250 }
251
252 let is_overload = has_overload_decorator(&func_def.decorator_list);
254 if !is_overload {
255 return Err(Error::new(
256 input.span(),
257 format!(
258 "Function '{}' in python_overload must have @overload decorator",
259 func_name
260 ),
261 ));
262 }
263
264 let stub = PythonFunctionStub {
266 func_def,
267 imports: imports.clone(),
268 is_async,
269 is_overload,
270 };
271 result.push(PyFunctionInfo::try_from(stub)?);
272 }
273
274 Ok(result)
275}
276
277pub fn parse_gen_function_from_python_input(
279 input: GenFunctionFromPythonInput,
280) -> Result<PyFunctionInfo> {
281 let mut info = parse_python_function_stub(input.python_stub)?;
282
283 if let Some(module) = input.module {
285 info.module = Some(module);
286 }
287
288 Ok(info)
289}
290
291#[cfg(test)]
292mod test {
293 use super::*;
294 use proc_macro2::TokenStream as TokenStream2;
295 use quote::{quote, ToTokens};
296
297 #[test]
298 fn test_basic_function() -> Result<()> {
299 let stub_str: LitStr = syn::parse2(quote! {
300 r#"
301 def foo(x: int) -> int:
302 """A simple function"""
303 "#
304 })?;
305 let info = parse_python_function_stub(stub_str)?;
306 let out = info.to_token_stream();
307 insta::assert_snapshot!(format_as_value(out), @r###"
308 ::pyo3_stub_gen::type_info::PyFunctionInfo {
309 name: "foo",
310 parameters: &[
311 ::pyo3_stub_gen::type_info::ParameterInfo {
312 name: "x",
313 kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
314 type_info: || ::pyo3_stub_gen::TypeInfo {
315 name: "int".to_string(),
316 import: ::std::collections::HashSet::from([]),
317 },
318 default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
319 },
320 ],
321 r#return: || ::pyo3_stub_gen::TypeInfo {
322 name: "int".to_string(),
323 import: ::std::collections::HashSet::from([]),
324 },
325 doc: "A simple function",
326 module: None,
327 is_async: false,
328 deprecated: None,
329 type_ignored: None,
330 is_overload: false,
331 file: file!(),
332 line: line!(),
333 column: column!(),
334 index: 0usize,
335 }
336 "###);
337 Ok(())
338 }
339
340 #[test]
341 fn test_function_with_imports() -> Result<()> {
342 let stub_str: LitStr = syn::parse2(quote! {
343 r#"
344 import typing
345 from collections.abc import Callable
346
347 def process(func: Callable[[str], int]) -> typing.Optional[int]:
348 """Process a callback function"""
349 "#
350 })?;
351 let info = parse_python_function_stub(stub_str)?;
352 let out = info.to_token_stream();
353 insta::assert_snapshot!(format_as_value(out), @r###"
354 ::pyo3_stub_gen::type_info::PyFunctionInfo {
355 name: "process",
356 parameters: &[
357 ::pyo3_stub_gen::type_info::ParameterInfo {
358 name: "func",
359 kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
360 type_info: || ::pyo3_stub_gen::TypeInfo {
361 name: "Callable[[str], int]".to_string(),
362 import: ::std::collections::HashSet::from([
363 "typing".into(),
364 "collections.abc".into(),
365 ]),
366 },
367 default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
368 },
369 ],
370 r#return: || ::pyo3_stub_gen::TypeInfo {
371 name: "typing.Optional[int]".to_string(),
372 import: ::std::collections::HashSet::from([
373 "typing".into(),
374 "collections.abc".into(),
375 ]),
376 },
377 doc: "Process a callback function",
378 module: None,
379 is_async: false,
380 deprecated: None,
381 type_ignored: None,
382 is_overload: false,
383 file: file!(),
384 line: line!(),
385 column: column!(),
386 index: 0usize,
387 }
388 "###);
389 Ok(())
390 }
391
392 #[test]
393 fn test_complex_types() -> Result<()> {
394 let stub_str: LitStr = syn::parse2(quote! {
395 r#"
396 import collections.abc
397 import typing
398
399 def fn_override_type(cb: collections.abc.Callable[[str], typing.Any]) -> collections.abc.Callable[[str], typing.Any]:
400 """Example function with complex types"""
401 "#
402 })?;
403 let info = parse_python_function_stub(stub_str)?;
404 let out = info.to_token_stream();
405 insta::assert_snapshot!(format_as_value(out), @r###"
406 ::pyo3_stub_gen::type_info::PyFunctionInfo {
407 name: "fn_override_type",
408 parameters: &[
409 ::pyo3_stub_gen::type_info::ParameterInfo {
410 name: "cb",
411 kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
412 type_info: || ::pyo3_stub_gen::TypeInfo {
413 name: "collections.abc.Callable[[str], typing.Any]".to_string(),
414 import: ::std::collections::HashSet::from([
415 "collections.abc".into(),
416 "typing".into(),
417 ]),
418 },
419 default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
420 },
421 ],
422 r#return: || ::pyo3_stub_gen::TypeInfo {
423 name: "collections.abc.Callable[[str], typing.Any]".to_string(),
424 import: ::std::collections::HashSet::from([
425 "collections.abc".into(),
426 "typing".into(),
427 ]),
428 },
429 doc: "Example function with complex types",
430 module: None,
431 is_async: false,
432 deprecated: None,
433 type_ignored: None,
434 is_overload: false,
435 file: file!(),
436 line: line!(),
437 column: column!(),
438 index: 0usize,
439 }
440 "###);
441 Ok(())
442 }
443
444 #[test]
445 fn test_multiple_args() -> Result<()> {
446 let stub_str: LitStr = syn::parse2(quote! {
447 r#"
448 import typing
449
450 def add(a: int, b: int, c: typing.Optional[int]) -> int: ...
451 "#
452 })?;
453 let info = parse_python_function_stub(stub_str)?;
454 let out = info.to_token_stream();
455 insta::assert_snapshot!(format_as_value(out), @r###"
456 ::pyo3_stub_gen::type_info::PyFunctionInfo {
457 name: "add",
458 parameters: &[
459 ::pyo3_stub_gen::type_info::ParameterInfo {
460 name: "a",
461 kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
462 type_info: || ::pyo3_stub_gen::TypeInfo {
463 name: "int".to_string(),
464 import: ::std::collections::HashSet::from(["typing".into()]),
465 },
466 default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
467 },
468 ::pyo3_stub_gen::type_info::ParameterInfo {
469 name: "b",
470 kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
471 type_info: || ::pyo3_stub_gen::TypeInfo {
472 name: "int".to_string(),
473 import: ::std::collections::HashSet::from(["typing".into()]),
474 },
475 default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
476 },
477 ::pyo3_stub_gen::type_info::ParameterInfo {
478 name: "c",
479 kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
480 type_info: || ::pyo3_stub_gen::TypeInfo {
481 name: "typing.Optional[int]".to_string(),
482 import: ::std::collections::HashSet::from(["typing".into()]),
483 },
484 default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
485 },
486 ],
487 r#return: || ::pyo3_stub_gen::TypeInfo {
488 name: "int".to_string(),
489 import: ::std::collections::HashSet::from(["typing".into()]),
490 },
491 doc: "",
492 module: None,
493 is_async: false,
494 deprecated: None,
495 type_ignored: None,
496 is_overload: false,
497 file: file!(),
498 line: line!(),
499 column: column!(),
500 index: 0usize,
501 }
502 "###);
503 Ok(())
504 }
505
506 #[test]
507 fn test_no_return_type() -> Result<()> {
508 let stub_str: LitStr = syn::parse2(quote! {
509 r#"
510 def print_hello(name: str):
511 """Print a greeting"""
512 "#
513 })?;
514 let info = parse_python_function_stub(stub_str)?;
515 let out = info.to_token_stream();
516 insta::assert_snapshot!(format_as_value(out), @r###"
517 ::pyo3_stub_gen::type_info::PyFunctionInfo {
518 name: "print_hello",
519 parameters: &[
520 ::pyo3_stub_gen::type_info::ParameterInfo {
521 name: "name",
522 kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
523 type_info: || ::pyo3_stub_gen::TypeInfo {
524 name: "str".to_string(),
525 import: ::std::collections::HashSet::from([]),
526 },
527 default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
528 },
529 ],
530 r#return: ::pyo3_stub_gen::type_info::no_return_type_output,
531 doc: "Print a greeting",
532 module: None,
533 is_async: false,
534 deprecated: None,
535 type_ignored: None,
536 is_overload: false,
537 file: file!(),
538 line: line!(),
539 column: column!(),
540 index: 0usize,
541 }
542 "###);
543 Ok(())
544 }
545
546 #[test]
547 fn test_async_function() -> Result<()> {
548 let stub_str: LitStr = syn::parse2(quote! {
549 r#"
550 async def fetch_data(url: str) -> str:
551 """Fetch data from URL"""
552 "#
553 })?;
554 let info = parse_python_function_stub(stub_str)?;
555 let out = info.to_token_stream();
556 insta::assert_snapshot!(format_as_value(out), @r###"
557 ::pyo3_stub_gen::type_info::PyFunctionInfo {
558 name: "fetch_data",
559 parameters: &[
560 ::pyo3_stub_gen::type_info::ParameterInfo {
561 name: "url",
562 kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
563 type_info: || ::pyo3_stub_gen::TypeInfo {
564 name: "str".to_string(),
565 import: ::std::collections::HashSet::from([]),
566 },
567 default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
568 },
569 ],
570 r#return: || ::pyo3_stub_gen::TypeInfo {
571 name: "str".to_string(),
572 import: ::std::collections::HashSet::from([]),
573 },
574 doc: "Fetch data from URL",
575 module: None,
576 is_async: true,
577 deprecated: None,
578 type_ignored: None,
579 is_overload: false,
580 file: file!(),
581 line: line!(),
582 column: column!(),
583 index: 0usize,
584 }
585 "###);
586 Ok(())
587 }
588
589 #[test]
590 fn test_deprecated_decorator() -> Result<()> {
591 let stub_str: LitStr = syn::parse2(quote! {
592 r#"
593 @deprecated
594 def old_function(x: int) -> int:
595 """This function is deprecated"""
596 "#
597 })?;
598 let info = parse_python_function_stub(stub_str)?;
599 let out = info.to_token_stream();
600 insta::assert_snapshot!(format_as_value(out), @r###"
601 ::pyo3_stub_gen::type_info::PyFunctionInfo {
602 name: "old_function",
603 parameters: &[
604 ::pyo3_stub_gen::type_info::ParameterInfo {
605 name: "x",
606 kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
607 type_info: || ::pyo3_stub_gen::TypeInfo {
608 name: "int".to_string(),
609 import: ::std::collections::HashSet::from([]),
610 },
611 default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
612 },
613 ],
614 r#return: || ::pyo3_stub_gen::TypeInfo {
615 name: "int".to_string(),
616 import: ::std::collections::HashSet::from([]),
617 },
618 doc: "This function is deprecated",
619 module: None,
620 is_async: false,
621 deprecated: Some(::pyo3_stub_gen::type_info::DeprecatedInfo {
622 since: None,
623 note: None,
624 }),
625 type_ignored: None,
626 is_overload: false,
627 file: file!(),
628 line: line!(),
629 column: column!(),
630 index: 0usize,
631 }
632 "###);
633 Ok(())
634 }
635
636 #[test]
637 fn test_deprecated_with_message() -> Result<()> {
638 let stub_str: LitStr = syn::parse2(quote! {
639 r#"
640 @deprecated("Use new_function instead")
641 def old_function(x: int) -> int:
642 """This function is deprecated"""
643 "#
644 })?;
645 let info = parse_python_function_stub(stub_str)?;
646 let out = info.to_token_stream();
647 insta::assert_snapshot!(format_as_value(out), @r###"
648 ::pyo3_stub_gen::type_info::PyFunctionInfo {
649 name: "old_function",
650 parameters: &[
651 ::pyo3_stub_gen::type_info::ParameterInfo {
652 name: "x",
653 kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
654 type_info: || ::pyo3_stub_gen::TypeInfo {
655 name: "int".to_string(),
656 import: ::std::collections::HashSet::from([]),
657 },
658 default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
659 },
660 ],
661 r#return: || ::pyo3_stub_gen::TypeInfo {
662 name: "int".to_string(),
663 import: ::std::collections::HashSet::from([]),
664 },
665 doc: "This function is deprecated",
666 module: None,
667 is_async: false,
668 deprecated: Some(::pyo3_stub_gen::type_info::DeprecatedInfo {
669 since: None,
670 note: Some("Use new_function instead"),
671 }),
672 type_ignored: None,
673 is_overload: false,
674 file: file!(),
675 line: line!(),
676 column: column!(),
677 index: 0usize,
678 }
679 "###);
680 Ok(())
681 }
682
683 #[test]
684 fn test_rust_type_marker() -> Result<()> {
685 let stub_str: LitStr = syn::parse2(quote! {
686 r#"
687 def process_data(x: pyo3_stub_gen.RustType["MyRustType"]) -> pyo3_stub_gen.RustType["MyRustType"]:
688 """Process data using Rust type marker"""
689 "#
690 })?;
691 let info = parse_python_function_stub(stub_str)?;
692 let out = info.to_token_stream();
693 insta::assert_snapshot!(format_as_value(out), @r###"
694 ::pyo3_stub_gen::type_info::PyFunctionInfo {
695 name: "process_data",
696 parameters: &[
697 ::pyo3_stub_gen::type_info::ParameterInfo {
698 name: "x",
699 kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
700 type_info: <MyRustType as ::pyo3_stub_gen::PyStubType>::type_input,
701 default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
702 },
703 ],
704 r#return: <MyRustType as pyo3_stub_gen::PyStubType>::type_output,
705 doc: "Process data using Rust type marker",
706 module: None,
707 is_async: false,
708 deprecated: None,
709 type_ignored: None,
710 is_overload: false,
711 file: file!(),
712 line: line!(),
713 column: column!(),
714 index: 0usize,
715 }
716 "###);
717 Ok(())
718 }
719
720 #[test]
721 fn test_rust_type_marker_with_path() -> Result<()> {
722 let stub_str: LitStr = syn::parse2(quote! {
723 r#"
724 def process(x: pyo3_stub_gen.RustType["crate::MyType"]) -> pyo3_stub_gen.RustType["Vec<String>"]:
725 """Test with type paths"""
726 "#
727 })?;
728 let info = parse_python_function_stub(stub_str)?;
729 let out = info.to_token_stream();
730 insta::assert_snapshot!(format_as_value(out), @r###"
731 ::pyo3_stub_gen::type_info::PyFunctionInfo {
732 name: "process",
733 parameters: &[
734 ::pyo3_stub_gen::type_info::ParameterInfo {
735 name: "x",
736 kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
737 type_info: <crate::MyType as ::pyo3_stub_gen::PyStubType>::type_input,
738 default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
739 },
740 ],
741 r#return: <Vec<String> as pyo3_stub_gen::PyStubType>::type_output,
742 doc: "Test with type paths",
743 module: None,
744 is_async: false,
745 deprecated: None,
746 type_ignored: None,
747 is_overload: false,
748 file: file!(),
749 line: line!(),
750 column: column!(),
751 index: 0usize,
752 }
753 "###);
754 Ok(())
755 }
756
757 #[test]
758 fn test_keyword_only_args() -> Result<()> {
759 let stub_str: LitStr = syn::parse2(quote! {
760 r#"
761 import typing
762
763 def configure(name: str, *, dtype: str, ndim: int, jagged: bool = False) -> None:
764 """Test keyword-only parameters"""
765 "#
766 })?;
767 let info = parse_python_function_stub(stub_str)?;
768 let out = info.to_token_stream();
769 insta::assert_snapshot!(format_as_value(out), @r###"
770 ::pyo3_stub_gen::type_info::PyFunctionInfo {
771 name: "configure",
772 parameters: &[
773 ::pyo3_stub_gen::type_info::ParameterInfo {
774 name: "name",
775 kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
776 type_info: || ::pyo3_stub_gen::TypeInfo {
777 name: "str".to_string(),
778 import: ::std::collections::HashSet::from(["typing".into()]),
779 },
780 default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
781 },
782 ::pyo3_stub_gen::type_info::ParameterInfo {
783 name: "dtype",
784 kind: ::pyo3_stub_gen::type_info::ParameterKind::KeywordOnly,
785 type_info: || ::pyo3_stub_gen::TypeInfo {
786 name: "str".to_string(),
787 import: ::std::collections::HashSet::from(["typing".into()]),
788 },
789 default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
790 },
791 ::pyo3_stub_gen::type_info::ParameterInfo {
792 name: "ndim",
793 kind: ::pyo3_stub_gen::type_info::ParameterKind::KeywordOnly,
794 type_info: || ::pyo3_stub_gen::TypeInfo {
795 name: "int".to_string(),
796 import: ::std::collections::HashSet::from(["typing".into()]),
797 },
798 default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
799 },
800 ::pyo3_stub_gen::type_info::ParameterInfo {
801 name: "jagged",
802 kind: ::pyo3_stub_gen::type_info::ParameterKind::KeywordOnly,
803 type_info: || ::pyo3_stub_gen::TypeInfo {
804 name: "bool".to_string(),
805 import: ::std::collections::HashSet::from(["typing".into()]),
806 },
807 default: ::pyo3_stub_gen::type_info::ParameterDefault::Expr({
808 fn _fmt() -> String {
809 "False".to_string()
810 }
811 _fmt
812 }),
813 },
814 ],
815 r#return: || ::pyo3_stub_gen::TypeInfo {
816 name: "None".to_string(),
817 import: ::std::collections::HashSet::from(["typing".into()]),
818 },
819 doc: "Test keyword-only parameters",
820 module: None,
821 is_async: false,
822 deprecated: None,
823 type_ignored: None,
824 is_overload: false,
825 file: file!(),
826 line: line!(),
827 column: column!(),
828 index: 0usize,
829 }
830 "###);
831 Ok(())
832 }
833
834 #[test]
835 fn test_positional_only_args() -> Result<()> {
836 let stub_str: LitStr = syn::parse2(quote! {
837 r#"
838 def func(x: int, y: int, /, z: int) -> int:
839 """Test positional-only parameters"""
840 "#
841 })?;
842 let info = parse_python_function_stub(stub_str)?;
843 let out = info.to_token_stream();
844 insta::assert_snapshot!(format_as_value(out), @r###"
845 ::pyo3_stub_gen::type_info::PyFunctionInfo {
846 name: "func",
847 parameters: &[
848 ::pyo3_stub_gen::type_info::ParameterInfo {
849 name: "x",
850 kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOnly,
851 type_info: || ::pyo3_stub_gen::TypeInfo {
852 name: "int".to_string(),
853 import: ::std::collections::HashSet::from([]),
854 },
855 default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
856 },
857 ::pyo3_stub_gen::type_info::ParameterInfo {
858 name: "y",
859 kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOnly,
860 type_info: || ::pyo3_stub_gen::TypeInfo {
861 name: "int".to_string(),
862 import: ::std::collections::HashSet::from([]),
863 },
864 default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
865 },
866 ::pyo3_stub_gen::type_info::ParameterInfo {
867 name: "z",
868 kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
869 type_info: || ::pyo3_stub_gen::TypeInfo {
870 name: "int".to_string(),
871 import: ::std::collections::HashSet::from([]),
872 },
873 default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
874 },
875 ],
876 r#return: || ::pyo3_stub_gen::TypeInfo {
877 name: "int".to_string(),
878 import: ::std::collections::HashSet::from([]),
879 },
880 doc: "Test positional-only parameters",
881 module: None,
882 is_async: false,
883 deprecated: None,
884 type_ignored: None,
885 is_overload: false,
886 file: file!(),
887 line: line!(),
888 column: column!(),
889 index: 0usize,
890 }
891 "###);
892 Ok(())
893 }
894
895 #[test]
896 fn test_single_overload() -> Result<()> {
897 let stub_str: LitStr = syn::parse2(quote! {
898 r#"
899 @overload
900 def foo(x: int) -> int:
901 """Integer overload"""
902 "#
903 })?;
904 let info = parse_python_function_stub(stub_str)?;
905 let out = info.to_token_stream();
906 insta::assert_snapshot!(format_as_value(out), @r###"
907 ::pyo3_stub_gen::type_info::PyFunctionInfo {
908 name: "foo",
909 parameters: &[
910 ::pyo3_stub_gen::type_info::ParameterInfo {
911 name: "x",
912 kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
913 type_info: || ::pyo3_stub_gen::TypeInfo {
914 name: "int".to_string(),
915 import: ::std::collections::HashSet::from([]),
916 },
917 default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
918 },
919 ],
920 r#return: || ::pyo3_stub_gen::TypeInfo {
921 name: "int".to_string(),
922 import: ::std::collections::HashSet::from([]),
923 },
924 doc: "Integer overload",
925 module: None,
926 is_async: false,
927 deprecated: None,
928 type_ignored: None,
929 is_overload: true,
930 file: file!(),
931 line: line!(),
932 column: column!(),
933 index: 0usize,
934 }
935 "###);
936 Ok(())
937 }
938
939 #[test]
940 fn test_multiple_overloads() -> Result<()> {
941 let stub_str: LitStr = syn::parse2(quote! {
942 r#"
943 @overload
944 def foo(x: int) -> int:
945 """Integer overload"""
946
947 @overload
948 def foo(x: float) -> float:
949 """Float overload"""
950 "#
951 })?;
952 let infos = parse_python_overload_stubs(stub_str, "foo")?;
953 assert_eq!(infos.len(), 2);
954
955 let out1 = infos[0].to_token_stream();
956 insta::assert_snapshot!(format_as_value(out1), @r###"
957 ::pyo3_stub_gen::type_info::PyFunctionInfo {
958 name: "foo",
959 parameters: &[
960 ::pyo3_stub_gen::type_info::ParameterInfo {
961 name: "x",
962 kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
963 type_info: || ::pyo3_stub_gen::TypeInfo {
964 name: "int".to_string(),
965 import: ::std::collections::HashSet::from([]),
966 },
967 default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
968 },
969 ],
970 r#return: || ::pyo3_stub_gen::TypeInfo {
971 name: "int".to_string(),
972 import: ::std::collections::HashSet::from([]),
973 },
974 doc: "Integer overload",
975 module: None,
976 is_async: false,
977 deprecated: None,
978 type_ignored: None,
979 is_overload: true,
980 file: file!(),
981 line: line!(),
982 column: column!(),
983 index: 0usize,
984 }
985 "###);
986
987 let out2 = infos[1].to_token_stream();
988 insta::assert_snapshot!(format_as_value(out2), @r###"
989 ::pyo3_stub_gen::type_info::PyFunctionInfo {
990 name: "foo",
991 parameters: &[
992 ::pyo3_stub_gen::type_info::ParameterInfo {
993 name: "x",
994 kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
995 type_info: || ::pyo3_stub_gen::TypeInfo {
996 name: "float".to_string(),
997 import: ::std::collections::HashSet::from([]),
998 },
999 default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
1000 },
1001 ],
1002 r#return: || ::pyo3_stub_gen::TypeInfo {
1003 name: "float".to_string(),
1004 import: ::std::collections::HashSet::from([]),
1005 },
1006 doc: "Float overload",
1007 module: None,
1008 is_async: false,
1009 deprecated: None,
1010 type_ignored: None,
1011 is_overload: true,
1012 file: file!(),
1013 line: line!(),
1014 column: column!(),
1015 index: 0usize,
1016 }
1017 "###);
1018 Ok(())
1019 }
1020
1021 #[test]
1022 fn test_overload_with_literal_types() -> Result<()> {
1023 let stub_str: LitStr = syn::parse2(quote! {
1024 r#"
1025 import typing
1026 @overload
1027 def as_tuple(xs: list[int], *, tuple_out: typing.Literal[True]) -> tuple[int, ...]:
1028 """Return as tuple"""
1029 "#
1030 })?;
1031 let info = parse_python_function_stub(stub_str)?;
1032 let out = info.to_token_stream();
1033 insta::assert_snapshot!(format_as_value(out), @r###"
1034 ::pyo3_stub_gen::type_info::PyFunctionInfo {
1035 name: "as_tuple",
1036 parameters: &[
1037 ::pyo3_stub_gen::type_info::ParameterInfo {
1038 name: "xs",
1039 kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
1040 type_info: || ::pyo3_stub_gen::TypeInfo {
1041 name: "list[int]".to_string(),
1042 import: ::std::collections::HashSet::from(["typing".into()]),
1043 },
1044 default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
1045 },
1046 ::pyo3_stub_gen::type_info::ParameterInfo {
1047 name: "tuple_out",
1048 kind: ::pyo3_stub_gen::type_info::ParameterKind::KeywordOnly,
1049 type_info: || ::pyo3_stub_gen::TypeInfo {
1050 name: "typing.Literal[True]".to_string(),
1051 import: ::std::collections::HashSet::from(["typing".into()]),
1052 },
1053 default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
1054 },
1055 ],
1056 r#return: || ::pyo3_stub_gen::TypeInfo {
1057 name: "tuple[int, ...]".to_string(),
1058 import: ::std::collections::HashSet::from(["typing".into()]),
1059 },
1060 doc: "Return as tuple",
1061 module: None,
1062 is_async: false,
1063 deprecated: None,
1064 type_ignored: None,
1065 is_overload: true,
1066 file: file!(),
1067 line: line!(),
1068 column: column!(),
1069 index: 0usize,
1070 }
1071 "###);
1072 Ok(())
1073 }
1074
1075 fn format_as_value(tt: TokenStream2) -> String {
1076 let ttt = quote! { const _: () = #tt; };
1077 let formatted = prettyplease::unparse(&syn::parse_file(&ttt.to_string()).unwrap());
1078 formatted
1079 .trim()
1080 .strip_prefix("const _: () = ")
1081 .unwrap()
1082 .strip_suffix(';')
1083 .unwrap()
1084 .to_string()
1085 }
1086}