pyo3_stub_gen_derive/gen_stub/parameter.rs
1//! Parameter intermediate representation for derive macros
2//!
3//! This module provides intermediate representations for parameters that are used
4//! during the code generation phase. These types exist only within the derive macro
5//! and are converted to `::pyo3_stub_gen::type_info::ParameterInfo` via `ToTokens`.
6
7use std::collections::HashMap;
8
9use proc_macro2::TokenStream as TokenStream2;
10use quote::{quote, ToTokens, TokenStreamExt};
11use syn::{Expr, Result};
12
13use super::{remove_lifetime, signature::SignatureArg, util::TypeOrOverride, ArgInfo, Signature};
14
15/// Represents a default value expression from either Rust or Python source
16#[derive(Debug, Clone)]
17pub(crate) enum DefaultExpr {
18 /// Rust expression that needs to be converted to Python representation at runtime
19 /// Example: `vec![1, 2]`, `Number::Float`, `10`
20 Rust(Expr),
21 /// Python expression already in Python syntax (from Python stub)
22 /// Example: `"False"`, `"[1, 2]"`, `"Number.FLOAT"`
23 Python(String),
24}
25
26/// Intermediate representation for a parameter with its kind determined
27#[derive(Debug, Clone)]
28pub(crate) struct ParameterWithKind {
29 pub(crate) arg_info: ArgInfo,
30 pub(crate) kind: ParameterKind,
31 pub(crate) default_expr: Option<DefaultExpr>,
32}
33
34impl ToTokens for ParameterWithKind {
35 fn to_tokens(&self, tokens: &mut TokenStream2) {
36 let name = &self.arg_info.name;
37 let kind = &self.kind;
38
39 let default_tokens = match &self.default_expr {
40 Some(DefaultExpr::Rust(expr)) => {
41 // Rust expression: needs runtime conversion via fmt_py_obj
42 match &self.arg_info.r#type {
43 TypeOrOverride::RustType { r#type } => {
44 let value = if expr.to_token_stream().to_string() == "None" {
45 quote! { "None".to_string() }
46 } else {
47 quote! {
48 {
49 let v: #r#type = #expr;
50 ::pyo3_stub_gen::util::fmt_py_obj(v)
51 }
52 }
53 };
54 // Use source_module from the type for module qualification at stub generation time
55 quote! {
56 ::pyo3_stub_gen::type_info::ParameterDefault::Expr {
57 value: {
58 fn _fmt() -> String {
59 #value
60 }
61 _fmt
62 },
63 source_module: Some({
64 fn _get_module() -> Option<::pyo3_stub_gen::ModuleRef> {
65 <#r#type as ::pyo3_stub_gen::PyStubType>::type_output().source_module
66 }
67 _get_module
68 }),
69 }
70 }
71 }
72 TypeOrOverride::OverrideType {
73 rust_type_markers, ..
74 } => {
75 // For OverrideType, convert the default value expression directly to a string
76 // since r#type may be a dummy type and we can't use it for type annotations
77 let mut value_str = expr.to_token_stream().to_string();
78 // Convert Rust bool literals to Python bool literals
79 if value_str == "false" {
80 value_str = "False".to_string();
81 } else if value_str == "true" {
82 value_str = "True".to_string();
83 }
84
85 // Check if the value is a literal that should not have module qualification
86 // Literals include: None, True, False, numeric literals, string literals
87 let is_literal = value_str == "None"
88 || value_str == "True"
89 || value_str == "False"
90 || value_str.parse::<f64>().is_ok()
91 || value_str.parse::<i64>().is_ok()
92 || (value_str.starts_with('"') && value_str.ends_with('"'))
93 || (value_str.starts_with('\'') && value_str.ends_with('\''));
94
95 // Find which rust_type_marker the default expression references
96 // Extract the first identifier from expressions like "MyEnum::Value" or "MyEnum.Value"
97 let referenced_type = value_str.split([':', '.']).next().map(|s| s.trim());
98
99 // Find the matching marker for this default expression
100 let matching_marker = if is_literal {
101 None
102 } else {
103 referenced_type.and_then(|ref_type| {
104 rust_type_markers.iter().find(|marker| {
105 // Extract the type name from the marker (e.g., "MyEnum" from "crate::MyEnum")
106 let marker_name = marker.rsplit("::").next().unwrap_or(marker);
107 marker_name == ref_type
108 })
109 })
110 };
111
112 // Use source_module from the matching marker if found,
113 // otherwise None to avoid using the wrong module
114 let source_module = if let Some(marker) = matching_marker {
115 if let Ok(marker_type) = syn::parse_str::<syn::Type>(marker) {
116 quote! {
117 Some({
118 fn _get_module() -> Option<::pyo3_stub_gen::ModuleRef> {
119 <#marker_type as ::pyo3_stub_gen::PyStubType>::type_output().source_module
120 }
121 _get_module
122 })
123 }
124 } else {
125 quote! { None }
126 }
127 } else {
128 quote! { None }
129 };
130
131 quote! {
132 ::pyo3_stub_gen::type_info::ParameterDefault::Expr {
133 value: {
134 fn _fmt() -> String {
135 #value_str.to_string()
136 }
137 _fmt
138 },
139 source_module: #source_module,
140 }
141 }
142 }
143 }
144 }
145 Some(DefaultExpr::Python(py_str)) => {
146 // Python expression: already in Python syntax, use directly
147 // No source_module since we don't know the module context from Python syntax
148 quote! {
149 ::pyo3_stub_gen::type_info::ParameterDefault::Expr {
150 value: {
151 fn _fmt() -> String {
152 #py_str.to_string()
153 }
154 _fmt
155 },
156 source_module: None,
157 }
158 }
159 }
160 None => quote! { ::pyo3_stub_gen::type_info::ParameterDefault::None },
161 };
162
163 let param_info = match &self.arg_info.r#type {
164 TypeOrOverride::RustType { r#type } => {
165 quote! {
166 ::pyo3_stub_gen::type_info::ParameterInfo {
167 name: #name,
168 kind: #kind,
169 type_info: <#r#type as ::pyo3_stub_gen::PyStubType>::type_input,
170 default: #default_tokens,
171 }
172 }
173 }
174 TypeOrOverride::OverrideType {
175 type_repr,
176 imports,
177 rust_type_markers,
178 ..
179 } => {
180 let imports = imports.iter().collect::<Vec<&String>>();
181
182 // Generate code to process RustType markers
183 let (type_name_code, type_refs_code) = if rust_type_markers.is_empty() {
184 (
185 quote! { #type_repr.to_string() },
186 quote! { ::std::collections::HashMap::new() },
187 )
188 } else {
189 // Parse rust_type_markers as syn::Type
190 let marker_types: Vec<syn::Type> = rust_type_markers
191 .iter()
192 .filter_map(|s| syn::parse_str(s).ok())
193 .collect();
194
195 let rust_names = rust_type_markers.iter().collect::<Vec<_>>();
196
197 (
198 quote! {
199 {
200 let mut type_name = #type_repr.to_string();
201 #(
202 let type_info = <#marker_types as ::pyo3_stub_gen::PyStubType>::type_input();
203 // Replace Rust type name with Python type name in the expression
204 type_name = type_name.replace(#rust_names, &type_info.name);
205 )*
206 type_name
207 }
208 },
209 quote! {
210 {
211 let mut type_refs = ::std::collections::HashMap::new();
212 #(
213 let type_info = <#marker_types as ::pyo3_stub_gen::PyStubType>::type_input();
214 // Add mapping from Python name to module
215 if let Some(module) = type_info.source_module {
216 type_refs.insert(
217 type_info.name.split('[').next().unwrap_or(&type_info.name).split('.').last().unwrap_or(&type_info.name).to_string(),
218 ::pyo3_stub_gen::TypeIdentifierRef {
219 module: module.into(),
220 import_kind: ::pyo3_stub_gen::ImportKind::Module,
221 }
222 );
223 }
224 type_refs.extend(type_info.type_refs);
225 )*
226 type_refs
227 }
228 },
229 )
230 };
231
232 quote! {
233 ::pyo3_stub_gen::type_info::ParameterInfo {
234 name: #name,
235 kind: #kind,
236 type_info: || ::pyo3_stub_gen::TypeInfo {
237 name: #type_name_code,
238 source_module: None,
239 import: ::std::collections::HashSet::from([#(#imports.into(),)*]),
240 type_refs: #type_refs_code,
241 },
242 default: #default_tokens,
243 }
244 }
245 }
246 };
247
248 tokens.append_all(param_info);
249 }
250}
251
252/// Parameter kind for intermediate representation in derive macro
253///
254/// This enum mirrors `::pyo3_stub_gen::type_info::ParameterKind` but exists
255/// in the derive macro context for code generation purposes.
256#[derive(Debug, Clone, Copy, PartialEq, Eq)]
257pub(crate) enum ParameterKind {
258 PositionalOnly,
259 PositionalOrKeyword,
260 KeywordOnly,
261 VarPositional,
262 VarKeyword,
263}
264
265impl ToTokens for ParameterKind {
266 fn to_tokens(&self, tokens: &mut TokenStream2) {
267 let kind_tokens = match self {
268 Self::PositionalOnly => {
269 quote! { ::pyo3_stub_gen::type_info::ParameterKind::PositionalOnly }
270 }
271 Self::PositionalOrKeyword => {
272 quote! { ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword }
273 }
274 Self::KeywordOnly => {
275 quote! { ::pyo3_stub_gen::type_info::ParameterKind::KeywordOnly }
276 }
277 Self::VarPositional => {
278 quote! { ::pyo3_stub_gen::type_info::ParameterKind::VarPositional }
279 }
280 Self::VarKeyword => {
281 quote! { ::pyo3_stub_gen::type_info::ParameterKind::VarKeyword }
282 }
283 };
284 tokens.append_all(kind_tokens);
285 }
286}
287
288/// Collection of parameters with their kinds determined
289///
290/// This newtype wraps `Vec<ParameterWithKind>` and provides constructors that
291/// parse PyO3 signature attributes and classify parameters accordingly.
292#[derive(Debug, Clone)]
293pub(crate) struct Parameters(Vec<ParameterWithKind>);
294
295impl Parameters {
296 /// Create Parameters from a Vec<ParameterWithKind>
297 ///
298 /// This is used when parameters are already classified (e.g., from Python AST).
299 pub(crate) fn from_vec(parameters: Vec<ParameterWithKind>) -> Self {
300 Self(parameters)
301 }
302
303 /// Get mutable access to internal parameters
304 pub(crate) fn iter_mut(&mut self) -> impl Iterator<Item = &mut ParameterWithKind> {
305 self.0.iter_mut()
306 }
307
308 /// Create parameters without signature attribute
309 ///
310 /// All parameters will be classified as `PositionalOrKeyword`.
311 pub(crate) fn new(args: &[ArgInfo]) -> Self {
312 let parameters = args
313 .iter()
314 .map(|arg| {
315 let mut arg_with_clean_type = arg.clone();
316 if let ArgInfo {
317 r#type: TypeOrOverride::RustType { r#type },
318 ..
319 } = &mut arg_with_clean_type
320 {
321 remove_lifetime(r#type);
322 }
323 ParameterWithKind {
324 arg_info: arg_with_clean_type,
325 kind: ParameterKind::PositionalOrKeyword,
326 default_expr: None,
327 }
328 })
329 .collect();
330 Self(parameters)
331 }
332
333 /// Create parameters with signature attribute
334 ///
335 /// Parses the signature to determine parameter kinds based on delimiters
336 /// (`/` for positional-only, `*` for keyword-only, etc.).
337 pub(crate) fn new_with_sig(args: &[ArgInfo], sig: &Signature) -> Result<Self> {
338 // Build a map of argument names to their type information
339 let args_map: HashMap<String, ArgInfo> = args
340 .iter()
341 .map(|arg| {
342 let mut arg_with_clean_type = arg.clone();
343 if let ArgInfo {
344 r#type: TypeOrOverride::RustType { r#type },
345 ..
346 } = &mut arg_with_clean_type
347 {
348 remove_lifetime(r#type);
349 }
350 (arg.name.clone(), arg_with_clean_type)
351 })
352 .collect();
353
354 // Track parameter kinds based on position and delimiters
355 // By default, parameters are PositionalOrKeyword unless `/` or `*` appear
356 let mut positional_only = false;
357 let mut after_star = false;
358 let mut parameters: Vec<ParameterWithKind> = Vec::new();
359
360 for sig_arg in sig.args() {
361 match sig_arg {
362 SignatureArg::Slash(_) => {
363 // `/` delimiter - mark all previous parameters as positional-only
364 for param in &mut parameters {
365 param.kind = ParameterKind::PositionalOnly;
366 }
367 positional_only = false;
368 }
369 SignatureArg::Star(_) => {
370 // Bare `*` - parameters after this are keyword-only
371 positional_only = false;
372 after_star = true;
373 }
374 SignatureArg::Ident(ident) => {
375 let name = ident.to_string();
376 let kind = if positional_only {
377 ParameterKind::PositionalOnly
378 } else if after_star {
379 ParameterKind::KeywordOnly
380 } else {
381 ParameterKind::PositionalOrKeyword
382 };
383
384 let arg_info = args_map
385 .get(&name)
386 .ok_or_else(|| {
387 syn::Error::new(ident.span(), format!("cannot find argument: {}", name))
388 })?
389 .clone();
390
391 parameters.push(ParameterWithKind {
392 arg_info,
393 kind,
394 default_expr: None,
395 });
396 }
397 SignatureArg::Assign(ident, _eq, value) => {
398 let name = ident.to_string();
399 let kind = if positional_only {
400 ParameterKind::PositionalOnly
401 } else if after_star {
402 ParameterKind::KeywordOnly
403 } else {
404 ParameterKind::PositionalOrKeyword
405 };
406
407 let arg_info = args_map
408 .get(&name)
409 .ok_or_else(|| {
410 syn::Error::new(ident.span(), format!("cannot find argument: {}", name))
411 })?
412 .clone();
413
414 parameters.push(ParameterWithKind {
415 arg_info,
416 kind,
417 default_expr: Some(DefaultExpr::Rust(value.clone())),
418 });
419 }
420 SignatureArg::Args(_, ident) => {
421 positional_only = false;
422 after_star = true; // After *args, everything is keyword-only
423 let name = ident.to_string();
424
425 let mut arg_info = args_map
426 .get(&name)
427 .ok_or_else(|| {
428 syn::Error::new(ident.span(), format!("cannot find argument: {}", name))
429 })?
430 .clone();
431
432 // For VarPositional, if the type is auto-inferred from Rust (RustType),
433 // replace it with typing.Any. If it's OverrideType, keep the user's specification.
434 if matches!(arg_info.r#type, TypeOrOverride::RustType { .. }) {
435 arg_info.r#type = TypeOrOverride::OverrideType {
436 r#type: syn::parse_quote!(()), // Dummy type, won't be used
437 type_repr: "typing.Any".to_string(),
438 imports: ["typing".to_string()].into_iter().collect(),
439 rust_type_markers: vec![],
440 };
441 }
442
443 parameters.push(ParameterWithKind {
444 arg_info,
445 kind: ParameterKind::VarPositional,
446 default_expr: None,
447 });
448 }
449 SignatureArg::Keywords(_, _, ident) => {
450 positional_only = false;
451 let name = ident.to_string();
452
453 let mut arg_info = args_map
454 .get(&name)
455 .ok_or_else(|| {
456 syn::Error::new(ident.span(), format!("cannot find argument: {}", name))
457 })?
458 .clone();
459
460 // For VarKeyword, if the type is auto-inferred from Rust (RustType),
461 // replace it with typing.Any. If it's OverrideType, keep the user's specification.
462 if matches!(arg_info.r#type, TypeOrOverride::RustType { .. }) {
463 arg_info.r#type = TypeOrOverride::OverrideType {
464 r#type: syn::parse_quote!(()), // Dummy type, won't be used
465 type_repr: "typing.Any".to_string(),
466 imports: ["typing".to_string()].into_iter().collect(),
467 rust_type_markers: vec![],
468 };
469 }
470
471 parameters.push(ParameterWithKind {
472 arg_info,
473 kind: ParameterKind::VarKeyword,
474 default_expr: None,
475 });
476 }
477 }
478 }
479
480 Ok(Self(parameters))
481 }
482}
483
484impl ToTokens for Parameters {
485 fn to_tokens(&self, tokens: &mut TokenStream2) {
486 let params = &self.0;
487 tokens.append_all(quote! { &[ #(#params),* ] })
488 }
489}