pyo3_stub_gen_derive/gen_stub/
pyfunction.rs1use proc_macro2::TokenStream as TokenStream2;
2use quote::{quote, ToTokens, TokenStreamExt};
3use syn::{
4 parse::{Parse, ParseStream},
5 Error, FnArg, ItemFn, Result,
6};
7
8use crate::gen_stub::util::TypeOrOverride;
9
10use super::{
11 attr::IgnoreTarget, extract_deprecated, extract_documents, extract_return_type,
12 parameter::Parameters, parse_args, parse_gen_stub_type_ignore, parse_pyo3_attrs, parse_python,
13 quote_option, Attr, DeprecatedInfo,
14};
15
16pub struct PyFunctionInfo {
17 pub(crate) name: String,
18 pub(crate) parameters: Parameters,
19 pub(crate) r#return: Option<TypeOrOverride>,
20 pub(crate) doc: String,
21 pub(crate) module: Option<String>,
22 pub(crate) is_async: bool,
23 pub(crate) deprecated: Option<DeprecatedInfo>,
24 pub(crate) type_ignored: Option<IgnoreTarget>,
25 pub(crate) is_overload: bool,
26 pub(crate) index: usize,
27}
28
29#[derive(Default)]
30pub(crate) struct PyFunctionAttr {
31 pub(crate) module: Option<String>,
32 pub(crate) python: Option<syn::LitStr>,
33 pub(crate) python_overload: Option<syn::LitStr>,
34 pub(crate) no_default_overload: bool,
35}
36
37impl Parse for PyFunctionAttr {
38 fn parse(input: ParseStream) -> Result<Self> {
39 let mut module = None;
40 let mut python = None;
41 let mut python_overload = None;
42 let mut no_default_overload = false;
43
44 while !input.is_empty() {
46 let key: syn::Ident = input.parse()?;
47
48 match key.to_string().as_str() {
49 "module" => {
50 let _: syn::token::Eq = input.parse()?;
51 let value: syn::LitStr = input.parse()?;
52 module = Some(value.value());
53 }
54 "python" => {
55 let _: syn::token::Eq = input.parse()?;
56 let value: syn::LitStr = input.parse()?;
57 python = Some(value);
58 }
59 "python_overload" => {
60 let _: syn::token::Eq = input.parse()?;
61 let value: syn::LitStr = input.parse()?;
62 python_overload = Some(value);
63 }
64 "no_default_overload" => {
65 let _: syn::token::Eq = input.parse()?;
66 let value: syn::LitBool = input.parse()?;
67 no_default_overload = value.value();
68 }
69 _ => {
70 return Err(Error::new(
71 key.span(),
72 format!("Unknown parameter: {}", key),
73 ));
74 }
75 }
76
77 if input.peek(syn::token::Comma) {
79 let _: syn::token::Comma = input.parse()?;
80 } else {
81 break;
82 }
83 }
84
85 if python.is_some() && python_overload.is_some() {
87 return Err(Error::new(
88 input.span(),
89 "Cannot specify both 'python' and 'python_overload' parameters. Use 'python' for single signatures or 'python_overload' for multiple overloads.",
90 ));
91 }
92
93 if no_default_overload && python_overload.is_none() {
95 return Err(Error::new(
96 input.span(),
97 "The 'no_default_overload' parameter can only be used with 'python_overload'. \
98 Use 'python_overload' to define multiple overload signatures.",
99 ));
100 }
101
102 Ok(Self {
103 module,
104 python,
105 python_overload,
106 no_default_overload,
107 })
108 }
109}
110
111impl TryFrom<ItemFn> for PyFunctionInfo {
112 type Error = Error;
113 fn try_from(item: ItemFn) -> Result<Self> {
114 let doc = extract_documents(&item.attrs).join("\n");
115 let deprecated = extract_deprecated(&item.attrs);
116 let type_ignored = parse_gen_stub_type_ignore(&item.attrs)?;
117 let args = parse_args(item.sig.inputs)?;
118 let r#return = extract_return_type(&item.sig.output, &item.attrs)?;
119 let mut name = None;
120 let mut sig = None;
121 let mut pyo3_module = None;
122 for attr in parse_pyo3_attrs(&item.attrs)? {
123 match attr {
124 Attr::Name(function_name) => name = Some(function_name),
125 Attr::Signature(signature) => sig = Some(signature),
126 Attr::Module(module_name) => pyo3_module = Some(module_name),
127 _ => {}
128 }
129 }
130 let name = name.unwrap_or_else(|| item.sig.ident.to_string());
131
132 let parameters = if let Some(sig) = sig {
134 Parameters::new_with_sig(&args, &sig)?
135 } else {
136 Parameters::new(&args)
137 };
138
139 Ok(Self {
140 name,
141 parameters,
142 r#return,
143 doc,
144 module: pyo3_module,
145 is_async: item.sig.asyncness.is_some(),
146 deprecated,
147 type_ignored,
148 is_overload: false, index: 0, })
151 }
152}
153
154impl ToTokens for PyFunctionInfo {
155 fn to_tokens(&self, tokens: &mut TokenStream2) {
156 let Self {
157 r#return: ret,
158 name,
159 doc,
160 parameters,
161 module,
162 is_async,
163 deprecated,
164 type_ignored,
165 is_overload,
166 index,
167 } = self;
168 let ret_tt = if let Some(ret) = ret {
169 match ret {
170 TypeOrOverride::RustType { r#type } => {
171 let ty = r#type.clone();
172 quote! { <#ty as pyo3_stub_gen::PyStubType>::type_output }
173 }
174 TypeOrOverride::OverrideType {
175 type_repr,
176 imports,
177 rust_type_markers,
178 ..
179 } => {
180 let imports = imports.iter().collect::<Vec<&String>>();
181
182 let (type_name_code, type_refs_code) = if rust_type_markers.is_empty() {
184 (
185 quote! { #type_repr.to_string() },
186 quote! { ::std::collections::HashMap::new() },
187 )
188 } else {
189 let marker_types: Vec<syn::Type> = rust_type_markers
190 .iter()
191 .filter_map(|s| syn::parse_str(s).ok())
192 .collect();
193
194 let rust_names = rust_type_markers.iter().collect::<Vec<_>>();
195
196 (
197 quote! {
198 {
199 let mut type_name = #type_repr.to_string();
200 #(
201 let type_info = <#marker_types as ::pyo3_stub_gen::PyStubType>::type_input();
202 type_name = type_name.replace(#rust_names, &type_info.name);
203 )*
204 type_name
205 }
206 },
207 quote! {
208 {
209 let mut type_refs = ::std::collections::HashMap::new();
210 #(
211 let type_info = <#marker_types as ::pyo3_stub_gen::PyStubType>::type_input();
212 if let Some(module) = type_info.source_module {
213 type_refs.insert(
214 type_info.name.split('[').next().unwrap_or(&type_info.name).split('.').last().unwrap_or(&type_info.name).to_string(),
215 ::pyo3_stub_gen::TypeIdentifierRef {
216 module: module.into(),
217 import_kind: ::pyo3_stub_gen::ImportKind::Module,
218 }
219 );
220 }
221 type_refs.extend(type_info.type_refs);
222 )*
223 type_refs
224 }
225 },
226 )
227 };
228
229 quote! {
230 || ::pyo3_stub_gen::TypeInfo { name: #type_name_code, source_module: None, import: ::std::collections::HashSet::from([#(#imports.into(),)*]), type_refs: #type_refs_code }
231 }
232 }
233 }
234 } else {
235 quote! { ::pyo3_stub_gen::type_info::no_return_type_output }
236 };
237 let module_tt = quote_option(module);
239 let deprecated_tt = deprecated
240 .as_ref()
241 .map(|d| quote! { Some(#d) })
242 .unwrap_or_else(|| quote! { None });
243 let type_ignored_tt = if let Some(target) = type_ignored {
244 match target {
245 IgnoreTarget::All => {
246 quote! { Some(::pyo3_stub_gen::type_info::IgnoreTarget::All) }
247 }
248 IgnoreTarget::SpecifiedLits(rules) => {
249 let rule_strs: Vec<String> = rules.iter().map(|lit| lit.value()).collect();
250 quote! {
251 Some(::pyo3_stub_gen::type_info::IgnoreTarget::Specified(
252 &[#(#rule_strs),*] as &[&str]
253 ))
254 }
255 }
256 }
257 } else {
258 quote! { None }
259 };
260
261 tokens.append_all(quote! {
262 ::pyo3_stub_gen::type_info::PyFunctionInfo {
263 name: #name,
264 parameters: #parameters,
265 r#return: #ret_tt,
266 doc: #doc,
267 module: #module_tt,
268 is_async: #is_async,
269 deprecated: #deprecated_tt,
270 type_ignored: #type_ignored_tt,
271 is_overload: #is_overload,
272 file: file!(),
273 line: line!(),
274 column: column!(),
275 index: #index,
276 }
277 })
278 }
279}
280
281pub fn prune_attrs(item_fn: &mut ItemFn) {
285 super::attr::prune_attrs(&mut item_fn.attrs);
286 for arg in item_fn.sig.inputs.iter_mut() {
287 if let FnArg::Typed(ref mut pat_type) = arg {
288 super::attr::prune_attrs(&mut pat_type.attrs);
289 }
290 }
291}
292
293pub struct PyFunctionInfos {
296 pub(crate) item_fn: ItemFn,
297 pub(crate) infos: Vec<PyFunctionInfo>,
298}
299
300impl PyFunctionInfos {
301 pub fn from_parts(mut item_fn: ItemFn, attr: PyFunctionAttr) -> Result<Self> {
303 let mut gen_stub_standalone_module = None;
305 for attr_item in parse_pyo3_attrs(&item_fn.attrs)? {
306 if let Attr::GenStubModule(module_name) = attr_item {
307 gen_stub_standalone_module = Some(module_name);
308 }
309 }
310
311 if let (Some(inline_mod), Some(standalone_mod)) =
313 (&attr.module, &gen_stub_standalone_module)
314 {
315 if inline_mod != standalone_mod {
316 return Err(Error::new(
317 item_fn.sig.ident.span(),
318 format!(
319 "Conflicting module specifications: #[gen_stub_pyfunction(module = \"{}\")] \
320 and #[gen_stub(module = \"{}\")]. Please use only one.",
321 inline_mod, standalone_mod
322 ),
323 ));
324 }
325 }
326
327 if let Some(python) = attr.python {
329 let mut python_info = parse_python::parse_python_function_stub(python)?;
330 python_info.module = if let Some(inline_mod) = attr.module {
332 Some(inline_mod) } else if let Some(standalone_mod) = gen_stub_standalone_module {
334 Some(standalone_mod) } else {
336 python_info.module };
338 prune_attrs(&mut item_fn);
339 return Ok(Self {
340 item_fn,
341 infos: vec![python_info],
342 });
343 }
344
345 let mut base_info = PyFunctionInfo::try_from(item_fn.clone())?;
347
348 let pyo3_module = base_info.module.clone();
351 base_info.module = if let Some(inline_mod) = attr.module {
352 Some(inline_mod) } else if let Some(standalone_mod) = gen_stub_standalone_module {
354 Some(standalone_mod) } else {
356 pyo3_module };
358
359 let infos = if let Some(python_overload) = attr.python_overload {
360 let function_name = base_info.name.clone();
362
363 let mut overload_infos =
365 parse_python::parse_python_overload_stubs(python_overload, &function_name)?;
366
367 for (index, info) in overload_infos.iter_mut().enumerate() {
369 info.module = base_info.module.clone();
370 info.index = index;
371 }
372
373 if !attr.no_default_overload {
375 base_info.is_overload = true;
377 base_info.index = overload_infos.len();
378 overload_infos.push(base_info);
379 }
380
381 overload_infos
382 } else {
383 vec![base_info]
385 };
386
387 prune_attrs(&mut item_fn);
389
390 Ok(Self { item_fn, infos })
391 }
392}
393
394impl ToTokens for PyFunctionInfos {
395 fn to_tokens(&self, tokens: &mut TokenStream2) {
396 let item_fn = &self.item_fn;
397 let infos = &self.infos;
398
399 let submits = infos.iter().map(|info| {
401 quote! {
402 #[automatically_derived]
403 pyo3_stub_gen::inventory::submit! {
404 #info
405 }
406 }
407 });
408
409 tokens.append_all(quote! {
410 #(#submits)*
411 #item_fn
412 })
413 }
414}