1use rustpython_parser::{ast, Parse};
4use syn::{parse::Parse as SynParse, parse::ParseStream, Error, LitStr, Result};
5
6use super::{
7 build_parameters_from_ast, dedent, extract_deprecated_from_decorators, extract_docstring,
8 extract_return_type, has_overload_decorator,
9};
10use crate::gen_stub::pyfunction::PyFunctionInfo;
11
12pub struct GenFunctionFromPythonInput {
14 module: Option<String>,
15 python_stub: LitStr,
16}
17
18impl SynParse for GenFunctionFromPythonInput {
19 fn parse(input: ParseStream) -> Result<Self> {
20 if input.peek(syn::Ident) {
22 let key: syn::Ident = input.parse()?;
23 if key == "module" {
24 let _: syn::token::Eq = input.parse()?;
25 let value: LitStr = input.parse()?;
26 let _: syn::token::Comma = input.parse()?;
27 let python_stub: LitStr = input.parse()?;
28 return Ok(Self {
29 module: Some(value.value()),
30 python_stub,
31 });
32 } else {
33 return Err(Error::new(
34 key.span(),
35 format!(
36 "Unknown parameter: {}. Expected 'module' or a string literal",
37 key
38 ),
39 ));
40 }
41 }
42
43 let python_stub: LitStr = input.parse()?;
45 Ok(Self {
46 module: None,
47 python_stub,
48 })
49 }
50}
51
52pub struct PythonFunctionStub {
54 pub func_def: ast::StmtFunctionDef,
55 pub imports: Vec<String>,
56 pub is_async: bool,
57 pub is_overload: bool,
58}
59
60impl TryFrom<PythonFunctionStub> for PyFunctionInfo {
61 type Error = syn::Error;
62
63 fn try_from(stub: PythonFunctionStub) -> Result<Self> {
64 let func_name = stub.func_def.name.to_string();
65
66 let doc = extract_docstring(&stub.func_def);
68
69 let parameters = build_parameters_from_ast(&stub.func_def.args, &stub.imports)?;
71
72 let return_type = extract_return_type(&stub.func_def.returns, &stub.imports)?;
74
75 let deprecated = extract_deprecated_from_decorators(&stub.func_def.decorator_list);
77
78 Ok(PyFunctionInfo {
83 name: func_name,
84 parameters, r#return: return_type,
86 doc,
87 module: None,
88 is_async: stub.is_async,
89 deprecated,
90 type_ignored: None,
91 is_overload: stub.is_overload,
92 index: 0, })
94 }
95}
96
97pub fn parse_python_function_stub(input: LitStr) -> Result<PyFunctionInfo> {
99 let stub_content = input.value();
100
101 let dedented_content = dedent(&stub_content);
103
104 let parsed = ast::Suite::parse(&dedented_content, "<stub>")
106 .map_err(|e| Error::new(input.span(), format!("Failed to parse Python stub: {}", e)))?;
107
108 let mut imports = Vec::new();
110 let mut function: Option<(ast::StmtFunctionDef, bool)> = None;
111
112 for stmt in parsed {
113 match stmt {
114 ast::Stmt::Import(import_stmt) => {
115 for alias in &import_stmt.names {
116 imports.push(alias.name.to_string());
117 }
118 }
119 ast::Stmt::ImportFrom(import_from_stmt) => {
120 if let Some(module) = &import_from_stmt.module {
121 imports.push(module.to_string());
122 }
123 }
124 ast::Stmt::FunctionDef(func_def) => {
125 if function.is_some() {
126 return Err(Error::new(
127 input.span(),
128 "Multiple function definitions found. Only one function is allowed per gen_function_from_python! call",
129 ));
130 }
131 function = Some((func_def, false));
132 }
133 ast::Stmt::AsyncFunctionDef(func_def) => {
134 if function.is_some() {
135 return Err(Error::new(
136 input.span(),
137 "Multiple function definitions found. Only one function is allowed per gen_function_from_python! call",
138 ));
139 }
140 let sync_func = ast::StmtFunctionDef {
142 range: func_def.range,
143 name: func_def.name,
144 type_params: func_def.type_params,
145 args: func_def.args,
146 body: func_def.body,
147 decorator_list: func_def.decorator_list,
148 returns: func_def.returns,
149 type_comment: func_def.type_comment,
150 };
151 function = Some((sync_func, true));
152 }
153 _ => {
154 }
156 }
157 }
158
159 let (func_def, is_async) = function
161 .ok_or_else(|| Error::new(input.span(), "No function definition found in Python stub"))?;
162
163 let is_overload = has_overload_decorator(&func_def.decorator_list);
165
166 let stub = PythonFunctionStub {
168 func_def,
169 imports,
170 is_async,
171 is_overload,
172 };
173 PyFunctionInfo::try_from(stub)
174}
175
176pub fn parse_python_overload_stubs(
179 input: LitStr,
180 expected_function_name: &str,
181) -> Result<Vec<PyFunctionInfo>> {
182 let stub_content = input.value();
183 let dedented_content = dedent(&stub_content);
184
185 let parsed = ast::Suite::parse(&dedented_content, "<stub>")
187 .map_err(|e| Error::new(input.span(), format!("Failed to parse Python stub: {}", e)))?;
188
189 let mut imports = Vec::new();
191 let mut functions: Vec<(ast::StmtFunctionDef, bool)> = Vec::new();
192
193 for stmt in parsed {
194 match stmt {
195 ast::Stmt::Import(import_stmt) => {
196 for alias in &import_stmt.names {
197 imports.push(alias.name.to_string());
198 }
199 }
200 ast::Stmt::ImportFrom(import_from_stmt) => {
201 if let Some(module) = &import_from_stmt.module {
202 imports.push(module.to_string());
203 }
204 }
205 ast::Stmt::FunctionDef(func_def) => {
206 functions.push((func_def, false));
207 }
208 ast::Stmt::AsyncFunctionDef(func_def) => {
209 let sync_func = ast::StmtFunctionDef {
211 range: func_def.range,
212 name: func_def.name,
213 type_params: func_def.type_params,
214 args: func_def.args,
215 body: func_def.body,
216 decorator_list: func_def.decorator_list,
217 returns: func_def.returns,
218 type_comment: func_def.type_comment,
219 };
220 functions.push((sync_func, true));
221 }
222 _ => {
223 }
225 }
226 }
227
228 if functions.is_empty() {
230 return Err(Error::new(
231 input.span(),
232 "No function definition found in python_overload parameter",
233 ));
234 }
235
236 let mut result = Vec::new();
238 for (func_def, is_async) in functions {
239 let func_name = func_def.name.to_string();
240
241 if func_name != expected_function_name {
243 return Err(Error::new(
244 input.span(),
245 format!(
246 "Function name '{}' in python_overload does not match Rust function name '{}'. Please ensure all overload function names match the Rust function name.",
247 func_name, expected_function_name
248 ),
249 ));
250 }
251
252 let is_overload = has_overload_decorator(&func_def.decorator_list);
254 if !is_overload {
255 return Err(Error::new(
256 input.span(),
257 format!(
258 "Function '{}' in python_overload must have @overload decorator",
259 func_name
260 ),
261 ));
262 }
263
264 let stub = PythonFunctionStub {
266 func_def,
267 imports: imports.clone(),
268 is_async,
269 is_overload,
270 };
271 result.push(PyFunctionInfo::try_from(stub)?);
272 }
273
274 Ok(result)
275}
276
277pub fn parse_gen_function_from_python_input(
279 input: GenFunctionFromPythonInput,
280) -> Result<PyFunctionInfo> {
281 let mut info = parse_python_function_stub(input.python_stub)?;
282
283 if let Some(module) = input.module {
285 info.module = Some(module);
286 }
287
288 Ok(info)
289}
290
291#[cfg(test)]
292mod test {
293 use super::*;
294 use proc_macro2::TokenStream as TokenStream2;
295 use quote::{quote, ToTokens};
296
297 #[test]
298 fn test_basic_function() -> Result<()> {
299 let stub_str: LitStr = syn::parse2(quote! {
300 r#"
301 def foo(x: int) -> int:
302 """A simple function"""
303 "#
304 })?;
305 let info = parse_python_function_stub(stub_str)?;
306 let out = info.to_token_stream();
307 insta::assert_snapshot!(format_as_value(out), @r#"
308 ::pyo3_stub_gen::type_info::PyFunctionInfo {
309 name: "foo",
310 parameters: &[
311 ::pyo3_stub_gen::type_info::ParameterInfo {
312 name: "x",
313 kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
314 type_info: || ::pyo3_stub_gen::TypeInfo {
315 name: "int".to_string(),
316 source_module: None,
317 import: ::std::collections::HashSet::from([]),
318 type_refs: ::std::collections::HashMap::new(),
319 },
320 default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
321 },
322 ],
323 r#return: || ::pyo3_stub_gen::TypeInfo {
324 name: "int".to_string(),
325 source_module: None,
326 import: ::std::collections::HashSet::from([]),
327 type_refs: ::std::collections::HashMap::new(),
328 },
329 doc: "A simple function",
330 module: None,
331 is_async: false,
332 deprecated: None,
333 type_ignored: None,
334 is_overload: false,
335 file: file!(),
336 line: line!(),
337 column: column!(),
338 index: 0usize,
339 }
340 "#);
341 Ok(())
342 }
343
344 #[test]
345 fn test_function_with_imports() -> Result<()> {
346 let stub_str: LitStr = syn::parse2(quote! {
347 r#"
348 import typing
349 from collections.abc import Callable
350
351 def process(func: Callable[[str], int]) -> typing.Optional[int]:
352 """Process a callback function"""
353 "#
354 })?;
355 let info = parse_python_function_stub(stub_str)?;
356 let out = info.to_token_stream();
357 insta::assert_snapshot!(format_as_value(out), @r#"
358 ::pyo3_stub_gen::type_info::PyFunctionInfo {
359 name: "process",
360 parameters: &[
361 ::pyo3_stub_gen::type_info::ParameterInfo {
362 name: "func",
363 kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
364 type_info: || ::pyo3_stub_gen::TypeInfo {
365 name: "Callable[[str], int]".to_string(),
366 source_module: None,
367 import: ::std::collections::HashSet::from([
368 "typing".into(),
369 "collections.abc".into(),
370 ]),
371 type_refs: ::std::collections::HashMap::new(),
372 },
373 default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
374 },
375 ],
376 r#return: || ::pyo3_stub_gen::TypeInfo {
377 name: "typing.Optional[int]".to_string(),
378 source_module: None,
379 import: ::std::collections::HashSet::from([
380 "typing".into(),
381 "collections.abc".into(),
382 ]),
383 type_refs: ::std::collections::HashMap::new(),
384 },
385 doc: "Process a callback function",
386 module: None,
387 is_async: false,
388 deprecated: None,
389 type_ignored: None,
390 is_overload: false,
391 file: file!(),
392 line: line!(),
393 column: column!(),
394 index: 0usize,
395 }
396 "#);
397 Ok(())
398 }
399
400 #[test]
401 fn test_complex_types() -> Result<()> {
402 let stub_str: LitStr = syn::parse2(quote! {
403 r#"
404 import collections.abc
405 import typing
406
407 def fn_override_type(cb: collections.abc.Callable[[str], typing.Any]) -> collections.abc.Callable[[str], typing.Any]:
408 """Example function with complex types"""
409 "#
410 })?;
411 let info = parse_python_function_stub(stub_str)?;
412 let out = info.to_token_stream();
413 insta::assert_snapshot!(format_as_value(out), @r#"
414 ::pyo3_stub_gen::type_info::PyFunctionInfo {
415 name: "fn_override_type",
416 parameters: &[
417 ::pyo3_stub_gen::type_info::ParameterInfo {
418 name: "cb",
419 kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
420 type_info: || ::pyo3_stub_gen::TypeInfo {
421 name: "collections.abc.Callable[[str], typing.Any]".to_string(),
422 source_module: None,
423 import: ::std::collections::HashSet::from([
424 "collections.abc".into(),
425 "typing".into(),
426 ]),
427 type_refs: ::std::collections::HashMap::new(),
428 },
429 default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
430 },
431 ],
432 r#return: || ::pyo3_stub_gen::TypeInfo {
433 name: "collections.abc.Callable[[str], typing.Any]".to_string(),
434 source_module: None,
435 import: ::std::collections::HashSet::from([
436 "collections.abc".into(),
437 "typing".into(),
438 ]),
439 type_refs: ::std::collections::HashMap::new(),
440 },
441 doc: "Example function with complex types",
442 module: None,
443 is_async: false,
444 deprecated: None,
445 type_ignored: None,
446 is_overload: false,
447 file: file!(),
448 line: line!(),
449 column: column!(),
450 index: 0usize,
451 }
452 "#);
453 Ok(())
454 }
455
456 #[test]
457 fn test_multiple_args() -> Result<()> {
458 let stub_str: LitStr = syn::parse2(quote! {
459 r#"
460 import typing
461
462 def add(a: int, b: int, c: typing.Optional[int]) -> int: ...
463 "#
464 })?;
465 let info = parse_python_function_stub(stub_str)?;
466 let out = info.to_token_stream();
467 insta::assert_snapshot!(format_as_value(out), @r#"
468 ::pyo3_stub_gen::type_info::PyFunctionInfo {
469 name: "add",
470 parameters: &[
471 ::pyo3_stub_gen::type_info::ParameterInfo {
472 name: "a",
473 kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
474 type_info: || ::pyo3_stub_gen::TypeInfo {
475 name: "int".to_string(),
476 source_module: None,
477 import: ::std::collections::HashSet::from(["typing".into()]),
478 type_refs: ::std::collections::HashMap::new(),
479 },
480 default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
481 },
482 ::pyo3_stub_gen::type_info::ParameterInfo {
483 name: "b",
484 kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
485 type_info: || ::pyo3_stub_gen::TypeInfo {
486 name: "int".to_string(),
487 source_module: None,
488 import: ::std::collections::HashSet::from(["typing".into()]),
489 type_refs: ::std::collections::HashMap::new(),
490 },
491 default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
492 },
493 ::pyo3_stub_gen::type_info::ParameterInfo {
494 name: "c",
495 kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
496 type_info: || ::pyo3_stub_gen::TypeInfo {
497 name: "typing.Optional[int]".to_string(),
498 source_module: None,
499 import: ::std::collections::HashSet::from(["typing".into()]),
500 type_refs: ::std::collections::HashMap::new(),
501 },
502 default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
503 },
504 ],
505 r#return: || ::pyo3_stub_gen::TypeInfo {
506 name: "int".to_string(),
507 source_module: None,
508 import: ::std::collections::HashSet::from(["typing".into()]),
509 type_refs: ::std::collections::HashMap::new(),
510 },
511 doc: "",
512 module: None,
513 is_async: false,
514 deprecated: None,
515 type_ignored: None,
516 is_overload: false,
517 file: file!(),
518 line: line!(),
519 column: column!(),
520 index: 0usize,
521 }
522 "#);
523 Ok(())
524 }
525
526 #[test]
527 fn test_no_return_type() -> Result<()> {
528 let stub_str: LitStr = syn::parse2(quote! {
529 r#"
530 def print_hello(name: str):
531 """Print a greeting"""
532 "#
533 })?;
534 let info = parse_python_function_stub(stub_str)?;
535 let out = info.to_token_stream();
536 insta::assert_snapshot!(format_as_value(out), @r#"
537 ::pyo3_stub_gen::type_info::PyFunctionInfo {
538 name: "print_hello",
539 parameters: &[
540 ::pyo3_stub_gen::type_info::ParameterInfo {
541 name: "name",
542 kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
543 type_info: || ::pyo3_stub_gen::TypeInfo {
544 name: "str".to_string(),
545 source_module: None,
546 import: ::std::collections::HashSet::from([]),
547 type_refs: ::std::collections::HashMap::new(),
548 },
549 default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
550 },
551 ],
552 r#return: ::pyo3_stub_gen::type_info::no_return_type_output,
553 doc: "Print a greeting",
554 module: None,
555 is_async: false,
556 deprecated: None,
557 type_ignored: None,
558 is_overload: false,
559 file: file!(),
560 line: line!(),
561 column: column!(),
562 index: 0usize,
563 }
564 "#);
565 Ok(())
566 }
567
568 #[test]
569 fn test_async_function() -> Result<()> {
570 let stub_str: LitStr = syn::parse2(quote! {
571 r#"
572 async def fetch_data(url: str) -> str:
573 """Fetch data from URL"""
574 "#
575 })?;
576 let info = parse_python_function_stub(stub_str)?;
577 let out = info.to_token_stream();
578 insta::assert_snapshot!(format_as_value(out), @r#"
579 ::pyo3_stub_gen::type_info::PyFunctionInfo {
580 name: "fetch_data",
581 parameters: &[
582 ::pyo3_stub_gen::type_info::ParameterInfo {
583 name: "url",
584 kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
585 type_info: || ::pyo3_stub_gen::TypeInfo {
586 name: "str".to_string(),
587 source_module: None,
588 import: ::std::collections::HashSet::from([]),
589 type_refs: ::std::collections::HashMap::new(),
590 },
591 default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
592 },
593 ],
594 r#return: || ::pyo3_stub_gen::TypeInfo {
595 name: "str".to_string(),
596 source_module: None,
597 import: ::std::collections::HashSet::from([]),
598 type_refs: ::std::collections::HashMap::new(),
599 },
600 doc: "Fetch data from URL",
601 module: None,
602 is_async: true,
603 deprecated: None,
604 type_ignored: None,
605 is_overload: false,
606 file: file!(),
607 line: line!(),
608 column: column!(),
609 index: 0usize,
610 }
611 "#);
612 Ok(())
613 }
614
615 #[test]
616 fn test_deprecated_decorator() -> Result<()> {
617 let stub_str: LitStr = syn::parse2(quote! {
618 r#"
619 @deprecated
620 def old_function(x: int) -> int:
621 """This function is deprecated"""
622 "#
623 })?;
624 let info = parse_python_function_stub(stub_str)?;
625 let out = info.to_token_stream();
626 insta::assert_snapshot!(format_as_value(out), @r#"
627 ::pyo3_stub_gen::type_info::PyFunctionInfo {
628 name: "old_function",
629 parameters: &[
630 ::pyo3_stub_gen::type_info::ParameterInfo {
631 name: "x",
632 kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
633 type_info: || ::pyo3_stub_gen::TypeInfo {
634 name: "int".to_string(),
635 source_module: None,
636 import: ::std::collections::HashSet::from([]),
637 type_refs: ::std::collections::HashMap::new(),
638 },
639 default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
640 },
641 ],
642 r#return: || ::pyo3_stub_gen::TypeInfo {
643 name: "int".to_string(),
644 source_module: None,
645 import: ::std::collections::HashSet::from([]),
646 type_refs: ::std::collections::HashMap::new(),
647 },
648 doc: "This function is deprecated",
649 module: None,
650 is_async: false,
651 deprecated: Some(::pyo3_stub_gen::type_info::DeprecatedInfo {
652 since: None,
653 note: None,
654 }),
655 type_ignored: None,
656 is_overload: false,
657 file: file!(),
658 line: line!(),
659 column: column!(),
660 index: 0usize,
661 }
662 "#);
663 Ok(())
664 }
665
666 #[test]
667 fn test_deprecated_with_message() -> Result<()> {
668 let stub_str: LitStr = syn::parse2(quote! {
669 r#"
670 @deprecated("Use new_function instead")
671 def old_function(x: int) -> int:
672 """This function is deprecated"""
673 "#
674 })?;
675 let info = parse_python_function_stub(stub_str)?;
676 let out = info.to_token_stream();
677 insta::assert_snapshot!(format_as_value(out), @r#"
678 ::pyo3_stub_gen::type_info::PyFunctionInfo {
679 name: "old_function",
680 parameters: &[
681 ::pyo3_stub_gen::type_info::ParameterInfo {
682 name: "x",
683 kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
684 type_info: || ::pyo3_stub_gen::TypeInfo {
685 name: "int".to_string(),
686 source_module: None,
687 import: ::std::collections::HashSet::from([]),
688 type_refs: ::std::collections::HashMap::new(),
689 },
690 default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
691 },
692 ],
693 r#return: || ::pyo3_stub_gen::TypeInfo {
694 name: "int".to_string(),
695 source_module: None,
696 import: ::std::collections::HashSet::from([]),
697 type_refs: ::std::collections::HashMap::new(),
698 },
699 doc: "This function is deprecated",
700 module: None,
701 is_async: false,
702 deprecated: Some(::pyo3_stub_gen::type_info::DeprecatedInfo {
703 since: None,
704 note: Some("Use new_function instead"),
705 }),
706 type_ignored: None,
707 is_overload: false,
708 file: file!(),
709 line: line!(),
710 column: column!(),
711 index: 0usize,
712 }
713 "#);
714 Ok(())
715 }
716
717 #[test]
718 fn test_rust_type_marker() -> Result<()> {
719 let stub_str: LitStr = syn::parse2(quote! {
720 r#"
721 def process_data(x: pyo3_stub_gen.RustType["MyRustType"]) -> pyo3_stub_gen.RustType["MyRustType"]:
722 """Process data using Rust type marker"""
723 "#
724 })?;
725 let info = parse_python_function_stub(stub_str)?;
726 let out = info.to_token_stream();
727 insta::assert_snapshot!(format_as_value(out), @r###"
728 ::pyo3_stub_gen::type_info::PyFunctionInfo {
729 name: "process_data",
730 parameters: &[
731 ::pyo3_stub_gen::type_info::ParameterInfo {
732 name: "x",
733 kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
734 type_info: <MyRustType as ::pyo3_stub_gen::PyStubType>::type_input,
735 default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
736 },
737 ],
738 r#return: <MyRustType as pyo3_stub_gen::PyStubType>::type_output,
739 doc: "Process data using Rust type marker",
740 module: None,
741 is_async: false,
742 deprecated: None,
743 type_ignored: None,
744 is_overload: false,
745 file: file!(),
746 line: line!(),
747 column: column!(),
748 index: 0usize,
749 }
750 "###);
751 Ok(())
752 }
753
754 #[test]
755 fn test_rust_type_marker_with_path() -> Result<()> {
756 let stub_str: LitStr = syn::parse2(quote! {
757 r#"
758 def process(x: pyo3_stub_gen.RustType["crate::MyType"]) -> pyo3_stub_gen.RustType["Vec<String>"]:
759 """Test with type paths"""
760 "#
761 })?;
762 let info = parse_python_function_stub(stub_str)?;
763 let out = info.to_token_stream();
764 insta::assert_snapshot!(format_as_value(out), @r###"
765 ::pyo3_stub_gen::type_info::PyFunctionInfo {
766 name: "process",
767 parameters: &[
768 ::pyo3_stub_gen::type_info::ParameterInfo {
769 name: "x",
770 kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
771 type_info: <crate::MyType as ::pyo3_stub_gen::PyStubType>::type_input,
772 default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
773 },
774 ],
775 r#return: <Vec<String> as pyo3_stub_gen::PyStubType>::type_output,
776 doc: "Test with type paths",
777 module: None,
778 is_async: false,
779 deprecated: None,
780 type_ignored: None,
781 is_overload: false,
782 file: file!(),
783 line: line!(),
784 column: column!(),
785 index: 0usize,
786 }
787 "###);
788 Ok(())
789 }
790
791 #[test]
792 fn test_keyword_only_args() -> Result<()> {
793 let stub_str: LitStr = syn::parse2(quote! {
794 r#"
795 import typing
796
797 def configure(name: str, *, dtype: str, ndim: int, jagged: bool = False) -> None:
798 """Test keyword-only parameters"""
799 "#
800 })?;
801 let info = parse_python_function_stub(stub_str)?;
802 let out = info.to_token_stream();
803 insta::assert_snapshot!(format_as_value(out), @r#"
804 ::pyo3_stub_gen::type_info::PyFunctionInfo {
805 name: "configure",
806 parameters: &[
807 ::pyo3_stub_gen::type_info::ParameterInfo {
808 name: "name",
809 kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
810 type_info: || ::pyo3_stub_gen::TypeInfo {
811 name: "str".to_string(),
812 source_module: None,
813 import: ::std::collections::HashSet::from(["typing".into()]),
814 type_refs: ::std::collections::HashMap::new(),
815 },
816 default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
817 },
818 ::pyo3_stub_gen::type_info::ParameterInfo {
819 name: "dtype",
820 kind: ::pyo3_stub_gen::type_info::ParameterKind::KeywordOnly,
821 type_info: || ::pyo3_stub_gen::TypeInfo {
822 name: "str".to_string(),
823 source_module: None,
824 import: ::std::collections::HashSet::from(["typing".into()]),
825 type_refs: ::std::collections::HashMap::new(),
826 },
827 default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
828 },
829 ::pyo3_stub_gen::type_info::ParameterInfo {
830 name: "ndim",
831 kind: ::pyo3_stub_gen::type_info::ParameterKind::KeywordOnly,
832 type_info: || ::pyo3_stub_gen::TypeInfo {
833 name: "int".to_string(),
834 source_module: None,
835 import: ::std::collections::HashSet::from(["typing".into()]),
836 type_refs: ::std::collections::HashMap::new(),
837 },
838 default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
839 },
840 ::pyo3_stub_gen::type_info::ParameterInfo {
841 name: "jagged",
842 kind: ::pyo3_stub_gen::type_info::ParameterKind::KeywordOnly,
843 type_info: || ::pyo3_stub_gen::TypeInfo {
844 name: "bool".to_string(),
845 source_module: None,
846 import: ::std::collections::HashSet::from(["typing".into()]),
847 type_refs: ::std::collections::HashMap::new(),
848 },
849 default: ::pyo3_stub_gen::type_info::ParameterDefault::Expr({
850 fn _fmt() -> String {
851 "False".to_string()
852 }
853 _fmt
854 }),
855 },
856 ],
857 r#return: || ::pyo3_stub_gen::TypeInfo {
858 name: "None".to_string(),
859 source_module: None,
860 import: ::std::collections::HashSet::from(["typing".into()]),
861 type_refs: ::std::collections::HashMap::new(),
862 },
863 doc: "Test keyword-only parameters",
864 module: None,
865 is_async: false,
866 deprecated: None,
867 type_ignored: None,
868 is_overload: false,
869 file: file!(),
870 line: line!(),
871 column: column!(),
872 index: 0usize,
873 }
874 "#);
875 Ok(())
876 }
877
878 #[test]
879 fn test_positional_only_args() -> Result<()> {
880 let stub_str: LitStr = syn::parse2(quote! {
881 r#"
882 def func(x: int, y: int, /, z: int) -> int:
883 """Test positional-only parameters"""
884 "#
885 })?;
886 let info = parse_python_function_stub(stub_str)?;
887 let out = info.to_token_stream();
888 insta::assert_snapshot!(format_as_value(out), @r#"
889 ::pyo3_stub_gen::type_info::PyFunctionInfo {
890 name: "func",
891 parameters: &[
892 ::pyo3_stub_gen::type_info::ParameterInfo {
893 name: "x",
894 kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOnly,
895 type_info: || ::pyo3_stub_gen::TypeInfo {
896 name: "int".to_string(),
897 source_module: None,
898 import: ::std::collections::HashSet::from([]),
899 type_refs: ::std::collections::HashMap::new(),
900 },
901 default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
902 },
903 ::pyo3_stub_gen::type_info::ParameterInfo {
904 name: "y",
905 kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOnly,
906 type_info: || ::pyo3_stub_gen::TypeInfo {
907 name: "int".to_string(),
908 source_module: None,
909 import: ::std::collections::HashSet::from([]),
910 type_refs: ::std::collections::HashMap::new(),
911 },
912 default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
913 },
914 ::pyo3_stub_gen::type_info::ParameterInfo {
915 name: "z",
916 kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
917 type_info: || ::pyo3_stub_gen::TypeInfo {
918 name: "int".to_string(),
919 source_module: None,
920 import: ::std::collections::HashSet::from([]),
921 type_refs: ::std::collections::HashMap::new(),
922 },
923 default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
924 },
925 ],
926 r#return: || ::pyo3_stub_gen::TypeInfo {
927 name: "int".to_string(),
928 source_module: None,
929 import: ::std::collections::HashSet::from([]),
930 type_refs: ::std::collections::HashMap::new(),
931 },
932 doc: "Test positional-only parameters",
933 module: None,
934 is_async: false,
935 deprecated: None,
936 type_ignored: None,
937 is_overload: false,
938 file: file!(),
939 line: line!(),
940 column: column!(),
941 index: 0usize,
942 }
943 "#);
944 Ok(())
945 }
946
947 #[test]
948 fn test_single_overload() -> Result<()> {
949 let stub_str: LitStr = syn::parse2(quote! {
950 r#"
951 @overload
952 def foo(x: int) -> int:
953 """Integer overload"""
954 "#
955 })?;
956 let info = parse_python_function_stub(stub_str)?;
957 let out = info.to_token_stream();
958 insta::assert_snapshot!(format_as_value(out), @r#"
959 ::pyo3_stub_gen::type_info::PyFunctionInfo {
960 name: "foo",
961 parameters: &[
962 ::pyo3_stub_gen::type_info::ParameterInfo {
963 name: "x",
964 kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
965 type_info: || ::pyo3_stub_gen::TypeInfo {
966 name: "int".to_string(),
967 source_module: None,
968 import: ::std::collections::HashSet::from([]),
969 type_refs: ::std::collections::HashMap::new(),
970 },
971 default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
972 },
973 ],
974 r#return: || ::pyo3_stub_gen::TypeInfo {
975 name: "int".to_string(),
976 source_module: None,
977 import: ::std::collections::HashSet::from([]),
978 type_refs: ::std::collections::HashMap::new(),
979 },
980 doc: "Integer overload",
981 module: None,
982 is_async: false,
983 deprecated: None,
984 type_ignored: None,
985 is_overload: true,
986 file: file!(),
987 line: line!(),
988 column: column!(),
989 index: 0usize,
990 }
991 "#);
992 Ok(())
993 }
994
995 #[test]
996 fn test_multiple_overloads() -> Result<()> {
997 let stub_str: LitStr = syn::parse2(quote! {
998 r#"
999 @overload
1000 def foo(x: int) -> int:
1001 """Integer overload"""
1002
1003 @overload
1004 def foo(x: float) -> float:
1005 """Float overload"""
1006 "#
1007 })?;
1008 let infos = parse_python_overload_stubs(stub_str, "foo")?;
1009 assert_eq!(infos.len(), 2);
1010
1011 let out1 = infos[0].to_token_stream();
1012 insta::assert_snapshot!(format_as_value(out1), @r#"
1013 ::pyo3_stub_gen::type_info::PyFunctionInfo {
1014 name: "foo",
1015 parameters: &[
1016 ::pyo3_stub_gen::type_info::ParameterInfo {
1017 name: "x",
1018 kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
1019 type_info: || ::pyo3_stub_gen::TypeInfo {
1020 name: "int".to_string(),
1021 source_module: None,
1022 import: ::std::collections::HashSet::from([]),
1023 type_refs: ::std::collections::HashMap::new(),
1024 },
1025 default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
1026 },
1027 ],
1028 r#return: || ::pyo3_stub_gen::TypeInfo {
1029 name: "int".to_string(),
1030 source_module: None,
1031 import: ::std::collections::HashSet::from([]),
1032 type_refs: ::std::collections::HashMap::new(),
1033 },
1034 doc: "Integer overload",
1035 module: None,
1036 is_async: false,
1037 deprecated: None,
1038 type_ignored: None,
1039 is_overload: true,
1040 file: file!(),
1041 line: line!(),
1042 column: column!(),
1043 index: 0usize,
1044 }
1045 "#);
1046
1047 let out2 = infos[1].to_token_stream();
1048 insta::assert_snapshot!(format_as_value(out2), @r#"
1049 ::pyo3_stub_gen::type_info::PyFunctionInfo {
1050 name: "foo",
1051 parameters: &[
1052 ::pyo3_stub_gen::type_info::ParameterInfo {
1053 name: "x",
1054 kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
1055 type_info: || ::pyo3_stub_gen::TypeInfo {
1056 name: "float".to_string(),
1057 source_module: None,
1058 import: ::std::collections::HashSet::from([]),
1059 type_refs: ::std::collections::HashMap::new(),
1060 },
1061 default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
1062 },
1063 ],
1064 r#return: || ::pyo3_stub_gen::TypeInfo {
1065 name: "float".to_string(),
1066 source_module: None,
1067 import: ::std::collections::HashSet::from([]),
1068 type_refs: ::std::collections::HashMap::new(),
1069 },
1070 doc: "Float overload",
1071 module: None,
1072 is_async: false,
1073 deprecated: None,
1074 type_ignored: None,
1075 is_overload: true,
1076 file: file!(),
1077 line: line!(),
1078 column: column!(),
1079 index: 0usize,
1080 }
1081 "#);
1082 Ok(())
1083 }
1084
1085 #[test]
1086 fn test_overload_with_literal_types() -> Result<()> {
1087 let stub_str: LitStr = syn::parse2(quote! {
1088 r#"
1089 import typing
1090 @overload
1091 def as_tuple(xs: list[int], *, tuple_out: typing.Literal[True]) -> tuple[int, ...]:
1092 """Return as tuple"""
1093 "#
1094 })?;
1095 let info = parse_python_function_stub(stub_str)?;
1096 let out = info.to_token_stream();
1097 insta::assert_snapshot!(format_as_value(out), @r#"
1098 ::pyo3_stub_gen::type_info::PyFunctionInfo {
1099 name: "as_tuple",
1100 parameters: &[
1101 ::pyo3_stub_gen::type_info::ParameterInfo {
1102 name: "xs",
1103 kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
1104 type_info: || ::pyo3_stub_gen::TypeInfo {
1105 name: "list[int]".to_string(),
1106 source_module: None,
1107 import: ::std::collections::HashSet::from(["typing".into()]),
1108 type_refs: ::std::collections::HashMap::new(),
1109 },
1110 default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
1111 },
1112 ::pyo3_stub_gen::type_info::ParameterInfo {
1113 name: "tuple_out",
1114 kind: ::pyo3_stub_gen::type_info::ParameterKind::KeywordOnly,
1115 type_info: || ::pyo3_stub_gen::TypeInfo {
1116 name: "typing.Literal[True]".to_string(),
1117 source_module: None,
1118 import: ::std::collections::HashSet::from(["typing".into()]),
1119 type_refs: ::std::collections::HashMap::new(),
1120 },
1121 default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
1122 },
1123 ],
1124 r#return: || ::pyo3_stub_gen::TypeInfo {
1125 name: "tuple[int, ...]".to_string(),
1126 source_module: None,
1127 import: ::std::collections::HashSet::from(["typing".into()]),
1128 type_refs: ::std::collections::HashMap::new(),
1129 },
1130 doc: "Return as tuple",
1131 module: None,
1132 is_async: false,
1133 deprecated: None,
1134 type_ignored: None,
1135 is_overload: true,
1136 file: file!(),
1137 line: line!(),
1138 column: column!(),
1139 index: 0usize,
1140 }
1141 "#);
1142 Ok(())
1143 }
1144
1145 fn format_as_value(tt: TokenStream2) -> String {
1146 let ttt = quote! { const _: () = #tt; };
1147 let formatted = prettyplease::unparse(&syn::parse_file(&ttt.to_string()).unwrap());
1148 formatted
1149 .trim()
1150 .strip_prefix("const _: () = ")
1151 .unwrap()
1152 .strip_suffix(';')
1153 .unwrap()
1154 .to_string()
1155 }
1156}