pyo3_stub_gen_derive/gen_stub/
parse_python.rs

1//! Parse Python stub syntax and generate PyFunctionInfo and MethodInfo
2//!
3//! This module provides functionality to parse Python stub syntax (type hints)
4//! and convert them into Rust metadata structures for stub generation.
5
6mod pyfunction;
7mod pymethods;
8mod type_alias;
9
10pub use pyfunction::{
11    parse_gen_function_from_python_input, parse_python_function_stub, parse_python_overload_stubs,
12    GenFunctionFromPythonInput,
13};
14pub use pymethods::parse_python_methods_stub;
15pub use type_alias::{parse_python_type_alias_stub, GenTypeAliasFromPythonInput};
16
17use indexmap::IndexSet;
18use rustpython_parser::ast;
19use syn::{Result, Type};
20
21use super::{
22    arg::ArgInfo,
23    attr::DeprecatedInfo,
24    parameter::DefaultExpr,
25    parameter::{ParameterKind, ParameterWithKind, Parameters},
26    util::TypeOrOverride,
27};
28
29/// Remove common leading whitespace from all lines (similar to Python's textwrap.dedent)
30fn dedent(text: &str) -> String {
31    let lines: Vec<&str> = text.lines().collect();
32
33    // Find the minimum indentation (ignoring empty lines)
34    let min_indent = lines
35        .iter()
36        .filter(|line| !line.trim().is_empty())
37        .map(|line| line.len() - line.trim_start().len())
38        .min()
39        .unwrap_or(0);
40
41    // Remove the common indentation from each line
42    lines
43        .iter()
44        .map(|line| {
45            if line.len() >= min_indent {
46                &line[min_indent..]
47            } else {
48                line
49            }
50        })
51        .collect::<Vec<_>>()
52        .join("\n")
53}
54
55/// Extract docstring from function definition
56fn extract_docstring(func_def: &ast::StmtFunctionDef) -> String {
57    if let Some(ast::Stmt::Expr(expr_stmt)) = func_def.body.first() {
58        if let ast::Expr::Constant(constant) = &*expr_stmt.value {
59            if let ast::Constant::Str(s) = &constant.value {
60                return s.to_string();
61            }
62        }
63    }
64    String::new()
65}
66
67/// Extract deprecated decorator information if present
68fn extract_deprecated_from_decorators(decorators: &[ast::Expr]) -> Option<DeprecatedInfo> {
69    for decorator in decorators {
70        // Check for @deprecated or @deprecated("message")
71        match decorator {
72            ast::Expr::Name(name) if name.id.as_str() == "deprecated" => {
73                return Some(DeprecatedInfo {
74                    since: None,
75                    note: None,
76                });
77            }
78            ast::Expr::Call(call) => {
79                if let ast::Expr::Name(name) = &*call.func {
80                    if name.id.as_str() == "deprecated" {
81                        // Try to extract the message from the first argument
82                        let note = call.args.first().and_then(|arg| match arg {
83                            ast::Expr::Constant(constant) => match &constant.value {
84                                ast::Constant::Str(s) => Some(s.to_string()),
85                                _ => None,
86                            },
87                            _ => None,
88                        });
89                        return Some(DeprecatedInfo { since: None, note });
90                    }
91                }
92            }
93            _ => {}
94        }
95    }
96    None
97}
98
99/// Check if decorator list contains @overload decorator
100fn has_overload_decorator(decorator_list: &[ast::Expr]) -> bool {
101    decorator_list.iter().any(|decorator| {
102        match decorator {
103            ast::Expr::Name(name) => name.id.as_str() == "overload",
104            ast::Expr::Attribute(attr) => {
105                // Handle typing.overload or t.overload
106                attr.attr.as_str() == "overload"
107            }
108            _ => false,
109        }
110    })
111}
112
113/// Build Parameters directly from Python AST Arguments
114///
115/// This function constructs Parameters with proper ParameterKind classification
116/// based on Python's argument structure (positional-only, keyword-only, varargs, etc.)
117pub(super) fn build_parameters_from_ast(
118    args: &ast::Arguments,
119    imports: &[String],
120) -> Result<Parameters> {
121    let dummy_type: Type = syn::parse_str("()").unwrap();
122    let mut parameters = Vec::new();
123
124    // Helper to process a single argument with default value
125    let process_arg_with_default =
126        |arg: &ast::ArgWithDefault, kind: ParameterKind| -> Result<Option<ParameterWithKind>> {
127            let arg_name = arg.def.arg.to_string();
128
129            // Skip 'self' and 'cls' arguments (they are added automatically in generation)
130            if arg_name == "self" || arg_name == "cls" {
131                return Ok(None);
132            }
133
134            let type_override = if let Some(annotation) = &arg.def.annotation {
135                type_annotation_to_type_override(annotation, imports, dummy_type.clone())?
136            } else {
137                // No type annotation - use Any
138                TypeOrOverride::OverrideType {
139                    r#type: dummy_type.clone(),
140                    type_repr: "typing.Any".to_string(),
141                    imports: IndexSet::from(["typing".to_string()]),
142                    rust_type_markers: vec![],
143                }
144            };
145
146            let arg_info = ArgInfo {
147                name: arg_name,
148                r#type: type_override,
149            };
150
151            // Convert default value from Python AST to Python string
152            let default_expr = if let Some(default) = &arg.default {
153                Some(DefaultExpr::Python(python_ast_to_python_string(default)?))
154            } else {
155                None
156            };
157
158            Ok(Some(ParameterWithKind {
159                arg_info,
160                kind,
161                default_expr,
162            }))
163        };
164
165    // Helper to process vararg or kwarg (ast::Arg, not ast::ArgWithDefault)
166    let process_var_arg = |arg: &ast::Arg, kind: ParameterKind| -> Result<ParameterWithKind> {
167        let arg_name = arg.arg.to_string();
168
169        let type_override = if let Some(annotation) = &arg.annotation {
170            type_annotation_to_type_override(annotation, imports, dummy_type.clone())?
171        } else {
172            // No type annotation - use Any
173            TypeOrOverride::OverrideType {
174                r#type: dummy_type.clone(),
175                type_repr: "typing.Any".to_string(),
176                imports: IndexSet::from(["typing".to_string()]),
177                rust_type_markers: vec![],
178            }
179        };
180
181        let arg_info = ArgInfo {
182            name: arg_name,
183            r#type: type_override,
184        };
185
186        Ok(ParameterWithKind {
187            arg_info,
188            kind,
189            default_expr: None,
190        })
191    };
192
193    // Process positional-only arguments (before /)
194    for arg in &args.posonlyargs {
195        if let Some(param) = process_arg_with_default(arg, ParameterKind::PositionalOnly)? {
196            parameters.push(param);
197        }
198    }
199
200    // Process regular positional/keyword arguments
201    for arg in &args.args {
202        if let Some(param) = process_arg_with_default(arg, ParameterKind::PositionalOrKeyword)? {
203            parameters.push(param);
204        }
205    }
206
207    // Process *args (vararg)
208    if let Some(vararg) = &args.vararg {
209        parameters.push(process_var_arg(vararg, ParameterKind::VarPositional)?);
210    }
211
212    // Process keyword-only arguments (after *)
213    for arg in &args.kwonlyargs {
214        if let Some(param) = process_arg_with_default(arg, ParameterKind::KeywordOnly)? {
215            parameters.push(param);
216        }
217    }
218
219    // Process **kwargs (kwarg)
220    if let Some(kwarg) = &args.kwarg {
221        parameters.push(process_var_arg(kwarg, ParameterKind::VarKeyword)?);
222    }
223
224    Ok(Parameters::from_vec(parameters))
225}
226
227/// Extract return type from function definition
228fn extract_return_type(
229    returns: &Option<Box<ast::Expr>>,
230    imports: &[String],
231) -> Result<Option<TypeOrOverride>> {
232    // Dummy type for TypeOrOverride (not used in ToTokens for OverrideType)
233    let dummy_type: Type = syn::parse_str("()").unwrap();
234
235    if let Some(return_annotation) = returns {
236        Ok(Some(type_annotation_to_type_override(
237            return_annotation,
238            imports,
239            dummy_type,
240        )?))
241    } else {
242        // No return type annotation - use None (void)
243        Ok(None)
244    }
245}
246
247/// Recursively collect all RustType markers from a Python AST expression
248///
249/// Returns a vector of Rust type names found in RustType["TypeName"] markers
250fn collect_rust_type_markers(expr: &ast::Expr) -> Result<Vec<String>> {
251    let mut markers = Vec::new();
252    collect_rust_type_markers_impl(expr, &mut markers)?;
253    Ok(markers)
254}
255
256fn collect_rust_type_markers_impl(expr: &ast::Expr, markers: &mut Vec<String>) -> Result<()> {
257    // Check if this expression itself is a RustType marker
258    if let Some(type_name) = extract_rust_type_marker(expr)? {
259        markers.push(type_name);
260        return Ok(());
261    }
262
263    // Recursively check children
264    match expr {
265        ast::Expr::Subscript(subscript) => {
266            collect_rust_type_markers_impl(&subscript.value, markers)?;
267            collect_rust_type_markers_impl(&subscript.slice, markers)?;
268        }
269        ast::Expr::Tuple(tuple) => {
270            for elt in &tuple.elts {
271                collect_rust_type_markers_impl(elt, markers)?;
272            }
273        }
274        ast::Expr::List(list) => {
275            for elt in &list.elts {
276                collect_rust_type_markers_impl(elt, markers)?;
277            }
278        }
279        ast::Expr::BinOp(binop) => {
280            collect_rust_type_markers_impl(&binop.left, markers)?;
281            collect_rust_type_markers_impl(&binop.right, markers)?;
282        }
283        _ => {}
284    }
285    Ok(())
286}
287
288/// Convert Python type annotation to TypeOrOverride
289fn type_annotation_to_type_override(
290    expr: &ast::Expr,
291    imports: &[String],
292    dummy_type: Type,
293) -> Result<TypeOrOverride> {
294    // Check for pyo3_stub_gen.RustType["TypeName"] marker
295    if let Some(type_name) = extract_rust_type_marker(expr)? {
296        let rust_type: Type = syn::parse_str(&type_name).map_err(|e| {
297            syn::Error::new(
298                proc_macro2::Span::call_site(),
299                format!("Failed to parse Rust type '{}': {}", type_name, e),
300            )
301        })?;
302        return Ok(TypeOrOverride::RustType { r#type: rust_type });
303    }
304
305    let type_str = expr_to_type_string(expr)?;
306
307    // Collect all RustType markers in compound expressions
308    let rust_type_markers = collect_rust_type_markers(expr)?;
309
310    // Convert imports to IndexSet
311    let import_set: IndexSet<String> = imports.iter().map(|s| s.to_string()).collect();
312
313    Ok(TypeOrOverride::OverrideType {
314        r#type: dummy_type,
315        type_repr: type_str,
316        imports: import_set,
317        rust_type_markers,
318    })
319}
320
321/// Extract type name from pyo3_stub_gen.RustType["TypeName"]
322///
323/// Returns Some(type_name) if the expression matches the pattern, None otherwise.
324/// Returns an error if the pattern matches but the type name is not a string literal.
325fn extract_rust_type_marker(expr: &ast::Expr) -> Result<Option<String>> {
326    // Match pattern: pyo3_stub_gen.RustType[...]
327    if let ast::Expr::Subscript(subscript) = expr {
328        if let ast::Expr::Attribute(attr) = &*subscript.value {
329            // Check attribute name is "RustType"
330            if attr.attr.as_str() == "RustType" {
331                // Check module name is "pyo3_stub_gen"
332                if let ast::Expr::Name(name) = &*attr.value {
333                    if name.id.as_str() == "pyo3_stub_gen" {
334                        // Extract type name from subscript (must be a string literal)
335                        if let ast::Expr::Constant(constant) = &*subscript.slice {
336                            if let ast::Constant::Str(s) = &constant.value {
337                                return Ok(Some(s.to_string()));
338                            }
339                        }
340                        return Err(syn::Error::new(
341                            proc_macro2::Span::call_site(),
342                            "pyo3_stub_gen.RustType requires a string literal (e.g., RustType[\"MyType\"])",
343                        ));
344                    }
345                }
346            }
347        }
348    }
349    Ok(None)
350}
351
352/// Escape a string for Python syntax
353///
354/// This function properly escapes a string to be used in Python source code,
355/// using the appropriate quote character and escaping rules.
356fn escape_python_string(s: &str) -> String {
357    // Choose quote character based on content
358    let use_double_quotes = s.contains('\'') && !s.contains('"');
359    let quote_char = if use_double_quotes { '"' } else { '\'' };
360
361    let mut result = String::with_capacity(s.len() + 2);
362    result.push(quote_char);
363
364    for ch in s.chars() {
365        match ch {
366            '\\' => result.push_str("\\\\"),
367            '\'' if !use_double_quotes => result.push_str("\\'"),
368            '"' if use_double_quotes => result.push_str("\\\""),
369            '\n' => result.push_str("\\n"),
370            '\r' => result.push_str("\\r"),
371            '\t' => result.push_str("\\t"),
372            '\x00' => result.push_str("\\x00"),
373            c if c.is_ascii_control() => {
374                // Other control characters as hex escape
375                result.push_str(&format!("\\x{:02x}", c as u8));
376            }
377            c => result.push(c),
378        }
379    }
380
381    result.push(quote_char);
382    result
383}
384
385/// Convert Python AST expression to Python syntax string
386///
387/// This converts Python AST expressions like `None`, `True`, `[1, 2]` to Python string representation
388/// that can be used directly in stub files.
389fn python_ast_to_python_string(expr: &ast::Expr) -> Result<String> {
390    match expr {
391        ast::Expr::Constant(constant) => match &constant.value {
392            ast::Constant::None => Ok("None".to_string()),
393            ast::Constant::Bool(true) => Ok("True".to_string()),
394            ast::Constant::Bool(false) => Ok("False".to_string()),
395            ast::Constant::Int(i) => Ok(i.to_string()),
396            ast::Constant::Float(f) => Ok(f.to_string()),
397            ast::Constant::Str(s) => Ok(escape_python_string(s)),
398            ast::Constant::Bytes(_) => Err(syn::Error::new(
399                proc_macro2::Span::call_site(),
400                "Bytes literals are not supported as default values",
401            )),
402            ast::Constant::Ellipsis => Ok("...".to_string()),
403            _ => Err(syn::Error::new(
404                proc_macro2::Span::call_site(),
405                format!("Unsupported constant type: {:?}", constant.value),
406            )),
407        },
408        ast::Expr::List(list) => {
409            // Recursively convert list elements
410            let elements: Result<Vec<_>> =
411                list.elts.iter().map(python_ast_to_python_string).collect();
412            Ok(format!("[{}]", elements?.join(", ")))
413        }
414        ast::Expr::Tuple(tuple) => {
415            // Recursively convert tuple elements
416            let elements: Result<Vec<_>> =
417                tuple.elts.iter().map(python_ast_to_python_string).collect();
418            let elements = elements?;
419            if elements.len() == 1 {
420                // Single-element tuple needs trailing comma
421                Ok(format!("({},)", elements[0]))
422            } else {
423                Ok(format!("({})", elements.join(", ")))
424            }
425        }
426        ast::Expr::Dict(dict) => {
427            // Recursively convert dict key-value pairs
428            let mut pairs = Vec::new();
429            for (key_opt, value) in dict.keys.iter().zip(dict.values.iter()) {
430                if let Some(key) = key_opt {
431                    let key_str = python_ast_to_python_string(key)?;
432                    let value_str = python_ast_to_python_string(value)?;
433                    pairs.push(format!("{}: {}", key_str, value_str));
434                } else {
435                    // Handle **kwargs expansion in dict literals
436                    return Ok("...".to_string());
437                }
438            }
439            Ok(format!("{{{}}}", pairs.join(", ")))
440        }
441        ast::Expr::Name(name) => Ok(name.id.to_string()),
442        ast::Expr::Attribute(_) => {
443            // Handle qualified names like `MyEnum.VARIANT`
444            expr_to_type_string(expr)
445        }
446        ast::Expr::UnaryOp(unary) => {
447            // Handle negative numbers
448            if matches!(unary.op, ast::UnaryOp::USub) {
449                if let ast::Expr::Constant(constant) = &*unary.operand {
450                    match &constant.value {
451                        ast::Constant::Int(i) => Ok(format!("-{}", i)),
452                        ast::Constant::Float(f) => Ok(format!("-{}", f)),
453                        _ => Ok("...".to_string()),
454                    }
455                } else {
456                    Ok("...".to_string())
457                }
458            } else {
459                Ok("...".to_string())
460            }
461        }
462        _ => {
463            // For other expressions, use "..." placeholder
464            Ok("...".to_string())
465        }
466    }
467}
468
469/// Convert Python expression to type string
470fn expr_to_type_string(expr: &ast::Expr) -> Result<String> {
471    expr_to_type_string_inner(expr, false)
472}
473
474/// Convert Python expression to type string with context
475fn expr_to_type_string_inner(expr: &ast::Expr, in_subscript: bool) -> Result<String> {
476    // Check for pyo3_stub_gen.RustType["TypeName"] marker first
477    // If found, return just the type name (the marker will be handled elsewhere)
478    if let Some(type_name) = extract_rust_type_marker(expr)? {
479        return Ok(type_name);
480    }
481
482    Ok(match expr {
483        ast::Expr::Name(name) => name.id.to_string(),
484        ast::Expr::Attribute(attr) => {
485            format!(
486                "{}.{}",
487                expr_to_type_string_inner(&attr.value, false)?,
488                attr.attr
489            )
490        }
491        ast::Expr::Subscript(subscript) => {
492            let base = expr_to_type_string_inner(&subscript.value, false)?;
493            let slice = expr_to_type_string_inner(&subscript.slice, true)?;
494            format!("{}[{}]", base, slice)
495        }
496        ast::Expr::List(list) => {
497            let elements: Result<Vec<String>> = list
498                .elts
499                .iter()
500                .map(|e| expr_to_type_string_inner(e, false))
501                .collect();
502            format!("[{}]", elements?.join(", "))
503        }
504        ast::Expr::Tuple(tuple) => {
505            let elements: Result<Vec<String>> = tuple
506                .elts
507                .iter()
508                .map(|e| expr_to_type_string_inner(e, in_subscript))
509                .collect();
510            let elements = elements?;
511            if in_subscript {
512                // In subscript context, preserve tuple structure without extra parentheses
513                elements.join(", ")
514            } else {
515                format!("({})", elements.join(", "))
516            }
517        }
518        ast::Expr::Constant(constant) => match &constant.value {
519            ast::Constant::Int(i) => i.to_string(),
520            ast::Constant::Str(s) => format!("\"{}\"", s),
521            ast::Constant::Bool(b) => if *b { "True" } else { "False" }.to_string(),
522            ast::Constant::None => "None".to_string(),
523            ast::Constant::Ellipsis => "...".to_string(),
524            _ => "Any".to_string(),
525        },
526        ast::Expr::BinOp(binop) => {
527            // Handle union types with | operator
528            if matches!(binop.op, ast::Operator::BitOr) {
529                let left = expr_to_type_string_inner(&binop.left, false)?;
530                let right = expr_to_type_string_inner(&binop.right, false)?;
531                format!("{} | {}", left, right)
532            } else {
533                "Any".to_string()
534            }
535        }
536        _ => "Any".to_string(),
537    })
538}
539
540#[cfg(test)]
541mod tests {
542    use super::*;
543    use rustpython_parser as parser;
544
545    /// Helper to parse a Python expression and convert it to Python string
546    fn parse_and_convert(python_expr: &str) -> Result<String> {
547        let source = format!("x = {}", python_expr);
548        let parsed = parser::parse(&source, parser::Mode::Module, "<test>")
549            .map_err(|e| syn::Error::new(proc_macro2::Span::call_site(), format!("{}", e)))?;
550
551        if let parser::ast::Mod::Module(module) = parsed {
552            if let Some(parser::ast::Stmt::Assign(assign)) = module.body.first() {
553                return python_ast_to_python_string(&assign.value);
554            }
555        }
556        Err(syn::Error::new(
557            proc_macro2::Span::call_site(),
558            "Failed to parse expression",
559        ))
560    }
561
562    #[test]
563    fn test_string_basic() -> Result<()> {
564        let result = parse_and_convert(r#""hello""#)?;
565        assert_eq!(result, r#"'hello'"#);
566        Ok(())
567    }
568
569    #[test]
570    fn test_string_with_single_quote() -> Result<()> {
571        // Python: "it's"
572        let result = parse_and_convert(r#""it's""#)?;
573        // Should use double quotes when string contains single quote
574        assert_eq!(result, r#""it's""#);
575        Ok(())
576    }
577
578    #[test]
579    fn test_string_with_double_quote() -> Result<()> {
580        // Python: 'say "hi"'
581        let result = parse_and_convert(r#"'say "hi"'"#)?;
582        // Should use single quotes when string contains double quote
583        assert_eq!(result, r#"'say "hi"'"#);
584        Ok(())
585    }
586
587    #[test]
588    fn test_string_with_newline() -> Result<()> {
589        // Python source with actual newline character
590        let result = parse_and_convert(r#""line1\nline2""#)?;
591        // Should preserve newline as \n (not \\n)
592        assert_eq!(result, "'line1\\nline2'");
593        Ok(())
594    }
595
596    #[test]
597    fn test_string_with_tab() -> Result<()> {
598        let result = parse_and_convert(r#""a\tb""#)?;
599        // Should preserve tab as \t (not \\t)
600        assert_eq!(result, "'a\\tb'");
601        Ok(())
602    }
603
604    #[test]
605    fn test_string_with_backslash() -> Result<()> {
606        // Python raw string or escaped backslash
607        let result = parse_and_convert(r#"r"path\to\file""#)?;
608        // When we parse r"path\to\file", the string value contains literal backslashes
609        // When converting back to Python syntax, we must escape those backslashes
610        // So 'path\to\file' (raw) becomes 'path\\to\\file' (escaped)
611        assert_eq!(result, r"'path\\to\\file'");
612        Ok(())
613    }
614
615    #[test]
616    fn test_string_with_both_quotes() -> Result<()> {
617        // String containing both ' and "
618        let result = parse_and_convert(r#""it's \"great\"""#)?;
619        // When a string contains both ' and ", we use single quotes and escape the '
620        assert_eq!(result, r#"'it\'s "great"'"#);
621        Ok(())
622    }
623
624    #[test]
625    fn test_string_empty() -> Result<()> {
626        let result = parse_and_convert(r#""""#)?;
627        assert_eq!(result, "''");
628        Ok(())
629    }
630
631    #[test]
632    fn test_none() -> Result<()> {
633        let result = parse_and_convert("None")?;
634        assert_eq!(result, "None");
635        Ok(())
636    }
637
638    #[test]
639    fn test_bool_true() -> Result<()> {
640        let result = parse_and_convert("True")?;
641        assert_eq!(result, "True");
642        Ok(())
643    }
644
645    #[test]
646    fn test_bool_false() -> Result<()> {
647        let result = parse_and_convert("False")?;
648        assert_eq!(result, "False");
649        Ok(())
650    }
651
652    #[test]
653    fn test_int() -> Result<()> {
654        let result = parse_and_convert("42")?;
655        assert_eq!(result, "42");
656        Ok(())
657    }
658
659    #[test]
660    fn test_float() -> Result<()> {
661        let result = parse_and_convert("3.14")?;
662        assert_eq!(result, "3.14");
663        Ok(())
664    }
665
666    #[test]
667    fn test_list() -> Result<()> {
668        let result = parse_and_convert("[1, 2, 3]")?;
669        assert_eq!(result, "[1, 2, 3]");
670        Ok(())
671    }
672
673    #[test]
674    fn test_tuple() -> Result<()> {
675        let result = parse_and_convert("(1, 2)")?;
676        assert_eq!(result, "(1, 2)");
677        Ok(())
678    }
679
680    #[test]
681    fn test_tuple_single_element() -> Result<()> {
682        let result = parse_and_convert("(1,)")?;
683        assert_eq!(result, "(1,)");
684        Ok(())
685    }
686
687    #[test]
688    fn test_dict() -> Result<()> {
689        let result = parse_and_convert(r#"{"a": 1, "b": 2}"#)?;
690        assert_eq!(result, "{'a': 1, 'b': 2}");
691        Ok(())
692    }
693
694    #[test]
695    fn test_negative_int() -> Result<()> {
696        let result = parse_and_convert("-42")?;
697        assert_eq!(result, "-42");
698        Ok(())
699    }
700
701    #[test]
702    fn test_negative_float() -> Result<()> {
703        let result = parse_and_convert("-3.14")?;
704        assert_eq!(result, "-3.14");
705        Ok(())
706    }
707}