pyo3_stub_gen_derive/gen_stub/
arg.rs

1use quote::ToTokens;
2use syn::{
3    spanned::Spanned, FnArg, GenericArgument, PatType, PathArguments, Result, Type, TypePath,
4    TypeReference,
5};
6
7use crate::gen_stub::{attr::parse_gen_stub_override_type, util::TypeOrOverride};
8
9pub fn parse_args(iter: impl IntoIterator<Item = FnArg>) -> Result<Vec<ArgInfo>> {
10    let mut args = Vec::new();
11    for (n, arg) in iter.into_iter().enumerate() {
12        if let FnArg::Receiver(_) = arg {
13            continue;
14        }
15        let arg = ArgInfo::try_from(arg)?;
16        let (ArgInfo {
17            r#type: TypeOrOverride::RustType { r#type },
18            ..
19        }
20        | ArgInfo {
21            r#type: TypeOrOverride::OverrideType { r#type, .. },
22            ..
23        }) = &arg;
24        // Regard the first argument with `&Bound<'_, PyType>`
25        if let Type::Reference(TypeReference { elem, .. }) = &r#type {
26            if let Type::Path(TypePath { path, .. }) = elem.as_ref() {
27                let last = path.segments.last().unwrap();
28                if n == 0 && last.ident == "Bound" {
29                    if let PathArguments::AngleBracketed(args, ..) = &last.arguments {
30                        if let Some(last_type) = args.args.last() {
31                            if last_type.to_token_stream().to_string() == "PyType" {
32                                continue;
33                            }
34                        }
35                    }
36                }
37            }
38        }
39        if let Type::Path(TypePath { path, .. }) = &r#type {
40            let last = path.segments.last().unwrap();
41            if last.ident == "Python" {
42                continue;
43            }
44            // Regard the first argument with `PyRef<'_, Self>` and `PyMutRef<'_, Self>` types as a receiver.
45            if n == 0 && (last.ident == "PyRef" || last.ident == "PyRefMut") {
46                if let PathArguments::AngleBracketed(inner) = &last.arguments {
47                    if let GenericArgument::Type(Type::Path(TypePath { path, .. })) =
48                        &inner.args[inner.args.len() - 1]
49                    {
50                        let last = path.segments.last().unwrap();
51                        if last.ident == "Self" {
52                            continue;
53                        }
54                    }
55                }
56            }
57        }
58        args.push(arg);
59    }
60    Ok(args)
61}
62
63#[derive(Debug, Clone)]
64pub struct ArgInfo {
65    pub(crate) name: String,
66    pub(crate) r#type: TypeOrOverride,
67}
68
69impl TryFrom<FnArg> for ArgInfo {
70    type Error = syn::Error;
71    fn try_from(value: FnArg) -> Result<Self> {
72        let span = value.span();
73        if let FnArg::Typed(PatType { pat, ty, attrs, .. }) = value {
74            if let syn::Pat::Ident(mut ident) = *pat {
75                ident.mutability = None;
76                let name = ident.to_token_stream().to_string();
77                if let Some(attr) = parse_gen_stub_override_type(&attrs)? {
78                    return Ok(Self {
79                        name,
80                        r#type: TypeOrOverride::OverrideType {
81                            r#type: (*ty).clone(),
82                            type_repr: attr.type_repr,
83                            imports: attr.imports,
84                        },
85                    });
86                }
87                return Ok(Self {
88                    name,
89                    r#type: TypeOrOverride::RustType {
90                        r#type: (*ty).clone(),
91                    },
92                });
93            }
94
95            if let syn::Pat::Wild(_) = *pat {
96                return Ok(Self {
97                    name: "_".to_owned(),
98                    r#type: TypeOrOverride::RustType { r#type: *ty },
99                });
100            }
101        }
102        Err(syn::Error::new(span, "Expected typed argument"))
103    }
104}