pyo3_stub_gen_derive/gen_stub/
member.rs

1use crate::gen_stub::{
2    attr::{parse_gen_stub_default, parse_gen_stub_override_type, OverrideTypeAttribute},
3    extract_documents,
4    util::TypeOrOverride,
5};
6
7use super::{extract_return_type, parse_pyo3_attrs, Attr};
8
9use crate::gen_stub::arg::ArgInfo;
10use proc_macro2::TokenStream as TokenStream2;
11use quote::{quote, ToTokens, TokenStreamExt};
12use syn::{Attribute, Error, Expr, Field, FnArg, ImplItemConst, ImplItemFn, Result};
13
14#[derive(Debug, Clone)]
15pub struct MemberInfo {
16    doc: String,
17    name: String,
18    r#type: TypeOrOverride,
19    default: Option<Expr>,
20    deprecated: Option<crate::gen_stub::attr::DeprecatedInfo>,
21}
22
23impl MemberInfo {
24    pub fn is_getter(attrs: &[Attribute]) -> Result<bool> {
25        let attrs = parse_pyo3_attrs(attrs)?;
26        Ok(attrs.iter().any(|attr| matches!(attr, Attr::Getter(_))))
27    }
28    pub fn is_setter(attrs: &[Attribute]) -> Result<bool> {
29        let attrs = parse_pyo3_attrs(attrs)?;
30        Ok(attrs.iter().any(|attr| matches!(attr, Attr::Setter(_))))
31    }
32
33    pub fn is_classattr(attrs: &[Attribute]) -> Result<bool> {
34        let attrs = parse_pyo3_attrs(attrs)?;
35        Ok(attrs.iter().any(|attr| matches!(attr, Attr::ClassAttr)))
36    }
37    pub fn is_get(field: &Field) -> Result<bool> {
38        let Field { attrs, .. } = field;
39        Ok(parse_pyo3_attrs(attrs)?
40            .iter()
41            .any(|attr| matches!(attr, Attr::Get)))
42    }
43    pub fn is_set(field: &Field) -> Result<bool> {
44        let Field { attrs, .. } = field;
45        Ok(parse_pyo3_attrs(attrs)?
46            .iter()
47            .any(|attr| matches!(attr, Attr::Set)))
48    }
49}
50
51impl MemberInfo {
52    pub fn new_getter(item: ImplItemFn) -> Result<Self> {
53        assert!(Self::is_getter(&item.attrs)?);
54        let ImplItemFn { attrs, sig, .. } = &item;
55        let default = parse_gen_stub_default(attrs)?;
56        let doc = extract_documents(attrs).join("\n");
57        let pyo3_attrs = parse_pyo3_attrs(attrs)?;
58        for attr in pyo3_attrs {
59            if let Attr::Getter(name) = attr {
60                let fn_name = sig.ident.to_string();
61                let fn_getter_name = match fn_name.strip_prefix("get_") {
62                    Some(s) => s.to_owned(),
63                    None => fn_name,
64                };
65                return Ok(MemberInfo {
66                    doc,
67                    name: name.unwrap_or(fn_getter_name),
68                    r#type: extract_return_type(&sig.output, attrs)?
69                        .expect("Getter must return a type"),
70                    default,
71                    deprecated: crate::gen_stub::attr::extract_deprecated(attrs),
72                });
73            }
74        }
75        unreachable!("Not a getter: {:?}", item)
76    }
77    pub fn new_setter(item: ImplItemFn) -> Result<Self> {
78        assert!(Self::is_setter(&item.attrs)?);
79        let ImplItemFn { attrs, sig, .. } = &item;
80        let default = parse_gen_stub_default(attrs)?;
81        let doc = extract_documents(attrs).join("\n");
82        let pyo3_attrs = parse_pyo3_attrs(attrs)?;
83        for attr in pyo3_attrs {
84            if let Attr::Setter(name) = attr {
85                let fn_name = sig.ident.to_string();
86                let fn_setter_name = match fn_name.strip_prefix("set_") {
87                    Some(s) => s.to_owned(),
88                    None => fn_name,
89                };
90                let r#type = sig
91                    .inputs
92                    .get(1)
93                    .ok_or(syn::Error::new_spanned(&item, "Setter must input a type"))
94                    .and_then(|arg| {
95                        if let FnArg::Typed(t) = arg {
96                            Ok(match parse_gen_stub_override_type(&t.attrs)? {
97                                Some(OverrideTypeAttribute { type_repr, imports }) => {
98                                    TypeOrOverride::OverrideType {
99                                        r#type: *t.ty.clone(),
100                                        type_repr,
101                                        imports,
102                                    }
103                                }
104                                _ => TypeOrOverride::RustType {
105                                    r#type: *t.ty.clone(),
106                                },
107                            })
108                        } else {
109                            Err(syn::Error::new_spanned(&item, "Setter must input a type"))
110                        }
111                    })?;
112                return Ok(MemberInfo {
113                    doc,
114                    name: name.unwrap_or(fn_setter_name),
115                    r#type,
116                    default,
117                    deprecated: crate::gen_stub::attr::extract_deprecated(attrs),
118                });
119            }
120        }
121        unreachable!("Not a setter: {:?}", item)
122    }
123    pub fn new_classattr_fn(item: ImplItemFn) -> Result<Self> {
124        assert!(Self::is_classattr(&item.attrs)?);
125        let ImplItemFn { attrs, sig, .. } = &item;
126        let default = parse_gen_stub_default(attrs)?;
127        let doc = extract_documents(attrs).join("\n");
128        let mut name = sig.ident.to_string();
129        for attr in parse_pyo3_attrs(attrs)? {
130            if let Attr::Name(_name) = attr {
131                name = _name;
132            }
133        }
134        Ok(MemberInfo {
135            doc,
136            name,
137            r#type: extract_return_type(&sig.output, attrs)?.expect("Getter must return a type"),
138            default,
139            deprecated: crate::gen_stub::attr::extract_deprecated(attrs),
140        })
141    }
142    pub fn new_classattr_const(item: ImplItemConst) -> Result<Self> {
143        assert!(Self::is_classattr(&item.attrs)?);
144        let ImplItemConst {
145            attrs,
146            ident,
147            ty,
148            expr,
149            ..
150        } = item;
151        let doc = extract_documents(&attrs).join("\n");
152        let mut name = ident.to_string();
153        for attr in parse_pyo3_attrs(&attrs)? {
154            if let Attr::Name(_name) = attr {
155                name = _name;
156            }
157        }
158        Ok(MemberInfo {
159            doc,
160            name,
161            r#type: TypeOrOverride::RustType { r#type: ty },
162            default: Some(expr),
163            deprecated: crate::gen_stub::attr::extract_deprecated(&attrs),
164        })
165    }
166}
167
168impl TryFrom<Field> for MemberInfo {
169    type Error = Error;
170    fn try_from(field: Field) -> Result<Self> {
171        let Field {
172            ident, ty, attrs, ..
173        } = field;
174        let mut field_name = None;
175        for attr in parse_pyo3_attrs(&attrs)? {
176            if let Attr::Name(name) = attr {
177                field_name = Some(name);
178            }
179        }
180        let doc = extract_documents(&attrs).join("\n");
181        let default = parse_gen_stub_default(&attrs)?;
182        let deprecated = crate::gen_stub::attr::extract_deprecated(&attrs);
183        Ok(Self {
184            name: field_name.unwrap_or(ident.unwrap().to_string()),
185            r#type: TypeOrOverride::RustType { r#type: ty },
186            doc,
187            default,
188            deprecated,
189        })
190    }
191}
192
193impl ToTokens for MemberInfo {
194    fn to_tokens(&self, tokens: &mut TokenStream2) {
195        let Self {
196            name,
197            r#type,
198            doc,
199            default,
200            deprecated,
201        } = self;
202        let default = default
203            .as_ref()
204            .map(|value| {
205                if value.to_token_stream().to_string() == "None" {
206                    quote! {
207                        "None".to_string()
208                    }
209                } else {
210                    let (TypeOrOverride::RustType { r#type: ty }
211                    | TypeOrOverride::OverrideType { r#type: ty, .. }) = r#type;
212                    quote! {
213                    let v: #ty = #value;
214                    ::pyo3_stub_gen::util::fmt_py_obj(v)
215                    }
216                }
217            })
218            .map_or(quote! {None}, |default| {
219                quote! {Some({
220                    fn _fmt() -> String {
221                        #default
222                    }
223                    _fmt
224                })}
225            });
226        let deprecated_info = deprecated
227            .as_ref()
228            .map(|deprecated| {
229                quote! {
230                    Some(::pyo3_stub_gen::type_info::DeprecatedInfo {
231                        since: #deprecated.since,
232                        note: #deprecated.note,
233                    })
234                }
235            })
236            .unwrap_or_else(|| quote! { None });
237        match r#type {
238            TypeOrOverride::RustType { r#type: ty } => tokens.append_all(quote! {
239                ::pyo3_stub_gen::type_info::MemberInfo {
240                    name: #name,
241                    r#type: <#ty as ::pyo3_stub_gen::PyStubType>::type_output,
242                    doc: #doc,
243                    default: #default,
244                    deprecated: #deprecated_info,
245                }
246            }),
247            TypeOrOverride::OverrideType {
248                type_repr, imports, ..
249            } => {
250                let imports = imports.iter().collect::<Vec<&String>>();
251                tokens.append_all(quote! {
252                    ::pyo3_stub_gen::type_info::MemberInfo {
253                        name: #name,
254                        r#type: || ::pyo3_stub_gen::TypeInfo { name: #type_repr.to_string(), import: ::std::collections::HashSet::from([#(#imports.into(),)*]) },
255                        doc: #doc,
256                        default: #default,
257                        deprecated: #deprecated_info,
258                    }
259                })
260            }
261        }
262    }
263}
264
265impl From<MemberInfo> for ArgInfo {
266    fn from(value: MemberInfo) -> Self {
267        let MemberInfo { name, r#type, .. } = value;
268
269        Self { name, r#type }
270    }
271}