1mod pyfunction;
7mod pymethods;
8mod type_alias;
9
10pub use pyfunction::{
11 parse_gen_function_from_python_input, parse_python_function_stub, parse_python_overload_stubs,
12 GenFunctionFromPythonInput,
13};
14pub use pymethods::parse_python_methods_stub;
15pub use type_alias::{parse_python_type_alias_stub, GenTypeAliasFromPythonInput};
16
17use indexmap::IndexSet;
18use rustpython_parser::ast;
19use syn::{Result, Type};
20
21use super::{
22 arg::ArgInfo,
23 attr::DeprecatedInfo,
24 parameter::DefaultExpr,
25 parameter::{ParameterKind, ParameterWithKind, Parameters},
26 util::TypeOrOverride,
27};
28
29fn dedent(text: &str) -> String {
31 let lines: Vec<&str> = text.lines().collect();
32
33 let min_indent = lines
35 .iter()
36 .filter(|line| !line.trim().is_empty())
37 .map(|line| line.len() - line.trim_start().len())
38 .min()
39 .unwrap_or(0);
40
41 lines
43 .iter()
44 .map(|line| {
45 if line.len() >= min_indent {
46 &line[min_indent..]
47 } else {
48 line
49 }
50 })
51 .collect::<Vec<_>>()
52 .join("\n")
53}
54
55fn extract_docstring(func_def: &ast::StmtFunctionDef) -> String {
57 if let Some(ast::Stmt::Expr(expr_stmt)) = func_def.body.first() {
58 if let ast::Expr::Constant(constant) = &*expr_stmt.value {
59 if let ast::Constant::Str(s) = &constant.value {
60 return s.to_string();
61 }
62 }
63 }
64 String::new()
65}
66
67fn extract_deprecated_from_decorators(decorators: &[ast::Expr]) -> Option<DeprecatedInfo> {
69 for decorator in decorators {
70 match decorator {
72 ast::Expr::Name(name) if name.id.as_str() == "deprecated" => {
73 return Some(DeprecatedInfo {
74 since: None,
75 note: None,
76 });
77 }
78 ast::Expr::Call(call) => {
79 if let ast::Expr::Name(name) = &*call.func {
80 if name.id.as_str() == "deprecated" {
81 let note = call.args.first().and_then(|arg| match arg {
83 ast::Expr::Constant(constant) => match &constant.value {
84 ast::Constant::Str(s) => Some(s.to_string()),
85 _ => None,
86 },
87 _ => None,
88 });
89 return Some(DeprecatedInfo { since: None, note });
90 }
91 }
92 }
93 _ => {}
94 }
95 }
96 None
97}
98
99fn has_overload_decorator(decorator_list: &[ast::Expr]) -> bool {
101 decorator_list.iter().any(|decorator| {
102 match decorator {
103 ast::Expr::Name(name) => name.id.as_str() == "overload",
104 ast::Expr::Attribute(attr) => {
105 attr.attr.as_str() == "overload"
107 }
108 _ => false,
109 }
110 })
111}
112
113pub(super) fn build_parameters_from_ast(
118 args: &ast::Arguments,
119 imports: &[String],
120) -> Result<Parameters> {
121 let dummy_type: Type = syn::parse_str("()").unwrap();
122 let mut parameters = Vec::new();
123
124 let process_arg_with_default =
126 |arg: &ast::ArgWithDefault, kind: ParameterKind| -> Result<Option<ParameterWithKind>> {
127 let arg_name = arg.def.arg.to_string();
128
129 if arg_name == "self" || arg_name == "cls" {
131 return Ok(None);
132 }
133
134 let type_override = if let Some(annotation) = &arg.def.annotation {
135 type_annotation_to_type_override(annotation, imports, dummy_type.clone())?
136 } else {
137 TypeOrOverride::OverrideType {
139 r#type: dummy_type.clone(),
140 type_repr: "typing.Any".to_string(),
141 imports: IndexSet::from(["typing".to_string()]),
142 rust_type_markers: vec![],
143 }
144 };
145
146 let arg_info = ArgInfo {
147 name: arg_name,
148 r#type: type_override,
149 };
150
151 let default_expr = if let Some(default) = &arg.default {
153 Some(DefaultExpr::Python(python_ast_to_python_string(default)?))
154 } else {
155 None
156 };
157
158 Ok(Some(ParameterWithKind {
159 arg_info,
160 kind,
161 default_expr,
162 }))
163 };
164
165 let process_var_arg = |arg: &ast::Arg, kind: ParameterKind| -> Result<ParameterWithKind> {
167 let arg_name = arg.arg.to_string();
168
169 let type_override = if let Some(annotation) = &arg.annotation {
170 type_annotation_to_type_override(annotation, imports, dummy_type.clone())?
171 } else {
172 TypeOrOverride::OverrideType {
174 r#type: dummy_type.clone(),
175 type_repr: "typing.Any".to_string(),
176 imports: IndexSet::from(["typing".to_string()]),
177 rust_type_markers: vec![],
178 }
179 };
180
181 let arg_info = ArgInfo {
182 name: arg_name,
183 r#type: type_override,
184 };
185
186 Ok(ParameterWithKind {
187 arg_info,
188 kind,
189 default_expr: None,
190 })
191 };
192
193 for arg in &args.posonlyargs {
195 if let Some(param) = process_arg_with_default(arg, ParameterKind::PositionalOnly)? {
196 parameters.push(param);
197 }
198 }
199
200 for arg in &args.args {
202 if let Some(param) = process_arg_with_default(arg, ParameterKind::PositionalOrKeyword)? {
203 parameters.push(param);
204 }
205 }
206
207 if let Some(vararg) = &args.vararg {
209 parameters.push(process_var_arg(vararg, ParameterKind::VarPositional)?);
210 }
211
212 for arg in &args.kwonlyargs {
214 if let Some(param) = process_arg_with_default(arg, ParameterKind::KeywordOnly)? {
215 parameters.push(param);
216 }
217 }
218
219 if let Some(kwarg) = &args.kwarg {
221 parameters.push(process_var_arg(kwarg, ParameterKind::VarKeyword)?);
222 }
223
224 Ok(Parameters::from_vec(parameters))
225}
226
227fn extract_return_type(
229 returns: &Option<Box<ast::Expr>>,
230 imports: &[String],
231) -> Result<Option<TypeOrOverride>> {
232 let dummy_type: Type = syn::parse_str("()").unwrap();
234
235 if let Some(return_annotation) = returns {
236 Ok(Some(type_annotation_to_type_override(
237 return_annotation,
238 imports,
239 dummy_type,
240 )?))
241 } else {
242 Ok(None)
244 }
245}
246
247fn collect_rust_type_markers(expr: &ast::Expr) -> Result<Vec<String>> {
251 let mut markers = Vec::new();
252 collect_rust_type_markers_impl(expr, &mut markers)?;
253 Ok(markers)
254}
255
256fn collect_rust_type_markers_impl(expr: &ast::Expr, markers: &mut Vec<String>) -> Result<()> {
257 if let Some(type_name) = extract_rust_type_marker(expr)? {
259 markers.push(type_name);
260 return Ok(());
261 }
262
263 match expr {
265 ast::Expr::Subscript(subscript) => {
266 collect_rust_type_markers_impl(&subscript.value, markers)?;
267 collect_rust_type_markers_impl(&subscript.slice, markers)?;
268 }
269 ast::Expr::Tuple(tuple) => {
270 for elt in &tuple.elts {
271 collect_rust_type_markers_impl(elt, markers)?;
272 }
273 }
274 ast::Expr::List(list) => {
275 for elt in &list.elts {
276 collect_rust_type_markers_impl(elt, markers)?;
277 }
278 }
279 ast::Expr::BinOp(binop) => {
280 collect_rust_type_markers_impl(&binop.left, markers)?;
281 collect_rust_type_markers_impl(&binop.right, markers)?;
282 }
283 _ => {}
284 }
285 Ok(())
286}
287
288fn type_annotation_to_type_override(
290 expr: &ast::Expr,
291 imports: &[String],
292 dummy_type: Type,
293) -> Result<TypeOrOverride> {
294 if let Some(type_name) = extract_rust_type_marker(expr)? {
296 let rust_type: Type = syn::parse_str(&type_name).map_err(|e| {
297 syn::Error::new(
298 proc_macro2::Span::call_site(),
299 format!("Failed to parse Rust type '{}': {}", type_name, e),
300 )
301 })?;
302 return Ok(TypeOrOverride::RustType { r#type: rust_type });
303 }
304
305 let type_str = expr_to_type_string(expr)?;
306
307 let rust_type_markers = collect_rust_type_markers(expr)?;
309
310 let import_set: IndexSet<String> = imports.iter().map(|s| s.to_string()).collect();
312
313 Ok(TypeOrOverride::OverrideType {
314 r#type: dummy_type,
315 type_repr: type_str,
316 imports: import_set,
317 rust_type_markers,
318 })
319}
320
321fn extract_rust_type_marker(expr: &ast::Expr) -> Result<Option<String>> {
326 if let ast::Expr::Subscript(subscript) = expr {
328 if let ast::Expr::Attribute(attr) = &*subscript.value {
329 if attr.attr.as_str() == "RustType" {
331 if let ast::Expr::Name(name) = &*attr.value {
333 if name.id.as_str() == "pyo3_stub_gen" {
334 if let ast::Expr::Constant(constant) = &*subscript.slice {
336 if let ast::Constant::Str(s) = &constant.value {
337 return Ok(Some(s.to_string()));
338 }
339 }
340 return Err(syn::Error::new(
341 proc_macro2::Span::call_site(),
342 "pyo3_stub_gen.RustType requires a string literal (e.g., RustType[\"MyType\"])",
343 ));
344 }
345 }
346 }
347 }
348 }
349 Ok(None)
350}
351
352fn escape_python_string(s: &str) -> String {
357 let use_double_quotes = s.contains('\'') && !s.contains('"');
359 let quote_char = if use_double_quotes { '"' } else { '\'' };
360
361 let mut result = String::with_capacity(s.len() + 2);
362 result.push(quote_char);
363
364 for ch in s.chars() {
365 match ch {
366 '\\' => result.push_str("\\\\"),
367 '\'' if !use_double_quotes => result.push_str("\\'"),
368 '"' if use_double_quotes => result.push_str("\\\""),
369 '\n' => result.push_str("\\n"),
370 '\r' => result.push_str("\\r"),
371 '\t' => result.push_str("\\t"),
372 '\x00' => result.push_str("\\x00"),
373 c if c.is_ascii_control() => {
374 result.push_str(&format!("\\x{:02x}", c as u8));
376 }
377 c => result.push(c),
378 }
379 }
380
381 result.push(quote_char);
382 result
383}
384
385fn python_ast_to_python_string(expr: &ast::Expr) -> Result<String> {
390 match expr {
391 ast::Expr::Constant(constant) => match &constant.value {
392 ast::Constant::None => Ok("None".to_string()),
393 ast::Constant::Bool(true) => Ok("True".to_string()),
394 ast::Constant::Bool(false) => Ok("False".to_string()),
395 ast::Constant::Int(i) => Ok(i.to_string()),
396 ast::Constant::Float(f) => Ok(f.to_string()),
397 ast::Constant::Str(s) => Ok(escape_python_string(s)),
398 ast::Constant::Bytes(_) => Err(syn::Error::new(
399 proc_macro2::Span::call_site(),
400 "Bytes literals are not supported as default values",
401 )),
402 ast::Constant::Ellipsis => Ok("...".to_string()),
403 _ => Err(syn::Error::new(
404 proc_macro2::Span::call_site(),
405 format!("Unsupported constant type: {:?}", constant.value),
406 )),
407 },
408 ast::Expr::List(list) => {
409 let elements: Result<Vec<_>> =
411 list.elts.iter().map(python_ast_to_python_string).collect();
412 Ok(format!("[{}]", elements?.join(", ")))
413 }
414 ast::Expr::Tuple(tuple) => {
415 let elements: Result<Vec<_>> =
417 tuple.elts.iter().map(python_ast_to_python_string).collect();
418 let elements = elements?;
419 if elements.len() == 1 {
420 Ok(format!("({},)", elements[0]))
422 } else {
423 Ok(format!("({})", elements.join(", ")))
424 }
425 }
426 ast::Expr::Dict(dict) => {
427 let mut pairs = Vec::new();
429 for (key_opt, value) in dict.keys.iter().zip(dict.values.iter()) {
430 if let Some(key) = key_opt {
431 let key_str = python_ast_to_python_string(key)?;
432 let value_str = python_ast_to_python_string(value)?;
433 pairs.push(format!("{}: {}", key_str, value_str));
434 } else {
435 return Ok("...".to_string());
437 }
438 }
439 Ok(format!("{{{}}}", pairs.join(", ")))
440 }
441 ast::Expr::Name(name) => Ok(name.id.to_string()),
442 ast::Expr::Attribute(_) => {
443 expr_to_type_string(expr)
445 }
446 ast::Expr::UnaryOp(unary) => {
447 if matches!(unary.op, ast::UnaryOp::USub) {
449 if let ast::Expr::Constant(constant) = &*unary.operand {
450 match &constant.value {
451 ast::Constant::Int(i) => Ok(format!("-{}", i)),
452 ast::Constant::Float(f) => Ok(format!("-{}", f)),
453 _ => Ok("...".to_string()),
454 }
455 } else {
456 Ok("...".to_string())
457 }
458 } else {
459 Ok("...".to_string())
460 }
461 }
462 _ => {
463 Ok("...".to_string())
465 }
466 }
467}
468
469fn expr_to_type_string(expr: &ast::Expr) -> Result<String> {
471 expr_to_type_string_inner(expr, false)
472}
473
474fn expr_to_type_string_inner(expr: &ast::Expr, in_subscript: bool) -> Result<String> {
476 if let Some(type_name) = extract_rust_type_marker(expr)? {
479 return Ok(type_name);
480 }
481
482 Ok(match expr {
483 ast::Expr::Name(name) => name.id.to_string(),
484 ast::Expr::Attribute(attr) => {
485 format!(
486 "{}.{}",
487 expr_to_type_string_inner(&attr.value, false)?,
488 attr.attr
489 )
490 }
491 ast::Expr::Subscript(subscript) => {
492 let base = expr_to_type_string_inner(&subscript.value, false)?;
493 let slice = expr_to_type_string_inner(&subscript.slice, true)?;
494 format!("{}[{}]", base, slice)
495 }
496 ast::Expr::List(list) => {
497 let elements: Result<Vec<String>> = list
498 .elts
499 .iter()
500 .map(|e| expr_to_type_string_inner(e, false))
501 .collect();
502 format!("[{}]", elements?.join(", "))
503 }
504 ast::Expr::Tuple(tuple) => {
505 let elements: Result<Vec<String>> = tuple
506 .elts
507 .iter()
508 .map(|e| expr_to_type_string_inner(e, in_subscript))
509 .collect();
510 let elements = elements?;
511 if in_subscript {
512 elements.join(", ")
514 } else {
515 format!("({})", elements.join(", "))
516 }
517 }
518 ast::Expr::Constant(constant) => match &constant.value {
519 ast::Constant::Int(i) => i.to_string(),
520 ast::Constant::Str(s) => format!("\"{}\"", s),
521 ast::Constant::Bool(b) => if *b { "True" } else { "False" }.to_string(),
522 ast::Constant::None => "None".to_string(),
523 ast::Constant::Ellipsis => "...".to_string(),
524 _ => "Any".to_string(),
525 },
526 ast::Expr::BinOp(binop) => {
527 if matches!(binop.op, ast::Operator::BitOr) {
529 let left = expr_to_type_string_inner(&binop.left, false)?;
530 let right = expr_to_type_string_inner(&binop.right, false)?;
531 format!("{} | {}", left, right)
532 } else {
533 "Any".to_string()
534 }
535 }
536 _ => "Any".to_string(),
537 })
538}
539
540#[cfg(test)]
541mod tests {
542 use super::*;
543 use rustpython_parser as parser;
544
545 fn parse_and_convert(python_expr: &str) -> Result<String> {
547 let source = format!("x = {}", python_expr);
548 let parsed = parser::parse(&source, parser::Mode::Module, "<test>")
549 .map_err(|e| syn::Error::new(proc_macro2::Span::call_site(), format!("{}", e)))?;
550
551 if let parser::ast::Mod::Module(module) = parsed {
552 if let Some(parser::ast::Stmt::Assign(assign)) = module.body.first() {
553 return python_ast_to_python_string(&assign.value);
554 }
555 }
556 Err(syn::Error::new(
557 proc_macro2::Span::call_site(),
558 "Failed to parse expression",
559 ))
560 }
561
562 #[test]
563 fn test_string_basic() -> Result<()> {
564 let result = parse_and_convert(r#""hello""#)?;
565 assert_eq!(result, r#"'hello'"#);
566 Ok(())
567 }
568
569 #[test]
570 fn test_string_with_single_quote() -> Result<()> {
571 let result = parse_and_convert(r#""it's""#)?;
573 assert_eq!(result, r#""it's""#);
575 Ok(())
576 }
577
578 #[test]
579 fn test_string_with_double_quote() -> Result<()> {
580 let result = parse_and_convert(r#"'say "hi"'"#)?;
582 assert_eq!(result, r#"'say "hi"'"#);
584 Ok(())
585 }
586
587 #[test]
588 fn test_string_with_newline() -> Result<()> {
589 let result = parse_and_convert(r#""line1\nline2""#)?;
591 assert_eq!(result, "'line1\\nline2'");
593 Ok(())
594 }
595
596 #[test]
597 fn test_string_with_tab() -> Result<()> {
598 let result = parse_and_convert(r#""a\tb""#)?;
599 assert_eq!(result, "'a\\tb'");
601 Ok(())
602 }
603
604 #[test]
605 fn test_string_with_backslash() -> Result<()> {
606 let result = parse_and_convert(r#"r"path\to\file""#)?;
608 assert_eq!(result, r"'path\\to\\file'");
612 Ok(())
613 }
614
615 #[test]
616 fn test_string_with_both_quotes() -> Result<()> {
617 let result = parse_and_convert(r#""it's \"great\"""#)?;
619 assert_eq!(result, r#"'it\'s "great"'"#);
621 Ok(())
622 }
623
624 #[test]
625 fn test_string_empty() -> Result<()> {
626 let result = parse_and_convert(r#""""#)?;
627 assert_eq!(result, "''");
628 Ok(())
629 }
630
631 #[test]
632 fn test_none() -> Result<()> {
633 let result = parse_and_convert("None")?;
634 assert_eq!(result, "None");
635 Ok(())
636 }
637
638 #[test]
639 fn test_bool_true() -> Result<()> {
640 let result = parse_and_convert("True")?;
641 assert_eq!(result, "True");
642 Ok(())
643 }
644
645 #[test]
646 fn test_bool_false() -> Result<()> {
647 let result = parse_and_convert("False")?;
648 assert_eq!(result, "False");
649 Ok(())
650 }
651
652 #[test]
653 fn test_int() -> Result<()> {
654 let result = parse_and_convert("42")?;
655 assert_eq!(result, "42");
656 Ok(())
657 }
658
659 #[test]
660 fn test_float() -> Result<()> {
661 let result = parse_and_convert("3.14")?;
662 assert_eq!(result, "3.14");
663 Ok(())
664 }
665
666 #[test]
667 fn test_list() -> Result<()> {
668 let result = parse_and_convert("[1, 2, 3]")?;
669 assert_eq!(result, "[1, 2, 3]");
670 Ok(())
671 }
672
673 #[test]
674 fn test_tuple() -> Result<()> {
675 let result = parse_and_convert("(1, 2)")?;
676 assert_eq!(result, "(1, 2)");
677 Ok(())
678 }
679
680 #[test]
681 fn test_tuple_single_element() -> Result<()> {
682 let result = parse_and_convert("(1,)")?;
683 assert_eq!(result, "(1,)");
684 Ok(())
685 }
686
687 #[test]
688 fn test_dict() -> Result<()> {
689 let result = parse_and_convert(r#"{"a": 1, "b": 2}"#)?;
690 assert_eq!(result, "{'a': 1, 'b': 2}");
691 Ok(())
692 }
693
694 #[test]
695 fn test_negative_int() -> Result<()> {
696 let result = parse_and_convert("-42")?;
697 assert_eq!(result, "-42");
698 Ok(())
699 }
700
701 #[test]
702 fn test_negative_float() -> Result<()> {
703 let result = parse_and_convert("-3.14")?;
704 assert_eq!(result, "-3.14");
705 Ok(())
706 }
707}