pyo3_stub_gen_derive/gen_stub/
parse_python.rs

1//! Parse Python stub syntax and generate PyFunctionInfo and MethodInfo
2//!
3//! This module provides functionality to parse Python stub syntax (type hints)
4//! and convert them into Rust metadata structures for stub generation.
5
6mod pyfunction;
7mod pymethods;
8
9pub use pyfunction::{
10    parse_gen_function_from_python_input, parse_python_function_stub, parse_python_overload_stubs,
11    GenFunctionFromPythonInput,
12};
13pub use pymethods::parse_python_methods_stub;
14
15use indexmap::IndexSet;
16use rustpython_parser::ast;
17use syn::{Result, Type};
18
19use super::{
20    arg::ArgInfo,
21    attr::DeprecatedInfo,
22    parameter::DefaultExpr,
23    parameter::{ParameterKind, ParameterWithKind, Parameters},
24    util::TypeOrOverride,
25};
26
27/// Remove common leading whitespace from all lines (similar to Python's textwrap.dedent)
28fn dedent(text: &str) -> String {
29    let lines: Vec<&str> = text.lines().collect();
30
31    // Find the minimum indentation (ignoring empty lines)
32    let min_indent = lines
33        .iter()
34        .filter(|line| !line.trim().is_empty())
35        .map(|line| line.len() - line.trim_start().len())
36        .min()
37        .unwrap_or(0);
38
39    // Remove the common indentation from each line
40    lines
41        .iter()
42        .map(|line| {
43            if line.len() >= min_indent {
44                &line[min_indent..]
45            } else {
46                line
47            }
48        })
49        .collect::<Vec<_>>()
50        .join("\n")
51}
52
53/// Extract docstring from function definition
54fn extract_docstring(func_def: &ast::StmtFunctionDef) -> String {
55    if let Some(ast::Stmt::Expr(expr_stmt)) = func_def.body.first() {
56        if let ast::Expr::Constant(constant) = &*expr_stmt.value {
57            if let ast::Constant::Str(s) = &constant.value {
58                return s.to_string();
59            }
60        }
61    }
62    String::new()
63}
64
65/// Extract deprecated decorator information if present
66fn extract_deprecated_from_decorators(decorators: &[ast::Expr]) -> Option<DeprecatedInfo> {
67    for decorator in decorators {
68        // Check for @deprecated or @deprecated("message")
69        match decorator {
70            ast::Expr::Name(name) if name.id.as_str() == "deprecated" => {
71                return Some(DeprecatedInfo {
72                    since: None,
73                    note: None,
74                });
75            }
76            ast::Expr::Call(call) => {
77                if let ast::Expr::Name(name) = &*call.func {
78                    if name.id.as_str() == "deprecated" {
79                        // Try to extract the message from the first argument
80                        let note = call.args.first().and_then(|arg| match arg {
81                            ast::Expr::Constant(constant) => match &constant.value {
82                                ast::Constant::Str(s) => Some(s.to_string()),
83                                _ => None,
84                            },
85                            _ => None,
86                        });
87                        return Some(DeprecatedInfo { since: None, note });
88                    }
89                }
90            }
91            _ => {}
92        }
93    }
94    None
95}
96
97/// Check if decorator list contains @overload decorator
98fn has_overload_decorator(decorator_list: &[ast::Expr]) -> bool {
99    decorator_list.iter().any(|decorator| {
100        match decorator {
101            ast::Expr::Name(name) => name.id.as_str() == "overload",
102            ast::Expr::Attribute(attr) => {
103                // Handle typing.overload or t.overload
104                attr.attr.as_str() == "overload"
105            }
106            _ => false,
107        }
108    })
109}
110
111/// Build Parameters directly from Python AST Arguments
112///
113/// This function constructs Parameters with proper ParameterKind classification
114/// based on Python's argument structure (positional-only, keyword-only, varargs, etc.)
115pub(super) fn build_parameters_from_ast(
116    args: &ast::Arguments,
117    imports: &[String],
118) -> Result<Parameters> {
119    let dummy_type: Type = syn::parse_str("()").unwrap();
120    let mut parameters = Vec::new();
121
122    // Helper to process a single argument with default value
123    let process_arg_with_default =
124        |arg: &ast::ArgWithDefault, kind: ParameterKind| -> Result<Option<ParameterWithKind>> {
125            let arg_name = arg.def.arg.to_string();
126
127            // Skip 'self' and 'cls' arguments (they are added automatically in generation)
128            if arg_name == "self" || arg_name == "cls" {
129                return Ok(None);
130            }
131
132            let type_override = if let Some(annotation) = &arg.def.annotation {
133                type_annotation_to_type_override(annotation, imports, dummy_type.clone())?
134            } else {
135                // No type annotation - use Any
136                TypeOrOverride::OverrideType {
137                    r#type: dummy_type.clone(),
138                    type_repr: "typing.Any".to_string(),
139                    imports: IndexSet::from(["typing".to_string()]),
140                }
141            };
142
143            let arg_info = ArgInfo {
144                name: arg_name,
145                r#type: type_override,
146            };
147
148            // Convert default value from Python AST to Python string
149            let default_expr = if let Some(default) = &arg.default {
150                Some(DefaultExpr::Python(python_ast_to_python_string(default)?))
151            } else {
152                None
153            };
154
155            Ok(Some(ParameterWithKind {
156                arg_info,
157                kind,
158                default_expr,
159            }))
160        };
161
162    // Helper to process vararg or kwarg (ast::Arg, not ast::ArgWithDefault)
163    let process_var_arg = |arg: &ast::Arg, kind: ParameterKind| -> Result<ParameterWithKind> {
164        let arg_name = arg.arg.to_string();
165
166        let type_override = if let Some(annotation) = &arg.annotation {
167            type_annotation_to_type_override(annotation, imports, dummy_type.clone())?
168        } else {
169            // No type annotation - use Any
170            TypeOrOverride::OverrideType {
171                r#type: dummy_type.clone(),
172                type_repr: "typing.Any".to_string(),
173                imports: IndexSet::from(["typing".to_string()]),
174            }
175        };
176
177        let arg_info = ArgInfo {
178            name: arg_name,
179            r#type: type_override,
180        };
181
182        Ok(ParameterWithKind {
183            arg_info,
184            kind,
185            default_expr: None,
186        })
187    };
188
189    // Process positional-only arguments (before /)
190    for arg in &args.posonlyargs {
191        if let Some(param) = process_arg_with_default(arg, ParameterKind::PositionalOnly)? {
192            parameters.push(param);
193        }
194    }
195
196    // Process regular positional/keyword arguments
197    for arg in &args.args {
198        if let Some(param) = process_arg_with_default(arg, ParameterKind::PositionalOrKeyword)? {
199            parameters.push(param);
200        }
201    }
202
203    // Process *args (vararg)
204    if let Some(vararg) = &args.vararg {
205        parameters.push(process_var_arg(vararg, ParameterKind::VarPositional)?);
206    }
207
208    // Process keyword-only arguments (after *)
209    for arg in &args.kwonlyargs {
210        if let Some(param) = process_arg_with_default(arg, ParameterKind::KeywordOnly)? {
211            parameters.push(param);
212        }
213    }
214
215    // Process **kwargs (kwarg)
216    if let Some(kwarg) = &args.kwarg {
217        parameters.push(process_var_arg(kwarg, ParameterKind::VarKeyword)?);
218    }
219
220    Ok(Parameters::from_vec(parameters))
221}
222
223/// Extract return type from function definition
224fn extract_return_type(
225    returns: &Option<Box<ast::Expr>>,
226    imports: &[String],
227) -> Result<Option<TypeOrOverride>> {
228    // Dummy type for TypeOrOverride (not used in ToTokens for OverrideType)
229    let dummy_type: Type = syn::parse_str("()").unwrap();
230
231    if let Some(return_annotation) = returns {
232        Ok(Some(type_annotation_to_type_override(
233            return_annotation,
234            imports,
235            dummy_type,
236        )?))
237    } else {
238        // No return type annotation - use None (void)
239        Ok(None)
240    }
241}
242
243/// Convert Python type annotation to TypeOrOverride
244fn type_annotation_to_type_override(
245    expr: &ast::Expr,
246    imports: &[String],
247    dummy_type: Type,
248) -> Result<TypeOrOverride> {
249    // Check for pyo3_stub_gen.RustType["TypeName"] marker
250    if let Some(type_name) = extract_rust_type_marker(expr)? {
251        let rust_type: Type = syn::parse_str(&type_name).map_err(|e| {
252            syn::Error::new(
253                proc_macro2::Span::call_site(),
254                format!("Failed to parse Rust type '{}': {}", type_name, e),
255            )
256        })?;
257        return Ok(TypeOrOverride::RustType { r#type: rust_type });
258    }
259
260    let type_str = expr_to_type_string(expr)?;
261
262    // Convert imports to IndexSet
263    let import_set: IndexSet<String> = imports.iter().map(|s| s.to_string()).collect();
264
265    Ok(TypeOrOverride::OverrideType {
266        r#type: dummy_type,
267        type_repr: type_str,
268        imports: import_set,
269    })
270}
271
272/// Extract type name from pyo3_stub_gen.RustType["TypeName"]
273///
274/// Returns Some(type_name) if the expression matches the pattern, None otherwise.
275/// Returns an error if the pattern matches but the type name is not a string literal.
276fn extract_rust_type_marker(expr: &ast::Expr) -> Result<Option<String>> {
277    // Match pattern: pyo3_stub_gen.RustType[...]
278    if let ast::Expr::Subscript(subscript) = expr {
279        if let ast::Expr::Attribute(attr) = &*subscript.value {
280            // Check attribute name is "RustType"
281            if attr.attr.as_str() == "RustType" {
282                // Check module name is "pyo3_stub_gen"
283                if let ast::Expr::Name(name) = &*attr.value {
284                    if name.id.as_str() == "pyo3_stub_gen" {
285                        // Extract type name from subscript (must be a string literal)
286                        if let ast::Expr::Constant(constant) = &*subscript.slice {
287                            if let ast::Constant::Str(s) = &constant.value {
288                                return Ok(Some(s.to_string()));
289                            }
290                        }
291                        return Err(syn::Error::new(
292                            proc_macro2::Span::call_site(),
293                            "pyo3_stub_gen.RustType requires a string literal (e.g., RustType[\"MyType\"])",
294                        ));
295                    }
296                }
297            }
298        }
299    }
300    Ok(None)
301}
302
303/// Escape a string for Python syntax
304///
305/// This function properly escapes a string to be used in Python source code,
306/// using the appropriate quote character and escaping rules.
307fn escape_python_string(s: &str) -> String {
308    // Choose quote character based on content
309    let use_double_quotes = s.contains('\'') && !s.contains('"');
310    let quote_char = if use_double_quotes { '"' } else { '\'' };
311
312    let mut result = String::with_capacity(s.len() + 2);
313    result.push(quote_char);
314
315    for ch in s.chars() {
316        match ch {
317            '\\' => result.push_str("\\\\"),
318            '\'' if !use_double_quotes => result.push_str("\\'"),
319            '"' if use_double_quotes => result.push_str("\\\""),
320            '\n' => result.push_str("\\n"),
321            '\r' => result.push_str("\\r"),
322            '\t' => result.push_str("\\t"),
323            '\x00' => result.push_str("\\x00"),
324            c if c.is_ascii_control() => {
325                // Other control characters as hex escape
326                result.push_str(&format!("\\x{:02x}", c as u8));
327            }
328            c => result.push(c),
329        }
330    }
331
332    result.push(quote_char);
333    result
334}
335
336/// Convert Python AST expression to Python syntax string
337///
338/// This converts Python AST expressions like `None`, `True`, `[1, 2]` to Python string representation
339/// that can be used directly in stub files.
340fn python_ast_to_python_string(expr: &ast::Expr) -> Result<String> {
341    match expr {
342        ast::Expr::Constant(constant) => match &constant.value {
343            ast::Constant::None => Ok("None".to_string()),
344            ast::Constant::Bool(true) => Ok("True".to_string()),
345            ast::Constant::Bool(false) => Ok("False".to_string()),
346            ast::Constant::Int(i) => Ok(i.to_string()),
347            ast::Constant::Float(f) => Ok(f.to_string()),
348            ast::Constant::Str(s) => Ok(escape_python_string(s)),
349            ast::Constant::Bytes(_) => Err(syn::Error::new(
350                proc_macro2::Span::call_site(),
351                "Bytes literals are not supported as default values",
352            )),
353            ast::Constant::Ellipsis => Ok("...".to_string()),
354            _ => Err(syn::Error::new(
355                proc_macro2::Span::call_site(),
356                format!("Unsupported constant type: {:?}", constant.value),
357            )),
358        },
359        ast::Expr::List(list) => {
360            // Recursively convert list elements
361            let elements: Result<Vec<_>> =
362                list.elts.iter().map(python_ast_to_python_string).collect();
363            Ok(format!("[{}]", elements?.join(", ")))
364        }
365        ast::Expr::Tuple(tuple) => {
366            // Recursively convert tuple elements
367            let elements: Result<Vec<_>> =
368                tuple.elts.iter().map(python_ast_to_python_string).collect();
369            let elements = elements?;
370            if elements.len() == 1 {
371                // Single-element tuple needs trailing comma
372                Ok(format!("({},)", elements[0]))
373            } else {
374                Ok(format!("({})", elements.join(", ")))
375            }
376        }
377        ast::Expr::Dict(dict) => {
378            // Recursively convert dict key-value pairs
379            let mut pairs = Vec::new();
380            for (key_opt, value) in dict.keys.iter().zip(dict.values.iter()) {
381                if let Some(key) = key_opt {
382                    let key_str = python_ast_to_python_string(key)?;
383                    let value_str = python_ast_to_python_string(value)?;
384                    pairs.push(format!("{}: {}", key_str, value_str));
385                } else {
386                    // Handle **kwargs expansion in dict literals
387                    return Ok("...".to_string());
388                }
389            }
390            Ok(format!("{{{}}}", pairs.join(", ")))
391        }
392        ast::Expr::Name(name) => Ok(name.id.to_string()),
393        ast::Expr::Attribute(_) => {
394            // Handle qualified names like `MyEnum.VARIANT`
395            expr_to_type_string(expr)
396        }
397        ast::Expr::UnaryOp(unary) => {
398            // Handle negative numbers
399            if matches!(unary.op, ast::UnaryOp::USub) {
400                if let ast::Expr::Constant(constant) = &*unary.operand {
401                    match &constant.value {
402                        ast::Constant::Int(i) => Ok(format!("-{}", i)),
403                        ast::Constant::Float(f) => Ok(format!("-{}", f)),
404                        _ => Ok("...".to_string()),
405                    }
406                } else {
407                    Ok("...".to_string())
408                }
409            } else {
410                Ok("...".to_string())
411            }
412        }
413        _ => {
414            // For other expressions, use "..." placeholder
415            Ok("...".to_string())
416        }
417    }
418}
419
420/// Convert Python expression to type string
421fn expr_to_type_string(expr: &ast::Expr) -> Result<String> {
422    expr_to_type_string_inner(expr, false)
423}
424
425/// Convert Python expression to type string with context
426fn expr_to_type_string_inner(expr: &ast::Expr, in_subscript: bool) -> Result<String> {
427    // Check for pyo3_stub_gen.RustType["TypeName"] marker first
428    // If found, return just the type name (the marker will be handled elsewhere)
429    if let Some(type_name) = extract_rust_type_marker(expr)? {
430        return Ok(type_name);
431    }
432
433    Ok(match expr {
434        ast::Expr::Name(name) => name.id.to_string(),
435        ast::Expr::Attribute(attr) => {
436            format!(
437                "{}.{}",
438                expr_to_type_string_inner(&attr.value, false)?,
439                attr.attr
440            )
441        }
442        ast::Expr::Subscript(subscript) => {
443            let base = expr_to_type_string_inner(&subscript.value, false)?;
444            let slice = expr_to_type_string_inner(&subscript.slice, true)?;
445            format!("{}[{}]", base, slice)
446        }
447        ast::Expr::List(list) => {
448            let elements: Result<Vec<String>> = list
449                .elts
450                .iter()
451                .map(|e| expr_to_type_string_inner(e, false))
452                .collect();
453            format!("[{}]", elements?.join(", "))
454        }
455        ast::Expr::Tuple(tuple) => {
456            let elements: Result<Vec<String>> = tuple
457                .elts
458                .iter()
459                .map(|e| expr_to_type_string_inner(e, in_subscript))
460                .collect();
461            let elements = elements?;
462            if in_subscript {
463                // In subscript context, preserve tuple structure without extra parentheses
464                elements.join(", ")
465            } else {
466                format!("({})", elements.join(", "))
467            }
468        }
469        ast::Expr::Constant(constant) => match &constant.value {
470            ast::Constant::Int(i) => i.to_string(),
471            ast::Constant::Str(s) => format!("\"{}\"", s),
472            ast::Constant::Bool(b) => if *b { "True" } else { "False" }.to_string(),
473            ast::Constant::None => "None".to_string(),
474            ast::Constant::Ellipsis => "...".to_string(),
475            _ => "Any".to_string(),
476        },
477        _ => "Any".to_string(),
478    })
479}
480
481#[cfg(test)]
482mod tests {
483    use super::*;
484    use rustpython_parser as parser;
485
486    /// Helper to parse a Python expression and convert it to Python string
487    fn parse_and_convert(python_expr: &str) -> Result<String> {
488        let source = format!("x = {}", python_expr);
489        let parsed = parser::parse(&source, parser::Mode::Module, "<test>")
490            .map_err(|e| syn::Error::new(proc_macro2::Span::call_site(), format!("{}", e)))?;
491
492        if let parser::ast::Mod::Module(module) = parsed {
493            if let Some(parser::ast::Stmt::Assign(assign)) = module.body.first() {
494                return python_ast_to_python_string(&assign.value);
495            }
496        }
497        Err(syn::Error::new(
498            proc_macro2::Span::call_site(),
499            "Failed to parse expression",
500        ))
501    }
502
503    #[test]
504    fn test_string_basic() -> Result<()> {
505        let result = parse_and_convert(r#""hello""#)?;
506        assert_eq!(result, r#"'hello'"#);
507        Ok(())
508    }
509
510    #[test]
511    fn test_string_with_single_quote() -> Result<()> {
512        // Python: "it's"
513        let result = parse_and_convert(r#""it's""#)?;
514        // Should use double quotes when string contains single quote
515        assert_eq!(result, r#""it's""#);
516        Ok(())
517    }
518
519    #[test]
520    fn test_string_with_double_quote() -> Result<()> {
521        // Python: 'say "hi"'
522        let result = parse_and_convert(r#"'say "hi"'"#)?;
523        // Should use single quotes when string contains double quote
524        assert_eq!(result, r#"'say "hi"'"#);
525        Ok(())
526    }
527
528    #[test]
529    fn test_string_with_newline() -> Result<()> {
530        // Python source with actual newline character
531        let result = parse_and_convert(r#""line1\nline2""#)?;
532        // Should preserve newline as \n (not \\n)
533        assert_eq!(result, "'line1\\nline2'");
534        Ok(())
535    }
536
537    #[test]
538    fn test_string_with_tab() -> Result<()> {
539        let result = parse_and_convert(r#""a\tb""#)?;
540        // Should preserve tab as \t (not \\t)
541        assert_eq!(result, "'a\\tb'");
542        Ok(())
543    }
544
545    #[test]
546    fn test_string_with_backslash() -> Result<()> {
547        // Python raw string or escaped backslash
548        let result = parse_and_convert(r#"r"path\to\file""#)?;
549        // When we parse r"path\to\file", the string value contains literal backslashes
550        // When converting back to Python syntax, we must escape those backslashes
551        // So 'path\to\file' (raw) becomes 'path\\to\\file' (escaped)
552        assert_eq!(result, r"'path\\to\\file'");
553        Ok(())
554    }
555
556    #[test]
557    fn test_string_with_both_quotes() -> Result<()> {
558        // String containing both ' and "
559        let result = parse_and_convert(r#""it's \"great\"""#)?;
560        // When a string contains both ' and ", we use single quotes and escape the '
561        assert_eq!(result, r#"'it\'s "great"'"#);
562        Ok(())
563    }
564
565    #[test]
566    fn test_string_empty() -> Result<()> {
567        let result = parse_and_convert(r#""""#)?;
568        assert_eq!(result, "''");
569        Ok(())
570    }
571
572    #[test]
573    fn test_none() -> Result<()> {
574        let result = parse_and_convert("None")?;
575        assert_eq!(result, "None");
576        Ok(())
577    }
578
579    #[test]
580    fn test_bool_true() -> Result<()> {
581        let result = parse_and_convert("True")?;
582        assert_eq!(result, "True");
583        Ok(())
584    }
585
586    #[test]
587    fn test_bool_false() -> Result<()> {
588        let result = parse_and_convert("False")?;
589        assert_eq!(result, "False");
590        Ok(())
591    }
592
593    #[test]
594    fn test_int() -> Result<()> {
595        let result = parse_and_convert("42")?;
596        assert_eq!(result, "42");
597        Ok(())
598    }
599
600    #[test]
601    fn test_float() -> Result<()> {
602        let result = parse_and_convert("3.14")?;
603        assert_eq!(result, "3.14");
604        Ok(())
605    }
606
607    #[test]
608    fn test_list() -> Result<()> {
609        let result = parse_and_convert("[1, 2, 3]")?;
610        assert_eq!(result, "[1, 2, 3]");
611        Ok(())
612    }
613
614    #[test]
615    fn test_tuple() -> Result<()> {
616        let result = parse_and_convert("(1, 2)")?;
617        assert_eq!(result, "(1, 2)");
618        Ok(())
619    }
620
621    #[test]
622    fn test_tuple_single_element() -> Result<()> {
623        let result = parse_and_convert("(1,)")?;
624        assert_eq!(result, "(1,)");
625        Ok(())
626    }
627
628    #[test]
629    fn test_dict() -> Result<()> {
630        let result = parse_and_convert(r#"{"a": 1, "b": 2}"#)?;
631        assert_eq!(result, "{'a': 1, 'b': 2}");
632        Ok(())
633    }
634
635    #[test]
636    fn test_negative_int() -> Result<()> {
637        let result = parse_and_convert("-42")?;
638        assert_eq!(result, "-42");
639        Ok(())
640    }
641
642    #[test]
643    fn test_negative_float() -> Result<()> {
644        let result = parse_and_convert("-3.14")?;
645        assert_eq!(result, "-3.14");
646        Ok(())
647    }
648}