pyo3_stub_gen_derive/gen_stub/
parameter.rs

1//! Parameter intermediate representation for derive macros
2//!
3//! This module provides intermediate representations for parameters that are used
4//! during the code generation phase. These types exist only within the derive macro
5//! and are converted to `::pyo3_stub_gen::type_info::ParameterInfo` via `ToTokens`.
6
7use std::collections::HashMap;
8
9use proc_macro2::TokenStream as TokenStream2;
10use quote::{quote, ToTokens, TokenStreamExt};
11use syn::{Expr, Result};
12
13use super::{remove_lifetime, signature::SignatureArg, util::TypeOrOverride, ArgInfo, Signature};
14
15/// Represents a default value expression from either Rust or Python source
16#[derive(Debug, Clone)]
17pub(crate) enum DefaultExpr {
18    /// Rust expression that needs to be converted to Python representation at runtime
19    /// Example: `vec![1, 2]`, `Number::Float`, `10`
20    Rust(Expr),
21    /// Python expression already in Python syntax (from Python stub)
22    /// Example: `"False"`, `"[1, 2]"`, `"Number.FLOAT"`
23    Python(String),
24}
25
26/// Intermediate representation for a parameter with its kind determined
27#[derive(Debug, Clone)]
28pub(crate) struct ParameterWithKind {
29    pub(crate) arg_info: ArgInfo,
30    pub(crate) kind: ParameterKind,
31    pub(crate) default_expr: Option<DefaultExpr>,
32}
33
34impl ToTokens for ParameterWithKind {
35    fn to_tokens(&self, tokens: &mut TokenStream2) {
36        let name = &self.arg_info.name;
37        let kind = &self.kind;
38
39        let default_tokens = match &self.default_expr {
40            Some(DefaultExpr::Rust(expr)) => {
41                // Rust expression: needs runtime conversion via fmt_py_obj
42                match &self.arg_info.r#type {
43                    TypeOrOverride::RustType { r#type } => {
44                        let default = if expr.to_token_stream().to_string() == "None" {
45                            quote! { "None".to_string() }
46                        } else {
47                            quote! {
48                                {
49                                    let v: #r#type = #expr;
50                                    let repr = ::pyo3_stub_gen::util::fmt_py_obj(v);
51                                    let type_info = <#r#type as ::pyo3_stub_gen::PyStubType>::type_output();
52
53                                    // Check if the type is defined in another module within the current package.
54                                    // For locally-defined types: qualified name is "sub_mod.ClassA" and
55                                    // import is Module("package.sub_mod"). The import ends with the module prefix.
56                                    // For external types: qualified name is "numpy.ndarray" and
57                                    // import is Module("numpy"). The import equals the module prefix.
58                                    let should_add_prefix = if let Some(dot_pos) = type_info.name.rfind('.') {
59                                        let module_prefix = &type_info.name[..dot_pos];
60                                        type_info.import.iter().any(|imp| {
61                                            if let ::pyo3_stub_gen::ImportRef::Module(module_ref) = imp {
62                                                if let Some(module_name) = module_ref.get() {
63                                                    // Local cross-module ref: full path ends with module prefix
64                                                    // e.g., "package.sub_mod" ends with ".sub_mod"
65                                                    module_name.ends_with(&format!(".{}", module_prefix))
66                                                } else {
67                                                    false
68                                                }
69                                            } else {
70                                                false
71                                            }
72                                        })
73                                    } else {
74                                        false
75                                    };
76
77                                    if should_add_prefix {
78                                        if let Some(dot_pos) = type_info.name.rfind('.') {
79                                            let module_prefix = &type_info.name[..dot_pos];
80                                            format!("{}.{}", module_prefix, repr)
81                                        } else {
82                                            repr
83                                        }
84                                    } else {
85                                        repr
86                                    }
87                                }
88                            }
89                        };
90                        quote! {
91                            ::pyo3_stub_gen::type_info::ParameterDefault::Expr({
92                                fn _fmt() -> String {
93                                    #default
94                                }
95                                _fmt
96                            })
97                        }
98                    }
99                    TypeOrOverride::OverrideType { .. } => {
100                        // For OverrideType, convert the default value expression directly to a string
101                        // since r#type may be a dummy type and we can't use it for type annotations
102                        let mut value_str = expr.to_token_stream().to_string();
103                        // Convert Rust bool literals to Python bool literals
104                        if value_str == "false" {
105                            value_str = "False".to_string();
106                        } else if value_str == "true" {
107                            value_str = "True".to_string();
108                        }
109                        quote! {
110                            ::pyo3_stub_gen::type_info::ParameterDefault::Expr({
111                                fn _fmt() -> String {
112                                    #value_str.to_string()
113                                }
114                                _fmt
115                            })
116                        }
117                    }
118                }
119            }
120            Some(DefaultExpr::Python(py_str)) => {
121                // Python expression: already in Python syntax, use directly
122                quote! {
123                    ::pyo3_stub_gen::type_info::ParameterDefault::Expr({
124                        fn _fmt() -> String {
125                            #py_str.to_string()
126                        }
127                        _fmt
128                    })
129                }
130            }
131            None => quote! { ::pyo3_stub_gen::type_info::ParameterDefault::None },
132        };
133
134        let param_info = match &self.arg_info.r#type {
135            TypeOrOverride::RustType { r#type } => {
136                quote! {
137                    ::pyo3_stub_gen::type_info::ParameterInfo {
138                        name: #name,
139                        kind: #kind,
140                        type_info: <#r#type as ::pyo3_stub_gen::PyStubType>::type_input,
141                        default: #default_tokens,
142                    }
143                }
144            }
145            TypeOrOverride::OverrideType {
146                type_repr,
147                imports,
148                rust_type_markers,
149                ..
150            } => {
151                let imports = imports.iter().collect::<Vec<&String>>();
152
153                // Generate code to process RustType markers
154                let (type_name_code, type_refs_code) = if rust_type_markers.is_empty() {
155                    (
156                        quote! { #type_repr.to_string() },
157                        quote! { ::std::collections::HashMap::new() },
158                    )
159                } else {
160                    // Parse rust_type_markers as syn::Type
161                    let marker_types: Vec<syn::Type> = rust_type_markers
162                        .iter()
163                        .filter_map(|s| syn::parse_str(s).ok())
164                        .collect();
165
166                    let rust_names = rust_type_markers.iter().collect::<Vec<_>>();
167
168                    (
169                        quote! {
170                            {
171                                let mut type_name = #type_repr.to_string();
172                                #(
173                                    let type_info = <#marker_types as ::pyo3_stub_gen::PyStubType>::type_input();
174                                    // Replace Rust type name with Python type name in the expression
175                                    type_name = type_name.replace(#rust_names, &type_info.name);
176                                )*
177                                type_name
178                            }
179                        },
180                        quote! {
181                            {
182                                let mut type_refs = ::std::collections::HashMap::new();
183                                #(
184                                    let type_info = <#marker_types as ::pyo3_stub_gen::PyStubType>::type_input();
185                                    // Add mapping from Python name to module
186                                    if let Some(module) = type_info.source_module {
187                                        type_refs.insert(
188                                            type_info.name.split('[').next().unwrap_or(&type_info.name).split('.').last().unwrap_or(&type_info.name).to_string(),
189                                            ::pyo3_stub_gen::TypeIdentifierRef {
190                                                module: module.into(),
191                                                import_kind: ::pyo3_stub_gen::ImportKind::Module,
192                                            }
193                                        );
194                                    }
195                                    type_refs.extend(type_info.type_refs);
196                                )*
197                                type_refs
198                            }
199                        },
200                    )
201                };
202
203                quote! {
204                    ::pyo3_stub_gen::type_info::ParameterInfo {
205                        name: #name,
206                        kind: #kind,
207                        type_info: || ::pyo3_stub_gen::TypeInfo {
208                            name: #type_name_code,
209                            source_module: None,
210                            import: ::std::collections::HashSet::from([#(#imports.into(),)*]),
211                            type_refs: #type_refs_code,
212                        },
213                        default: #default_tokens,
214                    }
215                }
216            }
217        };
218
219        tokens.append_all(param_info);
220    }
221}
222
223/// Parameter kind for intermediate representation in derive macro
224///
225/// This enum mirrors `::pyo3_stub_gen::type_info::ParameterKind` but exists
226/// in the derive macro context for code generation purposes.
227#[derive(Debug, Clone, Copy, PartialEq, Eq)]
228pub(crate) enum ParameterKind {
229    PositionalOnly,
230    PositionalOrKeyword,
231    KeywordOnly,
232    VarPositional,
233    VarKeyword,
234}
235
236impl ToTokens for ParameterKind {
237    fn to_tokens(&self, tokens: &mut TokenStream2) {
238        let kind_tokens = match self {
239            Self::PositionalOnly => {
240                quote! { ::pyo3_stub_gen::type_info::ParameterKind::PositionalOnly }
241            }
242            Self::PositionalOrKeyword => {
243                quote! { ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword }
244            }
245            Self::KeywordOnly => {
246                quote! { ::pyo3_stub_gen::type_info::ParameterKind::KeywordOnly }
247            }
248            Self::VarPositional => {
249                quote! { ::pyo3_stub_gen::type_info::ParameterKind::VarPositional }
250            }
251            Self::VarKeyword => {
252                quote! { ::pyo3_stub_gen::type_info::ParameterKind::VarKeyword }
253            }
254        };
255        tokens.append_all(kind_tokens);
256    }
257}
258
259/// Collection of parameters with their kinds determined
260///
261/// This newtype wraps `Vec<ParameterWithKind>` and provides constructors that
262/// parse PyO3 signature attributes and classify parameters accordingly.
263#[derive(Debug, Clone)]
264pub(crate) struct Parameters(Vec<ParameterWithKind>);
265
266impl Parameters {
267    /// Create Parameters from a Vec<ParameterWithKind>
268    ///
269    /// This is used when parameters are already classified (e.g., from Python AST).
270    pub(crate) fn from_vec(parameters: Vec<ParameterWithKind>) -> Self {
271        Self(parameters)
272    }
273
274    /// Get mutable access to internal parameters
275    pub(crate) fn iter_mut(&mut self) -> impl Iterator<Item = &mut ParameterWithKind> {
276        self.0.iter_mut()
277    }
278
279    /// Create parameters without signature attribute
280    ///
281    /// All parameters will be classified as `PositionalOrKeyword`.
282    pub(crate) fn new(args: &[ArgInfo]) -> Self {
283        let parameters = args
284            .iter()
285            .map(|arg| {
286                let mut arg_with_clean_type = arg.clone();
287                if let ArgInfo {
288                    r#type: TypeOrOverride::RustType { r#type },
289                    ..
290                } = &mut arg_with_clean_type
291                {
292                    remove_lifetime(r#type);
293                }
294                ParameterWithKind {
295                    arg_info: arg_with_clean_type,
296                    kind: ParameterKind::PositionalOrKeyword,
297                    default_expr: None,
298                }
299            })
300            .collect();
301        Self(parameters)
302    }
303
304    /// Create parameters with signature attribute
305    ///
306    /// Parses the signature to determine parameter kinds based on delimiters
307    /// (`/` for positional-only, `*` for keyword-only, etc.).
308    pub(crate) fn new_with_sig(args: &[ArgInfo], sig: &Signature) -> Result<Self> {
309        // Build a map of argument names to their type information
310        let args_map: HashMap<String, ArgInfo> = args
311            .iter()
312            .map(|arg| {
313                let mut arg_with_clean_type = arg.clone();
314                if let ArgInfo {
315                    r#type: TypeOrOverride::RustType { r#type },
316                    ..
317                } = &mut arg_with_clean_type
318                {
319                    remove_lifetime(r#type);
320                }
321                (arg.name.clone(), arg_with_clean_type)
322            })
323            .collect();
324
325        // Track parameter kinds based on position and delimiters
326        // By default, parameters are PositionalOrKeyword unless `/` or `*` appear
327        let mut positional_only = false;
328        let mut after_star = false;
329        let mut parameters: Vec<ParameterWithKind> = Vec::new();
330
331        for sig_arg in sig.args() {
332            match sig_arg {
333                SignatureArg::Slash(_) => {
334                    // `/` delimiter - mark all previous parameters as positional-only
335                    for param in &mut parameters {
336                        param.kind = ParameterKind::PositionalOnly;
337                    }
338                    positional_only = false;
339                }
340                SignatureArg::Star(_) => {
341                    // Bare `*` - parameters after this are keyword-only
342                    positional_only = false;
343                    after_star = true;
344                }
345                SignatureArg::Ident(ident) => {
346                    let name = ident.to_string();
347                    let kind = if positional_only {
348                        ParameterKind::PositionalOnly
349                    } else if after_star {
350                        ParameterKind::KeywordOnly
351                    } else {
352                        ParameterKind::PositionalOrKeyword
353                    };
354
355                    let arg_info = args_map
356                        .get(&name)
357                        .ok_or_else(|| {
358                            syn::Error::new(ident.span(), format!("cannot find argument: {}", name))
359                        })?
360                        .clone();
361
362                    parameters.push(ParameterWithKind {
363                        arg_info,
364                        kind,
365                        default_expr: None,
366                    });
367                }
368                SignatureArg::Assign(ident, _eq, value) => {
369                    let name = ident.to_string();
370                    let kind = if positional_only {
371                        ParameterKind::PositionalOnly
372                    } else if after_star {
373                        ParameterKind::KeywordOnly
374                    } else {
375                        ParameterKind::PositionalOrKeyword
376                    };
377
378                    let arg_info = args_map
379                        .get(&name)
380                        .ok_or_else(|| {
381                            syn::Error::new(ident.span(), format!("cannot find argument: {}", name))
382                        })?
383                        .clone();
384
385                    parameters.push(ParameterWithKind {
386                        arg_info,
387                        kind,
388                        default_expr: Some(DefaultExpr::Rust(value.clone())),
389                    });
390                }
391                SignatureArg::Args(_, ident) => {
392                    positional_only = false;
393                    after_star = true; // After *args, everything is keyword-only
394                    let name = ident.to_string();
395
396                    let mut arg_info = args_map
397                        .get(&name)
398                        .ok_or_else(|| {
399                            syn::Error::new(ident.span(), format!("cannot find argument: {}", name))
400                        })?
401                        .clone();
402
403                    // For VarPositional, if the type is auto-inferred from Rust (RustType),
404                    // replace it with typing.Any. If it's OverrideType, keep the user's specification.
405                    if matches!(arg_info.r#type, TypeOrOverride::RustType { .. }) {
406                        arg_info.r#type = TypeOrOverride::OverrideType {
407                            r#type: syn::parse_quote!(()), // Dummy type, won't be used
408                            type_repr: "typing.Any".to_string(),
409                            imports: ["typing".to_string()].into_iter().collect(),
410                            rust_type_markers: vec![],
411                        };
412                    }
413
414                    parameters.push(ParameterWithKind {
415                        arg_info,
416                        kind: ParameterKind::VarPositional,
417                        default_expr: None,
418                    });
419                }
420                SignatureArg::Keywords(_, _, ident) => {
421                    positional_only = false;
422                    let name = ident.to_string();
423
424                    let mut arg_info = args_map
425                        .get(&name)
426                        .ok_or_else(|| {
427                            syn::Error::new(ident.span(), format!("cannot find argument: {}", name))
428                        })?
429                        .clone();
430
431                    // For VarKeyword, if the type is auto-inferred from Rust (RustType),
432                    // replace it with typing.Any. If it's OverrideType, keep the user's specification.
433                    if matches!(arg_info.r#type, TypeOrOverride::RustType { .. }) {
434                        arg_info.r#type = TypeOrOverride::OverrideType {
435                            r#type: syn::parse_quote!(()), // Dummy type, won't be used
436                            type_repr: "typing.Any".to_string(),
437                            imports: ["typing".to_string()].into_iter().collect(),
438                            rust_type_markers: vec![],
439                        };
440                    }
441
442                    parameters.push(ParameterWithKind {
443                        arg_info,
444                        kind: ParameterKind::VarKeyword,
445                        default_expr: None,
446                    });
447                }
448            }
449        }
450
451        Ok(Self(parameters))
452    }
453}
454
455impl ToTokens for Parameters {
456    fn to_tokens(&self, tokens: &mut TokenStream2) {
457        let params = &self.0;
458        tokens.append_all(quote! { &[ #(#params),* ] })
459    }
460}