pyo3_stub_gen_derive/gen_stub/
pyclass.rs

1use super::{
2    extract_documents, parse_pyo3_attrs, util::quote_option, Attr, MemberInfo, PyClassAttr,
3    StubType,
4};
5use proc_macro2::TokenStream as TokenStream2;
6use quote::{quote, ToTokens, TokenStreamExt};
7use syn::{parse_quote, Error, ItemStruct, Result, Type};
8
9pub struct PyClassInfo {
10    pyclass_name: String,
11    struct_type: Type,
12    module: Option<String>,
13    getters: Vec<MemberInfo>,
14    setters: Vec<MemberInfo>,
15    doc: String,
16    bases: Vec<Type>,
17    has_eq: bool,
18    has_ord: bool,
19    has_hash: bool,
20    has_str: bool,
21    subclass: bool,
22}
23
24impl From<&PyClassInfo> for StubType {
25    fn from(info: &PyClassInfo) -> Self {
26        let PyClassInfo {
27            pyclass_name,
28            module,
29            struct_type,
30            ..
31        } = info;
32        Self {
33            ty: struct_type.clone(),
34            name: pyclass_name.clone(),
35            module: module.clone(),
36        }
37    }
38}
39
40impl PyClassInfo {
41    /// Create PyClassInfo from ItemStruct with PyClassAttr for module override support
42    pub fn from_item_with_attr(item: ItemStruct, attr: &PyClassAttr) -> Result<Self> {
43        let ItemStruct {
44            ident,
45            attrs,
46            fields,
47            ..
48        } = item;
49        let struct_type: Type = parse_quote!(#ident);
50        let mut pyclass_name = None;
51        let mut pyo3_module = None;
52        let mut gen_stub_standalone_module = None;
53        let mut is_get_all = false;
54        let mut is_set_all = false;
55        let mut bases = Vec::new();
56        let mut has_eq = false;
57        let mut has_ord = false;
58        let mut has_hash = false;
59        let mut has_str = false;
60        let mut subclass = false;
61        for attr_item in parse_pyo3_attrs(&attrs)? {
62            match attr_item {
63                Attr::Name(name) => pyclass_name = Some(name),
64                Attr::Module(name) => {
65                    pyo3_module = Some(name);
66                }
67                Attr::GenStubModule(name) => {
68                    gen_stub_standalone_module = Some(name);
69                }
70                Attr::GetAll => is_get_all = true,
71                Attr::SetAll => is_set_all = true,
72                Attr::Extends(typ) => bases.push(typ),
73                Attr::Eq => has_eq = true,
74                Attr::Ord => has_ord = true,
75                Attr::Hash => has_hash = true,
76                Attr::Str => has_str = true,
77                Attr::Subclass => subclass = true,
78                _ => {}
79            }
80        }
81
82        // Validate: inline and standalone gen_stub modules must not conflict
83        if let (Some(inline_mod), Some(standalone_mod)) =
84            (&attr.module, &gen_stub_standalone_module)
85        {
86            if inline_mod != standalone_mod {
87                return Err(Error::new(
88                    ident.span(),
89                    format!(
90                        "Conflicting module specifications: #[gen_stub_pyclass(module = \"{}\")] \
91                         and #[gen_stub(module = \"{}\")]. Please use only one.",
92                        inline_mod, standalone_mod
93                    ),
94                ));
95            }
96        }
97
98        // Priority: inline > standalone > pyo3 > default
99        let module = if let Some(inline_mod) = &attr.module {
100            Some(inline_mod.clone()) // Priority 1: #[gen_stub_pyclass(module = "...")]
101        } else if let Some(standalone_mod) = gen_stub_standalone_module {
102            Some(standalone_mod) // Priority 2: #[gen_stub(module = "...")]
103        } else {
104            pyo3_module // Priority 3: #[pyo3(module = "...")]
105        };
106
107        let pyclass_name = pyclass_name.unwrap_or_else(|| ident.to_string());
108        let mut getters = Vec::new();
109        let mut setters = Vec::new();
110        for field in fields {
111            if is_get_all || MemberInfo::is_get(&field)? {
112                getters.push(MemberInfo::try_from(field.clone())?)
113            }
114            if is_set_all || MemberInfo::is_set(&field)? {
115                setters.push(MemberInfo::try_from(field)?)
116            }
117        }
118        let doc = extract_documents(&attrs).join("\n");
119        Ok(Self {
120            struct_type,
121            pyclass_name,
122            getters,
123            setters,
124            module,
125            doc,
126            bases,
127            has_eq,
128            has_ord,
129            has_hash,
130            has_str,
131            subclass,
132        })
133    }
134}
135
136impl TryFrom<ItemStruct> for PyClassInfo {
137    type Error = Error;
138    fn try_from(item: ItemStruct) -> Result<Self> {
139        // Use the new method with default PyClassAttr
140        Self::from_item_with_attr(item, &PyClassAttr::default())
141    }
142}
143
144impl ToTokens for PyClassInfo {
145    fn to_tokens(&self, tokens: &mut TokenStream2) {
146        let Self {
147            pyclass_name,
148            struct_type,
149            getters,
150            setters,
151            doc,
152            module,
153            bases,
154            has_eq,
155            has_ord,
156            has_hash,
157            has_str,
158            subclass,
159        } = self;
160        let module = quote_option(module);
161        tokens.append_all(quote! {
162            ::pyo3_stub_gen::type_info::PyClassInfo {
163                pyclass_name: #pyclass_name,
164                struct_id: std::any::TypeId::of::<#struct_type>,
165                getters: &[ #( #getters),* ],
166                setters: &[ #( #setters),* ],
167                module: #module,
168                doc: #doc,
169                bases: &[ #( <#bases as ::pyo3_stub_gen::PyStubType>::type_output ),* ],
170                has_eq: #has_eq,
171                has_ord: #has_ord,
172                has_hash: #has_hash,
173                has_str: #has_str,
174                subclass: #subclass,
175            }
176        })
177    }
178}
179
180// `#[gen_stub(xxx)]` is not a valid proc_macro_attribute
181// it's only designed to receive user's setting.
182// We need to remove all `#[gen_stub(xxx)]` before print the item_struct back
183pub fn prune_attrs(item_struct: &mut ItemStruct) {
184    super::attr::prune_attrs(&mut item_struct.attrs);
185    for field in item_struct.fields.iter_mut() {
186        super::attr::prune_attrs(&mut field.attrs);
187    }
188}
189
190#[cfg(test)]
191mod test {
192    use super::*;
193    use syn::parse_str;
194
195    #[test]
196    fn test_pyclass() -> Result<()> {
197        let input: ItemStruct = parse_str(
198            r#"
199            #[pyclass(mapping, module = "my_module", name = "Placeholder")]
200            #[derive(
201                Debug, Clone, PyNeg, PyAdd, PySub, PyMul, PyDiv, PyMod, PyPow, PyCmp, PyIndex, PyPrint,
202            )]
203            pub struct PyPlaceholder {
204                #[pyo3(get)]
205                pub name: String,
206                #[pyo3(get)]
207                pub ndim: usize,
208                #[pyo3(get)]
209                pub description: Option<String>,
210                pub custom_latex: Option<String>,
211            }
212            "#,
213        )?;
214        let out = PyClassInfo::try_from(input)?.to_token_stream();
215        insta::assert_snapshot!(format_as_value(out), @r###"
216        ::pyo3_stub_gen::type_info::PyClassInfo {
217            pyclass_name: "Placeholder",
218            struct_id: std::any::TypeId::of::<PyPlaceholder>,
219            getters: &[
220                ::pyo3_stub_gen::type_info::MemberInfo {
221                    name: "name",
222                    r#type: <String as ::pyo3_stub_gen::PyStubType>::type_output,
223                    doc: "",
224                    default: None,
225                    deprecated: None,
226                },
227                ::pyo3_stub_gen::type_info::MemberInfo {
228                    name: "ndim",
229                    r#type: <usize as ::pyo3_stub_gen::PyStubType>::type_output,
230                    doc: "",
231                    default: None,
232                    deprecated: None,
233                },
234                ::pyo3_stub_gen::type_info::MemberInfo {
235                    name: "description",
236                    r#type: <Option<String> as ::pyo3_stub_gen::PyStubType>::type_output,
237                    doc: "",
238                    default: None,
239                    deprecated: None,
240                },
241            ],
242            setters: &[],
243            module: Some("my_module"),
244            doc: "",
245            bases: &[],
246            has_eq: false,
247            has_ord: false,
248            has_hash: false,
249            has_str: false,
250            subclass: false,
251        }
252        "###);
253        Ok(())
254    }
255
256    fn format_as_value(tt: TokenStream2) -> String {
257        let ttt = quote! { const _: () = #tt; };
258        let formatted = prettyplease::unparse(&syn::parse_file(&ttt.to_string()).unwrap());
259        formatted
260            .trim()
261            .strip_prefix("const _: () = ")
262            .unwrap()
263            .strip_suffix(';')
264            .unwrap()
265            .to_string()
266    }
267}