pyo3_stub_gen_derive/gen_stub/
util.rs

1use indexmap::IndexSet;
2use proc_macro2::TokenStream as TokenStream2;
3use quote::{quote, ToTokens};
4use syn::{
5    Attribute, GenericArgument, PathArguments, PathSegment, Result, ReturnType, Type, TypePath,
6};
7
8use crate::gen_stub::attr::parse_gen_stub_override_return_type;
9
10pub fn quote_option<T: ToTokens>(a: &Option<T>) -> TokenStream2 {
11    if let Some(a) = a {
12        quote! { Some(#a) }
13    } else {
14        quote! { None }
15    }
16}
17
18pub fn remove_lifetime(ty: &mut Type) {
19    match ty {
20        Type::Path(TypePath { path, .. }) => {
21            if let Some(PathSegment {
22                arguments: PathArguments::AngleBracketed(inner),
23                ..
24            }) = path.segments.last_mut()
25            {
26                for arg in &mut inner.args {
27                    match arg {
28                        GenericArgument::Lifetime(l) => {
29                            // `T::<'a, S>` becomes `T::<'_, S>`
30                            *l = syn::parse_quote!('_);
31                        }
32                        GenericArgument::Type(ty) => {
33                            remove_lifetime(ty);
34                        }
35                        _ => {}
36                    }
37                }
38            }
39        }
40        Type::Reference(rty) => {
41            rty.lifetime = None;
42            remove_lifetime(rty.elem.as_mut());
43        }
44        Type::Tuple(ty) => {
45            for elem in &mut ty.elems {
46                remove_lifetime(elem);
47            }
48        }
49        Type::Array(ary) => {
50            remove_lifetime(ary.elem.as_mut());
51        }
52        _ => {}
53    }
54}
55
56/// Extract `T` from `PyResult<T>` and apply `override_type` attribute if present.
57///
58/// For `PyResult<&'a T>` case, `'a` will be removed, i.e. returns `&T` for this case.
59pub fn extract_return_type(
60    ret: &ReturnType,
61    attrs: &[Attribute],
62) -> Result<Option<TypeOrOverride>> {
63    let ret = if let ReturnType::Type(_, ty) = ret {
64        unwrap_pyresult(ty)
65    } else {
66        return Ok(None);
67    };
68    let mut ret = ret.clone();
69    remove_lifetime(&mut ret);
70    if let Some(attr) = parse_gen_stub_override_return_type(attrs)? {
71        return Ok(Some(TypeOrOverride::OverrideType {
72            r#type: ret.clone(),
73            type_repr: attr.type_repr,
74            imports: attr.imports,
75        }));
76    }
77    Ok(Some(TypeOrOverride::RustType { r#type: ret }))
78}
79
80fn unwrap_pyresult(ty: &Type) -> &Type {
81    if let Type::Path(TypePath { path, .. }) = ty {
82        if let Some(last) = path.segments.last() {
83            if last.ident == "PyResult" {
84                if let PathArguments::AngleBracketed(inner) = &last.arguments {
85                    for arg in &inner.args {
86                        if let GenericArgument::Type(ty) = arg {
87                            return ty;
88                        }
89                    }
90                }
91            }
92        }
93    }
94    ty
95}
96
97#[derive(Debug, Clone)]
98pub enum TypeOrOverride {
99    RustType {
100        r#type: Type,
101    },
102    OverrideType {
103        r#type: Type,
104        type_repr: String,
105        imports: IndexSet<String>,
106    },
107}
108
109#[cfg(test)]
110mod test {
111    use super::*;
112    use syn::{parse_str, Result};
113
114    #[test]
115    fn test_unwrap_pyresult() -> Result<()> {
116        let ty: Type = parse_str("PyResult<i32>")?;
117        let out = unwrap_pyresult(&ty);
118        assert_eq!(out, &parse_str("i32")?);
119
120        let ty: Type = parse_str("PyResult<&PyString>")?;
121        let out = unwrap_pyresult(&ty);
122        assert_eq!(out, &parse_str("&PyString")?);
123
124        let ty: Type = parse_str("PyResult<&'a PyString>")?;
125        let out = unwrap_pyresult(&ty);
126        assert_eq!(out, &parse_str("&'a PyString")?);
127
128        let ty: Type = parse_str("::pyo3::PyResult<i32>")?;
129        let out = unwrap_pyresult(&ty);
130        assert_eq!(out, &parse_str("i32")?);
131
132        let ty: Type = parse_str("::pyo3::PyResult<&PyString>")?;
133        let out = unwrap_pyresult(&ty);
134        assert_eq!(out, &parse_str("&PyString")?);
135
136        let ty: Type = parse_str("::pyo3::PyResult<&'a PyString>")?;
137        let out = unwrap_pyresult(&ty);
138        assert_eq!(out, &parse_str("&'a PyString")?);
139
140        Ok(())
141    }
142}