pyo3_stub_gen_derive/gen_stub/
parameter.rs

1//! Parameter intermediate representation for derive macros
2//!
3//! This module provides intermediate representations for parameters that are used
4//! during the code generation phase. These types exist only within the derive macro
5//! and are converted to `::pyo3_stub_gen::type_info::ParameterInfo` via `ToTokens`.
6
7use std::collections::HashMap;
8
9use proc_macro2::TokenStream as TokenStream2;
10use quote::{quote, ToTokens, TokenStreamExt};
11use syn::{Expr, Result};
12
13use super::{remove_lifetime, signature::SignatureArg, util::TypeOrOverride, ArgInfo, Signature};
14
15/// Represents a default value expression from either Rust or Python source
16#[derive(Debug, Clone)]
17pub(crate) enum DefaultExpr {
18    /// Rust expression that needs to be converted to Python representation at runtime
19    /// Example: `vec![1, 2]`, `Number::Float`, `10`
20    Rust(Expr),
21    /// Python expression already in Python syntax (from Python stub)
22    /// Example: `"False"`, `"[1, 2]"`, `"Number.FLOAT"`
23    Python(String),
24}
25
26/// Intermediate representation for a parameter with its kind determined
27#[derive(Debug, Clone)]
28pub(crate) struct ParameterWithKind {
29    pub(crate) arg_info: ArgInfo,
30    pub(crate) kind: ParameterKind,
31    pub(crate) default_expr: Option<DefaultExpr>,
32}
33
34impl ToTokens for ParameterWithKind {
35    fn to_tokens(&self, tokens: &mut TokenStream2) {
36        let name = &self.arg_info.name;
37        let kind = &self.kind;
38
39        let default_tokens = match &self.default_expr {
40            Some(DefaultExpr::Rust(expr)) => {
41                // Rust expression: needs runtime conversion via fmt_py_obj
42                match &self.arg_info.r#type {
43                    TypeOrOverride::RustType { r#type } => {
44                        let value = if expr.to_token_stream().to_string() == "None" {
45                            quote! { "None".to_string() }
46                        } else {
47                            quote! {
48                                {
49                                    let v: #r#type = #expr;
50                                    ::pyo3_stub_gen::util::fmt_py_obj(v)
51                                }
52                            }
53                        };
54                        // Use source_module from the type for module qualification at stub generation time
55                        quote! {
56                            ::pyo3_stub_gen::type_info::ParameterDefault::Expr {
57                                value: {
58                                    fn _fmt() -> String {
59                                        #value
60                                    }
61                                    _fmt
62                                },
63                                source_module: Some({
64                                    fn _get_module() -> Option<::pyo3_stub_gen::ModuleRef> {
65                                        <#r#type as ::pyo3_stub_gen::PyStubType>::type_output().source_module
66                                    }
67                                    _get_module
68                                }),
69                            }
70                        }
71                    }
72                    TypeOrOverride::OverrideType {
73                        rust_type_markers, ..
74                    } => {
75                        // For OverrideType, convert the default value expression directly to a string
76                        // since r#type may be a dummy type and we can't use it for type annotations
77                        let mut value_str = expr.to_token_stream().to_string();
78                        // Convert Rust bool literals to Python bool literals
79                        if value_str == "false" {
80                            value_str = "False".to_string();
81                        } else if value_str == "true" {
82                            value_str = "True".to_string();
83                        }
84
85                        // Check if the value is a literal that should not have module qualification
86                        // Literals include: None, True, False, numeric literals, string literals
87                        let is_literal = value_str == "None"
88                            || value_str == "True"
89                            || value_str == "False"
90                            || value_str.parse::<f64>().is_ok()
91                            || value_str.parse::<i64>().is_ok()
92                            || (value_str.starts_with('"') && value_str.ends_with('"'))
93                            || (value_str.starts_with('\'') && value_str.ends_with('\''));
94
95                        // Find which rust_type_marker the default expression references
96                        // Extract the first identifier from expressions like "MyEnum::Value" or "MyEnum.Value"
97                        let referenced_type = value_str.split([':', '.']).next().map(|s| s.trim());
98
99                        // Find the matching marker for this default expression
100                        let matching_marker = if is_literal {
101                            None
102                        } else {
103                            referenced_type.and_then(|ref_type| {
104                                rust_type_markers.iter().find(|marker| {
105                                    // Extract the type name from the marker (e.g., "MyEnum" from "crate::MyEnum")
106                                    let marker_name = marker.rsplit("::").next().unwrap_or(marker);
107                                    marker_name == ref_type
108                                })
109                            })
110                        };
111
112                        // Use source_module from the matching marker if found,
113                        // otherwise None to avoid using the wrong module
114                        let source_module = if let Some(marker) = matching_marker {
115                            if let Ok(marker_type) = syn::parse_str::<syn::Type>(marker) {
116                                quote! {
117                                    Some({
118                                        fn _get_module() -> Option<::pyo3_stub_gen::ModuleRef> {
119                                            <#marker_type as ::pyo3_stub_gen::PyStubType>::type_output().source_module
120                                        }
121                                        _get_module
122                                    })
123                                }
124                            } else {
125                                quote! { None }
126                            }
127                        } else {
128                            quote! { None }
129                        };
130
131                        quote! {
132                            ::pyo3_stub_gen::type_info::ParameterDefault::Expr {
133                                value: {
134                                    fn _fmt() -> String {
135                                        #value_str.to_string()
136                                    }
137                                    _fmt
138                                },
139                                source_module: #source_module,
140                            }
141                        }
142                    }
143                }
144            }
145            Some(DefaultExpr::Python(py_str)) => {
146                // Python expression: already in Python syntax, use directly
147                // No source_module since we don't know the module context from Python syntax
148                quote! {
149                    ::pyo3_stub_gen::type_info::ParameterDefault::Expr {
150                        value: {
151                            fn _fmt() -> String {
152                                #py_str.to_string()
153                            }
154                            _fmt
155                        },
156                        source_module: None,
157                    }
158                }
159            }
160            None => quote! { ::pyo3_stub_gen::type_info::ParameterDefault::None },
161        };
162
163        let param_info = match &self.arg_info.r#type {
164            TypeOrOverride::RustType { r#type } => {
165                quote! {
166                    ::pyo3_stub_gen::type_info::ParameterInfo {
167                        name: #name,
168                        kind: #kind,
169                        type_info: <#r#type as ::pyo3_stub_gen::PyStubType>::type_input,
170                        default: #default_tokens,
171                    }
172                }
173            }
174            TypeOrOverride::OverrideType {
175                type_repr,
176                imports,
177                rust_type_markers,
178                ..
179            } => {
180                let imports = imports.iter().collect::<Vec<&String>>();
181
182                // Generate code to process RustType markers
183                let (type_name_code, type_refs_code) = if rust_type_markers.is_empty() {
184                    (
185                        quote! { #type_repr.to_string() },
186                        quote! { ::std::collections::HashMap::new() },
187                    )
188                } else {
189                    // Parse rust_type_markers as syn::Type
190                    let marker_types: Vec<syn::Type> = rust_type_markers
191                        .iter()
192                        .filter_map(|s| syn::parse_str(s).ok())
193                        .collect();
194
195                    let rust_names = rust_type_markers.iter().collect::<Vec<_>>();
196
197                    (
198                        quote! {
199                            {
200                                let mut type_name = #type_repr.to_string();
201                                #(
202                                    let type_info = <#marker_types as ::pyo3_stub_gen::PyStubType>::type_input();
203                                    // Replace Rust type name with Python type name in the expression
204                                    type_name = type_name.replace(#rust_names, &type_info.name);
205                                )*
206                                type_name
207                            }
208                        },
209                        quote! {
210                            {
211                                let mut type_refs = ::std::collections::HashMap::new();
212                                #(
213                                    let type_info = <#marker_types as ::pyo3_stub_gen::PyStubType>::type_input();
214                                    // Add mapping from Python name to module
215                                    if let Some(module) = type_info.source_module {
216                                        type_refs.insert(
217                                            type_info.name.split('[').next().unwrap_or(&type_info.name).split('.').last().unwrap_or(&type_info.name).to_string(),
218                                            ::pyo3_stub_gen::TypeIdentifierRef {
219                                                module: module.into(),
220                                                import_kind: ::pyo3_stub_gen::ImportKind::Module,
221                                            }
222                                        );
223                                    }
224                                    type_refs.extend(type_info.type_refs);
225                                )*
226                                type_refs
227                            }
228                        },
229                    )
230                };
231
232                quote! {
233                    ::pyo3_stub_gen::type_info::ParameterInfo {
234                        name: #name,
235                        kind: #kind,
236                        type_info: || ::pyo3_stub_gen::TypeInfo {
237                            name: #type_name_code,
238                            source_module: None,
239                            import: ::std::collections::HashSet::from([#(#imports.into(),)*]),
240                            type_refs: #type_refs_code,
241                        },
242                        default: #default_tokens,
243                    }
244                }
245            }
246        };
247
248        tokens.append_all(param_info);
249    }
250}
251
252/// Parameter kind for intermediate representation in derive macro
253///
254/// This enum mirrors `::pyo3_stub_gen::type_info::ParameterKind` but exists
255/// in the derive macro context for code generation purposes.
256#[derive(Debug, Clone, Copy, PartialEq, Eq)]
257pub(crate) enum ParameterKind {
258    PositionalOnly,
259    PositionalOrKeyword,
260    KeywordOnly,
261    VarPositional,
262    VarKeyword,
263}
264
265impl ToTokens for ParameterKind {
266    fn to_tokens(&self, tokens: &mut TokenStream2) {
267        let kind_tokens = match self {
268            Self::PositionalOnly => {
269                quote! { ::pyo3_stub_gen::type_info::ParameterKind::PositionalOnly }
270            }
271            Self::PositionalOrKeyword => {
272                quote! { ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword }
273            }
274            Self::KeywordOnly => {
275                quote! { ::pyo3_stub_gen::type_info::ParameterKind::KeywordOnly }
276            }
277            Self::VarPositional => {
278                quote! { ::pyo3_stub_gen::type_info::ParameterKind::VarPositional }
279            }
280            Self::VarKeyword => {
281                quote! { ::pyo3_stub_gen::type_info::ParameterKind::VarKeyword }
282            }
283        };
284        tokens.append_all(kind_tokens);
285    }
286}
287
288/// Collection of parameters with their kinds determined
289///
290/// This newtype wraps `Vec<ParameterWithKind>` and provides constructors that
291/// parse PyO3 signature attributes and classify parameters accordingly.
292#[derive(Debug, Clone)]
293pub(crate) struct Parameters(Vec<ParameterWithKind>);
294
295impl Parameters {
296    /// Create Parameters from a Vec<ParameterWithKind>
297    ///
298    /// This is used when parameters are already classified (e.g., from Python AST).
299    pub(crate) fn from_vec(parameters: Vec<ParameterWithKind>) -> Self {
300        Self(parameters)
301    }
302
303    /// Get mutable access to internal parameters
304    pub(crate) fn iter_mut(&mut self) -> impl Iterator<Item = &mut ParameterWithKind> {
305        self.0.iter_mut()
306    }
307
308    /// Create parameters without signature attribute
309    ///
310    /// All parameters will be classified as `PositionalOrKeyword`.
311    pub(crate) fn new(args: &[ArgInfo]) -> Self {
312        let parameters = args
313            .iter()
314            .map(|arg| {
315                let mut arg_with_clean_type = arg.clone();
316                if let ArgInfo {
317                    r#type: TypeOrOverride::RustType { r#type },
318                    ..
319                } = &mut arg_with_clean_type
320                {
321                    remove_lifetime(r#type);
322                }
323                ParameterWithKind {
324                    arg_info: arg_with_clean_type,
325                    kind: ParameterKind::PositionalOrKeyword,
326                    default_expr: None,
327                }
328            })
329            .collect();
330        Self(parameters)
331    }
332
333    /// Create parameters with signature attribute
334    ///
335    /// Parses the signature to determine parameter kinds based on delimiters
336    /// (`/` for positional-only, `*` for keyword-only, etc.).
337    pub(crate) fn new_with_sig(args: &[ArgInfo], sig: &Signature) -> Result<Self> {
338        // Build a map of argument names to their type information
339        let args_map: HashMap<String, ArgInfo> = args
340            .iter()
341            .map(|arg| {
342                let mut arg_with_clean_type = arg.clone();
343                if let ArgInfo {
344                    r#type: TypeOrOverride::RustType { r#type },
345                    ..
346                } = &mut arg_with_clean_type
347                {
348                    remove_lifetime(r#type);
349                }
350                (arg.name.clone(), arg_with_clean_type)
351            })
352            .collect();
353
354        // Track parameter kinds based on position and delimiters
355        // By default, parameters are PositionalOrKeyword unless `/` or `*` appear
356        let mut positional_only = false;
357        let mut after_star = false;
358        let mut parameters: Vec<ParameterWithKind> = Vec::new();
359
360        for sig_arg in sig.args() {
361            match sig_arg {
362                SignatureArg::Slash(_) => {
363                    // `/` delimiter - mark all previous parameters as positional-only
364                    for param in &mut parameters {
365                        param.kind = ParameterKind::PositionalOnly;
366                    }
367                    positional_only = false;
368                }
369                SignatureArg::Star(_) => {
370                    // Bare `*` - parameters after this are keyword-only
371                    positional_only = false;
372                    after_star = true;
373                }
374                SignatureArg::Ident(ident) => {
375                    let name = ident.to_string();
376                    let kind = if positional_only {
377                        ParameterKind::PositionalOnly
378                    } else if after_star {
379                        ParameterKind::KeywordOnly
380                    } else {
381                        ParameterKind::PositionalOrKeyword
382                    };
383
384                    let arg_info = args_map
385                        .get(&name)
386                        .ok_or_else(|| {
387                            syn::Error::new(ident.span(), format!("cannot find argument: {}", name))
388                        })?
389                        .clone();
390
391                    parameters.push(ParameterWithKind {
392                        arg_info,
393                        kind,
394                        default_expr: None,
395                    });
396                }
397                SignatureArg::Assign(ident, _eq, value) => {
398                    let name = ident.to_string();
399                    let kind = if positional_only {
400                        ParameterKind::PositionalOnly
401                    } else if after_star {
402                        ParameterKind::KeywordOnly
403                    } else {
404                        ParameterKind::PositionalOrKeyword
405                    };
406
407                    let arg_info = args_map
408                        .get(&name)
409                        .ok_or_else(|| {
410                            syn::Error::new(ident.span(), format!("cannot find argument: {}", name))
411                        })?
412                        .clone();
413
414                    parameters.push(ParameterWithKind {
415                        arg_info,
416                        kind,
417                        default_expr: Some(DefaultExpr::Rust(value.clone())),
418                    });
419                }
420                SignatureArg::Args(_, ident) => {
421                    positional_only = false;
422                    after_star = true; // After *args, everything is keyword-only
423                    let name = ident.to_string();
424
425                    let mut arg_info = args_map
426                        .get(&name)
427                        .ok_or_else(|| {
428                            syn::Error::new(ident.span(), format!("cannot find argument: {}", name))
429                        })?
430                        .clone();
431
432                    // For VarPositional, if the type is auto-inferred from Rust (RustType),
433                    // replace it with typing.Any. If it's OverrideType, keep the user's specification.
434                    if matches!(arg_info.r#type, TypeOrOverride::RustType { .. }) {
435                        arg_info.r#type = TypeOrOverride::OverrideType {
436                            r#type: syn::parse_quote!(()), // Dummy type, won't be used
437                            type_repr: "typing.Any".to_string(),
438                            imports: ["typing".to_string()].into_iter().collect(),
439                            rust_type_markers: vec![],
440                        };
441                    }
442
443                    parameters.push(ParameterWithKind {
444                        arg_info,
445                        kind: ParameterKind::VarPositional,
446                        default_expr: None,
447                    });
448                }
449                SignatureArg::Keywords(_, _, ident) => {
450                    positional_only = false;
451                    let name = ident.to_string();
452
453                    let mut arg_info = args_map
454                        .get(&name)
455                        .ok_or_else(|| {
456                            syn::Error::new(ident.span(), format!("cannot find argument: {}", name))
457                        })?
458                        .clone();
459
460                    // For VarKeyword, if the type is auto-inferred from Rust (RustType),
461                    // replace it with typing.Any. If it's OverrideType, keep the user's specification.
462                    if matches!(arg_info.r#type, TypeOrOverride::RustType { .. }) {
463                        arg_info.r#type = TypeOrOverride::OverrideType {
464                            r#type: syn::parse_quote!(()), // Dummy type, won't be used
465                            type_repr: "typing.Any".to_string(),
466                            imports: ["typing".to_string()].into_iter().collect(),
467                            rust_type_markers: vec![],
468                        };
469                    }
470
471                    parameters.push(ParameterWithKind {
472                        arg_info,
473                        kind: ParameterKind::VarKeyword,
474                        default_expr: None,
475                    });
476                }
477            }
478        }
479
480        Ok(Self(parameters))
481    }
482}
483
484impl ToTokens for Parameters {
485    fn to_tokens(&self, tokens: &mut TokenStream2) {
486        let params = &self.0;
487        tokens.append_all(quote! { &[ #(#params),* ] })
488    }
489}