pyo3_stub_gen_derive/gen_stub/
util.rs1use proc_macro2::TokenStream as TokenStream2;
2use quote::{quote, ToTokens};
3use syn::{GenericArgument, PathArguments, PathSegment, ReturnType, Type, TypePath};
4
5pub fn quote_option<T: ToTokens>(a: &Option<T>) -> TokenStream2 {
6 if let Some(a) = a {
7 quote! { Some(#a) }
8 } else {
9 quote! { None }
10 }
11}
12
13pub fn remove_lifetime(ty: &mut Type) {
14 match ty {
15 Type::Path(TypePath { path, .. }) => {
16 if let Some(PathSegment {
17 arguments: PathArguments::AngleBracketed(inner),
18 ..
19 }) = path.segments.last_mut()
20 {
21 for arg in &mut inner.args {
22 match arg {
23 GenericArgument::Lifetime(l) => {
24 *l = syn::parse_quote!('_);
26 }
27 GenericArgument::Type(ty) => {
28 remove_lifetime(ty);
29 }
30 _ => {}
31 }
32 }
33 }
34 }
35 Type::Reference(rty) => {
36 rty.lifetime = None;
37 remove_lifetime(rty.elem.as_mut());
38 }
39 Type::Tuple(ty) => {
40 for elem in &mut ty.elems {
41 remove_lifetime(elem);
42 }
43 }
44 Type::Array(ary) => {
45 remove_lifetime(ary.elem.as_mut());
46 }
47 _ => {}
48 }
49}
50
51pub fn escape_return_type(ret: &ReturnType) -> Option<Type> {
55 let ret = if let ReturnType::Type(_, ty) = ret {
56 unwrap_pyresult(ty)
57 } else {
58 return None;
59 };
60 let mut ret = ret.clone();
61 remove_lifetime(&mut ret);
62 Some(ret)
63}
64
65fn unwrap_pyresult(ty: &Type) -> &Type {
66 if let Type::Path(TypePath { path, .. }) = ty {
67 if let Some(last) = path.segments.last() {
68 if last.ident == "PyResult" {
69 if let PathArguments::AngleBracketed(inner) = &last.arguments {
70 for arg in &inner.args {
71 if let GenericArgument::Type(ty) = arg {
72 return ty;
73 }
74 }
75 }
76 }
77 }
78 }
79 ty
80}
81
82#[cfg(test)]
83mod test {
84 use super::*;
85 use syn::{parse_str, Result};
86
87 #[test]
88 fn test_unwrap_pyresult() -> Result<()> {
89 let ty: Type = parse_str("PyResult<i32>")?;
90 let out = unwrap_pyresult(&ty);
91 assert_eq!(out, &parse_str("i32")?);
92
93 let ty: Type = parse_str("PyResult<&PyString>")?;
94 let out = unwrap_pyresult(&ty);
95 assert_eq!(out, &parse_str("&PyString")?);
96
97 let ty: Type = parse_str("PyResult<&'a PyString>")?;
98 let out = unwrap_pyresult(&ty);
99 assert_eq!(out, &parse_str("&'a PyString")?);
100
101 let ty: Type = parse_str("::pyo3::PyResult<i32>")?;
102 let out = unwrap_pyresult(&ty);
103 assert_eq!(out, &parse_str("i32")?);
104
105 let ty: Type = parse_str("::pyo3::PyResult<&PyString>")?;
106 let out = unwrap_pyresult(&ty);
107 assert_eq!(out, &parse_str("&PyString")?);
108
109 let ty: Type = parse_str("::pyo3::PyResult<&'a PyString>")?;
110 let out = unwrap_pyresult(&ty);
111 assert_eq!(out, &parse_str("&'a PyString")?);
112
113 Ok(())
114 }
115}