pyo3_stub_gen_derive/gen_stub/
member.rs

1use crate::gen_stub::{
2    attr::{parse_gen_stub_default, parse_gen_stub_override_type, OverrideTypeAttribute},
3    extract_documents,
4    util::TypeOrOverride,
5};
6
7use super::{extract_return_type, parse_pyo3_attrs, Attr};
8
9use crate::gen_stub::arg::ArgInfo;
10use proc_macro2::TokenStream as TokenStream2;
11use quote::{quote, ToTokens, TokenStreamExt};
12use syn::{Attribute, Error, Expr, Field, FnArg, ImplItemConst, ImplItemFn, Result};
13
14#[derive(Debug, Clone)]
15pub struct MemberInfo {
16    doc: String,
17    name: String,
18    r#type: TypeOrOverride,
19    default: Option<Expr>,
20    deprecated: Option<crate::gen_stub::attr::DeprecatedInfo>,
21}
22
23impl MemberInfo {
24    pub fn is_getter(attrs: &[Attribute]) -> Result<bool> {
25        let attrs = parse_pyo3_attrs(attrs)?;
26        Ok(attrs.iter().any(|attr| matches!(attr, Attr::Getter(_))))
27    }
28    pub fn is_setter(attrs: &[Attribute]) -> Result<bool> {
29        let attrs = parse_pyo3_attrs(attrs)?;
30        Ok(attrs.iter().any(|attr| matches!(attr, Attr::Setter(_))))
31    }
32
33    pub fn is_classattr(attrs: &[Attribute]) -> Result<bool> {
34        let attrs = parse_pyo3_attrs(attrs)?;
35        Ok(attrs.iter().any(|attr| matches!(attr, Attr::ClassAttr)))
36    }
37    pub fn is_get(field: &Field) -> Result<bool> {
38        let Field { attrs, .. } = field;
39        Ok(parse_pyo3_attrs(attrs)?
40            .iter()
41            .any(|attr| matches!(attr, Attr::Get)))
42    }
43    pub fn is_set(field: &Field) -> Result<bool> {
44        let Field { attrs, .. } = field;
45        Ok(parse_pyo3_attrs(attrs)?
46            .iter()
47            .any(|attr| matches!(attr, Attr::Set)))
48    }
49}
50
51impl MemberInfo {
52    /// Create a new `MemberInfo` from a getter function.
53    ///
54    /// The property name is determined by the following precedence:
55    /// 1. `#[pyo3(name = "...")]` - explicit name via pyo3 attribute
56    /// 2. `#[getter(name)]` - explicit name via getter attribute
57    /// 3. Function name with `get_` prefix stripped
58    ///
59    /// Note: pyo3 does not allow specifying both `#[getter(name)]` and
60    /// `#[pyo3(name = "...")]` at the same time (compile error: "name may only be specified once").
61    pub fn new_getter(item: ImplItemFn) -> Result<Self> {
62        assert!(Self::is_getter(&item.attrs)?);
63        let ImplItemFn { attrs, sig, .. } = &item;
64        let default = parse_gen_stub_default(attrs)?;
65        let doc = extract_documents(attrs).join("\n");
66        let pyo3_attrs = parse_pyo3_attrs(attrs)?;
67
68        // First, get the name from #[getter] or #[getter(name)]
69        let mut name = None;
70        for attr in &pyo3_attrs {
71            if let Attr::Getter(getter_name) = attr {
72                let fn_name = sig.ident.to_string();
73                let fn_getter_name = match fn_name.strip_prefix("get_") {
74                    Some(s) => s.to_owned(),
75                    None => fn_name,
76                };
77                name = Some(getter_name.clone().unwrap_or(fn_getter_name));
78                break;
79            }
80        }
81
82        // Then, check for #[pyo3(name = "...")] which takes precedence
83        for attr in &pyo3_attrs {
84            if let Attr::Name(pyo3_name) = attr {
85                name = Some(pyo3_name.clone());
86                break;
87            }
88        }
89
90        let name = name.ok_or_else(|| Error::new_spanned(&item, "Not a getter"))?;
91        let r#type = extract_return_type(&sig.output, attrs)?
92            .ok_or_else(|| Error::new_spanned(&item, "Getter must return a type"))?;
93        Ok(MemberInfo {
94            doc,
95            name,
96            r#type,
97            default,
98            deprecated: crate::gen_stub::attr::extract_deprecated(attrs),
99        })
100    }
101    /// Create a new `MemberInfo` from a setter function.
102    ///
103    /// The property name is determined by the following precedence:
104    /// 1. `#[pyo3(name = "...")]` - explicit name via pyo3 attribute
105    /// 2. `#[setter(name)]` - explicit name via setter attribute
106    /// 3. Function name with `set_` prefix stripped
107    ///
108    /// Note: pyo3 does not allow specifying both `#[setter(name)]` and
109    /// `#[pyo3(name = "...")]` at the same time (compile error: "name may only be specified once").
110    pub fn new_setter(item: ImplItemFn) -> Result<Self> {
111        assert!(Self::is_setter(&item.attrs)?);
112        let ImplItemFn { attrs, sig, .. } = &item;
113        let default = parse_gen_stub_default(attrs)?;
114        let doc = extract_documents(attrs).join("\n");
115        let pyo3_attrs = parse_pyo3_attrs(attrs)?;
116
117        // First, get the name from #[setter] or #[setter(name)]
118        let mut name = None;
119        let mut r#type = None;
120        for attr in &pyo3_attrs {
121            if let Attr::Setter(setter_name) = attr {
122                let fn_name = sig.ident.to_string();
123                let fn_setter_name = match fn_name.strip_prefix("set_") {
124                    Some(s) => s.to_owned(),
125                    None => fn_name,
126                };
127                name = Some(setter_name.clone().unwrap_or(fn_setter_name));
128                r#type = Some(
129                    sig.inputs
130                        .get(1)
131                        .ok_or(syn::Error::new_spanned(&item, "Setter must input a type"))
132                        .and_then(|arg| {
133                            if let FnArg::Typed(t) = arg {
134                                Ok(match parse_gen_stub_override_type(&t.attrs)? {
135                                    Some(OverrideTypeAttribute { type_repr, imports }) => {
136                                        TypeOrOverride::OverrideType {
137                                            r#type: *t.ty.clone(),
138                                            type_repr,
139                                            imports,
140                                            rust_type_markers: vec![],
141                                        }
142                                    }
143                                    _ => TypeOrOverride::RustType {
144                                        r#type: *t.ty.clone(),
145                                    },
146                                })
147                            } else {
148                                Err(syn::Error::new_spanned(&item, "Setter must input a type"))
149                            }
150                        })?,
151                );
152                break;
153            }
154        }
155
156        // Then, check for #[pyo3(name = "...")] which takes precedence
157        for attr in &pyo3_attrs {
158            if let Attr::Name(pyo3_name) = attr {
159                name = Some(pyo3_name.clone());
160                break;
161            }
162        }
163
164        let name = name.ok_or_else(|| Error::new_spanned(&item, "Not a setter"))?;
165        let r#type = r#type.ok_or_else(|| Error::new_spanned(&item, "Setter type not found"))?;
166        Ok(MemberInfo {
167            doc,
168            name,
169            r#type,
170            default,
171            deprecated: crate::gen_stub::attr::extract_deprecated(attrs),
172        })
173    }
174    pub fn new_classattr_fn(item: ImplItemFn) -> Result<Self> {
175        assert!(Self::is_classattr(&item.attrs)?);
176        let ImplItemFn { attrs, sig, .. } = &item;
177        let default = parse_gen_stub_default(attrs)?;
178        let doc = extract_documents(attrs).join("\n");
179        let mut name = sig.ident.to_string();
180        for attr in parse_pyo3_attrs(attrs)? {
181            if let Attr::Name(_name) = attr {
182                name = _name;
183            }
184        }
185        Ok(MemberInfo {
186            doc,
187            name,
188            r#type: extract_return_type(&sig.output, attrs)?.expect("Getter must return a type"),
189            default,
190            deprecated: crate::gen_stub::attr::extract_deprecated(attrs),
191        })
192    }
193    pub fn new_classattr_const(item: ImplItemConst) -> Result<Self> {
194        assert!(Self::is_classattr(&item.attrs)?);
195        let ImplItemConst {
196            attrs,
197            ident,
198            ty,
199            expr,
200            ..
201        } = item;
202        let doc = extract_documents(&attrs).join("\n");
203        let mut name = ident.to_string();
204        for attr in parse_pyo3_attrs(&attrs)? {
205            if let Attr::Name(_name) = attr {
206                name = _name;
207            }
208        }
209        Ok(MemberInfo {
210            doc,
211            name,
212            r#type: TypeOrOverride::RustType { r#type: ty },
213            default: Some(expr),
214            deprecated: crate::gen_stub::attr::extract_deprecated(&attrs),
215        })
216    }
217}
218
219impl TryFrom<Field> for MemberInfo {
220    type Error = Error;
221    fn try_from(field: Field) -> Result<Self> {
222        let Field {
223            ident, ty, attrs, ..
224        } = field;
225        let mut field_name = None;
226        for attr in parse_pyo3_attrs(&attrs)? {
227            if let Attr::Name(name) = attr {
228                field_name = Some(name);
229            }
230        }
231        let doc = extract_documents(&attrs).join("\n");
232        let default = parse_gen_stub_default(&attrs)?;
233        let deprecated = crate::gen_stub::attr::extract_deprecated(&attrs);
234        Ok(Self {
235            name: field_name.unwrap_or(ident.unwrap().to_string()),
236            r#type: TypeOrOverride::RustType { r#type: ty },
237            doc,
238            default,
239            deprecated,
240        })
241    }
242}
243
244impl ToTokens for MemberInfo {
245    fn to_tokens(&self, tokens: &mut TokenStream2) {
246        let Self {
247            name,
248            r#type,
249            doc,
250            default,
251            deprecated,
252        } = self;
253        let default = default
254            .as_ref()
255            .map(|value| {
256                if value.to_token_stream().to_string() == "None" {
257                    quote! {
258                        "None".to_string()
259                    }
260                } else {
261                    let (TypeOrOverride::RustType { r#type: ty }
262                    | TypeOrOverride::OverrideType { r#type: ty, .. }) = r#type;
263                    quote! {
264                    let v: #ty = #value;
265                    ::pyo3_stub_gen::util::fmt_py_obj(v)
266                    }
267                }
268            })
269            .map_or(quote! {None}, |default| {
270                quote! {Some({
271                    fn _fmt() -> String {
272                        #default
273                    }
274                    _fmt
275                })}
276            });
277        let deprecated_info = deprecated
278            .as_ref()
279            .map(|deprecated| {
280                quote! {
281                    Some(::pyo3_stub_gen::type_info::DeprecatedInfo {
282                        since: #deprecated.since,
283                        note: #deprecated.note,
284                    })
285                }
286            })
287            .unwrap_or_else(|| quote! { None });
288        match r#type {
289            TypeOrOverride::RustType { r#type: ty } => tokens.append_all(quote! {
290                ::pyo3_stub_gen::type_info::MemberInfo {
291                    name: #name,
292                    r#type: <#ty as ::pyo3_stub_gen::PyStubType>::type_output,
293                    doc: #doc,
294                    default: #default,
295                    deprecated: #deprecated_info,
296                }
297            }),
298            TypeOrOverride::OverrideType {
299                type_repr,
300                imports,
301                rust_type_markers,
302                ..
303            } => {
304                let imports = imports.iter().collect::<Vec<&String>>();
305
306                // Generate code to process RustType markers
307                let (type_name_code, type_refs_code) = if rust_type_markers.is_empty() {
308                    (
309                        quote! { #type_repr.to_string() },
310                        quote! { ::std::collections::HashMap::new() },
311                    )
312                } else {
313                    // Parse rust_type_markers as syn::Type
314                    let marker_types: Vec<syn::Type> = rust_type_markers
315                        .iter()
316                        .filter_map(|s| syn::parse_str(s).ok())
317                        .collect();
318
319                    let rust_names = rust_type_markers.iter().collect::<Vec<_>>();
320
321                    (
322                        quote! {
323                            {
324                                let mut type_name = #type_repr.to_string();
325                                #(
326                                    let type_info = <#marker_types as ::pyo3_stub_gen::PyStubType>::type_input();
327                                    type_name = type_name.replace(#rust_names, &type_info.name);
328                                )*
329                                type_name
330                            }
331                        },
332                        quote! {
333                            {
334                                let mut type_refs = ::std::collections::HashMap::new();
335                                #(
336                                    let type_info = <#marker_types as ::pyo3_stub_gen::PyStubType>::type_input();
337                                    if let Some(module) = type_info.source_module {
338                                        type_refs.insert(
339                                            type_info.name.split('[').next().unwrap_or(&type_info.name).split('.').last().unwrap_or(&type_info.name).to_string(),
340                                            ::pyo3_stub_gen::TypeIdentifierRef {
341                                                module: module.into(),
342                                                import_kind: ::pyo3_stub_gen::ImportKind::Module,
343                                            }
344                                        );
345                                    }
346                                    type_refs.extend(type_info.type_refs);
347                                )*
348                                type_refs
349                            }
350                        },
351                    )
352                };
353
354                tokens.append_all(quote! {
355                    ::pyo3_stub_gen::type_info::MemberInfo {
356                        name: #name,
357                        r#type: || ::pyo3_stub_gen::TypeInfo { name: #type_name_code, source_module: None, import: ::std::collections::HashSet::from([#(#imports.into(),)*]), type_refs: #type_refs_code },
358                        doc: #doc,
359                        default: #default,
360                        deprecated: #deprecated_info,
361                    }
362                })
363            }
364        }
365    }
366}
367
368impl From<MemberInfo> for ArgInfo {
369    fn from(value: MemberInfo) -> Self {
370        let MemberInfo { name, r#type, .. } = value;
371
372        Self { name, r#type }
373    }
374}