pyo3_stub_gen_derive/gen_stub/
util.rs

1use proc_macro2::TokenStream as TokenStream2;
2use quote::{quote, ToTokens};
3use syn::{GenericArgument, PathArguments, PathSegment, ReturnType, Type, TypePath};
4
5pub fn quote_option<T: ToTokens>(a: &Option<T>) -> TokenStream2 {
6    if let Some(a) = a {
7        quote! { Some(#a) }
8    } else {
9        quote! { None }
10    }
11}
12
13pub fn remove_lifetime(ty: &mut Type) {
14    match ty {
15        Type::Path(TypePath { path, .. }) => {
16            if let Some(PathSegment {
17                arguments: PathArguments::AngleBracketed(inner),
18                ..
19            }) = path.segments.last_mut()
20            {
21                for arg in &mut inner.args {
22                    match arg {
23                        GenericArgument::Lifetime(l) => {
24                            // `T::<'a, S>` becomes `T::<'_, S>`
25                            *l = syn::parse_quote!('_);
26                        }
27                        GenericArgument::Type(ty) => {
28                            remove_lifetime(ty);
29                        }
30                        _ => {}
31                    }
32                }
33            }
34        }
35        Type::Reference(rty) => {
36            rty.lifetime = None;
37            remove_lifetime(rty.elem.as_mut());
38        }
39        Type::Tuple(ty) => {
40            for elem in &mut ty.elems {
41                remove_lifetime(elem);
42            }
43        }
44        Type::Array(ary) => {
45            remove_lifetime(ary.elem.as_mut());
46        }
47        _ => {}
48    }
49}
50
51/// Extract `T` from `PyResult<T>`.
52///
53/// For `PyResult<&'a T>` case, `'a` will be removed, i.e. returns `&T` for this case.
54pub fn escape_return_type(ret: &ReturnType) -> Option<Type> {
55    let ret = if let ReturnType::Type(_, ty) = ret {
56        unwrap_pyresult(ty)
57    } else {
58        return None;
59    };
60    let mut ret = ret.clone();
61    remove_lifetime(&mut ret);
62    Some(ret)
63}
64
65fn unwrap_pyresult(ty: &Type) -> &Type {
66    if let Type::Path(TypePath { path, .. }) = ty {
67        if let Some(last) = path.segments.last() {
68            if last.ident == "PyResult" {
69                if let PathArguments::AngleBracketed(inner) = &last.arguments {
70                    for arg in &inner.args {
71                        if let GenericArgument::Type(ty) = arg {
72                            return ty;
73                        }
74                    }
75                }
76            }
77        }
78    }
79    ty
80}
81
82#[cfg(test)]
83mod test {
84    use super::*;
85    use syn::{parse_str, Result};
86
87    #[test]
88    fn test_unwrap_pyresult() -> Result<()> {
89        let ty: Type = parse_str("PyResult<i32>")?;
90        let out = unwrap_pyresult(&ty);
91        assert_eq!(out, &parse_str("i32")?);
92
93        let ty: Type = parse_str("PyResult<&PyString>")?;
94        let out = unwrap_pyresult(&ty);
95        assert_eq!(out, &parse_str("&PyString")?);
96
97        let ty: Type = parse_str("PyResult<&'a PyString>")?;
98        let out = unwrap_pyresult(&ty);
99        assert_eq!(out, &parse_str("&'a PyString")?);
100
101        let ty: Type = parse_str("::pyo3::PyResult<i32>")?;
102        let out = unwrap_pyresult(&ty);
103        assert_eq!(out, &parse_str("i32")?);
104
105        let ty: Type = parse_str("::pyo3::PyResult<&PyString>")?;
106        let out = unwrap_pyresult(&ty);
107        assert_eq!(out, &parse_str("&PyString")?);
108
109        let ty: Type = parse_str("::pyo3::PyResult<&'a PyString>")?;
110        let out = unwrap_pyresult(&ty);
111        assert_eq!(out, &parse_str("&'a PyString")?);
112
113        Ok(())
114    }
115}