pyo3_stub_gen_derive/gen_stub/
arg.rs

1use quote::ToTokens;
2use syn::{
3    spanned::Spanned, FnArg, GenericArgument, PatType, PathArguments, Result, Type, TypePath,
4};
5
6pub fn parse_args(iter: impl IntoIterator<Item = FnArg>) -> Result<Vec<ArgInfo>> {
7    let mut args = Vec::new();
8    for (n, arg) in iter.into_iter().enumerate() {
9        if let FnArg::Receiver(_) = arg {
10            continue;
11        }
12        let arg = ArgInfo::try_from(arg)?;
13        if let Type::Path(TypePath { path, .. }) = &arg.r#type {
14            let last = path.segments.last().unwrap();
15            if last.ident == "Python" {
16                continue;
17            }
18            // Regard the first argument with `PyRef<'_, Self>` and `PyMutRef<'_, Self>` types as a receiver.
19            if n == 0 && (last.ident == "PyRef" || last.ident == "PyRefMut") {
20                if let PathArguments::AngleBracketed(inner) = &last.arguments {
21                    if let GenericArgument::Type(Type::Path(TypePath { path, .. })) =
22                        &inner.args[inner.args.len() - 1]
23                    {
24                        let last = path.segments.last().unwrap();
25                        if last.ident == "Self" {
26                            continue;
27                        }
28                    }
29                }
30            }
31        }
32        args.push(arg);
33    }
34    Ok(args)
35}
36
37#[derive(Debug)]
38pub struct ArgInfo {
39    pub name: String,
40    pub r#type: Type,
41}
42
43impl TryFrom<FnArg> for ArgInfo {
44    type Error = syn::Error;
45    fn try_from(value: FnArg) -> Result<Self> {
46        let span = value.span();
47        if let FnArg::Typed(PatType { pat, ty, .. }) = value {
48            if let syn::Pat::Ident(mut ident) = *pat {
49                ident.mutability = None;
50                let name = ident.to_token_stream().to_string();
51                return Ok(Self { name, r#type: *ty });
52            }
53        }
54        Err(syn::Error::new(span, "Expected typed argument"))
55    }
56}