pyo3_stub_gen_derive/gen_stub/
pyfunction.rs1use proc_macro2::TokenStream as TokenStream2;
2use quote::{quote, ToTokens, TokenStreamExt};
3use syn::{
4 parse::{Parse, ParseStream},
5 Error, FnArg, ItemFn, Result,
6};
7
8use crate::gen_stub::util::TypeOrOverride;
9
10use super::{
11 attr::IgnoreTarget, extract_deprecated, extract_documents, extract_return_type,
12 parameter::Parameters, parse_args, parse_gen_stub_type_ignore, parse_pyo3_attrs, parse_python,
13 quote_option, Attr, DeprecatedInfo,
14};
15
16pub struct PyFunctionInfo {
17 pub(crate) name: String,
18 pub(crate) parameters: Parameters,
19 pub(crate) r#return: Option<TypeOrOverride>,
20 pub(crate) doc: String,
21 pub(crate) module: Option<String>,
22 pub(crate) is_async: bool,
23 pub(crate) deprecated: Option<DeprecatedInfo>,
24 pub(crate) type_ignored: Option<IgnoreTarget>,
25 pub(crate) is_overload: bool,
26 pub(crate) index: usize,
27}
28
29#[derive(Default)]
30pub(crate) struct PyFunctionAttr {
31 pub(crate) module: Option<String>,
32 pub(crate) python: Option<syn::LitStr>,
33 pub(crate) python_overload: Option<syn::LitStr>,
34 pub(crate) no_default_overload: bool,
35}
36
37impl Parse for PyFunctionAttr {
38 fn parse(input: ParseStream) -> Result<Self> {
39 let mut module = None;
40 let mut python = None;
41 let mut python_overload = None;
42 let mut no_default_overload = false;
43
44 while !input.is_empty() {
46 let key: syn::Ident = input.parse()?;
47
48 match key.to_string().as_str() {
49 "module" => {
50 let _: syn::token::Eq = input.parse()?;
51 let value: syn::LitStr = input.parse()?;
52 module = Some(value.value());
53 }
54 "python" => {
55 let _: syn::token::Eq = input.parse()?;
56 let value: syn::LitStr = input.parse()?;
57 python = Some(value);
58 }
59 "python_overload" => {
60 let _: syn::token::Eq = input.parse()?;
61 let value: syn::LitStr = input.parse()?;
62 python_overload = Some(value);
63 }
64 "no_default_overload" => {
65 let _: syn::token::Eq = input.parse()?;
66 let value: syn::LitBool = input.parse()?;
67 no_default_overload = value.value();
68 }
69 _ => {
70 return Err(Error::new(
71 key.span(),
72 format!("Unknown parameter: {}", key),
73 ));
74 }
75 }
76
77 if input.peek(syn::token::Comma) {
79 let _: syn::token::Comma = input.parse()?;
80 } else {
81 break;
82 }
83 }
84
85 if python.is_some() && python_overload.is_some() {
87 return Err(Error::new(
88 input.span(),
89 "Cannot specify both 'python' and 'python_overload' parameters. Use 'python' for single signatures or 'python_overload' for multiple overloads.",
90 ));
91 }
92
93 if no_default_overload && python_overload.is_none() {
95 return Err(Error::new(
96 input.span(),
97 "The 'no_default_overload' parameter can only be used with 'python_overload'. \
98 Use 'python_overload' to define multiple overload signatures.",
99 ));
100 }
101
102 Ok(Self {
103 module,
104 python,
105 python_overload,
106 no_default_overload,
107 })
108 }
109}
110
111impl TryFrom<ItemFn> for PyFunctionInfo {
112 type Error = Error;
113 fn try_from(item: ItemFn) -> Result<Self> {
114 let doc = extract_documents(&item.attrs).join("\n");
115 let deprecated = extract_deprecated(&item.attrs);
116 let type_ignored = parse_gen_stub_type_ignore(&item.attrs)?;
117 let args = parse_args(item.sig.inputs)?;
118 let r#return = extract_return_type(&item.sig.output, &item.attrs)?;
119 let mut name = None;
120 let mut sig = None;
121 for attr in parse_pyo3_attrs(&item.attrs)? {
122 match attr {
123 Attr::Name(function_name) => name = Some(function_name),
124 Attr::Signature(signature) => sig = Some(signature),
125 _ => {}
126 }
127 }
128 let name = name.unwrap_or_else(|| item.sig.ident.to_string());
129
130 let parameters = if let Some(sig) = sig {
132 Parameters::new_with_sig(&args, &sig)?
133 } else {
134 Parameters::new(&args)
135 };
136
137 Ok(Self {
138 name,
139 parameters,
140 r#return,
141 doc,
142 module: None,
143 is_async: item.sig.asyncness.is_some(),
144 deprecated,
145 type_ignored,
146 is_overload: false, index: 0, })
149 }
150}
151
152impl ToTokens for PyFunctionInfo {
153 fn to_tokens(&self, tokens: &mut TokenStream2) {
154 let Self {
155 r#return: ret,
156 name,
157 doc,
158 parameters,
159 module,
160 is_async,
161 deprecated,
162 type_ignored,
163 is_overload,
164 index,
165 } = self;
166 let ret_tt = if let Some(ret) = ret {
167 match ret {
168 TypeOrOverride::RustType { r#type } => {
169 let ty = r#type.clone();
170 quote! { <#ty as pyo3_stub_gen::PyStubType>::type_output }
171 }
172 TypeOrOverride::OverrideType {
173 type_repr, imports, ..
174 } => {
175 let imports = imports.iter().collect::<Vec<&String>>();
176 quote! {
177 || ::pyo3_stub_gen::TypeInfo { name: #type_repr.to_string(), import: ::std::collections::HashSet::from([#(#imports.into(),)*]) }
178 }
179 }
180 }
181 } else {
182 quote! { ::pyo3_stub_gen::type_info::no_return_type_output }
183 };
184 let module_tt = quote_option(module);
186 let deprecated_tt = deprecated
187 .as_ref()
188 .map(|d| quote! { Some(#d) })
189 .unwrap_or_else(|| quote! { None });
190 let type_ignored_tt = if let Some(target) = type_ignored {
191 match target {
192 IgnoreTarget::All => {
193 quote! { Some(::pyo3_stub_gen::type_info::IgnoreTarget::All) }
194 }
195 IgnoreTarget::SpecifiedLits(rules) => {
196 let rule_strs: Vec<String> = rules.iter().map(|lit| lit.value()).collect();
197 quote! {
198 Some(::pyo3_stub_gen::type_info::IgnoreTarget::Specified(
199 &[#(#rule_strs),*] as &[&str]
200 ))
201 }
202 }
203 }
204 } else {
205 quote! { None }
206 };
207
208 tokens.append_all(quote! {
209 ::pyo3_stub_gen::type_info::PyFunctionInfo {
210 name: #name,
211 parameters: #parameters,
212 r#return: #ret_tt,
213 doc: #doc,
214 module: #module_tt,
215 is_async: #is_async,
216 deprecated: #deprecated_tt,
217 type_ignored: #type_ignored_tt,
218 is_overload: #is_overload,
219 file: file!(),
220 line: line!(),
221 column: column!(),
222 index: #index,
223 }
224 })
225 }
226}
227
228pub fn prune_attrs(item_fn: &mut ItemFn) {
232 super::attr::prune_attrs(&mut item_fn.attrs);
233 for arg in item_fn.sig.inputs.iter_mut() {
234 if let FnArg::Typed(ref mut pat_type) = arg {
235 super::attr::prune_attrs(&mut pat_type.attrs);
236 }
237 }
238}
239
240pub struct PyFunctionInfos {
243 pub(crate) item_fn: ItemFn,
244 pub(crate) infos: Vec<PyFunctionInfo>,
245}
246
247impl PyFunctionInfos {
248 pub fn from_parts(mut item_fn: ItemFn, attr: PyFunctionAttr) -> Result<Self> {
250 if let Some(python) = attr.python {
252 let mut python_info = parse_python::parse_python_function_stub(python)?;
253 python_info.module = attr.module;
254 prune_attrs(&mut item_fn);
255 return Ok(Self {
256 item_fn,
257 infos: vec![python_info],
258 });
259 }
260
261 let mut base_info = PyFunctionInfo::try_from(item_fn.clone())?;
263 base_info.module = attr.module;
264
265 let infos = if let Some(python_overload) = attr.python_overload {
266 let function_name = base_info.name.clone();
268
269 let mut overload_infos =
271 parse_python::parse_python_overload_stubs(python_overload, &function_name)?;
272
273 for (index, info) in overload_infos.iter_mut().enumerate() {
275 info.module = base_info.module.clone();
276 info.index = index;
277 }
278
279 if !attr.no_default_overload {
281 base_info.is_overload = true;
283 base_info.index = overload_infos.len();
284 overload_infos.push(base_info);
285 }
286
287 overload_infos
288 } else {
289 vec![base_info]
291 };
292
293 prune_attrs(&mut item_fn);
295
296 Ok(Self { item_fn, infos })
297 }
298}
299
300impl ToTokens for PyFunctionInfos {
301 fn to_tokens(&self, tokens: &mut TokenStream2) {
302 let item_fn = &self.item_fn;
303 let infos = &self.infos;
304
305 let submits = infos.iter().map(|info| {
307 quote! {
308 #[automatically_derived]
309 pyo3_stub_gen::inventory::submit! {
310 #info
311 }
312 }
313 });
314
315 tokens.append_all(quote! {
316 #(#submits)*
317 #item_fn
318 })
319 }
320}