pyo3_stub_gen_derive/gen_stub/
member.rs

1use crate::gen_stub::{
2    attr::{parse_gen_stub_default, parse_gen_stub_override_type, OverrideTypeAttribute},
3    extract_documents,
4    util::TypeOrOverride,
5};
6
7use super::{extract_return_type, parse_pyo3_attrs, Attr};
8
9use crate::gen_stub::arg::ArgInfo;
10use proc_macro2::TokenStream as TokenStream2;
11use quote::{quote, ToTokens, TokenStreamExt};
12use syn::{Attribute, Error, Expr, Field, FnArg, ImplItemConst, ImplItemFn, Result};
13
14/// Determines which `PyStubType` method to use when generating the type annotation.
15#[derive(Debug, Clone, Copy)]
16pub enum MemberKind {
17    /// Getter or classattr: uses `type_output` (the return type)
18    Getter,
19    /// Setter: uses `type_input` (the parameter type)
20    Setter,
21}
22
23impl MemberKind {
24    fn use_type_input(self) -> bool {
25        matches!(self, MemberKind::Setter)
26    }
27}
28
29#[derive(Debug, Clone)]
30pub struct MemberInfo {
31    doc: String,
32    name: String,
33    r#type: TypeOrOverride,
34    default: Option<Expr>,
35    deprecated: Option<crate::gen_stub::attr::DeprecatedInfo>,
36    kind: MemberKind,
37}
38
39impl MemberInfo {
40    pub fn is_getter(attrs: &[Attribute]) -> Result<bool> {
41        let attrs = parse_pyo3_attrs(attrs)?;
42        Ok(attrs.iter().any(|attr| matches!(attr, Attr::Getter(_))))
43    }
44    pub fn is_setter(attrs: &[Attribute]) -> Result<bool> {
45        let attrs = parse_pyo3_attrs(attrs)?;
46        Ok(attrs.iter().any(|attr| matches!(attr, Attr::Setter(_))))
47    }
48
49    pub fn is_classattr(attrs: &[Attribute]) -> Result<bool> {
50        let attrs = parse_pyo3_attrs(attrs)?;
51        Ok(attrs.iter().any(|attr| matches!(attr, Attr::ClassAttr)))
52    }
53    pub fn is_get(field: &Field) -> Result<bool> {
54        let Field { attrs, .. } = field;
55        Ok(parse_pyo3_attrs(attrs)?
56            .iter()
57            .any(|attr| matches!(attr, Attr::Get)))
58    }
59    pub fn is_set(field: &Field) -> Result<bool> {
60        let Field { attrs, .. } = field;
61        Ok(parse_pyo3_attrs(attrs)?
62            .iter()
63            .any(|attr| matches!(attr, Attr::Set)))
64    }
65}
66
67impl MemberInfo {
68    /// Create a new `MemberInfo` from a getter function.
69    ///
70    /// The property name is determined by the following precedence:
71    /// 1. `#[pyo3(name = "...")]` - explicit name via pyo3 attribute
72    /// 2. `#[getter(name)]` - explicit name via getter attribute
73    /// 3. Function name with `get_` prefix stripped
74    ///
75    /// Note: pyo3 does not allow specifying both `#[getter(name)]` and
76    /// `#[pyo3(name = "...")]` at the same time (compile error: "name may only be specified once").
77    pub fn new_getter(item: ImplItemFn) -> Result<Self> {
78        assert!(Self::is_getter(&item.attrs)?);
79        let ImplItemFn { attrs, sig, .. } = &item;
80        let default = parse_gen_stub_default(attrs)?;
81        let doc = extract_documents(attrs).join("\n");
82        let pyo3_attrs = parse_pyo3_attrs(attrs)?;
83
84        // First, get the name from #[getter] or #[getter(name)]
85        let mut name = None;
86        for attr in &pyo3_attrs {
87            if let Attr::Getter(getter_name) = attr {
88                let fn_name = sig.ident.to_string();
89                let fn_getter_name = match fn_name.strip_prefix("get_") {
90                    Some(s) => s.to_owned(),
91                    None => fn_name,
92                };
93                name = Some(getter_name.clone().unwrap_or(fn_getter_name));
94                break;
95            }
96        }
97
98        // Then, check for #[pyo3(name = "...")] which takes precedence
99        for attr in &pyo3_attrs {
100            if let Attr::Name(pyo3_name) = attr {
101                name = Some(pyo3_name.clone());
102                break;
103            }
104        }
105
106        let name = name.ok_or_else(|| Error::new_spanned(&item, "Not a getter"))?;
107        let r#type = extract_return_type(&sig.output, attrs)?
108            .ok_or_else(|| Error::new_spanned(&item, "Getter must return a type"))?;
109        Ok(MemberInfo {
110            doc,
111            name,
112            r#type,
113            default,
114            deprecated: crate::gen_stub::attr::extract_deprecated(attrs),
115            kind: MemberKind::Getter,
116        })
117    }
118    /// Create a new `MemberInfo` from a setter function.
119    ///
120    /// The property name is determined by the following precedence:
121    /// 1. `#[pyo3(name = "...")]` - explicit name via pyo3 attribute
122    /// 2. `#[setter(name)]` - explicit name via setter attribute
123    /// 3. Function name with `set_` prefix stripped
124    ///
125    /// Note: pyo3 does not allow specifying both `#[setter(name)]` and
126    /// `#[pyo3(name = "...")]` at the same time (compile error: "name may only be specified once").
127    pub fn new_setter(item: ImplItemFn) -> Result<Self> {
128        assert!(Self::is_setter(&item.attrs)?);
129        let ImplItemFn { attrs, sig, .. } = &item;
130        let default = parse_gen_stub_default(attrs)?;
131        let doc = extract_documents(attrs).join("\n");
132        let pyo3_attrs = parse_pyo3_attrs(attrs)?;
133
134        // First, get the name from #[setter] or #[setter(name)]
135        let mut name = None;
136        let mut r#type = None;
137        for attr in &pyo3_attrs {
138            if let Attr::Setter(setter_name) = attr {
139                let fn_name = sig.ident.to_string();
140                let fn_setter_name = match fn_name.strip_prefix("set_") {
141                    Some(s) => s.to_owned(),
142                    None => fn_name,
143                };
144                name = Some(setter_name.clone().unwrap_or(fn_setter_name));
145                r#type = Some(
146                    sig.inputs
147                        .get(1)
148                        .ok_or(syn::Error::new_spanned(&item, "Setter must input a type"))
149                        .and_then(|arg| {
150                            if let FnArg::Typed(t) = arg {
151                                Ok(match parse_gen_stub_override_type(&t.attrs)? {
152                                    Some(OverrideTypeAttribute { type_repr, imports }) => {
153                                        TypeOrOverride::OverrideType {
154                                            r#type: *t.ty.clone(),
155                                            type_repr,
156                                            imports,
157                                            rust_type_markers: vec![],
158                                        }
159                                    }
160                                    _ => TypeOrOverride::RustType {
161                                        r#type: *t.ty.clone(),
162                                    },
163                                })
164                            } else {
165                                Err(syn::Error::new_spanned(&item, "Setter must input a type"))
166                            }
167                        })?,
168                );
169                break;
170            }
171        }
172
173        // Then, check for #[pyo3(name = "...")] which takes precedence
174        for attr in &pyo3_attrs {
175            if let Attr::Name(pyo3_name) = attr {
176                name = Some(pyo3_name.clone());
177                break;
178            }
179        }
180
181        let name = name.ok_or_else(|| Error::new_spanned(&item, "Not a setter"))?;
182        let r#type = r#type.ok_or_else(|| Error::new_spanned(&item, "Setter type not found"))?;
183        Ok(MemberInfo {
184            doc,
185            name,
186            r#type,
187            default,
188            deprecated: crate::gen_stub::attr::extract_deprecated(attrs),
189            kind: MemberKind::Setter,
190        })
191    }
192    pub fn new_classattr_fn(item: ImplItemFn) -> Result<Self> {
193        assert!(Self::is_classattr(&item.attrs)?);
194        let ImplItemFn { attrs, sig, .. } = &item;
195        let default = parse_gen_stub_default(attrs)?;
196        let doc = extract_documents(attrs).join("\n");
197        let mut name = sig.ident.to_string();
198        for attr in parse_pyo3_attrs(attrs)? {
199            if let Attr::Name(_name) = attr {
200                name = _name;
201            }
202        }
203        Ok(MemberInfo {
204            doc,
205            name,
206            r#type: extract_return_type(&sig.output, attrs)?.expect("Getter must return a type"),
207            default,
208            deprecated: crate::gen_stub::attr::extract_deprecated(attrs),
209            kind: MemberKind::Getter,
210        })
211    }
212    pub fn new_classattr_const(item: ImplItemConst) -> Result<Self> {
213        assert!(Self::is_classattr(&item.attrs)?);
214        let ImplItemConst {
215            attrs,
216            ident,
217            ty,
218            expr,
219            ..
220        } = item;
221        let doc = extract_documents(&attrs).join("\n");
222        let mut name = ident.to_string();
223        for attr in parse_pyo3_attrs(&attrs)? {
224            if let Attr::Name(_name) = attr {
225                name = _name;
226            }
227        }
228        Ok(MemberInfo {
229            doc,
230            name,
231            r#type: TypeOrOverride::RustType { r#type: ty },
232            default: Some(expr),
233            deprecated: crate::gen_stub::attr::extract_deprecated(&attrs),
234            kind: MemberKind::Getter,
235        })
236    }
237}
238
239impl MemberInfo {
240    pub fn from_field(field: Field, kind: MemberKind) -> Result<Self> {
241        let Field {
242            ident, ty, attrs, ..
243        } = field;
244        let mut field_name = None;
245        for attr in parse_pyo3_attrs(&attrs)? {
246            if let Attr::Name(name) = attr {
247                field_name = Some(name);
248            }
249        }
250        let doc = extract_documents(&attrs).join("\n");
251        let default = parse_gen_stub_default(&attrs)?;
252        let deprecated = crate::gen_stub::attr::extract_deprecated(&attrs);
253        Ok(Self {
254            name: field_name.unwrap_or(ident.unwrap().to_string()),
255            r#type: TypeOrOverride::RustType { r#type: ty },
256            doc,
257            default,
258            deprecated,
259            kind,
260        })
261    }
262}
263
264impl ToTokens for MemberInfo {
265    fn to_tokens(&self, tokens: &mut TokenStream2) {
266        let Self {
267            name,
268            r#type,
269            doc,
270            default,
271            deprecated,
272            kind,
273        } = self;
274        let use_type_input = kind.use_type_input();
275        let default = default
276            .as_ref()
277            .map(|value| {
278                if value.to_token_stream().to_string() == "None" {
279                    quote! {
280                        "None".to_string()
281                    }
282                } else {
283                    let (TypeOrOverride::RustType { r#type: ty }
284                    | TypeOrOverride::OverrideType { r#type: ty, .. }) = r#type;
285                    quote! {
286                    let v: #ty = #value;
287                    ::pyo3_stub_gen::util::fmt_py_obj(v)
288                    }
289                }
290            })
291            .map_or(quote! {None}, |default| {
292                quote! {Some({
293                    fn _fmt() -> String {
294                        #default
295                    }
296                    _fmt
297                })}
298            });
299        let deprecated_info = deprecated
300            .as_ref()
301            .map(|deprecated| {
302                quote! {
303                    Some(::pyo3_stub_gen::type_info::DeprecatedInfo {
304                        since: #deprecated.since,
305                        note: #deprecated.note,
306                    })
307                }
308            })
309            .unwrap_or_else(|| quote! { None });
310        let type_fn = if use_type_input {
311            quote! { type_input }
312        } else {
313            quote! { type_output }
314        };
315        match r#type {
316            TypeOrOverride::RustType { r#type: ty } => tokens.append_all(quote! {
317                ::pyo3_stub_gen::type_info::MemberInfo {
318                    name: #name,
319                    r#type: <#ty as ::pyo3_stub_gen::PyStubType>::#type_fn,
320                    doc: #doc,
321                    default: #default,
322                    deprecated: #deprecated_info,
323                }
324            }),
325            TypeOrOverride::OverrideType {
326                type_repr,
327                imports,
328                rust_type_markers,
329                ..
330            } => {
331                let imports = imports.iter().collect::<Vec<&String>>();
332
333                // Generate code to process RustType markers
334                let (type_name_code, type_refs_code) = if rust_type_markers.is_empty() {
335                    (
336                        quote! { #type_repr.to_string() },
337                        quote! { ::std::collections::HashMap::new() },
338                    )
339                } else {
340                    // Parse rust_type_markers as syn::Type
341                    let marker_types: Vec<syn::Type> = rust_type_markers
342                        .iter()
343                        .filter_map(|s| syn::parse_str(s).ok())
344                        .collect();
345
346                    let rust_names = rust_type_markers.iter().collect::<Vec<_>>();
347
348                    (
349                        quote! {
350                            {
351                                let mut type_name = #type_repr.to_string();
352                                #(
353                                    let type_info = <#marker_types as ::pyo3_stub_gen::PyStubType>::#type_fn();
354                                    type_name = type_name.replace(#rust_names, &type_info.name);
355                                )*
356                                type_name
357                            }
358                        },
359                        quote! {
360                            {
361                                let mut type_refs = ::std::collections::HashMap::new();
362                                #(
363                                    let type_info = <#marker_types as ::pyo3_stub_gen::PyStubType>::#type_fn();
364                                    if let Some(module) = type_info.source_module {
365                                        type_refs.insert(
366                                            type_info.name.split('[').next().unwrap_or(&type_info.name).split('.').last().unwrap_or(&type_info.name).to_string(),
367                                            ::pyo3_stub_gen::TypeIdentifierRef {
368                                                module: module.into(),
369                                                import_kind: ::pyo3_stub_gen::ImportKind::Module,
370                                            }
371                                        );
372                                    }
373                                    type_refs.extend(type_info.type_refs);
374                                )*
375                                type_refs
376                            }
377                        },
378                    )
379                };
380
381                tokens.append_all(quote! {
382                    ::pyo3_stub_gen::type_info::MemberInfo {
383                        name: #name,
384                        r#type: || ::pyo3_stub_gen::TypeInfo { name: #type_name_code, source_module: None, import: ::std::collections::HashSet::from([#(#imports.into(),)*]), type_refs: #type_refs_code },
385                        doc: #doc,
386                        default: #default,
387                        deprecated: #deprecated_info,
388                    }
389                })
390            }
391        }
392    }
393}
394
395impl From<MemberInfo> for ArgInfo {
396    fn from(value: MemberInfo) -> Self {
397        let MemberInfo { name, r#type, .. } = value;
398
399        Self { name, r#type }
400    }
401}