pyo3_stub_gen_derive/gen_stub/
pyfunction.rs

1use proc_macro2::TokenStream as TokenStream2;
2use quote::{quote, ToTokens, TokenStreamExt};
3use syn::{
4    parse::{Parse, ParseStream},
5    Error, FnArg, ItemFn, Result,
6};
7
8use crate::gen_stub::util::TypeOrOverride;
9
10use super::{
11    attr::IgnoreTarget, extract_deprecated, extract_documents, extract_return_type, parse_args,
12    parse_gen_stub_type_ignore, parse_pyo3_attrs, quote_option, ArgInfo, ArgsWithSignature, Attr,
13    DeprecatedInfo, Signature,
14};
15
16pub struct PyFunctionInfo {
17    name: String,
18    args: Vec<ArgInfo>,
19    r#return: Option<TypeOrOverride>,
20    sig: Option<Signature>,
21    doc: String,
22    module: Option<String>,
23    is_async: bool,
24    deprecated: Option<DeprecatedInfo>,
25    type_ignored: Option<IgnoreTarget>,
26}
27
28struct ModuleAttr {
29    _module: syn::Ident,
30    _eq_token: syn::token::Eq,
31    name: syn::LitStr,
32}
33
34impl Parse for ModuleAttr {
35    fn parse(input: ParseStream) -> Result<Self> {
36        Ok(Self {
37            _module: input.parse()?,
38            _eq_token: input.parse()?,
39            name: input.parse()?,
40        })
41    }
42}
43
44impl PyFunctionInfo {
45    pub fn parse_attr(&mut self, attr: TokenStream2) -> Result<()> {
46        if attr.is_empty() {
47            return Ok(());
48        }
49        let attr: ModuleAttr = syn::parse2(attr)?;
50        self.module = Some(attr.name.value());
51        Ok(())
52    }
53}
54
55impl TryFrom<ItemFn> for PyFunctionInfo {
56    type Error = Error;
57    fn try_from(item: ItemFn) -> Result<Self> {
58        let doc = extract_documents(&item.attrs).join("\n");
59        let deprecated = extract_deprecated(&item.attrs);
60        let type_ignored = parse_gen_stub_type_ignore(&item.attrs)?;
61        let args = parse_args(item.sig.inputs)?;
62        let r#return = extract_return_type(&item.sig.output, &item.attrs)?;
63        let mut name = None;
64        let mut sig = None;
65        for attr in parse_pyo3_attrs(&item.attrs)? {
66            match attr {
67                Attr::Name(function_name) => name = Some(function_name),
68                Attr::Signature(signature) => sig = Some(signature),
69                _ => {}
70            }
71        }
72        let name = name.unwrap_or_else(|| item.sig.ident.to_string());
73        Ok(Self {
74            args,
75            sig,
76            r#return,
77            name,
78            doc,
79            module: None,
80            is_async: item.sig.asyncness.is_some(),
81            deprecated,
82            type_ignored,
83        })
84    }
85}
86
87impl ToTokens for PyFunctionInfo {
88    fn to_tokens(&self, tokens: &mut TokenStream2) {
89        let Self {
90            args,
91            r#return: ret,
92            name,
93            doc,
94            sig,
95            module,
96            is_async,
97            deprecated,
98            type_ignored,
99        } = self;
100        let ret_tt = if let Some(ret) = ret {
101            match ret {
102                TypeOrOverride::RustType { r#type } => {
103                    let ty = r#type.clone();
104                    quote! { <#ty as pyo3_stub_gen::PyStubType>::type_output }
105                }
106                TypeOrOverride::OverrideType {
107                    type_repr, imports, ..
108                } => {
109                    let imports = imports.iter().collect::<Vec<&String>>();
110                    quote! {
111                        || ::pyo3_stub_gen::TypeInfo { name: #type_repr.to_string(), import: ::std::collections::HashSet::from([#(#imports.into(),)*]) }
112                    }
113                }
114            }
115        } else {
116            quote! { ::pyo3_stub_gen::type_info::no_return_type_output }
117        };
118        // let sig_tt = quote_option(sig);
119        let module_tt = quote_option(module);
120        let deprecated_tt = deprecated
121            .as_ref()
122            .map(|d| quote! { Some(#d) })
123            .unwrap_or_else(|| quote! { None });
124        let type_ignored_tt = if let Some(target) = type_ignored {
125            match target {
126                IgnoreTarget::All => {
127                    quote! { Some(::pyo3_stub_gen::type_info::IgnoreTarget::All) }
128                }
129                IgnoreTarget::SpecifiedLits(rules) => {
130                    let rule_strs: Vec<String> = rules.iter().map(|lit| lit.value()).collect();
131                    quote! {
132                        Some(::pyo3_stub_gen::type_info::IgnoreTarget::Specified(
133                            &[#(#rule_strs),*] as &[&str]
134                        ))
135                    }
136                }
137            }
138        } else {
139            quote! { None }
140        };
141        let args_with_sig = ArgsWithSignature { args, sig };
142        tokens.append_all(quote! {
143            ::pyo3_stub_gen::type_info::PyFunctionInfo {
144                name: #name,
145                args: #args_with_sig,
146                r#return: #ret_tt,
147                doc: #doc,
148                module: #module_tt,
149                is_async: #is_async,
150                deprecated: #deprecated_tt,
151                type_ignored: #type_ignored_tt,
152            }
153        })
154    }
155}
156
157// `#[gen_stub(xxx)]` is not a valid proc_macro_attribute
158// it's only designed to receive user's setting.
159// We need to remove all `#[gen_stub(xxx)]` before print the item_fn back
160pub fn prune_attrs(item_fn: &mut ItemFn) {
161    super::attr::prune_attrs(&mut item_fn.attrs);
162    for arg in item_fn.sig.inputs.iter_mut() {
163        if let FnArg::Typed(ref mut pat_type) = arg {
164            super::attr::prune_attrs(&mut pat_type.attrs);
165        }
166    }
167}