pyo3_stub_gen_derive/gen_stub/
pyfunction.rs1use proc_macro2::TokenStream as TokenStream2;
2use quote::{quote, ToTokens, TokenStreamExt};
3use syn::{
4 parse::{Parse, ParseStream},
5 Error, FnArg, ItemFn, Result,
6};
7
8use crate::gen_stub::util::TypeOrOverride;
9
10use super::{
11 attr::IgnoreTarget, extract_deprecated, extract_documents, extract_return_type, parse_args,
12 parse_gen_stub_type_ignore, parse_pyo3_attrs, quote_option, ArgInfo, ArgsWithSignature, Attr,
13 DeprecatedInfo, Signature,
14};
15
16pub struct PyFunctionInfo {
17 pub(crate) name: String,
18 pub(crate) args: Vec<ArgInfo>,
19 pub(crate) r#return: Option<TypeOrOverride>,
20 pub(crate) sig: Option<Signature>,
21 pub(crate) doc: String,
22 pub(crate) module: Option<String>,
23 pub(crate) is_async: bool,
24 pub(crate) deprecated: Option<DeprecatedInfo>,
25 pub(crate) type_ignored: Option<IgnoreTarget>,
26}
27
28struct PyFunctionAttr {
29 module: Option<String>,
30 python: Option<syn::LitStr>,
31}
32
33impl Parse for PyFunctionAttr {
34 fn parse(input: ParseStream) -> Result<Self> {
35 let mut module = None;
36 let mut python = None;
37
38 while !input.is_empty() {
40 let key: syn::Ident = input.parse()?;
41 let _: syn::token::Eq = input.parse()?;
42
43 match key.to_string().as_str() {
44 "module" => {
45 let value: syn::LitStr = input.parse()?;
46 module = Some(value.value());
47 }
48 "python" => {
49 let value: syn::LitStr = input.parse()?;
50 python = Some(value);
51 }
52 _ => {
53 return Err(Error::new(
54 key.span(),
55 format!("Unknown parameter: {}", key),
56 ));
57 }
58 }
59
60 if input.peek(syn::token::Comma) {
62 let _: syn::token::Comma = input.parse()?;
63 } else {
64 break;
65 }
66 }
67
68 Ok(Self { module, python })
69 }
70}
71
72impl PyFunctionInfo {
73 pub fn parse_attr(&mut self, attr: TokenStream2) -> Result<Option<syn::LitStr>> {
75 if attr.is_empty() {
76 return Ok(None);
77 }
78 let parsed_attr: PyFunctionAttr = syn::parse2(attr)?;
79
80 if let Some(module) = parsed_attr.module {
82 self.module = Some(module);
83 }
84
85 Ok(parsed_attr.python)
87 }
88}
89
90impl TryFrom<ItemFn> for PyFunctionInfo {
91 type Error = Error;
92 fn try_from(item: ItemFn) -> Result<Self> {
93 let doc = extract_documents(&item.attrs).join("\n");
94 let deprecated = extract_deprecated(&item.attrs);
95 let type_ignored = parse_gen_stub_type_ignore(&item.attrs)?;
96 let args = parse_args(item.sig.inputs)?;
97 let r#return = extract_return_type(&item.sig.output, &item.attrs)?;
98 let mut name = None;
99 let mut sig = None;
100 for attr in parse_pyo3_attrs(&item.attrs)? {
101 match attr {
102 Attr::Name(function_name) => name = Some(function_name),
103 Attr::Signature(signature) => sig = Some(signature),
104 _ => {}
105 }
106 }
107 let name = name.unwrap_or_else(|| item.sig.ident.to_string());
108 Ok(Self {
109 args,
110 sig,
111 r#return,
112 name,
113 doc,
114 module: None,
115 is_async: item.sig.asyncness.is_some(),
116 deprecated,
117 type_ignored,
118 })
119 }
120}
121
122impl ToTokens for PyFunctionInfo {
123 fn to_tokens(&self, tokens: &mut TokenStream2) {
124 let Self {
125 args,
126 r#return: ret,
127 name,
128 doc,
129 sig,
130 module,
131 is_async,
132 deprecated,
133 type_ignored,
134 } = self;
135 let ret_tt = if let Some(ret) = ret {
136 match ret {
137 TypeOrOverride::RustType { r#type } => {
138 let ty = r#type.clone();
139 quote! { <#ty as pyo3_stub_gen::PyStubType>::type_output }
140 }
141 TypeOrOverride::OverrideType {
142 type_repr, imports, ..
143 } => {
144 let imports = imports.iter().collect::<Vec<&String>>();
145 quote! {
146 || ::pyo3_stub_gen::TypeInfo { name: #type_repr.to_string(), import: ::std::collections::HashSet::from([#(#imports.into(),)*]) }
147 }
148 }
149 }
150 } else {
151 quote! { ::pyo3_stub_gen::type_info::no_return_type_output }
152 };
153 let module_tt = quote_option(module);
155 let deprecated_tt = deprecated
156 .as_ref()
157 .map(|d| quote! { Some(#d) })
158 .unwrap_or_else(|| quote! { None });
159 let type_ignored_tt = if let Some(target) = type_ignored {
160 match target {
161 IgnoreTarget::All => {
162 quote! { Some(::pyo3_stub_gen::type_info::IgnoreTarget::All) }
163 }
164 IgnoreTarget::SpecifiedLits(rules) => {
165 let rule_strs: Vec<String> = rules.iter().map(|lit| lit.value()).collect();
166 quote! {
167 Some(::pyo3_stub_gen::type_info::IgnoreTarget::Specified(
168 &[#(#rule_strs),*] as &[&str]
169 ))
170 }
171 }
172 }
173 } else {
174 quote! { None }
175 };
176 let args_with_sig = ArgsWithSignature { args, sig };
177 tokens.append_all(quote! {
178 ::pyo3_stub_gen::type_info::PyFunctionInfo {
179 name: #name,
180 args: #args_with_sig,
181 r#return: #ret_tt,
182 doc: #doc,
183 module: #module_tt,
184 is_async: #is_async,
185 deprecated: #deprecated_tt,
186 type_ignored: #type_ignored_tt,
187 }
188 })
189 }
190}
191
192pub fn prune_attrs(item_fn: &mut ItemFn) {
196 super::attr::prune_attrs(&mut item_fn.attrs);
197 for arg in item_fn.sig.inputs.iter_mut() {
198 if let FnArg::Typed(ref mut pat_type) = arg {
199 super::attr::prune_attrs(&mut pat_type.attrs);
200 }
201 }
202}