pyo3_stub_gen_derive/gen_stub/
pyclass.rs

1use super::{extract_documents, parse_pyo3_attrs, util::quote_option, Attr, MemberInfo, StubType};
2use proc_macro2::TokenStream as TokenStream2;
3use quote::{quote, ToTokens, TokenStreamExt};
4use syn::{parse_quote, Error, ItemStruct, Result, Type};
5
6pub struct PyClassInfo {
7    pyclass_name: String,
8    struct_type: Type,
9    module: Option<String>,
10    getters: Vec<MemberInfo>,
11    setters: Vec<MemberInfo>,
12    doc: String,
13    bases: Vec<Type>,
14    has_eq: bool,
15    has_ord: bool,
16    has_hash: bool,
17    has_str: bool,
18}
19
20impl From<&PyClassInfo> for StubType {
21    fn from(info: &PyClassInfo) -> Self {
22        let PyClassInfo {
23            pyclass_name,
24            module,
25            struct_type,
26            ..
27        } = info;
28        Self {
29            ty: struct_type.clone(),
30            name: pyclass_name.clone(),
31            module: module.clone(),
32        }
33    }
34}
35
36impl TryFrom<ItemStruct> for PyClassInfo {
37    type Error = Error;
38    fn try_from(item: ItemStruct) -> Result<Self> {
39        let ItemStruct {
40            ident,
41            attrs,
42            fields,
43            ..
44        } = item;
45        let struct_type: Type = parse_quote!(#ident);
46        let mut pyclass_name = None;
47        let mut module = None;
48        let mut is_get_all = false;
49        let mut is_set_all = false;
50        let mut bases = Vec::new();
51        let mut has_eq = false;
52        let mut has_ord = false;
53        let mut has_hash = false;
54        let mut has_str = false;
55        for attr in parse_pyo3_attrs(&attrs)? {
56            match attr {
57                Attr::Name(name) => pyclass_name = Some(name),
58                Attr::Module(name) => {
59                    module = Some(name);
60                }
61                Attr::GetAll => is_get_all = true,
62                Attr::SetAll => is_set_all = true,
63                Attr::Extends(typ) => bases.push(typ),
64                Attr::Eq => has_eq = true,
65                Attr::Ord => has_ord = true,
66                Attr::Hash => has_hash = true,
67                Attr::Str => has_str = true,
68                _ => {}
69            }
70        }
71        let pyclass_name = pyclass_name.unwrap_or_else(|| ident.to_string());
72        let mut getters = Vec::new();
73        let mut setters = Vec::new();
74        for field in fields {
75            if is_get_all || MemberInfo::is_get(&field)? {
76                getters.push(MemberInfo::try_from(field.clone())?)
77            }
78            if is_set_all || MemberInfo::is_set(&field)? {
79                setters.push(MemberInfo::try_from(field)?)
80            }
81        }
82        let doc = extract_documents(&attrs).join("\n");
83        Ok(Self {
84            struct_type,
85            pyclass_name,
86            getters,
87            setters,
88            module,
89            doc,
90            bases,
91            has_eq,
92            has_ord,
93            has_hash,
94            has_str,
95        })
96    }
97}
98
99impl ToTokens for PyClassInfo {
100    fn to_tokens(&self, tokens: &mut TokenStream2) {
101        let Self {
102            pyclass_name,
103            struct_type,
104            getters,
105            setters,
106            doc,
107            module,
108            bases,
109            has_eq,
110            has_ord,
111            has_hash,
112            has_str,
113        } = self;
114        let module = quote_option(module);
115        tokens.append_all(quote! {
116            ::pyo3_stub_gen::type_info::PyClassInfo {
117                pyclass_name: #pyclass_name,
118                struct_id: std::any::TypeId::of::<#struct_type>,
119                getters: &[ #( #getters),* ],
120                setters: &[ #( #setters),* ],
121                module: #module,
122                doc: #doc,
123                bases: &[ #( <#bases as ::pyo3_stub_gen::PyStubType>::type_output ),* ],
124                has_eq: #has_eq,
125                has_ord: #has_ord,
126                has_hash: #has_hash,
127                has_str: #has_str,
128            }
129        })
130    }
131}
132
133// `#[gen_stub(xxx)]` is not a valid proc_macro_attribute
134// it's only designed to receive user's setting.
135// We need to remove all `#[gen_stub(xxx)]` before print the item_struct back
136pub fn prune_attrs(item_struct: &mut ItemStruct) {
137    super::attr::prune_attrs(&mut item_struct.attrs);
138    for field in item_struct.fields.iter_mut() {
139        super::attr::prune_attrs(&mut field.attrs);
140    }
141}
142
143#[cfg(test)]
144mod test {
145    use super::*;
146    use syn::parse_str;
147
148    #[test]
149    fn test_pyclass() -> Result<()> {
150        let input: ItemStruct = parse_str(
151            r#"
152            #[pyclass(mapping, module = "my_module", name = "Placeholder")]
153            #[derive(
154                Debug, Clone, PyNeg, PyAdd, PySub, PyMul, PyDiv, PyMod, PyPow, PyCmp, PyIndex, PyPrint,
155            )]
156            pub struct PyPlaceholder {
157                #[pyo3(get)]
158                pub name: String,
159                #[pyo3(get)]
160                pub ndim: usize,
161                #[pyo3(get)]
162                pub description: Option<String>,
163                pub custom_latex: Option<String>,
164            }
165            "#,
166        )?;
167        let out = PyClassInfo::try_from(input)?.to_token_stream();
168        insta::assert_snapshot!(format_as_value(out), @r###"
169        ::pyo3_stub_gen::type_info::PyClassInfo {
170            pyclass_name: "Placeholder",
171            struct_id: std::any::TypeId::of::<PyPlaceholder>,
172            getters: &[
173                ::pyo3_stub_gen::type_info::MemberInfo {
174                    name: "name",
175                    r#type: <String as ::pyo3_stub_gen::PyStubType>::type_output,
176                    doc: "",
177                    default: None,
178                    deprecated: None,
179                },
180                ::pyo3_stub_gen::type_info::MemberInfo {
181                    name: "ndim",
182                    r#type: <usize as ::pyo3_stub_gen::PyStubType>::type_output,
183                    doc: "",
184                    default: None,
185                    deprecated: None,
186                },
187                ::pyo3_stub_gen::type_info::MemberInfo {
188                    name: "description",
189                    r#type: <Option<String> as ::pyo3_stub_gen::PyStubType>::type_output,
190                    doc: "",
191                    default: None,
192                    deprecated: None,
193                },
194            ],
195            setters: &[],
196            module: Some("my_module"),
197            doc: "",
198            bases: &[],
199            has_eq: false,
200            has_ord: false,
201            has_hash: false,
202            has_str: false,
203        }
204        "###);
205        Ok(())
206    }
207
208    fn format_as_value(tt: TokenStream2) -> String {
209        let ttt = quote! { const _: () = #tt; };
210        let formatted = prettyplease::unparse(&syn::parse_file(&ttt.to_string()).unwrap());
211        formatted
212            .trim()
213            .strip_prefix("const _: () = ")
214            .unwrap()
215            .strip_suffix(';')
216            .unwrap()
217            .to_string()
218    }
219}