pyo3_stub_gen_derive/gen_stub/
pyfunction.rs

1use proc_macro2::TokenStream as TokenStream2;
2use quote::{quote, ToTokens, TokenStreamExt};
3use syn::{
4    parse::{Parse, ParseStream},
5    Error, FnArg, ItemFn, Result,
6};
7
8use crate::gen_stub::util::TypeOrOverride;
9
10use super::{
11    attr::IgnoreTarget, extract_deprecated, extract_documents, extract_return_type, parse_args,
12    parse_gen_stub_type_ignore, parse_pyo3_attrs, quote_option, ArgInfo, ArgsWithSignature, Attr,
13    DeprecatedInfo, Signature,
14};
15
16pub struct PyFunctionInfo {
17    pub(crate) name: String,
18    pub(crate) args: Vec<ArgInfo>,
19    pub(crate) r#return: Option<TypeOrOverride>,
20    pub(crate) sig: Option<Signature>,
21    pub(crate) doc: String,
22    pub(crate) module: Option<String>,
23    pub(crate) is_async: bool,
24    pub(crate) deprecated: Option<DeprecatedInfo>,
25    pub(crate) type_ignored: Option<IgnoreTarget>,
26}
27
28struct PyFunctionAttr {
29    module: Option<String>,
30    python: Option<syn::LitStr>,
31}
32
33impl Parse for PyFunctionAttr {
34    fn parse(input: ParseStream) -> Result<Self> {
35        let mut module = None;
36        let mut python = None;
37
38        // Parse comma-separated key-value pairs
39        while !input.is_empty() {
40            let key: syn::Ident = input.parse()?;
41            let _: syn::token::Eq = input.parse()?;
42
43            match key.to_string().as_str() {
44                "module" => {
45                    let value: syn::LitStr = input.parse()?;
46                    module = Some(value.value());
47                }
48                "python" => {
49                    let value: syn::LitStr = input.parse()?;
50                    python = Some(value);
51                }
52                _ => {
53                    return Err(Error::new(
54                        key.span(),
55                        format!("Unknown parameter: {}", key),
56                    ));
57                }
58            }
59
60            // Check for comma separator
61            if input.peek(syn::token::Comma) {
62                let _: syn::token::Comma = input.parse()?;
63            } else {
64                break;
65            }
66        }
67
68        Ok(Self { module, python })
69    }
70}
71
72impl PyFunctionInfo {
73    /// Parse attribute and return python stub string if present
74    pub fn parse_attr(&mut self, attr: TokenStream2) -> Result<Option<syn::LitStr>> {
75        if attr.is_empty() {
76            return Ok(None);
77        }
78        let parsed_attr: PyFunctionAttr = syn::parse2(attr)?;
79
80        // Set module if provided
81        if let Some(module) = parsed_attr.module {
82            self.module = Some(module);
83        }
84
85        // Return python stub string if provided
86        Ok(parsed_attr.python)
87    }
88}
89
90impl TryFrom<ItemFn> for PyFunctionInfo {
91    type Error = Error;
92    fn try_from(item: ItemFn) -> Result<Self> {
93        let doc = extract_documents(&item.attrs).join("\n");
94        let deprecated = extract_deprecated(&item.attrs);
95        let type_ignored = parse_gen_stub_type_ignore(&item.attrs)?;
96        let args = parse_args(item.sig.inputs)?;
97        let r#return = extract_return_type(&item.sig.output, &item.attrs)?;
98        let mut name = None;
99        let mut sig = None;
100        for attr in parse_pyo3_attrs(&item.attrs)? {
101            match attr {
102                Attr::Name(function_name) => name = Some(function_name),
103                Attr::Signature(signature) => sig = Some(signature),
104                _ => {}
105            }
106        }
107        let name = name.unwrap_or_else(|| item.sig.ident.to_string());
108        Ok(Self {
109            args,
110            sig,
111            r#return,
112            name,
113            doc,
114            module: None,
115            is_async: item.sig.asyncness.is_some(),
116            deprecated,
117            type_ignored,
118        })
119    }
120}
121
122impl ToTokens for PyFunctionInfo {
123    fn to_tokens(&self, tokens: &mut TokenStream2) {
124        let Self {
125            args,
126            r#return: ret,
127            name,
128            doc,
129            sig,
130            module,
131            is_async,
132            deprecated,
133            type_ignored,
134        } = self;
135        let ret_tt = if let Some(ret) = ret {
136            match ret {
137                TypeOrOverride::RustType { r#type } => {
138                    let ty = r#type.clone();
139                    quote! { <#ty as pyo3_stub_gen::PyStubType>::type_output }
140                }
141                TypeOrOverride::OverrideType {
142                    type_repr, imports, ..
143                } => {
144                    let imports = imports.iter().collect::<Vec<&String>>();
145                    quote! {
146                        || ::pyo3_stub_gen::TypeInfo { name: #type_repr.to_string(), import: ::std::collections::HashSet::from([#(#imports.into(),)*]) }
147                    }
148                }
149            }
150        } else {
151            quote! { ::pyo3_stub_gen::type_info::no_return_type_output }
152        };
153        // let sig_tt = quote_option(sig);
154        let module_tt = quote_option(module);
155        let deprecated_tt = deprecated
156            .as_ref()
157            .map(|d| quote! { Some(#d) })
158            .unwrap_or_else(|| quote! { None });
159        let type_ignored_tt = if let Some(target) = type_ignored {
160            match target {
161                IgnoreTarget::All => {
162                    quote! { Some(::pyo3_stub_gen::type_info::IgnoreTarget::All) }
163                }
164                IgnoreTarget::SpecifiedLits(rules) => {
165                    let rule_strs: Vec<String> = rules.iter().map(|lit| lit.value()).collect();
166                    quote! {
167                        Some(::pyo3_stub_gen::type_info::IgnoreTarget::Specified(
168                            &[#(#rule_strs),*] as &[&str]
169                        ))
170                    }
171                }
172            }
173        } else {
174            quote! { None }
175        };
176        let args_with_sig = ArgsWithSignature { args, sig };
177        tokens.append_all(quote! {
178            ::pyo3_stub_gen::type_info::PyFunctionInfo {
179                name: #name,
180                args: #args_with_sig,
181                r#return: #ret_tt,
182                doc: #doc,
183                module: #module_tt,
184                is_async: #is_async,
185                deprecated: #deprecated_tt,
186                type_ignored: #type_ignored_tt,
187            }
188        })
189    }
190}
191
192// `#[gen_stub(xxx)]` is not a valid proc_macro_attribute
193// it's only designed to receive user's setting.
194// We need to remove all `#[gen_stub(xxx)]` before print the item_fn back
195pub fn prune_attrs(item_fn: &mut ItemFn) {
196    super::attr::prune_attrs(&mut item_fn.attrs);
197    for arg in item_fn.sig.inputs.iter_mut() {
198        if let FnArg::Typed(ref mut pat_type) = arg {
199            super::attr::prune_attrs(&mut pat_type.attrs);
200        }
201    }
202}