pyo3_stub_gen_derive/gen_stub/
arg.rs

1use quote::ToTokens;
2use syn::{
3    spanned::Spanned, FnArg, GenericArgument, PatType, PathArguments, Result, Type, TypePath,
4    TypeReference,
5};
6
7use crate::gen_stub::{attr::parse_gen_stub_override_type, util::TypeOrOverride};
8
9pub fn parse_args(iter: impl IntoIterator<Item = FnArg>) -> Result<Vec<ArgInfo>> {
10    let mut args = Vec::new();
11    for (n, arg) in iter.into_iter().enumerate() {
12        if let FnArg::Receiver(_) = arg {
13            continue;
14        }
15        let arg = ArgInfo::try_from(arg)?;
16        let (ArgInfo {
17            r#type: TypeOrOverride::RustType { r#type },
18            ..
19        }
20        | ArgInfo {
21            r#type: TypeOrOverride::OverrideType { r#type, .. },
22            ..
23        }) = &arg;
24        // Regard the first argument with `&Bound<'_, PyType>` (classmethod
25        // `cls`) or `&Bound<'_, Self>` (explicit borrowed self) as a receiver.
26        if let Type::Reference(TypeReference { elem, .. }) = &r#type {
27            if let Type::Path(TypePath { path, .. }) = elem.as_ref() {
28                let last = path.segments.last().unwrap();
29                if n == 0 && last.ident == "Bound" {
30                    if let PathArguments::AngleBracketed(args, ..) = &last.arguments {
31                        if let Some(GenericArgument::Type(Type::Path(TypePath { path, .. }))) =
32                            args.args.last()
33                        {
34                            let inner_last = path.segments.last().unwrap();
35                            if inner_last.ident == "PyType" || inner_last.ident == "Self" {
36                                continue;
37                            }
38                        }
39                    }
40                }
41            }
42        }
43        if let Type::Path(TypePath { path, .. }) = &r#type {
44            let last = path.segments.last().unwrap();
45            if last.ident == "Python" {
46                continue;
47            }
48            // Regard the first argument with `PyRef<'_, Self>` /
49            // `PyRefMut<'_, Self>` / `Bound<'_, Self>` / `Py<Self>` as a
50            // receiver. PyO3 accepts all four shapes for self in `#[pymethods]`.
51            if n == 0
52                && (last.ident == "PyRef"
53                    || last.ident == "PyRefMut"
54                    || last.ident == "Bound"
55                    || last.ident == "Py")
56            {
57                if let PathArguments::AngleBracketed(inner) = &last.arguments {
58                    if let GenericArgument::Type(Type::Path(TypePath { path, .. })) =
59                        &inner.args[inner.args.len() - 1]
60                    {
61                        let last = path.segments.last().unwrap();
62                        if last.ident == "Self" {
63                            continue;
64                        }
65                    }
66                }
67            }
68        }
69        args.push(arg);
70    }
71    Ok(args)
72}
73
74#[derive(Debug, Clone)]
75pub struct ArgInfo {
76    pub(crate) name: String,
77    pub(crate) r#type: TypeOrOverride,
78}
79
80impl TryFrom<FnArg> for ArgInfo {
81    type Error = syn::Error;
82    fn try_from(value: FnArg) -> Result<Self> {
83        let span = value.span();
84        if let FnArg::Typed(PatType { pat, ty, attrs, .. }) = value {
85            if let syn::Pat::Ident(mut ident) = *pat {
86                ident.mutability = None;
87                let name = ident.to_token_stream().to_string();
88                if let Some(attr) = parse_gen_stub_override_type(&attrs)? {
89                    return Ok(Self {
90                        name,
91                        r#type: TypeOrOverride::OverrideType {
92                            r#type: (*ty).clone(),
93                            type_repr: attr.type_repr,
94                            imports: attr.imports,
95                            rust_type_markers: vec![],
96                        },
97                    });
98                }
99                return Ok(Self {
100                    name,
101                    r#type: TypeOrOverride::RustType {
102                        r#type: (*ty).clone(),
103                    },
104                });
105            }
106
107            if let syn::Pat::Wild(_) = *pat {
108                return Ok(Self {
109                    name: "_".to_owned(),
110                    r#type: TypeOrOverride::RustType { r#type: *ty },
111                });
112            }
113        }
114        Err(syn::Error::new(span, "Expected typed argument"))
115    }
116}