1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
use proc_macro2::TokenStream as TokenStream2;
use quote::{quote, ToTokens};
use syn::{GenericArgument, PathArguments, PathSegment, ReturnType, Type, TypePath};

pub fn quote_option<T: ToTokens>(a: &Option<T>) -> TokenStream2 {
    if let Some(a) = a {
        quote! { Some(#a) }
    } else {
        quote! { None }
    }
}

pub fn remove_lifetime(ty: &mut Type) {
    match ty {
        Type::Path(TypePath { path, .. }) => {
            if let Some(PathSegment {
                arguments: PathArguments::AngleBracketed(inner),
                ..
            }) = path.segments.last_mut()
            {
                for arg in &mut inner.args {
                    match arg {
                        GenericArgument::Lifetime(l) => {
                            // `T::<'a, S>` becomes `T::<'_, S>`
                            *l = syn::parse_quote!('_);
                        }
                        GenericArgument::Type(ty) => {
                            remove_lifetime(ty);
                        }
                        _ => {}
                    }
                }
            }
        }
        Type::Reference(rty) => {
            rty.lifetime = None;
            remove_lifetime(rty.elem.as_mut());
        }
        Type::Tuple(ty) => {
            for elem in &mut ty.elems {
                remove_lifetime(elem);
            }
        }
        Type::Array(ary) => {
            remove_lifetime(ary.elem.as_mut());
        }
        _ => {}
    }
}

/// Extract `T` from `PyResult<T>`.
///
/// For `PyResult<&'a T>` case, `'a` will be removed, i.e. returns `&T` for this case.
pub fn escape_return_type(ret: &ReturnType) -> Option<Type> {
    let ret = if let ReturnType::Type(_, ty) = ret {
        unwrap_pyresult(ty)
    } else {
        return None;
    };
    let mut ret = ret.clone();
    remove_lifetime(&mut ret);
    Some(ret)
}

fn unwrap_pyresult(ty: &Type) -> &Type {
    if let Type::Path(TypePath { path, .. }) = ty {
        if let Some(last) = path.segments.last() {
            if last.ident == "PyResult" {
                if let PathArguments::AngleBracketed(inner) = &last.arguments {
                    for arg in &inner.args {
                        if let GenericArgument::Type(ty) = arg {
                            return ty;
                        }
                    }
                }
            }
        }
    }
    ty
}

#[cfg(test)]
mod test {
    use super::*;
    use syn::{parse_str, Result};

    #[test]
    fn test_unwrap_pyresult() -> Result<()> {
        let ty: Type = parse_str("PyResult<i32>")?;
        let out = unwrap_pyresult(&ty);
        assert_eq!(out, &parse_str("i32")?);

        let ty: Type = parse_str("PyResult<&PyString>")?;
        let out = unwrap_pyresult(&ty);
        assert_eq!(out, &parse_str("&PyString")?);

        let ty: Type = parse_str("PyResult<&'a PyString>")?;
        let out = unwrap_pyresult(&ty);
        assert_eq!(out, &parse_str("&'a PyString")?);

        let ty: Type = parse_str("::pyo3::PyResult<i32>")?;
        let out = unwrap_pyresult(&ty);
        assert_eq!(out, &parse_str("i32")?);

        let ty: Type = parse_str("::pyo3::PyResult<&PyString>")?;
        let out = unwrap_pyresult(&ty);
        assert_eq!(out, &parse_str("&PyString")?);

        let ty: Type = parse_str("::pyo3::PyResult<&'a PyString>")?;
        let out = unwrap_pyresult(&ty);
        assert_eq!(out, &parse_str("&'a PyString")?);

        Ok(())
    }
}