pyo3_stub_gen_derive/gen_stub/
pyclass.rs

1use proc_macro2::TokenStream as TokenStream2;
2use quote::{quote, ToTokens, TokenStreamExt};
3use syn::{parse_quote, Error, ItemStruct, Result, Type};
4
5use super::{extract_documents, parse_pyo3_attrs, util::quote_option, Attr, MemberInfo, StubType};
6
7pub struct PyClassInfo {
8    pyclass_name: String,
9    struct_type: Type,
10    module: Option<String>,
11    members: Vec<MemberInfo>,
12    doc: String,
13    bases: Vec<Type>,
14}
15
16impl From<&PyClassInfo> for StubType {
17    fn from(info: &PyClassInfo) -> Self {
18        let PyClassInfo {
19            pyclass_name,
20            module,
21            struct_type,
22            ..
23        } = info;
24        Self {
25            ty: struct_type.clone(),
26            name: pyclass_name.clone(),
27            module: module.clone(),
28        }
29    }
30}
31
32impl TryFrom<ItemStruct> for PyClassInfo {
33    type Error = Error;
34    fn try_from(item: ItemStruct) -> Result<Self> {
35        let ItemStruct {
36            ident,
37            attrs,
38            fields,
39            ..
40        } = item;
41        let struct_type: Type = parse_quote!(#ident);
42        let mut pyclass_name = None;
43        let mut module = None;
44        let mut is_get_all = false;
45        let mut bases = Vec::new();
46        for attr in parse_pyo3_attrs(&attrs)? {
47            match attr {
48                Attr::Name(name) => pyclass_name = Some(name),
49                Attr::Module(name) => {
50                    module = Some(name);
51                }
52                Attr::GetAll => is_get_all = true,
53                Attr::Extends(typ) => bases.push(typ),
54                _ => {}
55            }
56        }
57        let pyclass_name = pyclass_name.unwrap_or_else(|| ident.to_string());
58        let mut members = Vec::new();
59        for field in fields {
60            if is_get_all || MemberInfo::is_candidate_field(&field)? {
61                members.push(MemberInfo::try_from(field)?)
62            }
63        }
64        let doc = extract_documents(&attrs).join("\n");
65        Ok(Self {
66            struct_type,
67            pyclass_name,
68            members,
69            module,
70            doc,
71            bases,
72        })
73    }
74}
75
76impl ToTokens for PyClassInfo {
77    fn to_tokens(&self, tokens: &mut TokenStream2) {
78        let Self {
79            pyclass_name,
80            struct_type,
81            members,
82            doc,
83            module,
84            bases,
85        } = self;
86        let module = quote_option(module);
87        tokens.append_all(quote! {
88            ::pyo3_stub_gen::type_info::PyClassInfo {
89                pyclass_name: #pyclass_name,
90                struct_id: std::any::TypeId::of::<#struct_type>,
91                members: &[ #( #members),* ],
92                module: #module,
93                doc: #doc,
94                bases: &[ #( <#bases as ::pyo3_stub_gen::PyStubType>::type_output ),* ],
95            }
96        })
97    }
98}
99
100#[cfg(test)]
101mod test {
102    use super::*;
103    use syn::parse_str;
104
105    #[test]
106    fn test_pyclass() -> Result<()> {
107        let input: ItemStruct = parse_str(
108            r#"
109            #[pyclass(mapping, module = "my_module", name = "Placeholder")]
110            #[derive(
111                Debug, Clone, PyNeg, PyAdd, PySub, PyMul, PyDiv, PyMod, PyPow, PyCmp, PyIndex, PyPrint,
112            )]
113            pub struct PyPlaceholder {
114                #[pyo3(get)]
115                pub name: String,
116                #[pyo3(get)]
117                pub ndim: usize,
118                #[pyo3(get)]
119                pub description: Option<String>,
120                pub custom_latex: Option<String>,
121            }
122            "#,
123        )?;
124        let out = PyClassInfo::try_from(input)?.to_token_stream();
125        insta::assert_snapshot!(format_as_value(out), @r###"
126        ::pyo3_stub_gen::type_info::PyClassInfo {
127            pyclass_name: "Placeholder",
128            struct_id: std::any::TypeId::of::<PyPlaceholder>,
129            members: &[
130                ::pyo3_stub_gen::type_info::MemberInfo {
131                    name: "name",
132                    r#type: <String as ::pyo3_stub_gen::PyStubType>::type_output,
133                    doc: "",
134                },
135                ::pyo3_stub_gen::type_info::MemberInfo {
136                    name: "ndim",
137                    r#type: <usize as ::pyo3_stub_gen::PyStubType>::type_output,
138                    doc: "",
139                },
140                ::pyo3_stub_gen::type_info::MemberInfo {
141                    name: "description",
142                    r#type: <Option<String> as ::pyo3_stub_gen::PyStubType>::type_output,
143                    doc: "",
144                },
145            ],
146            module: Some("my_module"),
147            doc: "",
148            bases: &[],
149        }
150        "###);
151        Ok(())
152    }
153
154    fn format_as_value(tt: TokenStream2) -> String {
155        let ttt = quote! { const _: () = #tt; };
156        let formatted = prettyplease::unparse(&syn::parse_file(&ttt.to_string()).unwrap());
157        formatted
158            .trim()
159            .strip_prefix("const _: () = ")
160            .unwrap()
161            .strip_suffix(';')
162            .unwrap()
163            .to_string()
164    }
165}