pyo3_stub_gen_derive/gen_stub/
member.rs

1use crate::gen_stub::{attr::parse_gen_stub_default, extract_documents};
2
3use super::{escape_return_type, parse_pyo3_attrs, Attr};
4
5use proc_macro2::TokenStream as TokenStream2;
6use quote::{quote, ToTokens, TokenStreamExt};
7use syn::{Attribute, Error, Expr, Field, FnArg, ImplItemConst, ImplItemFn, Result, Type};
8
9#[derive(Debug)]
10pub struct MemberInfo {
11    doc: String,
12    name: String,
13    r#type: Type,
14    default: Option<Expr>,
15}
16
17impl MemberInfo {
18    pub fn is_getter(attrs: &[Attribute]) -> Result<bool> {
19        let attrs = parse_pyo3_attrs(attrs)?;
20        Ok(attrs.iter().any(|attr| matches!(attr, Attr::Getter(_))))
21    }
22    pub fn is_setter(attrs: &[Attribute]) -> Result<bool> {
23        let attrs = parse_pyo3_attrs(attrs)?;
24        Ok(attrs.iter().any(|attr| matches!(attr, Attr::Setter(_))))
25    }
26
27    pub fn is_classattr(attrs: &[Attribute]) -> Result<bool> {
28        let attrs = parse_pyo3_attrs(attrs)?;
29        Ok(attrs.iter().any(|attr| matches!(attr, Attr::ClassAttr)))
30    }
31    pub fn is_get(field: &Field) -> Result<bool> {
32        let Field { attrs, .. } = field;
33        Ok(parse_pyo3_attrs(attrs)?
34            .iter()
35            .any(|attr| matches!(attr, Attr::Get)))
36    }
37    pub fn is_set(field: &Field) -> Result<bool> {
38        let Field { attrs, .. } = field;
39        Ok(parse_pyo3_attrs(attrs)?
40            .iter()
41            .any(|attr| matches!(attr, Attr::Set)))
42    }
43}
44
45impl MemberInfo {
46    pub fn new_getter(item: ImplItemFn) -> Result<Self> {
47        assert!(Self::is_getter(&item.attrs)?);
48        let ImplItemFn { attrs, sig, .. } = &item;
49        let default = parse_gen_stub_default(attrs)?;
50        let doc = extract_documents(attrs).join("\n");
51        let attrs = parse_pyo3_attrs(attrs)?;
52        for attr in attrs {
53            if let Attr::Getter(name) = attr {
54                let fn_name = sig.ident.to_string();
55                let fn_getter_name = match fn_name.strip_prefix("get_") {
56                    Some(s) => s.to_owned(),
57                    None => fn_name,
58                };
59                return Ok(MemberInfo {
60                    doc,
61                    name: name.unwrap_or(fn_getter_name),
62                    r#type: escape_return_type(&sig.output).expect("Getter must return a type"),
63                    default,
64                });
65            }
66        }
67        unreachable!("Not a getter: {:?}", item)
68    }
69    pub fn new_setter(item: ImplItemFn) -> Result<Self> {
70        assert!(Self::is_setter(&item.attrs)?);
71        let ImplItemFn { attrs, sig, .. } = &item;
72        let default = parse_gen_stub_default(attrs)?;
73        let doc = extract_documents(attrs).join("\n");
74        let attrs = parse_pyo3_attrs(attrs)?;
75        for attr in attrs {
76            if let Attr::Setter(name) = attr {
77                let fn_name = sig.ident.to_string();
78                let fn_setter_name = match fn_name.strip_prefix("set_") {
79                    Some(s) => s.to_owned(),
80                    None => fn_name,
81                };
82                return Ok(MemberInfo {
83                    doc,
84                    name: name.unwrap_or(fn_setter_name),
85                    r#type: sig
86                        .inputs
87                        .get(1)
88                        .and_then(|arg| {
89                            if let FnArg::Typed(t) = arg {
90                                Some(*t.ty.clone())
91                            } else {
92                                None
93                            }
94                        })
95                        .expect("Setter must input a type"),
96                    default,
97                });
98            }
99        }
100        unreachable!("Not a setter: {:?}", item)
101    }
102    pub fn new_classattr_fn(item: ImplItemFn) -> Result<Self> {
103        assert!(Self::is_classattr(&item.attrs)?);
104        let ImplItemFn { attrs, sig, .. } = &item;
105        let default = parse_gen_stub_default(attrs)?;
106        let doc = extract_documents(attrs).join("\n");
107        Ok(MemberInfo {
108            doc,
109            name: sig.ident.to_string(),
110            r#type: escape_return_type(&sig.output).expect("Getter must return a type"),
111            default,
112        })
113    }
114    pub fn new_classattr_const(item: ImplItemConst) -> Result<Self> {
115        assert!(Self::is_classattr(&item.attrs)?);
116        let ImplItemConst {
117            attrs,
118            ident,
119            ty,
120            expr,
121            ..
122        } = item;
123        let doc = extract_documents(&attrs).join("\n");
124        Ok(MemberInfo {
125            doc,
126            name: ident.to_string(),
127            r#type: ty,
128            default: Some(expr),
129        })
130    }
131}
132
133impl TryFrom<Field> for MemberInfo {
134    type Error = Error;
135    fn try_from(field: Field) -> Result<Self> {
136        let Field {
137            ident, ty, attrs, ..
138        } = field;
139        let mut field_name = None;
140        for attr in parse_pyo3_attrs(&attrs)? {
141            if let Attr::Name(name) = attr {
142                field_name = Some(name);
143            }
144        }
145        let doc = extract_documents(&attrs).join("\n");
146        let default = parse_gen_stub_default(&attrs)?;
147        Ok(Self {
148            name: field_name.unwrap_or(ident.unwrap().to_string()),
149            r#type: ty,
150            doc,
151            default,
152        })
153    }
154}
155
156impl ToTokens for MemberInfo {
157    fn to_tokens(&self, tokens: &mut TokenStream2) {
158        let Self {
159            name,
160            r#type: ty,
161            doc,
162            default,
163        } = self;
164        let default = default
165            .as_ref()
166            .map(|value| {
167                if value.to_token_stream().to_string() == "None" {
168                    quote! {
169                        "None".to_string()
170                    }
171                } else {
172                    quote! {
173                        ::pyo3::prepare_freethreaded_python();
174                        ::pyo3::Python::with_gil(|py| -> String {
175                            let v: #ty = #value;
176                            ::pyo3_stub_gen::util::fmt_py_obj(py, v)
177                        })
178                    }
179                }
180            })
181            .map_or(quote! {None}, |default| {
182                quote! {Some({
183                    static DEFAULT: std::sync::LazyLock<String> = std::sync::LazyLock::new(|| {
184                        #default
185                    });
186                    &DEFAULT
187                })}
188            });
189        tokens.append_all(quote! {
190            ::pyo3_stub_gen::type_info::MemberInfo {
191                name: #name,
192                r#type: <#ty as ::pyo3_stub_gen::PyStubType>::type_output,
193                doc: #doc,
194                default: #default,
195            }
196        })
197    }
198}