pyo3_stub_gen_derive/gen_stub/
util.rs1use std::collections::HashSet;
2
3use proc_macro2::TokenStream as TokenStream2;
4use quote::{quote, ToTokens};
5use syn::{
6 Attribute, GenericArgument, PathArguments, PathSegment, Result, ReturnType, Type, TypePath,
7};
8
9use crate::gen_stub::attr::parse_gen_stub_override_return_type;
10
11pub fn quote_option<T: ToTokens>(a: &Option<T>) -> TokenStream2 {
12 if let Some(a) = a {
13 quote! { Some(#a) }
14 } else {
15 quote! { None }
16 }
17}
18
19pub fn remove_lifetime(ty: &mut Type) {
20 match ty {
21 Type::Path(TypePath { path, .. }) => {
22 if let Some(PathSegment {
23 arguments: PathArguments::AngleBracketed(inner),
24 ..
25 }) = path.segments.last_mut()
26 {
27 for arg in &mut inner.args {
28 match arg {
29 GenericArgument::Lifetime(l) => {
30 *l = syn::parse_quote!('_);
32 }
33 GenericArgument::Type(ty) => {
34 remove_lifetime(ty);
35 }
36 _ => {}
37 }
38 }
39 }
40 }
41 Type::Reference(rty) => {
42 rty.lifetime = None;
43 remove_lifetime(rty.elem.as_mut());
44 }
45 Type::Tuple(ty) => {
46 for elem in &mut ty.elems {
47 remove_lifetime(elem);
48 }
49 }
50 Type::Array(ary) => {
51 remove_lifetime(ary.elem.as_mut());
52 }
53 _ => {}
54 }
55}
56
57pub fn extract_return_type(
61 ret: &ReturnType,
62 attrs: &[Attribute],
63) -> Result<Option<TypeOrOverride>> {
64 let ret = if let ReturnType::Type(_, ty) = ret {
65 unwrap_pyresult(ty)
66 } else {
67 return Ok(None);
68 };
69 let mut ret = ret.clone();
70 remove_lifetime(&mut ret);
71 if let Some(attr) = parse_gen_stub_override_return_type(attrs)? {
72 return Ok(Some(TypeOrOverride::OverrideType {
73 r#type: ret.clone(),
74 type_repr: attr.type_repr,
75 imports: attr.imports,
76 }));
77 }
78 Ok(Some(TypeOrOverride::RustType { r#type: ret }))
79}
80
81fn unwrap_pyresult(ty: &Type) -> &Type {
82 if let Type::Path(TypePath { path, .. }) = ty {
83 if let Some(last) = path.segments.last() {
84 if last.ident == "PyResult" {
85 if let PathArguments::AngleBracketed(inner) = &last.arguments {
86 for arg in &inner.args {
87 if let GenericArgument::Type(ty) = arg {
88 return ty;
89 }
90 }
91 }
92 }
93 }
94 }
95 ty
96}
97
98#[derive(Debug, Clone)]
99pub enum TypeOrOverride {
100 RustType {
101 r#type: Type,
102 },
103 OverrideType {
104 r#type: Type,
105 type_repr: String,
106 imports: HashSet<String>,
107 },
108}
109
110#[cfg(test)]
111mod test {
112 use super::*;
113 use syn::{parse_str, Result};
114
115 #[test]
116 fn test_unwrap_pyresult() -> Result<()> {
117 let ty: Type = parse_str("PyResult<i32>")?;
118 let out = unwrap_pyresult(&ty);
119 assert_eq!(out, &parse_str("i32")?);
120
121 let ty: Type = parse_str("PyResult<&PyString>")?;
122 let out = unwrap_pyresult(&ty);
123 assert_eq!(out, &parse_str("&PyString")?);
124
125 let ty: Type = parse_str("PyResult<&'a PyString>")?;
126 let out = unwrap_pyresult(&ty);
127 assert_eq!(out, &parse_str("&'a PyString")?);
128
129 let ty: Type = parse_str("::pyo3::PyResult<i32>")?;
130 let out = unwrap_pyresult(&ty);
131 assert_eq!(out, &parse_str("i32")?);
132
133 let ty: Type = parse_str("::pyo3::PyResult<&PyString>")?;
134 let out = unwrap_pyresult(&ty);
135 assert_eq!(out, &parse_str("&PyString")?);
136
137 let ty: Type = parse_str("::pyo3::PyResult<&'a PyString>")?;
138 let out = unwrap_pyresult(&ty);
139 assert_eq!(out, &parse_str("&'a PyString")?);
140
141 Ok(())
142 }
143}