pyo3_stub_gen_derive/gen_stub/
pyfunction.rs1use proc_macro2::TokenStream as TokenStream2;
2use quote::{quote, ToTokens, TokenStreamExt};
3use syn::{
4 parse::{Parse, ParseStream},
5 Error, FnArg, ItemFn, Result,
6};
7
8use crate::gen_stub::util::TypeOrOverride;
9
10use super::{
11 extract_deprecated, extract_documents, extract_return_type, parse_args, parse_pyo3_attrs,
12 quote_option, ArgInfo, ArgsWithSignature, Attr, DeprecatedInfo, Signature,
13};
14
15pub struct PyFunctionInfo {
16 name: String,
17 args: Vec<ArgInfo>,
18 r#return: Option<TypeOrOverride>,
19 sig: Option<Signature>,
20 doc: String,
21 module: Option<String>,
22 is_async: bool,
23 deprecated: Option<DeprecatedInfo>,
24}
25
26struct ModuleAttr {
27 _module: syn::Ident,
28 _eq_token: syn::token::Eq,
29 name: syn::LitStr,
30}
31
32impl Parse for ModuleAttr {
33 fn parse(input: ParseStream) -> Result<Self> {
34 Ok(Self {
35 _module: input.parse()?,
36 _eq_token: input.parse()?,
37 name: input.parse()?,
38 })
39 }
40}
41
42impl PyFunctionInfo {
43 pub fn parse_attr(&mut self, attr: TokenStream2) -> Result<()> {
44 if attr.is_empty() {
45 return Ok(());
46 }
47 let attr: ModuleAttr = syn::parse2(attr)?;
48 self.module = Some(attr.name.value());
49 Ok(())
50 }
51}
52
53impl TryFrom<ItemFn> for PyFunctionInfo {
54 type Error = Error;
55 fn try_from(item: ItemFn) -> Result<Self> {
56 let doc = extract_documents(&item.attrs).join("\n");
57 let deprecated = extract_deprecated(&item.attrs);
58 let args = parse_args(item.sig.inputs)?;
59 let r#return = extract_return_type(&item.sig.output, &item.attrs)?;
60 let mut name = None;
61 let mut sig = None;
62 for attr in parse_pyo3_attrs(&item.attrs)? {
63 match attr {
64 Attr::Name(function_name) => name = Some(function_name),
65 Attr::Signature(signature) => sig = Some(signature),
66 _ => {}
67 }
68 }
69 let name = name.unwrap_or_else(|| item.sig.ident.to_string());
70 Ok(Self {
71 args,
72 sig,
73 r#return,
74 name,
75 doc,
76 module: None,
77 is_async: item.sig.asyncness.is_some(),
78 deprecated,
79 })
80 }
81}
82
83impl ToTokens for PyFunctionInfo {
84 fn to_tokens(&self, tokens: &mut TokenStream2) {
85 let Self {
86 args,
87 r#return: ret,
88 name,
89 doc,
90 sig,
91 module,
92 is_async,
93 deprecated,
94 } = self;
95 let ret_tt = if let Some(ret) = ret {
96 match ret {
97 TypeOrOverride::RustType { r#type } => {
98 let ty = r#type.clone();
99 quote! { <#ty as pyo3_stub_gen::PyStubType>::type_output }
100 }
101 TypeOrOverride::OverrideType {
102 type_repr, imports, ..
103 } => {
104 let imports = imports.iter().collect::<Vec<&String>>();
105 quote! {
106 || ::pyo3_stub_gen::TypeInfo { name: #type_repr.to_string(), import: ::std::collections::HashSet::from([#(#imports.into(),)*]) }
107 }
108 }
109 }
110 } else {
111 quote! { ::pyo3_stub_gen::type_info::no_return_type_output }
112 };
113 let module_tt = quote_option(module);
115 let deprecated_tt = deprecated
116 .as_ref()
117 .map(|d| quote! { Some(#d) })
118 .unwrap_or_else(|| quote! { None });
119 let args_with_sig = ArgsWithSignature { args, sig };
120 tokens.append_all(quote! {
121 ::pyo3_stub_gen::type_info::PyFunctionInfo {
122 name: #name,
123 args: #args_with_sig,
124 r#return: #ret_tt,
125 doc: #doc,
126 module: #module_tt,
127 is_async: #is_async,
128 deprecated: #deprecated_tt,
129 }
130 })
131 }
132}
133
134pub fn prune_attrs(item_fn: &mut ItemFn) {
138 super::attr::prune_attrs(&mut item_fn.attrs);
139 for arg in item_fn.sig.inputs.iter_mut() {
140 if let FnArg::Typed(ref mut pat_type) = arg {
141 super::attr::prune_attrs(&mut pat_type.attrs);
142 }
143 }
144}