pyo3_stub_gen_derive/gen_stub/
parameter.rs

1//! Parameter intermediate representation for derive macros
2//!
3//! This module provides intermediate representations for parameters that are used
4//! during the code generation phase. These types exist only within the derive macro
5//! and are converted to `::pyo3_stub_gen::type_info::ParameterInfo` via `ToTokens`.
6
7use std::collections::HashMap;
8
9use proc_macro2::TokenStream as TokenStream2;
10use quote::{quote, ToTokens, TokenStreamExt};
11use syn::{Expr, Result};
12
13use super::{remove_lifetime, signature::SignatureArg, util::TypeOrOverride, ArgInfo, Signature};
14
15/// Represents a default value expression from either Rust or Python source
16#[derive(Debug, Clone)]
17pub(crate) enum DefaultExpr {
18    /// Rust expression that needs to be converted to Python representation at runtime
19    /// Example: `vec![1, 2]`, `Number::Float`, `10`
20    Rust(Expr),
21    /// Python expression already in Python syntax (from Python stub)
22    /// Example: `"False"`, `"[1, 2]"`, `"Number.FLOAT"`
23    Python(String),
24}
25
26/// Intermediate representation for a parameter with its kind determined
27#[derive(Debug, Clone)]
28pub(crate) struct ParameterWithKind {
29    pub(crate) arg_info: ArgInfo,
30    pub(crate) kind: ParameterKind,
31    pub(crate) default_expr: Option<DefaultExpr>,
32}
33
34impl ToTokens for ParameterWithKind {
35    fn to_tokens(&self, tokens: &mut TokenStream2) {
36        let name = &self.arg_info.name;
37        let kind = &self.kind;
38
39        let default_tokens = match &self.default_expr {
40            Some(DefaultExpr::Rust(expr)) => {
41                // Rust expression: needs runtime conversion via fmt_py_obj
42                match &self.arg_info.r#type {
43                    TypeOrOverride::RustType { r#type } => {
44                        let default = if expr.to_token_stream().to_string() == "None" {
45                            quote! { "None".to_string() }
46                        } else {
47                            quote! {
48                                let v: #r#type = #expr;
49                                ::pyo3_stub_gen::util::fmt_py_obj(v)
50                            }
51                        };
52                        quote! {
53                            ::pyo3_stub_gen::type_info::ParameterDefault::Expr({
54                                fn _fmt() -> String {
55                                    #default
56                                }
57                                _fmt
58                            })
59                        }
60                    }
61                    TypeOrOverride::OverrideType { .. } => {
62                        // For OverrideType, convert the default value expression directly to a string
63                        // since r#type may be a dummy type and we can't use it for type annotations
64                        let mut value_str = expr.to_token_stream().to_string();
65                        // Convert Rust bool literals to Python bool literals
66                        if value_str == "false" {
67                            value_str = "False".to_string();
68                        } else if value_str == "true" {
69                            value_str = "True".to_string();
70                        }
71                        quote! {
72                            ::pyo3_stub_gen::type_info::ParameterDefault::Expr({
73                                fn _fmt() -> String {
74                                    #value_str.to_string()
75                                }
76                                _fmt
77                            })
78                        }
79                    }
80                }
81            }
82            Some(DefaultExpr::Python(py_str)) => {
83                // Python expression: already in Python syntax, use directly
84                quote! {
85                    ::pyo3_stub_gen::type_info::ParameterDefault::Expr({
86                        fn _fmt() -> String {
87                            #py_str.to_string()
88                        }
89                        _fmt
90                    })
91                }
92            }
93            None => quote! { ::pyo3_stub_gen::type_info::ParameterDefault::None },
94        };
95
96        let param_info = match &self.arg_info.r#type {
97            TypeOrOverride::RustType { r#type } => {
98                quote! {
99                    ::pyo3_stub_gen::type_info::ParameterInfo {
100                        name: #name,
101                        kind: #kind,
102                        type_info: <#r#type as ::pyo3_stub_gen::PyStubType>::type_input,
103                        default: #default_tokens,
104                    }
105                }
106            }
107            TypeOrOverride::OverrideType {
108                type_repr, imports, ..
109            } => {
110                let imports = imports.iter().collect::<Vec<&String>>();
111                quote! {
112                    ::pyo3_stub_gen::type_info::ParameterInfo {
113                        name: #name,
114                        kind: #kind,
115                        type_info: || ::pyo3_stub_gen::TypeInfo {
116                            name: #type_repr.to_string(),
117                            import: ::std::collections::HashSet::from([#(#imports.into(),)*])
118                        },
119                        default: #default_tokens,
120                    }
121                }
122            }
123        };
124
125        tokens.append_all(param_info);
126    }
127}
128
129/// Parameter kind for intermediate representation in derive macro
130///
131/// This enum mirrors `::pyo3_stub_gen::type_info::ParameterKind` but exists
132/// in the derive macro context for code generation purposes.
133#[derive(Debug, Clone, Copy, PartialEq, Eq)]
134pub(crate) enum ParameterKind {
135    PositionalOnly,
136    PositionalOrKeyword,
137    KeywordOnly,
138    VarPositional,
139    VarKeyword,
140}
141
142impl ToTokens for ParameterKind {
143    fn to_tokens(&self, tokens: &mut TokenStream2) {
144        let kind_tokens = match self {
145            Self::PositionalOnly => {
146                quote! { ::pyo3_stub_gen::type_info::ParameterKind::PositionalOnly }
147            }
148            Self::PositionalOrKeyword => {
149                quote! { ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword }
150            }
151            Self::KeywordOnly => {
152                quote! { ::pyo3_stub_gen::type_info::ParameterKind::KeywordOnly }
153            }
154            Self::VarPositional => {
155                quote! { ::pyo3_stub_gen::type_info::ParameterKind::VarPositional }
156            }
157            Self::VarKeyword => {
158                quote! { ::pyo3_stub_gen::type_info::ParameterKind::VarKeyword }
159            }
160        };
161        tokens.append_all(kind_tokens);
162    }
163}
164
165/// Collection of parameters with their kinds determined
166///
167/// This newtype wraps `Vec<ParameterWithKind>` and provides constructors that
168/// parse PyO3 signature attributes and classify parameters accordingly.
169#[derive(Debug, Clone)]
170pub(crate) struct Parameters(Vec<ParameterWithKind>);
171
172impl Parameters {
173    /// Create Parameters from a Vec<ParameterWithKind>
174    ///
175    /// This is used when parameters are already classified (e.g., from Python AST).
176    pub(crate) fn from_vec(parameters: Vec<ParameterWithKind>) -> Self {
177        Self(parameters)
178    }
179
180    /// Get mutable access to internal parameters
181    pub(crate) fn iter_mut(&mut self) -> impl Iterator<Item = &mut ParameterWithKind> {
182        self.0.iter_mut()
183    }
184
185    /// Create parameters without signature attribute
186    ///
187    /// All parameters will be classified as `PositionalOrKeyword`.
188    pub(crate) fn new(args: &[ArgInfo]) -> Self {
189        let parameters = args
190            .iter()
191            .map(|arg| {
192                let mut arg_with_clean_type = arg.clone();
193                if let ArgInfo {
194                    r#type: TypeOrOverride::RustType { r#type },
195                    ..
196                } = &mut arg_with_clean_type
197                {
198                    remove_lifetime(r#type);
199                }
200                ParameterWithKind {
201                    arg_info: arg_with_clean_type,
202                    kind: ParameterKind::PositionalOrKeyword,
203                    default_expr: None,
204                }
205            })
206            .collect();
207        Self(parameters)
208    }
209
210    /// Create parameters with signature attribute
211    ///
212    /// Parses the signature to determine parameter kinds based on delimiters
213    /// (`/` for positional-only, `*` for keyword-only, etc.).
214    pub(crate) fn new_with_sig(args: &[ArgInfo], sig: &Signature) -> Result<Self> {
215        // Build a map of argument names to their type information
216        let args_map: HashMap<String, ArgInfo> = args
217            .iter()
218            .map(|arg| {
219                let mut arg_with_clean_type = arg.clone();
220                if let ArgInfo {
221                    r#type: TypeOrOverride::RustType { r#type },
222                    ..
223                } = &mut arg_with_clean_type
224                {
225                    remove_lifetime(r#type);
226                }
227                (arg.name.clone(), arg_with_clean_type)
228            })
229            .collect();
230
231        // Track parameter kinds based on position and delimiters
232        // By default, parameters are PositionalOrKeyword unless `/` or `*` appear
233        let mut positional_only = false;
234        let mut after_star = false;
235        let mut parameters: Vec<ParameterWithKind> = Vec::new();
236
237        for sig_arg in sig.args() {
238            match sig_arg {
239                SignatureArg::Slash(_) => {
240                    // `/` delimiter - mark all previous parameters as positional-only
241                    for param in &mut parameters {
242                        param.kind = ParameterKind::PositionalOnly;
243                    }
244                    positional_only = false;
245                }
246                SignatureArg::Star(_) => {
247                    // Bare `*` - parameters after this are keyword-only
248                    positional_only = false;
249                    after_star = true;
250                }
251                SignatureArg::Ident(ident) => {
252                    let name = ident.to_string();
253                    let kind = if positional_only {
254                        ParameterKind::PositionalOnly
255                    } else if after_star {
256                        ParameterKind::KeywordOnly
257                    } else {
258                        ParameterKind::PositionalOrKeyword
259                    };
260
261                    let arg_info = args_map
262                        .get(&name)
263                        .ok_or_else(|| {
264                            syn::Error::new(ident.span(), format!("cannot find argument: {}", name))
265                        })?
266                        .clone();
267
268                    parameters.push(ParameterWithKind {
269                        arg_info,
270                        kind,
271                        default_expr: None,
272                    });
273                }
274                SignatureArg::Assign(ident, _eq, value) => {
275                    let name = ident.to_string();
276                    let kind = if positional_only {
277                        ParameterKind::PositionalOnly
278                    } else if after_star {
279                        ParameterKind::KeywordOnly
280                    } else {
281                        ParameterKind::PositionalOrKeyword
282                    };
283
284                    let arg_info = args_map
285                        .get(&name)
286                        .ok_or_else(|| {
287                            syn::Error::new(ident.span(), format!("cannot find argument: {}", name))
288                        })?
289                        .clone();
290
291                    parameters.push(ParameterWithKind {
292                        arg_info,
293                        kind,
294                        default_expr: Some(DefaultExpr::Rust(value.clone())),
295                    });
296                }
297                SignatureArg::Args(_, ident) => {
298                    positional_only = false;
299                    after_star = true; // After *args, everything is keyword-only
300                    let name = ident.to_string();
301
302                    let mut arg_info = args_map
303                        .get(&name)
304                        .ok_or_else(|| {
305                            syn::Error::new(ident.span(), format!("cannot find argument: {}", name))
306                        })?
307                        .clone();
308
309                    // For VarPositional, if the type is auto-inferred from Rust (RustType),
310                    // replace it with typing.Any. If it's OverrideType, keep the user's specification.
311                    if matches!(arg_info.r#type, TypeOrOverride::RustType { .. }) {
312                        arg_info.r#type = TypeOrOverride::OverrideType {
313                            r#type: syn::parse_quote!(()), // Dummy type, won't be used
314                            type_repr: "typing.Any".to_string(),
315                            imports: ["typing".to_string()].into_iter().collect(),
316                        };
317                    }
318
319                    parameters.push(ParameterWithKind {
320                        arg_info,
321                        kind: ParameterKind::VarPositional,
322                        default_expr: None,
323                    });
324                }
325                SignatureArg::Keywords(_, _, ident) => {
326                    positional_only = false;
327                    let name = ident.to_string();
328
329                    let mut arg_info = args_map
330                        .get(&name)
331                        .ok_or_else(|| {
332                            syn::Error::new(ident.span(), format!("cannot find argument: {}", name))
333                        })?
334                        .clone();
335
336                    // For VarKeyword, if the type is auto-inferred from Rust (RustType),
337                    // replace it with typing.Any. If it's OverrideType, keep the user's specification.
338                    if matches!(arg_info.r#type, TypeOrOverride::RustType { .. }) {
339                        arg_info.r#type = TypeOrOverride::OverrideType {
340                            r#type: syn::parse_quote!(()), // Dummy type, won't be used
341                            type_repr: "typing.Any".to_string(),
342                            imports: ["typing".to_string()].into_iter().collect(),
343                        };
344                    }
345
346                    parameters.push(ParameterWithKind {
347                        arg_info,
348                        kind: ParameterKind::VarKeyword,
349                        default_expr: None,
350                    });
351                }
352            }
353        }
354
355        Ok(Self(parameters))
356    }
357}
358
359impl ToTokens for Parameters {
360    fn to_tokens(&self, tokens: &mut TokenStream2) {
361        let params = &self.0;
362        tokens.append_all(quote! { &[ #(#params),* ] })
363    }
364}