pyo3_stub_gen_derive/gen_stub/
pyclass.rs1use super::{extract_documents, parse_pyo3_attrs, util::quote_option, Attr, MemberInfo, StubType};
2use proc_macro2::TokenStream as TokenStream2;
3use quote::{quote, ToTokens, TokenStreamExt};
4use syn::{parse_quote, Error, ItemStruct, Result, Type};
5
6pub struct PyClassInfo {
7 pyclass_name: String,
8 struct_type: Type,
9 module: Option<String>,
10 getters: Vec<MemberInfo>,
11 setters: Vec<MemberInfo>,
12 doc: String,
13 bases: Vec<Type>,
14}
15
16impl From<&PyClassInfo> for StubType {
17 fn from(info: &PyClassInfo) -> Self {
18 let PyClassInfo {
19 pyclass_name,
20 module,
21 struct_type,
22 ..
23 } = info;
24 Self {
25 ty: struct_type.clone(),
26 name: pyclass_name.clone(),
27 module: module.clone(),
28 }
29 }
30}
31
32impl TryFrom<ItemStruct> for PyClassInfo {
33 type Error = Error;
34 fn try_from(item: ItemStruct) -> Result<Self> {
35 let ItemStruct {
36 ident,
37 attrs,
38 fields,
39 ..
40 } = item;
41 let struct_type: Type = parse_quote!(#ident);
42 let mut pyclass_name = None;
43 let mut module = None;
44 let mut is_get_all = false;
45 let mut is_set_all = false;
46 let mut bases = Vec::new();
47 for attr in parse_pyo3_attrs(&attrs)? {
48 match attr {
49 Attr::Name(name) => pyclass_name = Some(name),
50 Attr::Module(name) => {
51 module = Some(name);
52 }
53 Attr::GetAll => is_get_all = true,
54 Attr::SetAll => is_set_all = true,
55 Attr::Extends(typ) => bases.push(typ),
56 _ => {}
57 }
58 }
59 let pyclass_name = pyclass_name.unwrap_or_else(|| ident.to_string());
60 let mut getters = Vec::new();
61 let mut setters = Vec::new();
62 for field in fields {
63 if is_get_all || MemberInfo::is_get(&field)? {
64 getters.push(MemberInfo::try_from(field.clone())?)
65 }
66 if is_set_all || MemberInfo::is_set(&field)? {
67 setters.push(MemberInfo::try_from(field)?)
68 }
69 }
70 let doc = extract_documents(&attrs).join("\n");
71 Ok(Self {
72 struct_type,
73 pyclass_name,
74 getters,
75 setters,
76 module,
77 doc,
78 bases,
79 })
80 }
81}
82
83impl ToTokens for PyClassInfo {
84 fn to_tokens(&self, tokens: &mut TokenStream2) {
85 let Self {
86 pyclass_name,
87 struct_type,
88 getters,
89 setters,
90 doc,
91 module,
92 bases,
93 } = self;
94 let module = quote_option(module);
95 tokens.append_all(quote! {
96 ::pyo3_stub_gen::type_info::PyClassInfo {
97 pyclass_name: #pyclass_name,
98 struct_id: std::any::TypeId::of::<#struct_type>,
99 getters: &[ #( #getters),* ],
100 setters: &[ #( #setters),* ],
101 module: #module,
102 doc: #doc,
103 bases: &[ #( <#bases as ::pyo3_stub_gen::PyStubType>::type_output ),* ],
104 }
105 })
106 }
107}
108
109pub fn prune_attrs(item_struct: &mut ItemStruct) {
113 super::attr::prune_attrs(&mut item_struct.attrs);
114 for field in item_struct.fields.iter_mut() {
115 super::attr::prune_attrs(&mut field.attrs);
116 }
117}
118
119#[cfg(test)]
120mod test {
121 use super::*;
122 use syn::parse_str;
123
124 #[test]
125 fn test_pyclass() -> Result<()> {
126 let input: ItemStruct = parse_str(
127 r#"
128 #[pyclass(mapping, module = "my_module", name = "Placeholder")]
129 #[derive(
130 Debug, Clone, PyNeg, PyAdd, PySub, PyMul, PyDiv, PyMod, PyPow, PyCmp, PyIndex, PyPrint,
131 )]
132 pub struct PyPlaceholder {
133 #[pyo3(get)]
134 pub name: String,
135 #[pyo3(get)]
136 pub ndim: usize,
137 #[pyo3(get)]
138 pub description: Option<String>,
139 pub custom_latex: Option<String>,
140 }
141 "#,
142 )?;
143 let out = PyClassInfo::try_from(input)?.to_token_stream();
144 insta::assert_snapshot!(format_as_value(out), @r###"
145 ::pyo3_stub_gen::type_info::PyClassInfo {
146 pyclass_name: "Placeholder",
147 struct_id: std::any::TypeId::of::<PyPlaceholder>,
148 getters: &[
149 ::pyo3_stub_gen::type_info::MemberInfo {
150 name: "name",
151 r#type: <String as ::pyo3_stub_gen::PyStubType>::type_output,
152 doc: "",
153 default: None,
154 deprecated: None,
155 },
156 ::pyo3_stub_gen::type_info::MemberInfo {
157 name: "ndim",
158 r#type: <usize as ::pyo3_stub_gen::PyStubType>::type_output,
159 doc: "",
160 default: None,
161 deprecated: None,
162 },
163 ::pyo3_stub_gen::type_info::MemberInfo {
164 name: "description",
165 r#type: <Option<String> as ::pyo3_stub_gen::PyStubType>::type_output,
166 doc: "",
167 default: None,
168 deprecated: None,
169 },
170 ],
171 setters: &[],
172 module: Some("my_module"),
173 doc: "",
174 bases: &[],
175 }
176 "###);
177 Ok(())
178 }
179
180 fn format_as_value(tt: TokenStream2) -> String {
181 let ttt = quote! { const _: () = #tt; };
182 let formatted = prettyplease::unparse(&syn::parse_file(&ttt.to_string()).unwrap());
183 formatted
184 .trim()
185 .strip_prefix("const _: () = ")
186 .unwrap()
187 .strip_suffix(';')
188 .unwrap()
189 .to_string()
190 }
191}