pyo3_stub_gen_derive/gen_stub/
pyclass.rs

1use super::{extract_documents, parse_pyo3_attrs, util::quote_option, Attr, MemberInfo, StubType};
2use proc_macro2::TokenStream as TokenStream2;
3use quote::{quote, ToTokens, TokenStreamExt};
4use syn::{parse_quote, Error, ItemStruct, Result, Type};
5
6pub struct PyClassInfo {
7    pyclass_name: String,
8    struct_type: Type,
9    module: Option<String>,
10    getters: Vec<MemberInfo>,
11    setters: Vec<MemberInfo>,
12    doc: String,
13    bases: Vec<Type>,
14}
15
16impl From<&PyClassInfo> for StubType {
17    fn from(info: &PyClassInfo) -> Self {
18        let PyClassInfo {
19            pyclass_name,
20            module,
21            struct_type,
22            ..
23        } = info;
24        Self {
25            ty: struct_type.clone(),
26            name: pyclass_name.clone(),
27            module: module.clone(),
28        }
29    }
30}
31
32impl TryFrom<ItemStruct> for PyClassInfo {
33    type Error = Error;
34    fn try_from(item: ItemStruct) -> Result<Self> {
35        let ItemStruct {
36            ident,
37            attrs,
38            fields,
39            ..
40        } = item;
41        let struct_type: Type = parse_quote!(#ident);
42        let mut pyclass_name = None;
43        let mut module = None;
44        let mut is_get_all = false;
45        let mut is_set_all = false;
46        let mut bases = Vec::new();
47        for attr in parse_pyo3_attrs(&attrs)? {
48            match attr {
49                Attr::Name(name) => pyclass_name = Some(name),
50                Attr::Module(name) => {
51                    module = Some(name);
52                }
53                Attr::GetAll => is_get_all = true,
54                Attr::SetAll => is_set_all = true,
55                Attr::Extends(typ) => bases.push(typ),
56                _ => {}
57            }
58        }
59        let pyclass_name = pyclass_name.unwrap_or_else(|| ident.to_string());
60        let mut getters = Vec::new();
61        let mut setters = Vec::new();
62        for field in fields {
63            if is_get_all || MemberInfo::is_get(&field)? {
64                getters.push(MemberInfo::try_from(field.clone())?)
65            }
66            if is_set_all || MemberInfo::is_set(&field)? {
67                setters.push(MemberInfo::try_from(field)?)
68            }
69        }
70        let doc = extract_documents(&attrs).join("\n");
71        Ok(Self {
72            struct_type,
73            pyclass_name,
74            getters,
75            setters,
76            module,
77            doc,
78            bases,
79        })
80    }
81}
82
83impl ToTokens for PyClassInfo {
84    fn to_tokens(&self, tokens: &mut TokenStream2) {
85        let Self {
86            pyclass_name,
87            struct_type,
88            getters,
89            setters,
90            doc,
91            module,
92            bases,
93        } = self;
94        let module = quote_option(module);
95        tokens.append_all(quote! {
96            ::pyo3_stub_gen::type_info::PyClassInfo {
97                pyclass_name: #pyclass_name,
98                struct_id: std::any::TypeId::of::<#struct_type>,
99                getters: &[ #( #getters),* ],
100                setters: &[ #( #setters),* ],
101                module: #module,
102                doc: #doc,
103                bases: &[ #( <#bases as ::pyo3_stub_gen::PyStubType>::type_output ),* ],
104            }
105        })
106    }
107}
108
109// `#[gen_stub(xxx)]` is not a valid proc_macro_attribute
110// it's only designed to receive user's setting.
111// We need to remove all `#[gen_stub(xxx)]` before print the item_struct back
112pub fn prune_attrs(item_struct: &mut ItemStruct) {
113    super::attr::prune_attrs(&mut item_struct.attrs);
114    for field in item_struct.fields.iter_mut() {
115        super::attr::prune_attrs(&mut field.attrs);
116    }
117}
118
119#[cfg(test)]
120mod test {
121    use super::*;
122    use syn::parse_str;
123
124    #[test]
125    fn test_pyclass() -> Result<()> {
126        let input: ItemStruct = parse_str(
127            r#"
128            #[pyclass(mapping, module = "my_module", name = "Placeholder")]
129            #[derive(
130                Debug, Clone, PyNeg, PyAdd, PySub, PyMul, PyDiv, PyMod, PyPow, PyCmp, PyIndex, PyPrint,
131            )]
132            pub struct PyPlaceholder {
133                #[pyo3(get)]
134                pub name: String,
135                #[pyo3(get)]
136                pub ndim: usize,
137                #[pyo3(get)]
138                pub description: Option<String>,
139                pub custom_latex: Option<String>,
140            }
141            "#,
142        )?;
143        let out = PyClassInfo::try_from(input)?.to_token_stream();
144        insta::assert_snapshot!(format_as_value(out), @r###"
145        ::pyo3_stub_gen::type_info::PyClassInfo {
146            pyclass_name: "Placeholder",
147            struct_id: std::any::TypeId::of::<PyPlaceholder>,
148            getters: &[
149                ::pyo3_stub_gen::type_info::MemberInfo {
150                    name: "name",
151                    r#type: <String as ::pyo3_stub_gen::PyStubType>::type_output,
152                    doc: "",
153                    default: None,
154                    deprecated: None,
155                },
156                ::pyo3_stub_gen::type_info::MemberInfo {
157                    name: "ndim",
158                    r#type: <usize as ::pyo3_stub_gen::PyStubType>::type_output,
159                    doc: "",
160                    default: None,
161                    deprecated: None,
162                },
163                ::pyo3_stub_gen::type_info::MemberInfo {
164                    name: "description",
165                    r#type: <Option<String> as ::pyo3_stub_gen::PyStubType>::type_output,
166                    doc: "",
167                    default: None,
168                    deprecated: None,
169                },
170            ],
171            setters: &[],
172            module: Some("my_module"),
173            doc: "",
174            bases: &[],
175        }
176        "###);
177        Ok(())
178    }
179
180    fn format_as_value(tt: TokenStream2) -> String {
181        let ttt = quote! { const _: () = #tt; };
182        let formatted = prettyplease::unparse(&syn::parse_file(&ttt.to_string()).unwrap());
183        formatted
184            .trim()
185            .strip_prefix("const _: () = ")
186            .unwrap()
187            .strip_suffix(';')
188            .unwrap()
189            .to_string()
190    }
191}