pyo3_stub_gen_derive/gen_stub/
util.rs1use indexmap::IndexSet;
2use proc_macro2::TokenStream as TokenStream2;
3use quote::{quote, ToTokens};
4use syn::{
5 Attribute, GenericArgument, PathArguments, PathSegment, Result, ReturnType, Type, TypePath,
6};
7
8use crate::gen_stub::attr::parse_gen_stub_override_return_type;
9
10pub fn quote_option<T: ToTokens>(a: &Option<T>) -> TokenStream2 {
11 if let Some(a) = a {
12 quote! { Some(#a) }
13 } else {
14 quote! { None }
15 }
16}
17
18pub fn remove_lifetime(ty: &mut Type) {
19 match ty {
20 Type::Path(TypePath { path, .. }) => {
21 if let Some(PathSegment {
22 arguments: PathArguments::AngleBracketed(inner),
23 ..
24 }) = path.segments.last_mut()
25 {
26 for arg in &mut inner.args {
27 match arg {
28 GenericArgument::Lifetime(l) => {
29 *l = syn::parse_quote!('_);
31 }
32 GenericArgument::Type(ty) => {
33 remove_lifetime(ty);
34 }
35 _ => {}
36 }
37 }
38 }
39 }
40 Type::Reference(rty) => {
41 rty.lifetime = None;
42 remove_lifetime(rty.elem.as_mut());
43 }
44 Type::Tuple(ty) => {
45 for elem in &mut ty.elems {
46 remove_lifetime(elem);
47 }
48 }
49 Type::Array(ary) => {
50 remove_lifetime(ary.elem.as_mut());
51 }
52 _ => {}
53 }
54}
55
56pub fn extract_return_type(
60 ret: &ReturnType,
61 attrs: &[Attribute],
62) -> Result<Option<TypeOrOverride>> {
63 let ret = if let ReturnType::Type(_, ty) = ret {
64 unwrap_pyresult(ty)
65 } else {
66 return Ok(None);
67 };
68 let mut ret = ret.clone();
69 remove_lifetime(&mut ret);
70 if let Some(attr) = parse_gen_stub_override_return_type(attrs)? {
71 return Ok(Some(TypeOrOverride::OverrideType {
72 r#type: ret.clone(),
73 type_repr: attr.type_repr,
74 imports: attr.imports,
75 }));
76 }
77 Ok(Some(TypeOrOverride::RustType { r#type: ret }))
78}
79
80fn unwrap_pyresult(ty: &Type) -> &Type {
81 if let Type::Path(TypePath { path, .. }) = ty {
82 if let Some(last) = path.segments.last() {
83 if last.ident == "PyResult" {
84 if let PathArguments::AngleBracketed(inner) = &last.arguments {
85 for arg in &inner.args {
86 if let GenericArgument::Type(ty) = arg {
87 return ty;
88 }
89 }
90 }
91 }
92 }
93 }
94 ty
95}
96
97#[derive(Debug, Clone)]
98pub enum TypeOrOverride {
99 RustType {
100 r#type: Type,
101 },
102 OverrideType {
103 r#type: Type,
104 type_repr: String,
105 imports: IndexSet<String>,
106 },
107}
108
109#[cfg(test)]
110mod test {
111 use super::*;
112 use syn::{parse_str, Result};
113
114 #[test]
115 fn test_unwrap_pyresult() -> Result<()> {
116 let ty: Type = parse_str("PyResult<i32>")?;
117 let out = unwrap_pyresult(&ty);
118 assert_eq!(out, &parse_str("i32")?);
119
120 let ty: Type = parse_str("PyResult<&PyString>")?;
121 let out = unwrap_pyresult(&ty);
122 assert_eq!(out, &parse_str("&PyString")?);
123
124 let ty: Type = parse_str("PyResult<&'a PyString>")?;
125 let out = unwrap_pyresult(&ty);
126 assert_eq!(out, &parse_str("&'a PyString")?);
127
128 let ty: Type = parse_str("::pyo3::PyResult<i32>")?;
129 let out = unwrap_pyresult(&ty);
130 assert_eq!(out, &parse_str("i32")?);
131
132 let ty: Type = parse_str("::pyo3::PyResult<&PyString>")?;
133 let out = unwrap_pyresult(&ty);
134 assert_eq!(out, &parse_str("&PyString")?);
135
136 let ty: Type = parse_str("::pyo3::PyResult<&'a PyString>")?;
137 let out = unwrap_pyresult(&ty);
138 assert_eq!(out, &parse_str("&'a PyString")?);
139
140 Ok(())
141 }
142}