pyo3_stub_gen_derive/gen_stub/
pyfunction.rs

1use proc_macro2::TokenStream as TokenStream2;
2use quote::{quote, ToTokens, TokenStreamExt};
3use syn::{
4    parse::{Parse, ParseStream},
5    Error, FnArg, ItemFn, Result,
6};
7
8use crate::gen_stub::util::TypeOrOverride;
9
10use super::{
11    attr::IgnoreTarget, extract_deprecated, extract_documents, extract_return_type,
12    parameter::Parameters, parse_args, parse_gen_stub_type_ignore, parse_pyo3_attrs, parse_python,
13    quote_option, Attr, DeprecatedInfo,
14};
15
16pub struct PyFunctionInfo {
17    pub(crate) name: String,
18    pub(crate) parameters: Parameters,
19    pub(crate) r#return: Option<TypeOrOverride>,
20    pub(crate) doc: String,
21    pub(crate) module: Option<String>,
22    pub(crate) is_async: bool,
23    pub(crate) deprecated: Option<DeprecatedInfo>,
24    pub(crate) type_ignored: Option<IgnoreTarget>,
25    pub(crate) is_overload: bool,
26    pub(crate) index: usize,
27}
28
29#[derive(Default)]
30pub(crate) struct PyFunctionAttr {
31    pub(crate) module: Option<String>,
32    pub(crate) python: Option<syn::LitStr>,
33    pub(crate) python_overload: Option<syn::LitStr>,
34    pub(crate) no_default_overload: bool,
35}
36
37impl Parse for PyFunctionAttr {
38    fn parse(input: ParseStream) -> Result<Self> {
39        let mut module = None;
40        let mut python = None;
41        let mut python_overload = None;
42        let mut no_default_overload = false;
43
44        // Parse comma-separated key-value pairs
45        while !input.is_empty() {
46            let key: syn::Ident = input.parse()?;
47
48            match key.to_string().as_str() {
49                "module" => {
50                    let _: syn::token::Eq = input.parse()?;
51                    let value: syn::LitStr = input.parse()?;
52                    module = Some(value.value());
53                }
54                "python" => {
55                    let _: syn::token::Eq = input.parse()?;
56                    let value: syn::LitStr = input.parse()?;
57                    python = Some(value);
58                }
59                "python_overload" => {
60                    let _: syn::token::Eq = input.parse()?;
61                    let value: syn::LitStr = input.parse()?;
62                    python_overload = Some(value);
63                }
64                "no_default_overload" => {
65                    let _: syn::token::Eq = input.parse()?;
66                    let value: syn::LitBool = input.parse()?;
67                    no_default_overload = value.value();
68                }
69                _ => {
70                    return Err(Error::new(
71                        key.span(),
72                        format!("Unknown parameter: {}", key),
73                    ));
74                }
75            }
76
77            // Check for comma separator
78            if input.peek(syn::token::Comma) {
79                let _: syn::token::Comma = input.parse()?;
80            } else {
81                break;
82            }
83        }
84
85        // Validate: cannot mix python and python_overload
86        if python.is_some() && python_overload.is_some() {
87            return Err(Error::new(
88                input.span(),
89                "Cannot specify both 'python' and 'python_overload' parameters. Use 'python' for single signatures or 'python_overload' for multiple overloads.",
90            ));
91        }
92
93        // Validate: no_default_overload requires python_overload
94        if no_default_overload && python_overload.is_none() {
95            return Err(Error::new(
96                input.span(),
97                "The 'no_default_overload' parameter can only be used with 'python_overload'. \
98                 Use 'python_overload' to define multiple overload signatures.",
99            ));
100        }
101
102        Ok(Self {
103            module,
104            python,
105            python_overload,
106            no_default_overload,
107        })
108    }
109}
110
111impl TryFrom<ItemFn> for PyFunctionInfo {
112    type Error = Error;
113    fn try_from(item: ItemFn) -> Result<Self> {
114        let doc = extract_documents(&item.attrs).join("\n");
115        let deprecated = extract_deprecated(&item.attrs);
116        let type_ignored = parse_gen_stub_type_ignore(&item.attrs)?;
117        let args = parse_args(item.sig.inputs)?;
118        let r#return = extract_return_type(&item.sig.output, &item.attrs)?;
119        let mut name = None;
120        let mut sig = None;
121        let mut pyo3_module = None;
122        for attr in parse_pyo3_attrs(&item.attrs)? {
123            match attr {
124                Attr::Name(function_name) => name = Some(function_name),
125                Attr::Signature(signature) => sig = Some(signature),
126                Attr::Module(module_name) => pyo3_module = Some(module_name),
127                _ => {}
128            }
129        }
130        let name = name.unwrap_or_else(|| item.sig.ident.to_string());
131
132        // Build parameters from args and signature
133        let parameters = if let Some(sig) = sig {
134            Parameters::new_with_sig(&args, &sig)?
135        } else {
136            Parameters::new(&args)
137        };
138
139        Ok(Self {
140            name,
141            parameters,
142            r#return,
143            doc,
144            module: pyo3_module,
145            is_async: item.sig.asyncness.is_some(),
146            deprecated,
147            type_ignored,
148            is_overload: false, // Default to false, will be set by macro if needed
149            index: 0, // Default to 0, will be set by macro if multiple functions are generated
150        })
151    }
152}
153
154impl ToTokens for PyFunctionInfo {
155    fn to_tokens(&self, tokens: &mut TokenStream2) {
156        let Self {
157            r#return: ret,
158            name,
159            doc,
160            parameters,
161            module,
162            is_async,
163            deprecated,
164            type_ignored,
165            is_overload,
166            index,
167        } = self;
168        let ret_tt = if let Some(ret) = ret {
169            match ret {
170                TypeOrOverride::RustType { r#type } => {
171                    let ty = r#type.clone();
172                    quote! { <#ty as pyo3_stub_gen::PyStubType>::type_output }
173                }
174                TypeOrOverride::OverrideType {
175                    type_repr,
176                    imports,
177                    rust_type_markers,
178                    ..
179                } => {
180                    let imports = imports.iter().collect::<Vec<&String>>();
181
182                    // Generate code to process RustType markers
183                    let (type_name_code, type_refs_code) = if rust_type_markers.is_empty() {
184                        (
185                            quote! { #type_repr.to_string() },
186                            quote! { ::std::collections::HashMap::new() },
187                        )
188                    } else {
189                        let marker_types: Vec<syn::Type> = rust_type_markers
190                            .iter()
191                            .filter_map(|s| syn::parse_str(s).ok())
192                            .collect();
193
194                        let rust_names = rust_type_markers.iter().collect::<Vec<_>>();
195
196                        (
197                            quote! {
198                                {
199                                    let mut type_name = #type_repr.to_string();
200                                    #(
201                                        let type_info = <#marker_types as ::pyo3_stub_gen::PyStubType>::type_input();
202                                        type_name = type_name.replace(#rust_names, &type_info.name);
203                                    )*
204                                    type_name
205                                }
206                            },
207                            quote! {
208                                {
209                                    let mut type_refs = ::std::collections::HashMap::new();
210                                    #(
211                                        let type_info = <#marker_types as ::pyo3_stub_gen::PyStubType>::type_input();
212                                        if let Some(module) = type_info.source_module {
213                                            type_refs.insert(
214                                                type_info.name.split('[').next().unwrap_or(&type_info.name).split('.').last().unwrap_or(&type_info.name).to_string(),
215                                                ::pyo3_stub_gen::TypeIdentifierRef {
216                                                    module: module.into(),
217                                                    import_kind: ::pyo3_stub_gen::ImportKind::Module,
218                                                }
219                                            );
220                                        }
221                                        type_refs.extend(type_info.type_refs);
222                                    )*
223                                    type_refs
224                                }
225                            },
226                        )
227                    };
228
229                    quote! {
230                        || ::pyo3_stub_gen::TypeInfo { name: #type_name_code, source_module: None, import: ::std::collections::HashSet::from([#(#imports.into(),)*]), type_refs: #type_refs_code }
231                    }
232                }
233            }
234        } else {
235            quote! { ::pyo3_stub_gen::type_info::no_return_type_output }
236        };
237        // let sig_tt = quote_option(sig);
238        let module_tt = quote_option(module);
239        let deprecated_tt = deprecated
240            .as_ref()
241            .map(|d| quote! { Some(#d) })
242            .unwrap_or_else(|| quote! { None });
243        let type_ignored_tt = if let Some(target) = type_ignored {
244            match target {
245                IgnoreTarget::All => {
246                    quote! { Some(::pyo3_stub_gen::type_info::IgnoreTarget::All) }
247                }
248                IgnoreTarget::SpecifiedLits(rules) => {
249                    let rule_strs: Vec<String> = rules.iter().map(|lit| lit.value()).collect();
250                    quote! {
251                        Some(::pyo3_stub_gen::type_info::IgnoreTarget::Specified(
252                            &[#(#rule_strs),*] as &[&str]
253                        ))
254                    }
255                }
256            }
257        } else {
258            quote! { None }
259        };
260
261        tokens.append_all(quote! {
262            ::pyo3_stub_gen::type_info::PyFunctionInfo {
263                name: #name,
264                parameters: #parameters,
265                r#return: #ret_tt,
266                doc: #doc,
267                module: #module_tt,
268                is_async: #is_async,
269                deprecated: #deprecated_tt,
270                type_ignored: #type_ignored_tt,
271                is_overload: #is_overload,
272                file: file!(),
273                line: line!(),
274                column: column!(),
275                index: #index,
276            }
277        })
278    }
279}
280
281// `#[gen_stub(xxx)]` is not a valid proc_macro_attribute
282// it's only designed to receive user's setting.
283// We need to remove all `#[gen_stub(xxx)]` before print the item_fn back
284pub fn prune_attrs(item_fn: &mut ItemFn) {
285    super::attr::prune_attrs(&mut item_fn.attrs);
286    for arg in item_fn.sig.inputs.iter_mut() {
287        if let FnArg::Typed(ref mut pat_type) = arg {
288            super::attr::prune_attrs(&mut pat_type.attrs);
289        }
290    }
291}
292
293/// Represents one or more PyFunctionInfo with the original ItemFn.
294/// This handles the case where python_overload generates multiple function signatures.
295pub struct PyFunctionInfos {
296    pub(crate) item_fn: ItemFn,
297    pub(crate) infos: Vec<PyFunctionInfo>,
298}
299
300impl PyFunctionInfos {
301    /// Create PyFunctionInfos from ItemFn and PyFunctionAttr
302    pub fn from_parts(mut item_fn: ItemFn, attr: PyFunctionAttr) -> Result<Self> {
303        // Extract standalone gen_stub module from item attributes
304        let mut gen_stub_standalone_module = None;
305        for attr_item in parse_pyo3_attrs(&item_fn.attrs)? {
306            if let Attr::GenStubModule(module_name) = attr_item {
307                gen_stub_standalone_module = Some(module_name);
308            }
309        }
310
311        // Validate: inline and standalone gen_stub modules must not conflict
312        if let (Some(inline_mod), Some(standalone_mod)) =
313            (&attr.module, &gen_stub_standalone_module)
314        {
315            if inline_mod != standalone_mod {
316                return Err(Error::new(
317                    item_fn.sig.ident.span(),
318                    format!(
319                        "Conflicting module specifications: #[gen_stub_pyfunction(module = \"{}\")] \
320                         and #[gen_stub(module = \"{}\")]. Please use only one.",
321                        inline_mod, standalone_mod
322                    ),
323                ));
324            }
325        }
326
327        // Handle python stub syntax early (doesn't need base_info)
328        if let Some(python) = attr.python {
329            let mut python_info = parse_python::parse_python_function_stub(python)?;
330            // Priority: inline > standalone > pyo3 (pyo3 already in python_info from python stub)
331            python_info.module = if let Some(inline_mod) = attr.module {
332                Some(inline_mod) // Priority 1
333            } else if let Some(standalone_mod) = gen_stub_standalone_module {
334                Some(standalone_mod) // Priority 2
335            } else {
336                python_info.module // Priority 3: from Python stub or None
337            };
338            prune_attrs(&mut item_fn);
339            return Ok(Self {
340                item_fn,
341                infos: vec![python_info],
342            });
343        }
344
345        // Convert ItemFn to base PyFunctionInfo for Rust-based generation
346        let mut base_info = PyFunctionInfo::try_from(item_fn.clone())?;
347
348        // Priority: inline > standalone > pyo3 > default
349        // base_info.module already contains pyo3 module from try_from
350        let pyo3_module = base_info.module.clone();
351        base_info.module = if let Some(inline_mod) = attr.module {
352            Some(inline_mod) // Priority 1: #[gen_stub_pyfunction(module = "...")]
353        } else if let Some(standalone_mod) = gen_stub_standalone_module {
354            Some(standalone_mod) // Priority 2: #[gen_stub(module = "...")]
355        } else {
356            pyo3_module // Priority 3: #[pyo3(module = "...")]
357        };
358
359        let infos = if let Some(python_overload) = attr.python_overload {
360            // Get function name for validation
361            let function_name = base_info.name.clone();
362
363            // Parse multiple overload definitions
364            let mut overload_infos =
365                parse_python::parse_python_overload_stubs(python_overload, &function_name)?;
366
367            // Preserve module information and assign indices
368            for (index, info) in overload_infos.iter_mut().enumerate() {
369                info.module = base_info.module.clone();
370                info.index = index;
371            }
372
373            // If no_default_overload is false (default), also generate from Rust type
374            if !attr.no_default_overload {
375                // Mark the Rust-generated function as overload
376                base_info.is_overload = true;
377                base_info.index = overload_infos.len();
378                overload_infos.push(base_info);
379            }
380
381            overload_infos
382        } else {
383            // No python or python_overload, use auto-generated
384            vec![base_info]
385        };
386
387        // Prune attributes from ItemFn
388        prune_attrs(&mut item_fn);
389
390        Ok(Self { item_fn, infos })
391    }
392}
393
394impl ToTokens for PyFunctionInfos {
395    fn to_tokens(&self, tokens: &mut TokenStream2) {
396        let item_fn = &self.item_fn;
397        let infos = &self.infos;
398
399        // Generate multiple submit! blocks
400        let submits = infos.iter().map(|info| {
401            quote! {
402                #[automatically_derived]
403                pyo3_stub_gen::inventory::submit! {
404                    #info
405                }
406            }
407        });
408
409        tokens.append_all(quote! {
410            #(#submits)*
411            #item_fn
412        })
413    }
414}