pyo3_stub_gen_derive/gen_stub/
member.rs

1use crate::gen_stub::{
2    attr::{parse_gen_stub_default, parse_gen_stub_override_type, OverrideTypeAttribute},
3    extract_documents,
4    util::TypeOrOverride,
5};
6
7use super::{extract_return_type, parse_pyo3_attrs, Attr};
8
9use crate::gen_stub::arg::ArgInfo;
10use proc_macro2::TokenStream as TokenStream2;
11use quote::{quote, ToTokens, TokenStreamExt};
12use syn::{Attribute, Error, Expr, Field, FnArg, ImplItemConst, ImplItemFn, Result};
13
14#[derive(Debug, Clone)]
15pub struct MemberInfo {
16    doc: String,
17    name: String,
18    r#type: TypeOrOverride,
19    default: Option<Expr>,
20    deprecated: Option<crate::gen_stub::attr::DeprecatedInfo>,
21}
22
23impl MemberInfo {
24    pub fn is_getter(attrs: &[Attribute]) -> Result<bool> {
25        let attrs = parse_pyo3_attrs(attrs)?;
26        Ok(attrs.iter().any(|attr| matches!(attr, Attr::Getter(_))))
27    }
28    pub fn is_setter(attrs: &[Attribute]) -> Result<bool> {
29        let attrs = parse_pyo3_attrs(attrs)?;
30        Ok(attrs.iter().any(|attr| matches!(attr, Attr::Setter(_))))
31    }
32
33    pub fn is_classattr(attrs: &[Attribute]) -> Result<bool> {
34        let attrs = parse_pyo3_attrs(attrs)?;
35        Ok(attrs.iter().any(|attr| matches!(attr, Attr::ClassAttr)))
36    }
37    pub fn is_get(field: &Field) -> Result<bool> {
38        let Field { attrs, .. } = field;
39        Ok(parse_pyo3_attrs(attrs)?
40            .iter()
41            .any(|attr| matches!(attr, Attr::Get)))
42    }
43    pub fn is_set(field: &Field) -> Result<bool> {
44        let Field { attrs, .. } = field;
45        Ok(parse_pyo3_attrs(attrs)?
46            .iter()
47            .any(|attr| matches!(attr, Attr::Set)))
48    }
49}
50
51impl MemberInfo {
52    pub fn new_getter(item: ImplItemFn) -> Result<Self> {
53        assert!(Self::is_getter(&item.attrs)?);
54        let ImplItemFn { attrs, sig, .. } = &item;
55        let default = parse_gen_stub_default(attrs)?;
56        let doc = extract_documents(attrs).join("\n");
57        let pyo3_attrs = parse_pyo3_attrs(attrs)?;
58        for attr in pyo3_attrs {
59            if let Attr::Getter(name) = attr {
60                let fn_name = sig.ident.to_string();
61                let fn_getter_name = match fn_name.strip_prefix("get_") {
62                    Some(s) => s.to_owned(),
63                    None => fn_name,
64                };
65                return Ok(MemberInfo {
66                    doc,
67                    name: name.unwrap_or(fn_getter_name),
68                    r#type: extract_return_type(&sig.output, attrs)?
69                        .expect("Getter must return a type"),
70                    default,
71                    deprecated: crate::gen_stub::attr::extract_deprecated(attrs),
72                });
73            }
74        }
75        unreachable!("Not a getter: {:?}", item)
76    }
77    pub fn new_setter(item: ImplItemFn) -> Result<Self> {
78        assert!(Self::is_setter(&item.attrs)?);
79        let ImplItemFn { attrs, sig, .. } = &item;
80        let default = parse_gen_stub_default(attrs)?;
81        let doc = extract_documents(attrs).join("\n");
82        let pyo3_attrs = parse_pyo3_attrs(attrs)?;
83        for attr in pyo3_attrs {
84            if let Attr::Setter(name) = attr {
85                let fn_name = sig.ident.to_string();
86                let fn_setter_name = match fn_name.strip_prefix("set_") {
87                    Some(s) => s.to_owned(),
88                    None => fn_name,
89                };
90                let r#type = sig
91                    .inputs
92                    .get(1)
93                    .ok_or(syn::Error::new_spanned(&item, "Setter must input a type"))
94                    .and_then(|arg| {
95                        if let FnArg::Typed(t) = arg {
96                            Ok(match parse_gen_stub_override_type(&t.attrs)? {
97                                Some(OverrideTypeAttribute { type_repr, imports }) => {
98                                    TypeOrOverride::OverrideType {
99                                        r#type: *t.ty.clone(),
100                                        type_repr,
101                                        imports,
102                                    }
103                                }
104                                _ => TypeOrOverride::RustType {
105                                    r#type: *t.ty.clone(),
106                                },
107                            })
108                        } else {
109                            Err(syn::Error::new_spanned(&item, "Setter must input a type"))
110                        }
111                    })?;
112                return Ok(MemberInfo {
113                    doc,
114                    name: name.unwrap_or(fn_setter_name),
115                    r#type,
116                    default,
117                    deprecated: crate::gen_stub::attr::extract_deprecated(attrs),
118                });
119            }
120        }
121        unreachable!("Not a setter: {:?}", item)
122    }
123    pub fn new_classattr_fn(item: ImplItemFn) -> Result<Self> {
124        assert!(Self::is_classattr(&item.attrs)?);
125        let ImplItemFn { attrs, sig, .. } = &item;
126        let default = parse_gen_stub_default(attrs)?;
127        let doc = extract_documents(attrs).join("\n");
128        Ok(MemberInfo {
129            doc,
130            name: sig.ident.to_string(),
131            r#type: extract_return_type(&sig.output, attrs)?.expect("Getter must return a type"),
132            default,
133            deprecated: crate::gen_stub::attr::extract_deprecated(attrs),
134        })
135    }
136    pub fn new_classattr_const(item: ImplItemConst) -> Result<Self> {
137        assert!(Self::is_classattr(&item.attrs)?);
138        let ImplItemConst {
139            attrs,
140            ident,
141            ty,
142            expr,
143            ..
144        } = item;
145        let doc = extract_documents(&attrs).join("\n");
146        Ok(MemberInfo {
147            doc,
148            name: ident.to_string(),
149            r#type: TypeOrOverride::RustType { r#type: ty },
150            default: Some(expr),
151            deprecated: crate::gen_stub::attr::extract_deprecated(&attrs),
152        })
153    }
154}
155
156impl TryFrom<Field> for MemberInfo {
157    type Error = Error;
158    fn try_from(field: Field) -> Result<Self> {
159        let Field {
160            ident, ty, attrs, ..
161        } = field;
162        let mut field_name = None;
163        for attr in parse_pyo3_attrs(&attrs)? {
164            if let Attr::Name(name) = attr {
165                field_name = Some(name);
166            }
167        }
168        let doc = extract_documents(&attrs).join("\n");
169        let default = parse_gen_stub_default(&attrs)?;
170        let deprecated = crate::gen_stub::attr::extract_deprecated(&attrs);
171        Ok(Self {
172            name: field_name.unwrap_or(ident.unwrap().to_string()),
173            r#type: TypeOrOverride::RustType { r#type: ty },
174            doc,
175            default,
176            deprecated,
177        })
178    }
179}
180
181impl ToTokens for MemberInfo {
182    fn to_tokens(&self, tokens: &mut TokenStream2) {
183        let Self {
184            name,
185            r#type,
186            doc,
187            default,
188            deprecated,
189        } = self;
190        let default = default
191            .as_ref()
192            .map(|value| {
193                if value.to_token_stream().to_string() == "None" {
194                    quote! {
195                        "None".to_string()
196                    }
197                } else {
198                    let (TypeOrOverride::RustType { r#type: ty }
199                    | TypeOrOverride::OverrideType { r#type: ty, .. }) = r#type;
200                    quote! {
201                        ::pyo3::prepare_freethreaded_python();
202                        ::pyo3::Python::with_gil(|py| -> String {
203                            let v: #ty = #value;
204                            ::pyo3_stub_gen::util::fmt_py_obj(py, v)
205                        })
206                    }
207                }
208            })
209            .map_or(quote! {None}, |default| {
210                quote! {Some({
211                    static DEFAULT: std::sync::LazyLock<String> = std::sync::LazyLock::new(|| {
212                        #default
213                    });
214                    &DEFAULT
215                })}
216            });
217        let deprecated_info = deprecated
218            .as_ref()
219            .map(|deprecated| {
220                quote! {
221                    Some(::pyo3_stub_gen::type_info::DeprecatedInfo {
222                        since: #deprecated.since,
223                        note: #deprecated.note,
224                    })
225                }
226            })
227            .unwrap_or_else(|| quote! { None });
228        match r#type {
229            TypeOrOverride::RustType { r#type: ty } => tokens.append_all(quote! {
230                ::pyo3_stub_gen::type_info::MemberInfo {
231                    name: #name,
232                    r#type: <#ty as ::pyo3_stub_gen::PyStubType>::type_output,
233                    doc: #doc,
234                    default: #default,
235                    deprecated: #deprecated_info,
236                }
237            }),
238            TypeOrOverride::OverrideType {
239                type_repr, imports, ..
240            } => {
241                let imports = imports.iter().collect::<Vec<&String>>();
242                tokens.append_all(quote! {
243                    ::pyo3_stub_gen::type_info::MemberInfo {
244                        name: #name,
245                        r#type: || ::pyo3_stub_gen::TypeInfo { name: #type_repr.to_string(), import: ::std::collections::HashSet::from([#(#imports.into(),)*]) },
246                        doc: #doc,
247                        default: #default,
248                        deprecated: #deprecated_info,
249                    }
250                })
251            }
252        }
253    }
254}
255
256impl From<MemberInfo> for ArgInfo {
257    fn from(value: MemberInfo) -> Self {
258        let MemberInfo { name, r#type, .. } = value;
259
260        Self { name, r#type }
261    }
262}