pyo3_stub_gen_derive/gen_stub/
pyfunction.rs

1use proc_macro2::TokenStream as TokenStream2;
2use quote::{quote, ToTokens, TokenStreamExt};
3use syn::{
4    parse::{Parse, ParseStream},
5    Error, FnArg, ItemFn, Result,
6};
7
8use crate::gen_stub::util::TypeOrOverride;
9
10use super::{
11    attr::IgnoreTarget, extract_deprecated, extract_documents, extract_return_type,
12    parameter::Parameters, parse_args, parse_gen_stub_type_ignore, parse_pyo3_attrs, parse_python,
13    quote_option, Attr, DeprecatedInfo,
14};
15
16pub struct PyFunctionInfo {
17    pub(crate) name: String,
18    pub(crate) parameters: Parameters,
19    pub(crate) r#return: Option<TypeOrOverride>,
20    pub(crate) doc: String,
21    pub(crate) module: Option<String>,
22    pub(crate) is_async: bool,
23    pub(crate) deprecated: Option<DeprecatedInfo>,
24    pub(crate) type_ignored: Option<IgnoreTarget>,
25    pub(crate) is_overload: bool,
26    pub(crate) index: usize,
27}
28
29#[derive(Default)]
30pub(crate) struct PyFunctionAttr {
31    pub(crate) module: Option<String>,
32    pub(crate) python: Option<syn::LitStr>,
33    pub(crate) python_overload: Option<syn::LitStr>,
34    pub(crate) no_default_overload: bool,
35}
36
37impl Parse for PyFunctionAttr {
38    fn parse(input: ParseStream) -> Result<Self> {
39        let mut module = None;
40        let mut python = None;
41        let mut python_overload = None;
42        let mut no_default_overload = false;
43
44        // Parse comma-separated key-value pairs
45        while !input.is_empty() {
46            let key: syn::Ident = input.parse()?;
47
48            match key.to_string().as_str() {
49                "module" => {
50                    let _: syn::token::Eq = input.parse()?;
51                    let value: syn::LitStr = input.parse()?;
52                    module = Some(value.value());
53                }
54                "python" => {
55                    let _: syn::token::Eq = input.parse()?;
56                    let value: syn::LitStr = input.parse()?;
57                    python = Some(value);
58                }
59                "python_overload" => {
60                    let _: syn::token::Eq = input.parse()?;
61                    let value: syn::LitStr = input.parse()?;
62                    python_overload = Some(value);
63                }
64                "no_default_overload" => {
65                    let _: syn::token::Eq = input.parse()?;
66                    let value: syn::LitBool = input.parse()?;
67                    no_default_overload = value.value();
68                }
69                _ => {
70                    return Err(Error::new(
71                        key.span(),
72                        format!("Unknown parameter: {}", key),
73                    ));
74                }
75            }
76
77            // Check for comma separator
78            if input.peek(syn::token::Comma) {
79                let _: syn::token::Comma = input.parse()?;
80            } else {
81                break;
82            }
83        }
84
85        // Validate: cannot mix python and python_overload
86        if python.is_some() && python_overload.is_some() {
87            return Err(Error::new(
88                input.span(),
89                "Cannot specify both 'python' and 'python_overload' parameters. Use 'python' for single signatures or 'python_overload' for multiple overloads.",
90            ));
91        }
92
93        // Validate: no_default_overload requires python_overload
94        if no_default_overload && python_overload.is_none() {
95            return Err(Error::new(
96                input.span(),
97                "The 'no_default_overload' parameter can only be used with 'python_overload'. \
98                 Use 'python_overload' to define multiple overload signatures.",
99            ));
100        }
101
102        Ok(Self {
103            module,
104            python,
105            python_overload,
106            no_default_overload,
107        })
108    }
109}
110
111impl TryFrom<ItemFn> for PyFunctionInfo {
112    type Error = Error;
113    fn try_from(item: ItemFn) -> Result<Self> {
114        let doc = extract_documents(&item.attrs).join("\n");
115        let deprecated = extract_deprecated(&item.attrs);
116        let type_ignored = parse_gen_stub_type_ignore(&item.attrs)?;
117        let args = parse_args(item.sig.inputs)?;
118        let r#return = extract_return_type(&item.sig.output, &item.attrs)?;
119        let mut name = None;
120        let mut sig = None;
121        for attr in parse_pyo3_attrs(&item.attrs)? {
122            match attr {
123                Attr::Name(function_name) => name = Some(function_name),
124                Attr::Signature(signature) => sig = Some(signature),
125                _ => {}
126            }
127        }
128        let name = name.unwrap_or_else(|| item.sig.ident.to_string());
129
130        // Build parameters from args and signature
131        let parameters = if let Some(sig) = sig {
132            Parameters::new_with_sig(&args, &sig)?
133        } else {
134            Parameters::new(&args)
135        };
136
137        Ok(Self {
138            name,
139            parameters,
140            r#return,
141            doc,
142            module: None,
143            is_async: item.sig.asyncness.is_some(),
144            deprecated,
145            type_ignored,
146            is_overload: false, // Default to false, will be set by macro if needed
147            index: 0, // Default to 0, will be set by macro if multiple functions are generated
148        })
149    }
150}
151
152impl ToTokens for PyFunctionInfo {
153    fn to_tokens(&self, tokens: &mut TokenStream2) {
154        let Self {
155            r#return: ret,
156            name,
157            doc,
158            parameters,
159            module,
160            is_async,
161            deprecated,
162            type_ignored,
163            is_overload,
164            index,
165        } = self;
166        let ret_tt = if let Some(ret) = ret {
167            match ret {
168                TypeOrOverride::RustType { r#type } => {
169                    let ty = r#type.clone();
170                    quote! { <#ty as pyo3_stub_gen::PyStubType>::type_output }
171                }
172                TypeOrOverride::OverrideType {
173                    type_repr, imports, ..
174                } => {
175                    let imports = imports.iter().collect::<Vec<&String>>();
176                    quote! {
177                        || ::pyo3_stub_gen::TypeInfo { name: #type_repr.to_string(), import: ::std::collections::HashSet::from([#(#imports.into(),)*]) }
178                    }
179                }
180            }
181        } else {
182            quote! { ::pyo3_stub_gen::type_info::no_return_type_output }
183        };
184        // let sig_tt = quote_option(sig);
185        let module_tt = quote_option(module);
186        let deprecated_tt = deprecated
187            .as_ref()
188            .map(|d| quote! { Some(#d) })
189            .unwrap_or_else(|| quote! { None });
190        let type_ignored_tt = if let Some(target) = type_ignored {
191            match target {
192                IgnoreTarget::All => {
193                    quote! { Some(::pyo3_stub_gen::type_info::IgnoreTarget::All) }
194                }
195                IgnoreTarget::SpecifiedLits(rules) => {
196                    let rule_strs: Vec<String> = rules.iter().map(|lit| lit.value()).collect();
197                    quote! {
198                        Some(::pyo3_stub_gen::type_info::IgnoreTarget::Specified(
199                            &[#(#rule_strs),*] as &[&str]
200                        ))
201                    }
202                }
203            }
204        } else {
205            quote! { None }
206        };
207
208        tokens.append_all(quote! {
209            ::pyo3_stub_gen::type_info::PyFunctionInfo {
210                name: #name,
211                parameters: #parameters,
212                r#return: #ret_tt,
213                doc: #doc,
214                module: #module_tt,
215                is_async: #is_async,
216                deprecated: #deprecated_tt,
217                type_ignored: #type_ignored_tt,
218                is_overload: #is_overload,
219                file: file!(),
220                line: line!(),
221                column: column!(),
222                index: #index,
223            }
224        })
225    }
226}
227
228// `#[gen_stub(xxx)]` is not a valid proc_macro_attribute
229// it's only designed to receive user's setting.
230// We need to remove all `#[gen_stub(xxx)]` before print the item_fn back
231pub fn prune_attrs(item_fn: &mut ItemFn) {
232    super::attr::prune_attrs(&mut item_fn.attrs);
233    for arg in item_fn.sig.inputs.iter_mut() {
234        if let FnArg::Typed(ref mut pat_type) = arg {
235            super::attr::prune_attrs(&mut pat_type.attrs);
236        }
237    }
238}
239
240/// Represents one or more PyFunctionInfo with the original ItemFn.
241/// This handles the case where python_overload generates multiple function signatures.
242pub struct PyFunctionInfos {
243    pub(crate) item_fn: ItemFn,
244    pub(crate) infos: Vec<PyFunctionInfo>,
245}
246
247impl PyFunctionInfos {
248    /// Create PyFunctionInfos from ItemFn and PyFunctionAttr
249    pub fn from_parts(mut item_fn: ItemFn, attr: PyFunctionAttr) -> Result<Self> {
250        // Handle python stub syntax early (doesn't need base_info)
251        if let Some(python) = attr.python {
252            let mut python_info = parse_python::parse_python_function_stub(python)?;
253            python_info.module = attr.module;
254            prune_attrs(&mut item_fn);
255            return Ok(Self {
256                item_fn,
257                infos: vec![python_info],
258            });
259        }
260
261        // Convert ItemFn to base PyFunctionInfo for Rust-based generation
262        let mut base_info = PyFunctionInfo::try_from(item_fn.clone())?;
263        base_info.module = attr.module;
264
265        let infos = if let Some(python_overload) = attr.python_overload {
266            // Get function name for validation
267            let function_name = base_info.name.clone();
268
269            // Parse multiple overload definitions
270            let mut overload_infos =
271                parse_python::parse_python_overload_stubs(python_overload, &function_name)?;
272
273            // Preserve module information and assign indices
274            for (index, info) in overload_infos.iter_mut().enumerate() {
275                info.module = base_info.module.clone();
276                info.index = index;
277            }
278
279            // If no_default_overload is false (default), also generate from Rust type
280            if !attr.no_default_overload {
281                // Mark the Rust-generated function as overload
282                base_info.is_overload = true;
283                base_info.index = overload_infos.len();
284                overload_infos.push(base_info);
285            }
286
287            overload_infos
288        } else {
289            // No python or python_overload, use auto-generated
290            vec![base_info]
291        };
292
293        // Prune attributes from ItemFn
294        prune_attrs(&mut item_fn);
295
296        Ok(Self { item_fn, infos })
297    }
298}
299
300impl ToTokens for PyFunctionInfos {
301    fn to_tokens(&self, tokens: &mut TokenStream2) {
302        let item_fn = &self.item_fn;
303        let infos = &self.infos;
304
305        // Generate multiple submit! blocks
306        let submits = infos.iter().map(|info| {
307            quote! {
308                #[automatically_derived]
309                pyo3_stub_gen::inventory::submit! {
310                    #info
311                }
312            }
313        });
314
315        tokens.append_all(quote! {
316            #(#submits)*
317            #item_fn
318        })
319    }
320}