pyo3_stub_gen_derive/gen_stub/
util.rs

1use std::collections::HashSet;
2
3use proc_macro2::TokenStream as TokenStream2;
4use quote::{quote, ToTokens};
5use syn::{
6    Attribute, GenericArgument, PathArguments, PathSegment, Result, ReturnType, Type, TypePath,
7};
8
9use crate::gen_stub::attr::parse_gen_stub_override_return_type;
10
11pub fn quote_option<T: ToTokens>(a: &Option<T>) -> TokenStream2 {
12    if let Some(a) = a {
13        quote! { Some(#a) }
14    } else {
15        quote! { None }
16    }
17}
18
19pub fn remove_lifetime(ty: &mut Type) {
20    match ty {
21        Type::Path(TypePath { path, .. }) => {
22            if let Some(PathSegment {
23                arguments: PathArguments::AngleBracketed(inner),
24                ..
25            }) = path.segments.last_mut()
26            {
27                for arg in &mut inner.args {
28                    match arg {
29                        GenericArgument::Lifetime(l) => {
30                            // `T::<'a, S>` becomes `T::<'_, S>`
31                            *l = syn::parse_quote!('_);
32                        }
33                        GenericArgument::Type(ty) => {
34                            remove_lifetime(ty);
35                        }
36                        _ => {}
37                    }
38                }
39            }
40        }
41        Type::Reference(rty) => {
42            rty.lifetime = None;
43            remove_lifetime(rty.elem.as_mut());
44        }
45        Type::Tuple(ty) => {
46            for elem in &mut ty.elems {
47                remove_lifetime(elem);
48            }
49        }
50        Type::Array(ary) => {
51            remove_lifetime(ary.elem.as_mut());
52        }
53        _ => {}
54    }
55}
56
57/// Extract `T` from `PyResult<T>` and apply `override_type` attribute if present.
58///
59/// For `PyResult<&'a T>` case, `'a` will be removed, i.e. returns `&T` for this case.
60pub fn extract_return_type(
61    ret: &ReturnType,
62    attrs: &[Attribute],
63) -> Result<Option<TypeOrOverride>> {
64    let ret = if let ReturnType::Type(_, ty) = ret {
65        unwrap_pyresult(ty)
66    } else {
67        return Ok(None);
68    };
69    let mut ret = ret.clone();
70    remove_lifetime(&mut ret);
71    if let Some(attr) = parse_gen_stub_override_return_type(attrs)? {
72        return Ok(Some(TypeOrOverride::OverrideType {
73            r#type: ret.clone(),
74            type_repr: attr.type_repr,
75            imports: attr.imports,
76        }));
77    }
78    Ok(Some(TypeOrOverride::RustType { r#type: ret }))
79}
80
81fn unwrap_pyresult(ty: &Type) -> &Type {
82    if let Type::Path(TypePath { path, .. }) = ty {
83        if let Some(last) = path.segments.last() {
84            if last.ident == "PyResult" {
85                if let PathArguments::AngleBracketed(inner) = &last.arguments {
86                    for arg in &inner.args {
87                        if let GenericArgument::Type(ty) = arg {
88                            return ty;
89                        }
90                    }
91                }
92            }
93        }
94    }
95    ty
96}
97
98#[derive(Debug, Clone)]
99pub enum TypeOrOverride {
100    RustType {
101        r#type: Type,
102    },
103    OverrideType {
104        r#type: Type,
105        type_repr: String,
106        imports: HashSet<String>,
107    },
108}
109
110#[cfg(test)]
111mod test {
112    use super::*;
113    use syn::{parse_str, Result};
114
115    #[test]
116    fn test_unwrap_pyresult() -> Result<()> {
117        let ty: Type = parse_str("PyResult<i32>")?;
118        let out = unwrap_pyresult(&ty);
119        assert_eq!(out, &parse_str("i32")?);
120
121        let ty: Type = parse_str("PyResult<&PyString>")?;
122        let out = unwrap_pyresult(&ty);
123        assert_eq!(out, &parse_str("&PyString")?);
124
125        let ty: Type = parse_str("PyResult<&'a PyString>")?;
126        let out = unwrap_pyresult(&ty);
127        assert_eq!(out, &parse_str("&'a PyString")?);
128
129        let ty: Type = parse_str("::pyo3::PyResult<i32>")?;
130        let out = unwrap_pyresult(&ty);
131        assert_eq!(out, &parse_str("i32")?);
132
133        let ty: Type = parse_str("::pyo3::PyResult<&PyString>")?;
134        let out = unwrap_pyresult(&ty);
135        assert_eq!(out, &parse_str("&PyString")?);
136
137        let ty: Type = parse_str("::pyo3::PyResult<&'a PyString>")?;
138        let out = unwrap_pyresult(&ty);
139        assert_eq!(out, &parse_str("&'a PyString")?);
140
141        Ok(())
142    }
143}