pyo3_stub_gen_derive/gen_stub/
pyclass.rs1use super::{extract_documents, parse_pyo3_attrs, util::quote_option, Attr, MemberInfo, StubType};
2use proc_macro2::TokenStream as TokenStream2;
3use quote::{quote, ToTokens, TokenStreamExt};
4use syn::{parse_quote, Error, ItemStruct, Result, Type};
5
6pub struct PyClassInfo {
7 pyclass_name: String,
8 struct_type: Type,
9 module: Option<String>,
10 getters: Vec<MemberInfo>,
11 setters: Vec<MemberInfo>,
12 doc: String,
13 bases: Vec<Type>,
14 has_eq: bool,
15 has_ord: bool,
16 has_hash: bool,
17 has_str: bool,
18 subclass: bool,
19}
20
21impl From<&PyClassInfo> for StubType {
22 fn from(info: &PyClassInfo) -> Self {
23 let PyClassInfo {
24 pyclass_name,
25 module,
26 struct_type,
27 ..
28 } = info;
29 Self {
30 ty: struct_type.clone(),
31 name: pyclass_name.clone(),
32 module: module.clone(),
33 }
34 }
35}
36
37impl TryFrom<ItemStruct> for PyClassInfo {
38 type Error = Error;
39 fn try_from(item: ItemStruct) -> Result<Self> {
40 let ItemStruct {
41 ident,
42 attrs,
43 fields,
44 ..
45 } = item;
46 let struct_type: Type = parse_quote!(#ident);
47 let mut pyclass_name = None;
48 let mut module = None;
49 let mut is_get_all = false;
50 let mut is_set_all = false;
51 let mut bases = Vec::new();
52 let mut has_eq = false;
53 let mut has_ord = false;
54 let mut has_hash = false;
55 let mut has_str = false;
56 let mut subclass = false;
57 for attr in parse_pyo3_attrs(&attrs)? {
58 match attr {
59 Attr::Name(name) => pyclass_name = Some(name),
60 Attr::Module(name) => {
61 module = Some(name);
62 }
63 Attr::GetAll => is_get_all = true,
64 Attr::SetAll => is_set_all = true,
65 Attr::Extends(typ) => bases.push(typ),
66 Attr::Eq => has_eq = true,
67 Attr::Ord => has_ord = true,
68 Attr::Hash => has_hash = true,
69 Attr::Str => has_str = true,
70 Attr::Subclass => subclass = true,
71 _ => {}
72 }
73 }
74 let pyclass_name = pyclass_name.unwrap_or_else(|| ident.to_string());
75 let mut getters = Vec::new();
76 let mut setters = Vec::new();
77 for field in fields {
78 if is_get_all || MemberInfo::is_get(&field)? {
79 getters.push(MemberInfo::try_from(field.clone())?)
80 }
81 if is_set_all || MemberInfo::is_set(&field)? {
82 setters.push(MemberInfo::try_from(field)?)
83 }
84 }
85 let doc = extract_documents(&attrs).join("\n");
86 Ok(Self {
87 struct_type,
88 pyclass_name,
89 getters,
90 setters,
91 module,
92 doc,
93 bases,
94 has_eq,
95 has_ord,
96 has_hash,
97 has_str,
98 subclass,
99 })
100 }
101}
102
103impl ToTokens for PyClassInfo {
104 fn to_tokens(&self, tokens: &mut TokenStream2) {
105 let Self {
106 pyclass_name,
107 struct_type,
108 getters,
109 setters,
110 doc,
111 module,
112 bases,
113 has_eq,
114 has_ord,
115 has_hash,
116 has_str,
117 subclass,
118 } = self;
119 let module = quote_option(module);
120 tokens.append_all(quote! {
121 ::pyo3_stub_gen::type_info::PyClassInfo {
122 pyclass_name: #pyclass_name,
123 struct_id: std::any::TypeId::of::<#struct_type>,
124 getters: &[ #( #getters),* ],
125 setters: &[ #( #setters),* ],
126 module: #module,
127 doc: #doc,
128 bases: &[ #( <#bases as ::pyo3_stub_gen::PyStubType>::type_output ),* ],
129 has_eq: #has_eq,
130 has_ord: #has_ord,
131 has_hash: #has_hash,
132 has_str: #has_str,
133 subclass: #subclass,
134 }
135 })
136 }
137}
138
139pub fn prune_attrs(item_struct: &mut ItemStruct) {
143 super::attr::prune_attrs(&mut item_struct.attrs);
144 for field in item_struct.fields.iter_mut() {
145 super::attr::prune_attrs(&mut field.attrs);
146 }
147}
148
149#[cfg(test)]
150mod test {
151 use super::*;
152 use syn::parse_str;
153
154 #[test]
155 fn test_pyclass() -> Result<()> {
156 let input: ItemStruct = parse_str(
157 r#"
158 #[pyclass(mapping, module = "my_module", name = "Placeholder")]
159 #[derive(
160 Debug, Clone, PyNeg, PyAdd, PySub, PyMul, PyDiv, PyMod, PyPow, PyCmp, PyIndex, PyPrint,
161 )]
162 pub struct PyPlaceholder {
163 #[pyo3(get)]
164 pub name: String,
165 #[pyo3(get)]
166 pub ndim: usize,
167 #[pyo3(get)]
168 pub description: Option<String>,
169 pub custom_latex: Option<String>,
170 }
171 "#,
172 )?;
173 let out = PyClassInfo::try_from(input)?.to_token_stream();
174 insta::assert_snapshot!(format_as_value(out), @r###"
175 ::pyo3_stub_gen::type_info::PyClassInfo {
176 pyclass_name: "Placeholder",
177 struct_id: std::any::TypeId::of::<PyPlaceholder>,
178 getters: &[
179 ::pyo3_stub_gen::type_info::MemberInfo {
180 name: "name",
181 r#type: <String as ::pyo3_stub_gen::PyStubType>::type_output,
182 doc: "",
183 default: None,
184 deprecated: None,
185 },
186 ::pyo3_stub_gen::type_info::MemberInfo {
187 name: "ndim",
188 r#type: <usize as ::pyo3_stub_gen::PyStubType>::type_output,
189 doc: "",
190 default: None,
191 deprecated: None,
192 },
193 ::pyo3_stub_gen::type_info::MemberInfo {
194 name: "description",
195 r#type: <Option<String> as ::pyo3_stub_gen::PyStubType>::type_output,
196 doc: "",
197 default: None,
198 deprecated: None,
199 },
200 ],
201 setters: &[],
202 module: Some("my_module"),
203 doc: "",
204 bases: &[],
205 has_eq: false,
206 has_ord: false,
207 has_hash: false,
208 has_str: false,
209 subclass: false,
210 }
211 "###);
212 Ok(())
213 }
214
215 fn format_as_value(tt: TokenStream2) -> String {
216 let ttt = quote! { const _: () = #tt; };
217 let formatted = prettyplease::unparse(&syn::parse_file(&ttt.to_string()).unwrap());
218 formatted
219 .trim()
220 .strip_prefix("const _: () = ")
221 .unwrap()
222 .strip_suffix(';')
223 .unwrap()
224 .to_string()
225 }
226}