pyo3_stub_gen_derive/gen_stub/
pyclass.rs1use super::{
2 extract_documents, member::MemberKind, parse_pyo3_attrs, util::quote_option, Attr, MemberInfo,
3 PyClassAttr, StubType,
4};
5use proc_macro2::TokenStream as TokenStream2;
6use quote::{quote, ToTokens, TokenStreamExt};
7use syn::{parse_quote, Error, ItemStruct, Result, Type};
8
9pub struct PyClassInfo {
10 pyclass_name: String,
11 struct_type: Type,
12 module: Option<String>,
13 getters: Vec<MemberInfo>,
14 setters: Vec<MemberInfo>,
15 doc: String,
16 bases: Vec<Type>,
17 has_eq: bool,
18 has_ord: bool,
19 has_hash: bool,
20 has_str: bool,
21 subclass: bool,
22}
23
24impl From<&PyClassInfo> for StubType {
25 fn from(info: &PyClassInfo) -> Self {
26 let PyClassInfo {
27 pyclass_name,
28 module,
29 struct_type,
30 ..
31 } = info;
32 Self {
33 ty: struct_type.clone(),
34 name: pyclass_name.clone(),
35 module: module.clone(),
36 }
37 }
38}
39
40impl PyClassInfo {
41 pub fn from_item_with_attr(item: ItemStruct, attr: &PyClassAttr) -> Result<Self> {
43 let ItemStruct {
44 ident,
45 attrs,
46 fields,
47 ..
48 } = item;
49 let struct_type: Type = parse_quote!(#ident);
50 let mut pyclass_name = None;
51 let mut pyo3_module = None;
52 let mut gen_stub_standalone_module = None;
53 let mut is_get_all = false;
54 let mut is_set_all = false;
55 let mut bases = Vec::new();
56 let mut has_eq = false;
57 let mut has_ord = false;
58 let mut has_hash = false;
59 let mut has_str = false;
60 let mut subclass = false;
61 for attr_item in parse_pyo3_attrs(&attrs)? {
62 match attr_item {
63 Attr::Name(name) => pyclass_name = Some(name),
64 Attr::Module(name) => {
65 pyo3_module = Some(name);
66 }
67 Attr::GenStubModule(name) => {
68 gen_stub_standalone_module = Some(name);
69 }
70 Attr::GetAll => is_get_all = true,
71 Attr::SetAll => is_set_all = true,
72 Attr::Extends(typ) => bases.push(typ),
73 Attr::Eq => has_eq = true,
74 Attr::Ord => has_ord = true,
75 Attr::Hash => has_hash = true,
76 Attr::Str => has_str = true,
77 Attr::Subclass => subclass = true,
78 _ => {}
79 }
80 }
81
82 if let (Some(inline_mod), Some(standalone_mod)) =
84 (&attr.module, &gen_stub_standalone_module)
85 {
86 if inline_mod != standalone_mod {
87 return Err(Error::new(
88 ident.span(),
89 format!(
90 "Conflicting module specifications: #[gen_stub_pyclass(module = \"{}\")] \
91 and #[gen_stub(module = \"{}\")]. Please use only one.",
92 inline_mod, standalone_mod
93 ),
94 ));
95 }
96 }
97
98 let module = if let Some(inline_mod) = &attr.module {
100 Some(inline_mod.clone()) } else if let Some(standalone_mod) = gen_stub_standalone_module {
102 Some(standalone_mod) } else {
104 pyo3_module };
106
107 let pyclass_name = pyclass_name.unwrap_or_else(|| ident.to_string());
108 let mut getters = Vec::new();
109 let mut setters = Vec::new();
110 for field in fields {
111 let has_get = is_get_all || MemberInfo::is_get(&field)?;
112 let has_set = is_set_all || MemberInfo::is_set(&field)?;
113 if has_get {
114 getters.push(MemberInfo::from_field(field.clone(), MemberKind::Getter)?)
115 }
116 if has_set {
117 setters.push(MemberInfo::from_field(field, MemberKind::Setter)?)
118 }
119 }
120 let doc = extract_documents(&attrs).join("\n");
121 Ok(Self {
122 struct_type,
123 pyclass_name,
124 getters,
125 setters,
126 module,
127 doc,
128 bases,
129 has_eq,
130 has_ord,
131 has_hash,
132 has_str,
133 subclass,
134 })
135 }
136}
137
138impl TryFrom<ItemStruct> for PyClassInfo {
139 type Error = Error;
140 fn try_from(item: ItemStruct) -> Result<Self> {
141 Self::from_item_with_attr(item, &PyClassAttr::default())
143 }
144}
145
146impl ToTokens for PyClassInfo {
147 fn to_tokens(&self, tokens: &mut TokenStream2) {
148 let Self {
149 pyclass_name,
150 struct_type,
151 getters,
152 setters,
153 doc,
154 module,
155 bases,
156 has_eq,
157 has_ord,
158 has_hash,
159 has_str,
160 subclass,
161 } = self;
162 let module = quote_option(module);
163 tokens.append_all(quote! {
164 ::pyo3_stub_gen::type_info::PyClassInfo {
165 pyclass_name: #pyclass_name,
166 struct_id: std::any::TypeId::of::<#struct_type>,
167 getters: &[ #( #getters),* ],
168 setters: &[ #( #setters),* ],
169 module: #module,
170 doc: #doc,
171 bases: &[ #( <#bases as ::pyo3_stub_gen::PyStubType>::type_output ),* ],
172 has_eq: #has_eq,
173 has_ord: #has_ord,
174 has_hash: #has_hash,
175 has_str: #has_str,
176 subclass: #subclass,
177 }
178 })
179 }
180}
181
182pub fn prune_attrs(item_struct: &mut ItemStruct) {
186 super::attr::prune_attrs(&mut item_struct.attrs);
187 for field in item_struct.fields.iter_mut() {
188 super::attr::prune_attrs(&mut field.attrs);
189 }
190}
191
192#[cfg(test)]
193mod test {
194 use super::*;
195 use syn::parse_str;
196
197 #[test]
198 fn test_pyclass() -> Result<()> {
199 let input: ItemStruct = parse_str(
200 r#"
201 #[pyclass(mapping, module = "my_module", name = "Placeholder")]
202 #[derive(
203 Debug, Clone, PyNeg, PyAdd, PySub, PyMul, PyDiv, PyMod, PyPow, PyCmp, PyIndex, PyPrint,
204 )]
205 pub struct PyPlaceholder {
206 #[pyo3(get)]
207 pub name: String,
208 #[pyo3(get)]
209 pub ndim: usize,
210 #[pyo3(get)]
211 pub description: Option<String>,
212 pub custom_latex: Option<String>,
213 }
214 "#,
215 )?;
216 let out = PyClassInfo::try_from(input)?.to_token_stream();
217 insta::assert_snapshot!(format_as_value(out), @r###"
218 ::pyo3_stub_gen::type_info::PyClassInfo {
219 pyclass_name: "Placeholder",
220 struct_id: std::any::TypeId::of::<PyPlaceholder>,
221 getters: &[
222 ::pyo3_stub_gen::type_info::MemberInfo {
223 name: "name",
224 r#type: <String as ::pyo3_stub_gen::PyStubType>::type_output,
225 doc: "",
226 default: None,
227 deprecated: None,
228 },
229 ::pyo3_stub_gen::type_info::MemberInfo {
230 name: "ndim",
231 r#type: <usize as ::pyo3_stub_gen::PyStubType>::type_output,
232 doc: "",
233 default: None,
234 deprecated: None,
235 },
236 ::pyo3_stub_gen::type_info::MemberInfo {
237 name: "description",
238 r#type: <Option<String> as ::pyo3_stub_gen::PyStubType>::type_output,
239 doc: "",
240 default: None,
241 deprecated: None,
242 },
243 ],
244 setters: &[],
245 module: Some("my_module"),
246 doc: "",
247 bases: &[],
248 has_eq: false,
249 has_ord: false,
250 has_hash: false,
251 has_str: false,
252 subclass: false,
253 }
254 "###);
255 Ok(())
256 }
257
258 fn format_as_value(tt: TokenStream2) -> String {
259 let ttt = quote! { const _: () = #tt; };
260 let formatted = prettyplease::unparse(&syn::parse_file(&ttt.to_string()).unwrap());
261 formatted
262 .trim()
263 .strip_prefix("const _: () = ")
264 .unwrap()
265 .strip_suffix(';')
266 .unwrap()
267 .to_string()
268 }
269}