pyo3_stub_gen_derive/gen_stub/
pyfunction.rs1use proc_macro2::TokenStream as TokenStream2;
2use quote::{quote, ToTokens, TokenStreamExt};
3use syn::{
4 parse::{Parse, ParseStream},
5 Error, FnArg, ItemFn, Result,
6};
7
8use crate::gen_stub::util::TypeOrOverride;
9
10use super::{
11 attr::IgnoreTarget, extract_deprecated, extract_documents, extract_return_type, parse_args,
12 parse_gen_stub_type_ignore, parse_pyo3_attrs, quote_option, ArgInfo, ArgsWithSignature, Attr,
13 DeprecatedInfo, Signature,
14};
15
16pub struct PyFunctionInfo {
17 name: String,
18 args: Vec<ArgInfo>,
19 r#return: Option<TypeOrOverride>,
20 sig: Option<Signature>,
21 doc: String,
22 module: Option<String>,
23 is_async: bool,
24 deprecated: Option<DeprecatedInfo>,
25 type_ignored: Option<IgnoreTarget>,
26}
27
28struct ModuleAttr {
29 _module: syn::Ident,
30 _eq_token: syn::token::Eq,
31 name: syn::LitStr,
32}
33
34impl Parse for ModuleAttr {
35 fn parse(input: ParseStream) -> Result<Self> {
36 Ok(Self {
37 _module: input.parse()?,
38 _eq_token: input.parse()?,
39 name: input.parse()?,
40 })
41 }
42}
43
44impl PyFunctionInfo {
45 pub fn parse_attr(&mut self, attr: TokenStream2) -> Result<()> {
46 if attr.is_empty() {
47 return Ok(());
48 }
49 let attr: ModuleAttr = syn::parse2(attr)?;
50 self.module = Some(attr.name.value());
51 Ok(())
52 }
53}
54
55impl TryFrom<ItemFn> for PyFunctionInfo {
56 type Error = Error;
57 fn try_from(item: ItemFn) -> Result<Self> {
58 let doc = extract_documents(&item.attrs).join("\n");
59 let deprecated = extract_deprecated(&item.attrs);
60 let type_ignored = parse_gen_stub_type_ignore(&item.attrs)?;
61 let args = parse_args(item.sig.inputs)?;
62 let r#return = extract_return_type(&item.sig.output, &item.attrs)?;
63 let mut name = None;
64 let mut sig = None;
65 for attr in parse_pyo3_attrs(&item.attrs)? {
66 match attr {
67 Attr::Name(function_name) => name = Some(function_name),
68 Attr::Signature(signature) => sig = Some(signature),
69 _ => {}
70 }
71 }
72 let name = name.unwrap_or_else(|| item.sig.ident.to_string());
73 Ok(Self {
74 args,
75 sig,
76 r#return,
77 name,
78 doc,
79 module: None,
80 is_async: item.sig.asyncness.is_some(),
81 deprecated,
82 type_ignored,
83 })
84 }
85}
86
87impl ToTokens for PyFunctionInfo {
88 fn to_tokens(&self, tokens: &mut TokenStream2) {
89 let Self {
90 args,
91 r#return: ret,
92 name,
93 doc,
94 sig,
95 module,
96 is_async,
97 deprecated,
98 type_ignored,
99 } = self;
100 let ret_tt = if let Some(ret) = ret {
101 match ret {
102 TypeOrOverride::RustType { r#type } => {
103 let ty = r#type.clone();
104 quote! { <#ty as pyo3_stub_gen::PyStubType>::type_output }
105 }
106 TypeOrOverride::OverrideType {
107 type_repr, imports, ..
108 } => {
109 let imports = imports.iter().collect::<Vec<&String>>();
110 quote! {
111 || ::pyo3_stub_gen::TypeInfo { name: #type_repr.to_string(), import: ::std::collections::HashSet::from([#(#imports.into(),)*]) }
112 }
113 }
114 }
115 } else {
116 quote! { ::pyo3_stub_gen::type_info::no_return_type_output }
117 };
118 let module_tt = quote_option(module);
120 let deprecated_tt = deprecated
121 .as_ref()
122 .map(|d| quote! { Some(#d) })
123 .unwrap_or_else(|| quote! { None });
124 let type_ignored_tt = if let Some(target) = type_ignored {
125 match target {
126 IgnoreTarget::All => {
127 quote! { Some(::pyo3_stub_gen::type_info::IgnoreTarget::All) }
128 }
129 IgnoreTarget::SpecifiedLits(rules) => {
130 let rule_strs: Vec<String> = rules.iter().map(|lit| lit.value()).collect();
131 quote! {
132 Some(::pyo3_stub_gen::type_info::IgnoreTarget::Specified(
133 &[#(#rule_strs),*] as &[&str]
134 ))
135 }
136 }
137 }
138 } else {
139 quote! { None }
140 };
141 let args_with_sig = ArgsWithSignature { args, sig };
142 tokens.append_all(quote! {
143 ::pyo3_stub_gen::type_info::PyFunctionInfo {
144 name: #name,
145 args: #args_with_sig,
146 r#return: #ret_tt,
147 doc: #doc,
148 module: #module_tt,
149 is_async: #is_async,
150 deprecated: #deprecated_tt,
151 type_ignored: #type_ignored_tt,
152 }
153 })
154 }
155}
156
157pub fn prune_attrs(item_fn: &mut ItemFn) {
161 super::attr::prune_attrs(&mut item_fn.attrs);
162 for arg in item_fn.sig.inputs.iter_mut() {
163 if let FnArg::Typed(ref mut pat_type) = arg {
164 super::attr::prune_attrs(&mut pat_type.attrs);
165 }
166 }
167}