pyo3_stub_gen_derive/gen_stub/
member.rs

1use crate::gen_stub::{
2    attr::{parse_gen_stub_default, parse_gen_stub_override_type, OverrideTypeAttribute},
3    extract_documents,
4    util::TypeOrOverride,
5};
6
7use super::{extract_return_type, parse_pyo3_attrs, Attr};
8
9use crate::gen_stub::arg::ArgInfo;
10use proc_macro2::TokenStream as TokenStream2;
11use quote::{quote, ToTokens, TokenStreamExt};
12use syn::{Attribute, Error, Expr, Field, FnArg, ImplItemConst, ImplItemFn, Result};
13
14#[derive(Debug, Clone)]
15pub struct MemberInfo {
16    doc: String,
17    name: String,
18    r#type: TypeOrOverride,
19    default: Option<Expr>,
20    deprecated: Option<crate::gen_stub::attr::DeprecatedInfo>,
21}
22
23impl MemberInfo {
24    pub fn is_getter(attrs: &[Attribute]) -> Result<bool> {
25        let attrs = parse_pyo3_attrs(attrs)?;
26        Ok(attrs.iter().any(|attr| matches!(attr, Attr::Getter(_))))
27    }
28    pub fn is_setter(attrs: &[Attribute]) -> Result<bool> {
29        let attrs = parse_pyo3_attrs(attrs)?;
30        Ok(attrs.iter().any(|attr| matches!(attr, Attr::Setter(_))))
31    }
32
33    pub fn is_classattr(attrs: &[Attribute]) -> Result<bool> {
34        let attrs = parse_pyo3_attrs(attrs)?;
35        Ok(attrs.iter().any(|attr| matches!(attr, Attr::ClassAttr)))
36    }
37    pub fn is_get(field: &Field) -> Result<bool> {
38        let Field { attrs, .. } = field;
39        Ok(parse_pyo3_attrs(attrs)?
40            .iter()
41            .any(|attr| matches!(attr, Attr::Get)))
42    }
43    pub fn is_set(field: &Field) -> Result<bool> {
44        let Field { attrs, .. } = field;
45        Ok(parse_pyo3_attrs(attrs)?
46            .iter()
47            .any(|attr| matches!(attr, Attr::Set)))
48    }
49}
50
51impl MemberInfo {
52    pub fn new_getter(item: ImplItemFn) -> Result<Self> {
53        assert!(Self::is_getter(&item.attrs)?);
54        let ImplItemFn { attrs, sig, .. } = &item;
55        let default = parse_gen_stub_default(attrs)?;
56        let doc = extract_documents(attrs).join("\n");
57        let pyo3_attrs = parse_pyo3_attrs(attrs)?;
58        for attr in pyo3_attrs {
59            if let Attr::Getter(name) = attr {
60                let fn_name = sig.ident.to_string();
61                let fn_getter_name = match fn_name.strip_prefix("get_") {
62                    Some(s) => s.to_owned(),
63                    None => fn_name,
64                };
65                return Ok(MemberInfo {
66                    doc,
67                    name: name.unwrap_or(fn_getter_name),
68                    r#type: extract_return_type(&sig.output, attrs)?
69                        .expect("Getter must return a type"),
70                    default,
71                    deprecated: crate::gen_stub::attr::extract_deprecated(attrs),
72                });
73            }
74        }
75        unreachable!("Not a getter: {:?}", item)
76    }
77    pub fn new_setter(item: ImplItemFn) -> Result<Self> {
78        assert!(Self::is_setter(&item.attrs)?);
79        let ImplItemFn { attrs, sig, .. } = &item;
80        let default = parse_gen_stub_default(attrs)?;
81        let doc = extract_documents(attrs).join("\n");
82        let pyo3_attrs = parse_pyo3_attrs(attrs)?;
83        for attr in pyo3_attrs {
84            if let Attr::Setter(name) = attr {
85                let fn_name = sig.ident.to_string();
86                let fn_setter_name = match fn_name.strip_prefix("set_") {
87                    Some(s) => s.to_owned(),
88                    None => fn_name,
89                };
90                let r#type = sig
91                    .inputs
92                    .get(1)
93                    .ok_or(syn::Error::new_spanned(&item, "Setter must input a type"))
94                    .and_then(|arg| {
95                        if let FnArg::Typed(t) = arg {
96                            Ok(match parse_gen_stub_override_type(&t.attrs)? {
97                                Some(OverrideTypeAttribute { type_repr, imports }) => {
98                                    TypeOrOverride::OverrideType {
99                                        r#type: *t.ty.clone(),
100                                        type_repr,
101                                        imports,
102                                        rust_type_markers: vec![],
103                                    }
104                                }
105                                _ => TypeOrOverride::RustType {
106                                    r#type: *t.ty.clone(),
107                                },
108                            })
109                        } else {
110                            Err(syn::Error::new_spanned(&item, "Setter must input a type"))
111                        }
112                    })?;
113                return Ok(MemberInfo {
114                    doc,
115                    name: name.unwrap_or(fn_setter_name),
116                    r#type,
117                    default,
118                    deprecated: crate::gen_stub::attr::extract_deprecated(attrs),
119                });
120            }
121        }
122        unreachable!("Not a setter: {:?}", item)
123    }
124    pub fn new_classattr_fn(item: ImplItemFn) -> Result<Self> {
125        assert!(Self::is_classattr(&item.attrs)?);
126        let ImplItemFn { attrs, sig, .. } = &item;
127        let default = parse_gen_stub_default(attrs)?;
128        let doc = extract_documents(attrs).join("\n");
129        let mut name = sig.ident.to_string();
130        for attr in parse_pyo3_attrs(attrs)? {
131            if let Attr::Name(_name) = attr {
132                name = _name;
133            }
134        }
135        Ok(MemberInfo {
136            doc,
137            name,
138            r#type: extract_return_type(&sig.output, attrs)?.expect("Getter must return a type"),
139            default,
140            deprecated: crate::gen_stub::attr::extract_deprecated(attrs),
141        })
142    }
143    pub fn new_classattr_const(item: ImplItemConst) -> Result<Self> {
144        assert!(Self::is_classattr(&item.attrs)?);
145        let ImplItemConst {
146            attrs,
147            ident,
148            ty,
149            expr,
150            ..
151        } = item;
152        let doc = extract_documents(&attrs).join("\n");
153        let mut name = ident.to_string();
154        for attr in parse_pyo3_attrs(&attrs)? {
155            if let Attr::Name(_name) = attr {
156                name = _name;
157            }
158        }
159        Ok(MemberInfo {
160            doc,
161            name,
162            r#type: TypeOrOverride::RustType { r#type: ty },
163            default: Some(expr),
164            deprecated: crate::gen_stub::attr::extract_deprecated(&attrs),
165        })
166    }
167}
168
169impl TryFrom<Field> for MemberInfo {
170    type Error = Error;
171    fn try_from(field: Field) -> Result<Self> {
172        let Field {
173            ident, ty, attrs, ..
174        } = field;
175        let mut field_name = None;
176        for attr in parse_pyo3_attrs(&attrs)? {
177            if let Attr::Name(name) = attr {
178                field_name = Some(name);
179            }
180        }
181        let doc = extract_documents(&attrs).join("\n");
182        let default = parse_gen_stub_default(&attrs)?;
183        let deprecated = crate::gen_stub::attr::extract_deprecated(&attrs);
184        Ok(Self {
185            name: field_name.unwrap_or(ident.unwrap().to_string()),
186            r#type: TypeOrOverride::RustType { r#type: ty },
187            doc,
188            default,
189            deprecated,
190        })
191    }
192}
193
194impl ToTokens for MemberInfo {
195    fn to_tokens(&self, tokens: &mut TokenStream2) {
196        let Self {
197            name,
198            r#type,
199            doc,
200            default,
201            deprecated,
202        } = self;
203        let default = default
204            .as_ref()
205            .map(|value| {
206                if value.to_token_stream().to_string() == "None" {
207                    quote! {
208                        "None".to_string()
209                    }
210                } else {
211                    let (TypeOrOverride::RustType { r#type: ty }
212                    | TypeOrOverride::OverrideType { r#type: ty, .. }) = r#type;
213                    quote! {
214                    let v: #ty = #value;
215                    ::pyo3_stub_gen::util::fmt_py_obj(v)
216                    }
217                }
218            })
219            .map_or(quote! {None}, |default| {
220                quote! {Some({
221                    fn _fmt() -> String {
222                        #default
223                    }
224                    _fmt
225                })}
226            });
227        let deprecated_info = deprecated
228            .as_ref()
229            .map(|deprecated| {
230                quote! {
231                    Some(::pyo3_stub_gen::type_info::DeprecatedInfo {
232                        since: #deprecated.since,
233                        note: #deprecated.note,
234                    })
235                }
236            })
237            .unwrap_or_else(|| quote! { None });
238        match r#type {
239            TypeOrOverride::RustType { r#type: ty } => tokens.append_all(quote! {
240                ::pyo3_stub_gen::type_info::MemberInfo {
241                    name: #name,
242                    r#type: <#ty as ::pyo3_stub_gen::PyStubType>::type_output,
243                    doc: #doc,
244                    default: #default,
245                    deprecated: #deprecated_info,
246                }
247            }),
248            TypeOrOverride::OverrideType {
249                type_repr,
250                imports,
251                rust_type_markers,
252                ..
253            } => {
254                let imports = imports.iter().collect::<Vec<&String>>();
255
256                // Generate code to process RustType markers
257                let (type_name_code, type_refs_code) = if rust_type_markers.is_empty() {
258                    (
259                        quote! { #type_repr.to_string() },
260                        quote! { ::std::collections::HashMap::new() },
261                    )
262                } else {
263                    // Parse rust_type_markers as syn::Type
264                    let marker_types: Vec<syn::Type> = rust_type_markers
265                        .iter()
266                        .filter_map(|s| syn::parse_str(s).ok())
267                        .collect();
268
269                    let rust_names = rust_type_markers.iter().collect::<Vec<_>>();
270
271                    (
272                        quote! {
273                            {
274                                let mut type_name = #type_repr.to_string();
275                                #(
276                                    let type_info = <#marker_types as ::pyo3_stub_gen::PyStubType>::type_input();
277                                    type_name = type_name.replace(#rust_names, &type_info.name);
278                                )*
279                                type_name
280                            }
281                        },
282                        quote! {
283                            {
284                                let mut type_refs = ::std::collections::HashMap::new();
285                                #(
286                                    let type_info = <#marker_types as ::pyo3_stub_gen::PyStubType>::type_input();
287                                    if let Some(module) = type_info.source_module {
288                                        type_refs.insert(
289                                            type_info.name.split('[').next().unwrap_or(&type_info.name).split('.').last().unwrap_or(&type_info.name).to_string(),
290                                            ::pyo3_stub_gen::TypeIdentifierRef {
291                                                module: module.into(),
292                                                import_kind: ::pyo3_stub_gen::ImportKind::Module,
293                                            }
294                                        );
295                                    }
296                                    type_refs.extend(type_info.type_refs);
297                                )*
298                                type_refs
299                            }
300                        },
301                    )
302                };
303
304                tokens.append_all(quote! {
305                    ::pyo3_stub_gen::type_info::MemberInfo {
306                        name: #name,
307                        r#type: || ::pyo3_stub_gen::TypeInfo { name: #type_name_code, source_module: None, import: ::std::collections::HashSet::from([#(#imports.into(),)*]), type_refs: #type_refs_code },
308                        doc: #doc,
309                        default: #default,
310                        deprecated: #deprecated_info,
311                    }
312                })
313            }
314        }
315    }
316}
317
318impl From<MemberInfo> for ArgInfo {
319    fn from(value: MemberInfo) -> Self {
320        let MemberInfo { name, r#type, .. } = value;
321
322        Self { name, r#type }
323    }
324}