pyo3_stub_gen_derive/gen_stub/
pyclass.rs

1use super::{
2    extract_documents, member::MemberKind, parse_pyo3_attrs, util::quote_option, Attr, MemberInfo,
3    PyClassAttr, StubType,
4};
5use proc_macro2::TokenStream as TokenStream2;
6use quote::{quote, ToTokens, TokenStreamExt};
7use syn::{parse_quote, Error, ItemStruct, Result, Type};
8
9pub struct PyClassInfo {
10    pyclass_name: String,
11    struct_type: Type,
12    module: Option<String>,
13    getters: Vec<MemberInfo>,
14    setters: Vec<MemberInfo>,
15    doc: String,
16    bases: Vec<Type>,
17    has_eq: bool,
18    has_ord: bool,
19    has_hash: bool,
20    has_str: bool,
21    subclass: bool,
22}
23
24impl From<&PyClassInfo> for StubType {
25    fn from(info: &PyClassInfo) -> Self {
26        let PyClassInfo {
27            pyclass_name,
28            module,
29            struct_type,
30            ..
31        } = info;
32        Self {
33            ty: struct_type.clone(),
34            name: pyclass_name.clone(),
35            module: module.clone(),
36        }
37    }
38}
39
40impl PyClassInfo {
41    /// Create PyClassInfo from ItemStruct with PyClassAttr for module override support
42    pub fn from_item_with_attr(item: ItemStruct, attr: &PyClassAttr) -> Result<Self> {
43        let ItemStruct {
44            ident,
45            attrs,
46            fields,
47            ..
48        } = item;
49        let struct_type: Type = parse_quote!(#ident);
50        let mut pyclass_name = None;
51        let mut pyo3_module = None;
52        let mut gen_stub_standalone_module = None;
53        let mut is_get_all = false;
54        let mut is_set_all = false;
55        let mut bases = Vec::new();
56        let mut has_eq = false;
57        let mut has_ord = false;
58        let mut has_hash = false;
59        let mut has_str = false;
60        let mut subclass = false;
61        for attr_item in parse_pyo3_attrs(&attrs)? {
62            match attr_item {
63                Attr::Name(name) => pyclass_name = Some(name),
64                Attr::Module(name) => {
65                    pyo3_module = Some(name);
66                }
67                Attr::GenStubModule(name) => {
68                    gen_stub_standalone_module = Some(name);
69                }
70                Attr::GetAll => is_get_all = true,
71                Attr::SetAll => is_set_all = true,
72                Attr::Extends(typ) => bases.push(typ),
73                Attr::Eq => has_eq = true,
74                Attr::Ord => has_ord = true,
75                Attr::Hash => has_hash = true,
76                Attr::Str => has_str = true,
77                Attr::Subclass => subclass = true,
78                _ => {}
79            }
80        }
81
82        // Validate: inline and standalone gen_stub modules must not conflict
83        if let (Some(inline_mod), Some(standalone_mod)) =
84            (&attr.module, &gen_stub_standalone_module)
85        {
86            if inline_mod != standalone_mod {
87                return Err(Error::new(
88                    ident.span(),
89                    format!(
90                        "Conflicting module specifications: #[gen_stub_pyclass(module = \"{}\")] \
91                         and #[gen_stub(module = \"{}\")]. Please use only one.",
92                        inline_mod, standalone_mod
93                    ),
94                ));
95            }
96        }
97
98        // Priority: inline > standalone > pyo3 > default
99        let module = if let Some(inline_mod) = &attr.module {
100            Some(inline_mod.clone()) // Priority 1: #[gen_stub_pyclass(module = "...")]
101        } else if let Some(standalone_mod) = gen_stub_standalone_module {
102            Some(standalone_mod) // Priority 2: #[gen_stub(module = "...")]
103        } else {
104            pyo3_module // Priority 3: #[pyo3(module = "...")]
105        };
106
107        let pyclass_name = pyclass_name.unwrap_or_else(|| ident.to_string());
108        let mut getters = Vec::new();
109        let mut setters = Vec::new();
110        for field in fields {
111            let has_get = is_get_all || MemberInfo::is_get(&field)?;
112            let has_set = is_set_all || MemberInfo::is_set(&field)?;
113            if has_get {
114                getters.push(MemberInfo::from_field(field.clone(), MemberKind::Getter)?)
115            }
116            if has_set {
117                setters.push(MemberInfo::from_field(field, MemberKind::Setter)?)
118            }
119        }
120        let doc = extract_documents(&attrs).join("\n");
121        Ok(Self {
122            struct_type,
123            pyclass_name,
124            getters,
125            setters,
126            module,
127            doc,
128            bases,
129            has_eq,
130            has_ord,
131            has_hash,
132            has_str,
133            subclass,
134        })
135    }
136}
137
138impl TryFrom<ItemStruct> for PyClassInfo {
139    type Error = Error;
140    fn try_from(item: ItemStruct) -> Result<Self> {
141        // Use the new method with default PyClassAttr
142        Self::from_item_with_attr(item, &PyClassAttr::default())
143    }
144}
145
146impl ToTokens for PyClassInfo {
147    fn to_tokens(&self, tokens: &mut TokenStream2) {
148        let Self {
149            pyclass_name,
150            struct_type,
151            getters,
152            setters,
153            doc,
154            module,
155            bases,
156            has_eq,
157            has_ord,
158            has_hash,
159            has_str,
160            subclass,
161        } = self;
162        let module = quote_option(module);
163        tokens.append_all(quote! {
164            ::pyo3_stub_gen::type_info::PyClassInfo {
165                pyclass_name: #pyclass_name,
166                struct_id: std::any::TypeId::of::<#struct_type>,
167                getters: &[ #( #getters),* ],
168                setters: &[ #( #setters),* ],
169                module: #module,
170                doc: #doc,
171                bases: &[ #( <#bases as ::pyo3_stub_gen::PyStubType>::type_output ),* ],
172                has_eq: #has_eq,
173                has_ord: #has_ord,
174                has_hash: #has_hash,
175                has_str: #has_str,
176                subclass: #subclass,
177            }
178        })
179    }
180}
181
182// `#[gen_stub(xxx)]` is not a valid proc_macro_attribute
183// it's only designed to receive user's setting.
184// We need to remove all `#[gen_stub(xxx)]` before print the item_struct back
185pub fn prune_attrs(item_struct: &mut ItemStruct) {
186    super::attr::prune_attrs(&mut item_struct.attrs);
187    for field in item_struct.fields.iter_mut() {
188        super::attr::prune_attrs(&mut field.attrs);
189    }
190}
191
192#[cfg(test)]
193mod test {
194    use super::*;
195    use syn::parse_str;
196
197    #[test]
198    fn test_pyclass() -> Result<()> {
199        let input: ItemStruct = parse_str(
200            r#"
201            #[pyclass(mapping, module = "my_module", name = "Placeholder")]
202            #[derive(
203                Debug, Clone, PyNeg, PyAdd, PySub, PyMul, PyDiv, PyMod, PyPow, PyCmp, PyIndex, PyPrint,
204            )]
205            pub struct PyPlaceholder {
206                #[pyo3(get)]
207                pub name: String,
208                #[pyo3(get)]
209                pub ndim: usize,
210                #[pyo3(get)]
211                pub description: Option<String>,
212                pub custom_latex: Option<String>,
213            }
214            "#,
215        )?;
216        let out = PyClassInfo::try_from(input)?.to_token_stream();
217        insta::assert_snapshot!(format_as_value(out), @r###"
218        ::pyo3_stub_gen::type_info::PyClassInfo {
219            pyclass_name: "Placeholder",
220            struct_id: std::any::TypeId::of::<PyPlaceholder>,
221            getters: &[
222                ::pyo3_stub_gen::type_info::MemberInfo {
223                    name: "name",
224                    r#type: <String as ::pyo3_stub_gen::PyStubType>::type_output,
225                    doc: "",
226                    default: None,
227                    deprecated: None,
228                },
229                ::pyo3_stub_gen::type_info::MemberInfo {
230                    name: "ndim",
231                    r#type: <usize as ::pyo3_stub_gen::PyStubType>::type_output,
232                    doc: "",
233                    default: None,
234                    deprecated: None,
235                },
236                ::pyo3_stub_gen::type_info::MemberInfo {
237                    name: "description",
238                    r#type: <Option<String> as ::pyo3_stub_gen::PyStubType>::type_output,
239                    doc: "",
240                    default: None,
241                    deprecated: None,
242                },
243            ],
244            setters: &[],
245            module: Some("my_module"),
246            doc: "",
247            bases: &[],
248            has_eq: false,
249            has_ord: false,
250            has_hash: false,
251            has_str: false,
252            subclass: false,
253        }
254        "###);
255        Ok(())
256    }
257
258    fn format_as_value(tt: TokenStream2) -> String {
259        let ttt = quote! { const _: () = #tt; };
260        let formatted = prettyplease::unparse(&syn::parse_file(&ttt.to_string()).unwrap());
261        formatted
262            .trim()
263            .strip_prefix("const _: () = ")
264            .unwrap()
265            .strip_suffix(';')
266            .unwrap()
267            .to_string()
268    }
269}