pyo3_stub_gen_derive/gen_stub/parameter.rs
1//! Parameter intermediate representation for derive macros
2//!
3//! This module provides intermediate representations for parameters that are used
4//! during the code generation phase. These types exist only within the derive macro
5//! and are converted to `::pyo3_stub_gen::type_info::ParameterInfo` via `ToTokens`.
6
7use std::collections::HashMap;
8
9use proc_macro2::TokenStream as TokenStream2;
10use quote::{quote, ToTokens, TokenStreamExt};
11use syn::{Expr, Result};
12
13use super::{remove_lifetime, signature::SignatureArg, util::TypeOrOverride, ArgInfo, Signature};
14
15/// Represents a default value expression from either Rust or Python source
16#[derive(Debug, Clone)]
17pub(crate) enum DefaultExpr {
18 /// Rust expression that needs to be converted to Python representation at runtime
19 /// Example: `vec![1, 2]`, `Number::Float`, `10`
20 Rust(Expr),
21 /// Python expression already in Python syntax (from Python stub)
22 /// Example: `"False"`, `"[1, 2]"`, `"Number.FLOAT"`
23 Python(String),
24}
25
26/// Intermediate representation for a parameter with its kind determined
27#[derive(Debug, Clone)]
28pub(crate) struct ParameterWithKind {
29 pub(crate) arg_info: ArgInfo,
30 pub(crate) kind: ParameterKind,
31 pub(crate) default_expr: Option<DefaultExpr>,
32}
33
34impl ToTokens for ParameterWithKind {
35 fn to_tokens(&self, tokens: &mut TokenStream2) {
36 let name = &self.arg_info.name;
37 let kind = &self.kind;
38
39 let default_tokens = match &self.default_expr {
40 Some(DefaultExpr::Rust(expr)) => {
41 // Rust expression: needs runtime conversion via fmt_py_obj
42 match &self.arg_info.r#type {
43 TypeOrOverride::RustType { r#type } => {
44 let default = if expr.to_token_stream().to_string() == "None" {
45 quote! { "None".to_string() }
46 } else {
47 quote! {
48 {
49 let v: #r#type = #expr;
50 let repr = ::pyo3_stub_gen::util::fmt_py_obj(v);
51 let type_info = <#r#type as ::pyo3_stub_gen::PyStubType>::type_output();
52
53 // Check if the type is defined in another module within the current package.
54 // For locally-defined types: qualified name is "sub_mod.ClassA" and
55 // import is Module("package.sub_mod"). The import ends with the module prefix.
56 // For external types: qualified name is "numpy.ndarray" and
57 // import is Module("numpy"). The import equals the module prefix.
58 let should_add_prefix = if let Some(dot_pos) = type_info.name.rfind('.') {
59 let module_prefix = &type_info.name[..dot_pos];
60 type_info.import.iter().any(|imp| {
61 if let ::pyo3_stub_gen::ImportRef::Module(module_ref) = imp {
62 if let Some(module_name) = module_ref.get() {
63 // Local cross-module ref: full path ends with module prefix
64 // e.g., "package.sub_mod" ends with ".sub_mod"
65 module_name.ends_with(&format!(".{}", module_prefix))
66 } else {
67 false
68 }
69 } else {
70 false
71 }
72 })
73 } else {
74 false
75 };
76
77 if should_add_prefix {
78 if let Some(dot_pos) = type_info.name.rfind('.') {
79 let module_prefix = &type_info.name[..dot_pos];
80 format!("{}.{}", module_prefix, repr)
81 } else {
82 repr
83 }
84 } else {
85 repr
86 }
87 }
88 }
89 };
90 quote! {
91 ::pyo3_stub_gen::type_info::ParameterDefault::Expr({
92 fn _fmt() -> String {
93 #default
94 }
95 _fmt
96 })
97 }
98 }
99 TypeOrOverride::OverrideType { .. } => {
100 // For OverrideType, convert the default value expression directly to a string
101 // since r#type may be a dummy type and we can't use it for type annotations
102 let mut value_str = expr.to_token_stream().to_string();
103 // Convert Rust bool literals to Python bool literals
104 if value_str == "false" {
105 value_str = "False".to_string();
106 } else if value_str == "true" {
107 value_str = "True".to_string();
108 }
109 quote! {
110 ::pyo3_stub_gen::type_info::ParameterDefault::Expr({
111 fn _fmt() -> String {
112 #value_str.to_string()
113 }
114 _fmt
115 })
116 }
117 }
118 }
119 }
120 Some(DefaultExpr::Python(py_str)) => {
121 // Python expression: already in Python syntax, use directly
122 quote! {
123 ::pyo3_stub_gen::type_info::ParameterDefault::Expr({
124 fn _fmt() -> String {
125 #py_str.to_string()
126 }
127 _fmt
128 })
129 }
130 }
131 None => quote! { ::pyo3_stub_gen::type_info::ParameterDefault::None },
132 };
133
134 let param_info = match &self.arg_info.r#type {
135 TypeOrOverride::RustType { r#type } => {
136 quote! {
137 ::pyo3_stub_gen::type_info::ParameterInfo {
138 name: #name,
139 kind: #kind,
140 type_info: <#r#type as ::pyo3_stub_gen::PyStubType>::type_input,
141 default: #default_tokens,
142 }
143 }
144 }
145 TypeOrOverride::OverrideType {
146 type_repr,
147 imports,
148 rust_type_markers,
149 ..
150 } => {
151 let imports = imports.iter().collect::<Vec<&String>>();
152
153 // Generate code to process RustType markers
154 let (type_name_code, type_refs_code) = if rust_type_markers.is_empty() {
155 (
156 quote! { #type_repr.to_string() },
157 quote! { ::std::collections::HashMap::new() },
158 )
159 } else {
160 // Parse rust_type_markers as syn::Type
161 let marker_types: Vec<syn::Type> = rust_type_markers
162 .iter()
163 .filter_map(|s| syn::parse_str(s).ok())
164 .collect();
165
166 let rust_names = rust_type_markers.iter().collect::<Vec<_>>();
167
168 (
169 quote! {
170 {
171 let mut type_name = #type_repr.to_string();
172 #(
173 let type_info = <#marker_types as ::pyo3_stub_gen::PyStubType>::type_input();
174 // Replace Rust type name with Python type name in the expression
175 type_name = type_name.replace(#rust_names, &type_info.name);
176 )*
177 type_name
178 }
179 },
180 quote! {
181 {
182 let mut type_refs = ::std::collections::HashMap::new();
183 #(
184 let type_info = <#marker_types as ::pyo3_stub_gen::PyStubType>::type_input();
185 // Add mapping from Python name to module
186 if let Some(module) = type_info.source_module {
187 type_refs.insert(
188 type_info.name.split('[').next().unwrap_or(&type_info.name).split('.').last().unwrap_or(&type_info.name).to_string(),
189 ::pyo3_stub_gen::TypeIdentifierRef {
190 module: module.into(),
191 import_kind: ::pyo3_stub_gen::ImportKind::Module,
192 }
193 );
194 }
195 type_refs.extend(type_info.type_refs);
196 )*
197 type_refs
198 }
199 },
200 )
201 };
202
203 quote! {
204 ::pyo3_stub_gen::type_info::ParameterInfo {
205 name: #name,
206 kind: #kind,
207 type_info: || ::pyo3_stub_gen::TypeInfo {
208 name: #type_name_code,
209 source_module: None,
210 import: ::std::collections::HashSet::from([#(#imports.into(),)*]),
211 type_refs: #type_refs_code,
212 },
213 default: #default_tokens,
214 }
215 }
216 }
217 };
218
219 tokens.append_all(param_info);
220 }
221}
222
223/// Parameter kind for intermediate representation in derive macro
224///
225/// This enum mirrors `::pyo3_stub_gen::type_info::ParameterKind` but exists
226/// in the derive macro context for code generation purposes.
227#[derive(Debug, Clone, Copy, PartialEq, Eq)]
228pub(crate) enum ParameterKind {
229 PositionalOnly,
230 PositionalOrKeyword,
231 KeywordOnly,
232 VarPositional,
233 VarKeyword,
234}
235
236impl ToTokens for ParameterKind {
237 fn to_tokens(&self, tokens: &mut TokenStream2) {
238 let kind_tokens = match self {
239 Self::PositionalOnly => {
240 quote! { ::pyo3_stub_gen::type_info::ParameterKind::PositionalOnly }
241 }
242 Self::PositionalOrKeyword => {
243 quote! { ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword }
244 }
245 Self::KeywordOnly => {
246 quote! { ::pyo3_stub_gen::type_info::ParameterKind::KeywordOnly }
247 }
248 Self::VarPositional => {
249 quote! { ::pyo3_stub_gen::type_info::ParameterKind::VarPositional }
250 }
251 Self::VarKeyword => {
252 quote! { ::pyo3_stub_gen::type_info::ParameterKind::VarKeyword }
253 }
254 };
255 tokens.append_all(kind_tokens);
256 }
257}
258
259/// Collection of parameters with their kinds determined
260///
261/// This newtype wraps `Vec<ParameterWithKind>` and provides constructors that
262/// parse PyO3 signature attributes and classify parameters accordingly.
263#[derive(Debug, Clone)]
264pub(crate) struct Parameters(Vec<ParameterWithKind>);
265
266impl Parameters {
267 /// Create Parameters from a Vec<ParameterWithKind>
268 ///
269 /// This is used when parameters are already classified (e.g., from Python AST).
270 pub(crate) fn from_vec(parameters: Vec<ParameterWithKind>) -> Self {
271 Self(parameters)
272 }
273
274 /// Get mutable access to internal parameters
275 pub(crate) fn iter_mut(&mut self) -> impl Iterator<Item = &mut ParameterWithKind> {
276 self.0.iter_mut()
277 }
278
279 /// Create parameters without signature attribute
280 ///
281 /// All parameters will be classified as `PositionalOrKeyword`.
282 pub(crate) fn new(args: &[ArgInfo]) -> Self {
283 let parameters = args
284 .iter()
285 .map(|arg| {
286 let mut arg_with_clean_type = arg.clone();
287 if let ArgInfo {
288 r#type: TypeOrOverride::RustType { r#type },
289 ..
290 } = &mut arg_with_clean_type
291 {
292 remove_lifetime(r#type);
293 }
294 ParameterWithKind {
295 arg_info: arg_with_clean_type,
296 kind: ParameterKind::PositionalOrKeyword,
297 default_expr: None,
298 }
299 })
300 .collect();
301 Self(parameters)
302 }
303
304 /// Create parameters with signature attribute
305 ///
306 /// Parses the signature to determine parameter kinds based on delimiters
307 /// (`/` for positional-only, `*` for keyword-only, etc.).
308 pub(crate) fn new_with_sig(args: &[ArgInfo], sig: &Signature) -> Result<Self> {
309 // Build a map of argument names to their type information
310 let args_map: HashMap<String, ArgInfo> = args
311 .iter()
312 .map(|arg| {
313 let mut arg_with_clean_type = arg.clone();
314 if let ArgInfo {
315 r#type: TypeOrOverride::RustType { r#type },
316 ..
317 } = &mut arg_with_clean_type
318 {
319 remove_lifetime(r#type);
320 }
321 (arg.name.clone(), arg_with_clean_type)
322 })
323 .collect();
324
325 // Track parameter kinds based on position and delimiters
326 // By default, parameters are PositionalOrKeyword unless `/` or `*` appear
327 let mut positional_only = false;
328 let mut after_star = false;
329 let mut parameters: Vec<ParameterWithKind> = Vec::new();
330
331 for sig_arg in sig.args() {
332 match sig_arg {
333 SignatureArg::Slash(_) => {
334 // `/` delimiter - mark all previous parameters as positional-only
335 for param in &mut parameters {
336 param.kind = ParameterKind::PositionalOnly;
337 }
338 positional_only = false;
339 }
340 SignatureArg::Star(_) => {
341 // Bare `*` - parameters after this are keyword-only
342 positional_only = false;
343 after_star = true;
344 }
345 SignatureArg::Ident(ident) => {
346 let name = ident.to_string();
347 let kind = if positional_only {
348 ParameterKind::PositionalOnly
349 } else if after_star {
350 ParameterKind::KeywordOnly
351 } else {
352 ParameterKind::PositionalOrKeyword
353 };
354
355 let arg_info = args_map
356 .get(&name)
357 .ok_or_else(|| {
358 syn::Error::new(ident.span(), format!("cannot find argument: {}", name))
359 })?
360 .clone();
361
362 parameters.push(ParameterWithKind {
363 arg_info,
364 kind,
365 default_expr: None,
366 });
367 }
368 SignatureArg::Assign(ident, _eq, value) => {
369 let name = ident.to_string();
370 let kind = if positional_only {
371 ParameterKind::PositionalOnly
372 } else if after_star {
373 ParameterKind::KeywordOnly
374 } else {
375 ParameterKind::PositionalOrKeyword
376 };
377
378 let arg_info = args_map
379 .get(&name)
380 .ok_or_else(|| {
381 syn::Error::new(ident.span(), format!("cannot find argument: {}", name))
382 })?
383 .clone();
384
385 parameters.push(ParameterWithKind {
386 arg_info,
387 kind,
388 default_expr: Some(DefaultExpr::Rust(value.clone())),
389 });
390 }
391 SignatureArg::Args(_, ident) => {
392 positional_only = false;
393 after_star = true; // After *args, everything is keyword-only
394 let name = ident.to_string();
395
396 let mut arg_info = args_map
397 .get(&name)
398 .ok_or_else(|| {
399 syn::Error::new(ident.span(), format!("cannot find argument: {}", name))
400 })?
401 .clone();
402
403 // For VarPositional, if the type is auto-inferred from Rust (RustType),
404 // replace it with typing.Any. If it's OverrideType, keep the user's specification.
405 if matches!(arg_info.r#type, TypeOrOverride::RustType { .. }) {
406 arg_info.r#type = TypeOrOverride::OverrideType {
407 r#type: syn::parse_quote!(()), // Dummy type, won't be used
408 type_repr: "typing.Any".to_string(),
409 imports: ["typing".to_string()].into_iter().collect(),
410 rust_type_markers: vec![],
411 };
412 }
413
414 parameters.push(ParameterWithKind {
415 arg_info,
416 kind: ParameterKind::VarPositional,
417 default_expr: None,
418 });
419 }
420 SignatureArg::Keywords(_, _, ident) => {
421 positional_only = false;
422 let name = ident.to_string();
423
424 let mut arg_info = args_map
425 .get(&name)
426 .ok_or_else(|| {
427 syn::Error::new(ident.span(), format!("cannot find argument: {}", name))
428 })?
429 .clone();
430
431 // For VarKeyword, if the type is auto-inferred from Rust (RustType),
432 // replace it with typing.Any. If it's OverrideType, keep the user's specification.
433 if matches!(arg_info.r#type, TypeOrOverride::RustType { .. }) {
434 arg_info.r#type = TypeOrOverride::OverrideType {
435 r#type: syn::parse_quote!(()), // Dummy type, won't be used
436 type_repr: "typing.Any".to_string(),
437 imports: ["typing".to_string()].into_iter().collect(),
438 rust_type_markers: vec![],
439 };
440 }
441
442 parameters.push(ParameterWithKind {
443 arg_info,
444 kind: ParameterKind::VarKeyword,
445 default_expr: None,
446 });
447 }
448 }
449 }
450
451 Ok(Self(parameters))
452 }
453}
454
455impl ToTokens for Parameters {
456 fn to_tokens(&self, tokens: &mut TokenStream2) {
457 let params = &self.0;
458 tokens.append_all(quote! { &[ #(#params),* ] })
459 }
460}