pyo3_stub_gen_derive/gen_stub/
util.rs

1use indexmap::IndexSet;
2use proc_macro2::TokenStream as TokenStream2;
3use quote::{quote, ToTokens};
4use syn::{
5    Attribute, GenericArgument, PathArguments, PathSegment, Result, ReturnType, Type, TypePath,
6};
7
8use crate::gen_stub::attr::parse_gen_stub_override_return_type;
9
10pub fn quote_option<T: ToTokens>(a: &Option<T>) -> TokenStream2 {
11    if let Some(a) = a {
12        quote! { Some(#a) }
13    } else {
14        quote! { None }
15    }
16}
17
18pub fn remove_lifetime(ty: &mut Type) {
19    match ty {
20        Type::Path(TypePath { path, .. }) => {
21            if let Some(PathSegment {
22                arguments: PathArguments::AngleBracketed(inner),
23                ..
24            }) = path.segments.last_mut()
25            {
26                for arg in &mut inner.args {
27                    match arg {
28                        GenericArgument::Lifetime(l) => {
29                            // `T::<'a, S>` becomes `T::<'_, S>`
30                            *l = syn::parse_quote!('_);
31                        }
32                        GenericArgument::Type(ty) => {
33                            remove_lifetime(ty);
34                        }
35                        _ => {}
36                    }
37                }
38            }
39        }
40        Type::Reference(rty) => {
41            rty.lifetime = None;
42            remove_lifetime(rty.elem.as_mut());
43        }
44        Type::Tuple(ty) => {
45            for elem in &mut ty.elems {
46                remove_lifetime(elem);
47            }
48        }
49        Type::Array(ary) => {
50            remove_lifetime(ary.elem.as_mut());
51        }
52        _ => {}
53    }
54}
55
56/// Extract `T` from `PyResult<T>` and apply `override_type` attribute if present.
57///
58/// For `PyResult<&'a T>` case, `'a` will be removed, i.e. returns `&T` for this case.
59pub fn extract_return_type(
60    ret: &ReturnType,
61    attrs: &[Attribute],
62) -> Result<Option<TypeOrOverride>> {
63    let ret = if let ReturnType::Type(_, ty) = ret {
64        unwrap_pyresult(ty)
65    } else {
66        return Ok(None);
67    };
68    let mut ret = ret.clone();
69    remove_lifetime(&mut ret);
70    if let Some(attr) = parse_gen_stub_override_return_type(attrs)? {
71        return Ok(Some(TypeOrOverride::OverrideType {
72            r#type: ret.clone(),
73            type_repr: attr.type_repr,
74            imports: attr.imports,
75            rust_type_markers: vec![],
76        }));
77    }
78    Ok(Some(TypeOrOverride::RustType { r#type: ret }))
79}
80
81fn unwrap_pyresult(ty: &Type) -> &Type {
82    if let Type::Path(TypePath { path, .. }) = ty {
83        if let Some(last) = path.segments.last() {
84            if last.ident == "PyResult" {
85                if let PathArguments::AngleBracketed(inner) = &last.arguments {
86                    for arg in &inner.args {
87                        if let GenericArgument::Type(ty) = arg {
88                            return ty;
89                        }
90                    }
91                }
92            }
93        }
94    }
95    ty
96}
97
98#[derive(Debug, Clone)]
99pub enum TypeOrOverride {
100    RustType {
101        r#type: Type,
102    },
103    OverrideType {
104        r#type: Type,
105        type_repr: String,
106        imports: IndexSet<String>,
107        /// List of Rust type names found in RustType markers within this type expression.
108        /// Used to generate code that collects type_refs from those types.
109        rust_type_markers: Vec<String>,
110    },
111}
112
113#[cfg(test)]
114mod test {
115    use super::*;
116    use syn::{parse_str, Result};
117
118    #[test]
119    fn test_unwrap_pyresult() -> Result<()> {
120        let ty: Type = parse_str("PyResult<i32>")?;
121        let out = unwrap_pyresult(&ty);
122        assert_eq!(out, &parse_str("i32")?);
123
124        let ty: Type = parse_str("PyResult<&PyString>")?;
125        let out = unwrap_pyresult(&ty);
126        assert_eq!(out, &parse_str("&PyString")?);
127
128        let ty: Type = parse_str("PyResult<&'a PyString>")?;
129        let out = unwrap_pyresult(&ty);
130        assert_eq!(out, &parse_str("&'a PyString")?);
131
132        let ty: Type = parse_str("::pyo3::PyResult<i32>")?;
133        let out = unwrap_pyresult(&ty);
134        assert_eq!(out, &parse_str("i32")?);
135
136        let ty: Type = parse_str("::pyo3::PyResult<&PyString>")?;
137        let out = unwrap_pyresult(&ty);
138        assert_eq!(out, &parse_str("&PyString")?);
139
140        let ty: Type = parse_str("::pyo3::PyResult<&'a PyString>")?;
141        let out = unwrap_pyresult(&ty);
142        assert_eq!(out, &parse_str("&'a PyString")?);
143
144        Ok(())
145    }
146}