pyo3_stub_gen_derive/gen_stub/
pyfunction.rs1use proc_macro2::TokenStream as TokenStream2;
2use quote::{quote, ToTokens, TokenStreamExt};
3use syn::{
4 parse::{Parse, ParseStream},
5 Error, FnArg, ItemFn, Result,
6};
7
8use crate::gen_stub::util::TypeOrOverride;
9
10use super::{
11 attr::IgnoreTarget, extract_deprecated, extract_documents, extract_return_type,
12 parameter::Parameters, parse_args, parse_gen_stub_type_ignore, parse_pyo3_attrs, quote_option,
13 Attr, DeprecatedInfo,
14};
15
16pub struct PyFunctionInfo {
17 pub(crate) name: String,
18 pub(crate) parameters: Parameters,
19 pub(crate) r#return: Option<TypeOrOverride>,
20 pub(crate) doc: String,
21 pub(crate) module: Option<String>,
22 pub(crate) is_async: bool,
23 pub(crate) deprecated: Option<DeprecatedInfo>,
24 pub(crate) type_ignored: Option<IgnoreTarget>,
25}
26
27struct PyFunctionAttr {
28 module: Option<String>,
29 python: Option<syn::LitStr>,
30}
31
32impl Parse for PyFunctionAttr {
33 fn parse(input: ParseStream) -> Result<Self> {
34 let mut module = None;
35 let mut python = None;
36
37 while !input.is_empty() {
39 let key: syn::Ident = input.parse()?;
40 let _: syn::token::Eq = input.parse()?;
41
42 match key.to_string().as_str() {
43 "module" => {
44 let value: syn::LitStr = input.parse()?;
45 module = Some(value.value());
46 }
47 "python" => {
48 let value: syn::LitStr = input.parse()?;
49 python = Some(value);
50 }
51 _ => {
52 return Err(Error::new(
53 key.span(),
54 format!("Unknown parameter: {}", key),
55 ));
56 }
57 }
58
59 if input.peek(syn::token::Comma) {
61 let _: syn::token::Comma = input.parse()?;
62 } else {
63 break;
64 }
65 }
66
67 Ok(Self { module, python })
68 }
69}
70
71impl PyFunctionInfo {
72 pub fn parse_attr(&mut self, attr: TokenStream2) -> Result<Option<syn::LitStr>> {
74 if attr.is_empty() {
75 return Ok(None);
76 }
77 let parsed_attr: PyFunctionAttr = syn::parse2(attr)?;
78
79 if let Some(module) = parsed_attr.module {
81 self.module = Some(module);
82 }
83
84 Ok(parsed_attr.python)
86 }
87}
88
89impl TryFrom<ItemFn> for PyFunctionInfo {
90 type Error = Error;
91 fn try_from(item: ItemFn) -> Result<Self> {
92 let doc = extract_documents(&item.attrs).join("\n");
93 let deprecated = extract_deprecated(&item.attrs);
94 let type_ignored = parse_gen_stub_type_ignore(&item.attrs)?;
95 let args = parse_args(item.sig.inputs)?;
96 let r#return = extract_return_type(&item.sig.output, &item.attrs)?;
97 let mut name = None;
98 let mut sig = None;
99 for attr in parse_pyo3_attrs(&item.attrs)? {
100 match attr {
101 Attr::Name(function_name) => name = Some(function_name),
102 Attr::Signature(signature) => sig = Some(signature),
103 _ => {}
104 }
105 }
106 let name = name.unwrap_or_else(|| item.sig.ident.to_string());
107
108 let parameters = if let Some(sig) = sig {
110 Parameters::new_with_sig(&args, &sig)?
111 } else {
112 Parameters::new(&args)
113 };
114
115 Ok(Self {
116 name,
117 parameters,
118 r#return,
119 doc,
120 module: None,
121 is_async: item.sig.asyncness.is_some(),
122 deprecated,
123 type_ignored,
124 })
125 }
126}
127
128impl ToTokens for PyFunctionInfo {
129 fn to_tokens(&self, tokens: &mut TokenStream2) {
130 let Self {
131 r#return: ret,
132 name,
133 doc,
134 parameters,
135 module,
136 is_async,
137 deprecated,
138 type_ignored,
139 } = self;
140 let ret_tt = if let Some(ret) = ret {
141 match ret {
142 TypeOrOverride::RustType { r#type } => {
143 let ty = r#type.clone();
144 quote! { <#ty as pyo3_stub_gen::PyStubType>::type_output }
145 }
146 TypeOrOverride::OverrideType {
147 type_repr, imports, ..
148 } => {
149 let imports = imports.iter().collect::<Vec<&String>>();
150 quote! {
151 || ::pyo3_stub_gen::TypeInfo { name: #type_repr.to_string(), import: ::std::collections::HashSet::from([#(#imports.into(),)*]) }
152 }
153 }
154 }
155 } else {
156 quote! { ::pyo3_stub_gen::type_info::no_return_type_output }
157 };
158 let module_tt = quote_option(module);
160 let deprecated_tt = deprecated
161 .as_ref()
162 .map(|d| quote! { Some(#d) })
163 .unwrap_or_else(|| quote! { None });
164 let type_ignored_tt = if let Some(target) = type_ignored {
165 match target {
166 IgnoreTarget::All => {
167 quote! { Some(::pyo3_stub_gen::type_info::IgnoreTarget::All) }
168 }
169 IgnoreTarget::SpecifiedLits(rules) => {
170 let rule_strs: Vec<String> = rules.iter().map(|lit| lit.value()).collect();
171 quote! {
172 Some(::pyo3_stub_gen::type_info::IgnoreTarget::Specified(
173 &[#(#rule_strs),*] as &[&str]
174 ))
175 }
176 }
177 }
178 } else {
179 quote! { None }
180 };
181
182 tokens.append_all(quote! {
183 ::pyo3_stub_gen::type_info::PyFunctionInfo {
184 name: #name,
185 parameters: #parameters,
186 r#return: #ret_tt,
187 doc: #doc,
188 module: #module_tt,
189 is_async: #is_async,
190 deprecated: #deprecated_tt,
191 type_ignored: #type_ignored_tt,
192 }
193 })
194 }
195}
196
197pub fn prune_attrs(item_fn: &mut ItemFn) {
201 super::attr::prune_attrs(&mut item_fn.attrs);
202 for arg in item_fn.sig.inputs.iter_mut() {
203 if let FnArg::Typed(ref mut pat_type) = arg {
204 super::attr::prune_attrs(&mut pat_type.attrs);
205 }
206 }
207}