pyo3_stub_gen_derive/gen_stub/
pyfunction.rs

1use proc_macro2::TokenStream as TokenStream2;
2use quote::{quote, ToTokens, TokenStreamExt};
3use syn::{
4    parse::{Parse, ParseStream},
5    Error, FnArg, ItemFn, Result,
6};
7
8use crate::gen_stub::util::TypeOrOverride;
9
10use super::{
11    extract_deprecated, extract_documents, extract_return_type, parse_args, parse_pyo3_attrs,
12    quote_option, ArgInfo, ArgsWithSignature, Attr, DeprecatedInfo, Signature,
13};
14
15pub struct PyFunctionInfo {
16    name: String,
17    args: Vec<ArgInfo>,
18    r#return: Option<TypeOrOverride>,
19    sig: Option<Signature>,
20    doc: String,
21    module: Option<String>,
22    is_async: bool,
23    deprecated: Option<DeprecatedInfo>,
24}
25
26struct ModuleAttr {
27    _module: syn::Ident,
28    _eq_token: syn::token::Eq,
29    name: syn::LitStr,
30}
31
32impl Parse for ModuleAttr {
33    fn parse(input: ParseStream) -> Result<Self> {
34        Ok(Self {
35            _module: input.parse()?,
36            _eq_token: input.parse()?,
37            name: input.parse()?,
38        })
39    }
40}
41
42impl PyFunctionInfo {
43    pub fn parse_attr(&mut self, attr: TokenStream2) -> Result<()> {
44        if attr.is_empty() {
45            return Ok(());
46        }
47        let attr: ModuleAttr = syn::parse2(attr)?;
48        self.module = Some(attr.name.value());
49        Ok(())
50    }
51}
52
53impl TryFrom<ItemFn> for PyFunctionInfo {
54    type Error = Error;
55    fn try_from(item: ItemFn) -> Result<Self> {
56        let doc = extract_documents(&item.attrs).join("\n");
57        let deprecated = extract_deprecated(&item.attrs);
58        let args = parse_args(item.sig.inputs)?;
59        let r#return = extract_return_type(&item.sig.output, &item.attrs)?;
60        let mut name = None;
61        let mut sig = None;
62        for attr in parse_pyo3_attrs(&item.attrs)? {
63            match attr {
64                Attr::Name(function_name) => name = Some(function_name),
65                Attr::Signature(signature) => sig = Some(signature),
66                _ => {}
67            }
68        }
69        let name = name.unwrap_or_else(|| item.sig.ident.to_string());
70        Ok(Self {
71            args,
72            sig,
73            r#return,
74            name,
75            doc,
76            module: None,
77            is_async: item.sig.asyncness.is_some(),
78            deprecated,
79        })
80    }
81}
82
83impl ToTokens for PyFunctionInfo {
84    fn to_tokens(&self, tokens: &mut TokenStream2) {
85        let Self {
86            args,
87            r#return: ret,
88            name,
89            doc,
90            sig,
91            module,
92            is_async,
93            deprecated,
94        } = self;
95        let ret_tt = if let Some(ret) = ret {
96            match ret {
97                TypeOrOverride::RustType { r#type } => {
98                    let ty = r#type.clone();
99                    quote! { <#ty as pyo3_stub_gen::PyStubType>::type_output }
100                }
101                TypeOrOverride::OverrideType {
102                    type_repr, imports, ..
103                } => {
104                    let imports = imports.iter().collect::<Vec<&String>>();
105                    quote! {
106                        || ::pyo3_stub_gen::TypeInfo { name: #type_repr.to_string(), import: ::std::collections::HashSet::from([#(#imports.into(),)*]) }
107                    }
108                }
109            }
110        } else {
111            quote! { ::pyo3_stub_gen::type_info::no_return_type_output }
112        };
113        // let sig_tt = quote_option(sig);
114        let module_tt = quote_option(module);
115        let deprecated_tt = deprecated
116            .as_ref()
117            .map(|d| quote! { Some(#d) })
118            .unwrap_or_else(|| quote! { None });
119        let args_with_sig = ArgsWithSignature { args, sig };
120        tokens.append_all(quote! {
121            ::pyo3_stub_gen::type_info::PyFunctionInfo {
122                name: #name,
123                args: #args_with_sig,
124                r#return: #ret_tt,
125                doc: #doc,
126                module: #module_tt,
127                is_async: #is_async,
128                deprecated: #deprecated_tt,
129            }
130        })
131    }
132}
133
134// `#[gen_stub(xxx)]` is not a valid proc_macro_attribute
135// it's only designed to receive user's setting.
136// We need to remove all `#[gen_stub(xxx)]` before print the item_fn back
137pub fn prune_attrs(item_fn: &mut ItemFn) {
138    super::attr::prune_attrs(&mut item_fn.attrs);
139    for arg in item_fn.sig.inputs.iter_mut() {
140        if let FnArg::Typed(ref mut pat_type) = arg {
141            super::attr::prune_attrs(&mut pat_type.attrs);
142        }
143    }
144}