pyo3_stub_gen_derive/gen_stub/
pyclass.rs

1use super::{extract_documents, parse_pyo3_attrs, util::quote_option, Attr, MemberInfo, StubType};
2use proc_macro2::TokenStream as TokenStream2;
3use quote::{quote, ToTokens, TokenStreamExt};
4use syn::{parse_quote, Error, ItemStruct, Result, Type};
5
6pub struct PyClassInfo {
7    pyclass_name: String,
8    struct_type: Type,
9    module: Option<String>,
10    getters: Vec<MemberInfo>,
11    setters: Vec<MemberInfo>,
12    doc: String,
13    bases: Vec<Type>,
14    has_eq: bool,
15    has_ord: bool,
16    has_hash: bool,
17    has_str: bool,
18    subclass: bool,
19}
20
21impl From<&PyClassInfo> for StubType {
22    fn from(info: &PyClassInfo) -> Self {
23        let PyClassInfo {
24            pyclass_name,
25            module,
26            struct_type,
27            ..
28        } = info;
29        Self {
30            ty: struct_type.clone(),
31            name: pyclass_name.clone(),
32            module: module.clone(),
33        }
34    }
35}
36
37impl TryFrom<ItemStruct> for PyClassInfo {
38    type Error = Error;
39    fn try_from(item: ItemStruct) -> Result<Self> {
40        let ItemStruct {
41            ident,
42            attrs,
43            fields,
44            ..
45        } = item;
46        let struct_type: Type = parse_quote!(#ident);
47        let mut pyclass_name = None;
48        let mut module = None;
49        let mut is_get_all = false;
50        let mut is_set_all = false;
51        let mut bases = Vec::new();
52        let mut has_eq = false;
53        let mut has_ord = false;
54        let mut has_hash = false;
55        let mut has_str = false;
56        let mut subclass = false;
57        for attr in parse_pyo3_attrs(&attrs)? {
58            match attr {
59                Attr::Name(name) => pyclass_name = Some(name),
60                Attr::Module(name) => {
61                    module = Some(name);
62                }
63                Attr::GetAll => is_get_all = true,
64                Attr::SetAll => is_set_all = true,
65                Attr::Extends(typ) => bases.push(typ),
66                Attr::Eq => has_eq = true,
67                Attr::Ord => has_ord = true,
68                Attr::Hash => has_hash = true,
69                Attr::Str => has_str = true,
70                Attr::Subclass => subclass = true,
71                _ => {}
72            }
73        }
74        let pyclass_name = pyclass_name.unwrap_or_else(|| ident.to_string());
75        let mut getters = Vec::new();
76        let mut setters = Vec::new();
77        for field in fields {
78            if is_get_all || MemberInfo::is_get(&field)? {
79                getters.push(MemberInfo::try_from(field.clone())?)
80            }
81            if is_set_all || MemberInfo::is_set(&field)? {
82                setters.push(MemberInfo::try_from(field)?)
83            }
84        }
85        let doc = extract_documents(&attrs).join("\n");
86        Ok(Self {
87            struct_type,
88            pyclass_name,
89            getters,
90            setters,
91            module,
92            doc,
93            bases,
94            has_eq,
95            has_ord,
96            has_hash,
97            has_str,
98            subclass,
99        })
100    }
101}
102
103impl ToTokens for PyClassInfo {
104    fn to_tokens(&self, tokens: &mut TokenStream2) {
105        let Self {
106            pyclass_name,
107            struct_type,
108            getters,
109            setters,
110            doc,
111            module,
112            bases,
113            has_eq,
114            has_ord,
115            has_hash,
116            has_str,
117            subclass,
118        } = self;
119        let module = quote_option(module);
120        tokens.append_all(quote! {
121            ::pyo3_stub_gen::type_info::PyClassInfo {
122                pyclass_name: #pyclass_name,
123                struct_id: std::any::TypeId::of::<#struct_type>,
124                getters: &[ #( #getters),* ],
125                setters: &[ #( #setters),* ],
126                module: #module,
127                doc: #doc,
128                bases: &[ #( <#bases as ::pyo3_stub_gen::PyStubType>::type_output ),* ],
129                has_eq: #has_eq,
130                has_ord: #has_ord,
131                has_hash: #has_hash,
132                has_str: #has_str,
133                subclass: #subclass,
134            }
135        })
136    }
137}
138
139// `#[gen_stub(xxx)]` is not a valid proc_macro_attribute
140// it's only designed to receive user's setting.
141// We need to remove all `#[gen_stub(xxx)]` before print the item_struct back
142pub fn prune_attrs(item_struct: &mut ItemStruct) {
143    super::attr::prune_attrs(&mut item_struct.attrs);
144    for field in item_struct.fields.iter_mut() {
145        super::attr::prune_attrs(&mut field.attrs);
146    }
147}
148
149#[cfg(test)]
150mod test {
151    use super::*;
152    use syn::parse_str;
153
154    #[test]
155    fn test_pyclass() -> Result<()> {
156        let input: ItemStruct = parse_str(
157            r#"
158            #[pyclass(mapping, module = "my_module", name = "Placeholder")]
159            #[derive(
160                Debug, Clone, PyNeg, PyAdd, PySub, PyMul, PyDiv, PyMod, PyPow, PyCmp, PyIndex, PyPrint,
161            )]
162            pub struct PyPlaceholder {
163                #[pyo3(get)]
164                pub name: String,
165                #[pyo3(get)]
166                pub ndim: usize,
167                #[pyo3(get)]
168                pub description: Option<String>,
169                pub custom_latex: Option<String>,
170            }
171            "#,
172        )?;
173        let out = PyClassInfo::try_from(input)?.to_token_stream();
174        insta::assert_snapshot!(format_as_value(out), @r###"
175        ::pyo3_stub_gen::type_info::PyClassInfo {
176            pyclass_name: "Placeholder",
177            struct_id: std::any::TypeId::of::<PyPlaceholder>,
178            getters: &[
179                ::pyo3_stub_gen::type_info::MemberInfo {
180                    name: "name",
181                    r#type: <String as ::pyo3_stub_gen::PyStubType>::type_output,
182                    doc: "",
183                    default: None,
184                    deprecated: None,
185                },
186                ::pyo3_stub_gen::type_info::MemberInfo {
187                    name: "ndim",
188                    r#type: <usize as ::pyo3_stub_gen::PyStubType>::type_output,
189                    doc: "",
190                    default: None,
191                    deprecated: None,
192                },
193                ::pyo3_stub_gen::type_info::MemberInfo {
194                    name: "description",
195                    r#type: <Option<String> as ::pyo3_stub_gen::PyStubType>::type_output,
196                    doc: "",
197                    default: None,
198                    deprecated: None,
199                },
200            ],
201            setters: &[],
202            module: Some("my_module"),
203            doc: "",
204            bases: &[],
205            has_eq: false,
206            has_ord: false,
207            has_hash: false,
208            has_str: false,
209            subclass: false,
210        }
211        "###);
212        Ok(())
213    }
214
215    fn format_as_value(tt: TokenStream2) -> String {
216        let ttt = quote! { const _: () = #tt; };
217        let formatted = prettyplease::unparse(&syn::parse_file(&ttt.to_string()).unwrap());
218        formatted
219            .trim()
220            .strip_prefix("const _: () = ")
221            .unwrap()
222            .strip_suffix(';')
223            .unwrap()
224            .to_string()
225    }
226}