pyo3_stub_gen_derive/gen_stub/
pyfunction.rs

1use proc_macro2::TokenStream as TokenStream2;
2use quote::{quote, ToTokens, TokenStreamExt};
3use syn::{
4    parse::{Parse, ParseStream},
5    Error, FnArg, ItemFn, Result,
6};
7
8use crate::gen_stub::util::TypeOrOverride;
9
10use super::{
11    attr::IgnoreTarget, extract_deprecated, extract_documents, extract_return_type,
12    parameter::Parameters, parse_args, parse_gen_stub_type_ignore, parse_pyo3_attrs, quote_option,
13    Attr, DeprecatedInfo,
14};
15
16pub struct PyFunctionInfo {
17    pub(crate) name: String,
18    pub(crate) parameters: Parameters,
19    pub(crate) r#return: Option<TypeOrOverride>,
20    pub(crate) doc: String,
21    pub(crate) module: Option<String>,
22    pub(crate) is_async: bool,
23    pub(crate) deprecated: Option<DeprecatedInfo>,
24    pub(crate) type_ignored: Option<IgnoreTarget>,
25}
26
27struct PyFunctionAttr {
28    module: Option<String>,
29    python: Option<syn::LitStr>,
30}
31
32impl Parse for PyFunctionAttr {
33    fn parse(input: ParseStream) -> Result<Self> {
34        let mut module = None;
35        let mut python = None;
36
37        // Parse comma-separated key-value pairs
38        while !input.is_empty() {
39            let key: syn::Ident = input.parse()?;
40            let _: syn::token::Eq = input.parse()?;
41
42            match key.to_string().as_str() {
43                "module" => {
44                    let value: syn::LitStr = input.parse()?;
45                    module = Some(value.value());
46                }
47                "python" => {
48                    let value: syn::LitStr = input.parse()?;
49                    python = Some(value);
50                }
51                _ => {
52                    return Err(Error::new(
53                        key.span(),
54                        format!("Unknown parameter: {}", key),
55                    ));
56                }
57            }
58
59            // Check for comma separator
60            if input.peek(syn::token::Comma) {
61                let _: syn::token::Comma = input.parse()?;
62            } else {
63                break;
64            }
65        }
66
67        Ok(Self { module, python })
68    }
69}
70
71impl PyFunctionInfo {
72    /// Parse attribute and return python stub string if present
73    pub fn parse_attr(&mut self, attr: TokenStream2) -> Result<Option<syn::LitStr>> {
74        if attr.is_empty() {
75            return Ok(None);
76        }
77        let parsed_attr: PyFunctionAttr = syn::parse2(attr)?;
78
79        // Set module if provided
80        if let Some(module) = parsed_attr.module {
81            self.module = Some(module);
82        }
83
84        // Return python stub string if provided
85        Ok(parsed_attr.python)
86    }
87}
88
89impl TryFrom<ItemFn> for PyFunctionInfo {
90    type Error = Error;
91    fn try_from(item: ItemFn) -> Result<Self> {
92        let doc = extract_documents(&item.attrs).join("\n");
93        let deprecated = extract_deprecated(&item.attrs);
94        let type_ignored = parse_gen_stub_type_ignore(&item.attrs)?;
95        let args = parse_args(item.sig.inputs)?;
96        let r#return = extract_return_type(&item.sig.output, &item.attrs)?;
97        let mut name = None;
98        let mut sig = None;
99        for attr in parse_pyo3_attrs(&item.attrs)? {
100            match attr {
101                Attr::Name(function_name) => name = Some(function_name),
102                Attr::Signature(signature) => sig = Some(signature),
103                _ => {}
104            }
105        }
106        let name = name.unwrap_or_else(|| item.sig.ident.to_string());
107
108        // Build parameters from args and signature
109        let parameters = if let Some(sig) = sig {
110            Parameters::new_with_sig(&args, &sig)?
111        } else {
112            Parameters::new(&args)
113        };
114
115        Ok(Self {
116            name,
117            parameters,
118            r#return,
119            doc,
120            module: None,
121            is_async: item.sig.asyncness.is_some(),
122            deprecated,
123            type_ignored,
124        })
125    }
126}
127
128impl ToTokens for PyFunctionInfo {
129    fn to_tokens(&self, tokens: &mut TokenStream2) {
130        let Self {
131            r#return: ret,
132            name,
133            doc,
134            parameters,
135            module,
136            is_async,
137            deprecated,
138            type_ignored,
139        } = self;
140        let ret_tt = if let Some(ret) = ret {
141            match ret {
142                TypeOrOverride::RustType { r#type } => {
143                    let ty = r#type.clone();
144                    quote! { <#ty as pyo3_stub_gen::PyStubType>::type_output }
145                }
146                TypeOrOverride::OverrideType {
147                    type_repr, imports, ..
148                } => {
149                    let imports = imports.iter().collect::<Vec<&String>>();
150                    quote! {
151                        || ::pyo3_stub_gen::TypeInfo { name: #type_repr.to_string(), import: ::std::collections::HashSet::from([#(#imports.into(),)*]) }
152                    }
153                }
154            }
155        } else {
156            quote! { ::pyo3_stub_gen::type_info::no_return_type_output }
157        };
158        // let sig_tt = quote_option(sig);
159        let module_tt = quote_option(module);
160        let deprecated_tt = deprecated
161            .as_ref()
162            .map(|d| quote! { Some(#d) })
163            .unwrap_or_else(|| quote! { None });
164        let type_ignored_tt = if let Some(target) = type_ignored {
165            match target {
166                IgnoreTarget::All => {
167                    quote! { Some(::pyo3_stub_gen::type_info::IgnoreTarget::All) }
168                }
169                IgnoreTarget::SpecifiedLits(rules) => {
170                    let rule_strs: Vec<String> = rules.iter().map(|lit| lit.value()).collect();
171                    quote! {
172                        Some(::pyo3_stub_gen::type_info::IgnoreTarget::Specified(
173                            &[#(#rule_strs),*] as &[&str]
174                        ))
175                    }
176                }
177            }
178        } else {
179            quote! { None }
180        };
181
182        tokens.append_all(quote! {
183            ::pyo3_stub_gen::type_info::PyFunctionInfo {
184                name: #name,
185                parameters: #parameters,
186                r#return: #ret_tt,
187                doc: #doc,
188                module: #module_tt,
189                is_async: #is_async,
190                deprecated: #deprecated_tt,
191                type_ignored: #type_ignored_tt,
192            }
193        })
194    }
195}
196
197// `#[gen_stub(xxx)]` is not a valid proc_macro_attribute
198// it's only designed to receive user's setting.
199// We need to remove all `#[gen_stub(xxx)]` before print the item_fn back
200pub fn prune_attrs(item_fn: &mut ItemFn) {
201    super::attr::prune_attrs(&mut item_fn.attrs);
202    for arg in item_fn.sig.inputs.iter_mut() {
203        if let FnArg::Typed(ref mut pat_type) = arg {
204            super::attr::prune_attrs(&mut pat_type.attrs);
205        }
206    }
207}