pyo3_stub_gen_derive/gen_stub/
pyclass.rs1use super::{
2 extract_documents, parse_pyo3_attrs, util::quote_option, Attr, MemberInfo, PyClassAttr,
3 StubType,
4};
5use proc_macro2::TokenStream as TokenStream2;
6use quote::{quote, ToTokens, TokenStreamExt};
7use syn::{parse_quote, Error, ItemStruct, Result, Type};
8
9pub struct PyClassInfo {
10 pyclass_name: String,
11 struct_type: Type,
12 module: Option<String>,
13 getters: Vec<MemberInfo>,
14 setters: Vec<MemberInfo>,
15 doc: String,
16 bases: Vec<Type>,
17 has_eq: bool,
18 has_ord: bool,
19 has_hash: bool,
20 has_str: bool,
21 subclass: bool,
22}
23
24impl From<&PyClassInfo> for StubType {
25 fn from(info: &PyClassInfo) -> Self {
26 let PyClassInfo {
27 pyclass_name,
28 module,
29 struct_type,
30 ..
31 } = info;
32 Self {
33 ty: struct_type.clone(),
34 name: pyclass_name.clone(),
35 module: module.clone(),
36 }
37 }
38}
39
40impl PyClassInfo {
41 pub fn from_item_with_attr(item: ItemStruct, attr: &PyClassAttr) -> Result<Self> {
43 let ItemStruct {
44 ident,
45 attrs,
46 fields,
47 ..
48 } = item;
49 let struct_type: Type = parse_quote!(#ident);
50 let mut pyclass_name = None;
51 let mut pyo3_module = None;
52 let mut gen_stub_standalone_module = None;
53 let mut is_get_all = false;
54 let mut is_set_all = false;
55 let mut bases = Vec::new();
56 let mut has_eq = false;
57 let mut has_ord = false;
58 let mut has_hash = false;
59 let mut has_str = false;
60 let mut subclass = false;
61 for attr_item in parse_pyo3_attrs(&attrs)? {
62 match attr_item {
63 Attr::Name(name) => pyclass_name = Some(name),
64 Attr::Module(name) => {
65 pyo3_module = Some(name);
66 }
67 Attr::GenStubModule(name) => {
68 gen_stub_standalone_module = Some(name);
69 }
70 Attr::GetAll => is_get_all = true,
71 Attr::SetAll => is_set_all = true,
72 Attr::Extends(typ) => bases.push(typ),
73 Attr::Eq => has_eq = true,
74 Attr::Ord => has_ord = true,
75 Attr::Hash => has_hash = true,
76 Attr::Str => has_str = true,
77 Attr::Subclass => subclass = true,
78 _ => {}
79 }
80 }
81
82 if let (Some(inline_mod), Some(standalone_mod)) =
84 (&attr.module, &gen_stub_standalone_module)
85 {
86 if inline_mod != standalone_mod {
87 return Err(Error::new(
88 ident.span(),
89 format!(
90 "Conflicting module specifications: #[gen_stub_pyclass(module = \"{}\")] \
91 and #[gen_stub(module = \"{}\")]. Please use only one.",
92 inline_mod, standalone_mod
93 ),
94 ));
95 }
96 }
97
98 let module = if let Some(inline_mod) = &attr.module {
100 Some(inline_mod.clone()) } else if let Some(standalone_mod) = gen_stub_standalone_module {
102 Some(standalone_mod) } else {
104 pyo3_module };
106
107 let pyclass_name = pyclass_name.unwrap_or_else(|| ident.to_string());
108 let mut getters = Vec::new();
109 let mut setters = Vec::new();
110 for field in fields {
111 if is_get_all || MemberInfo::is_get(&field)? {
112 getters.push(MemberInfo::try_from(field.clone())?)
113 }
114 if is_set_all || MemberInfo::is_set(&field)? {
115 setters.push(MemberInfo::try_from(field)?)
116 }
117 }
118 let doc = extract_documents(&attrs).join("\n");
119 Ok(Self {
120 struct_type,
121 pyclass_name,
122 getters,
123 setters,
124 module,
125 doc,
126 bases,
127 has_eq,
128 has_ord,
129 has_hash,
130 has_str,
131 subclass,
132 })
133 }
134}
135
136impl TryFrom<ItemStruct> for PyClassInfo {
137 type Error = Error;
138 fn try_from(item: ItemStruct) -> Result<Self> {
139 Self::from_item_with_attr(item, &PyClassAttr::default())
141 }
142}
143
144impl ToTokens for PyClassInfo {
145 fn to_tokens(&self, tokens: &mut TokenStream2) {
146 let Self {
147 pyclass_name,
148 struct_type,
149 getters,
150 setters,
151 doc,
152 module,
153 bases,
154 has_eq,
155 has_ord,
156 has_hash,
157 has_str,
158 subclass,
159 } = self;
160 let module = quote_option(module);
161 tokens.append_all(quote! {
162 ::pyo3_stub_gen::type_info::PyClassInfo {
163 pyclass_name: #pyclass_name,
164 struct_id: std::any::TypeId::of::<#struct_type>,
165 getters: &[ #( #getters),* ],
166 setters: &[ #( #setters),* ],
167 module: #module,
168 doc: #doc,
169 bases: &[ #( <#bases as ::pyo3_stub_gen::PyStubType>::type_output ),* ],
170 has_eq: #has_eq,
171 has_ord: #has_ord,
172 has_hash: #has_hash,
173 has_str: #has_str,
174 subclass: #subclass,
175 }
176 })
177 }
178}
179
180pub fn prune_attrs(item_struct: &mut ItemStruct) {
184 super::attr::prune_attrs(&mut item_struct.attrs);
185 for field in item_struct.fields.iter_mut() {
186 super::attr::prune_attrs(&mut field.attrs);
187 }
188}
189
190#[cfg(test)]
191mod test {
192 use super::*;
193 use syn::parse_str;
194
195 #[test]
196 fn test_pyclass() -> Result<()> {
197 let input: ItemStruct = parse_str(
198 r#"
199 #[pyclass(mapping, module = "my_module", name = "Placeholder")]
200 #[derive(
201 Debug, Clone, PyNeg, PyAdd, PySub, PyMul, PyDiv, PyMod, PyPow, PyCmp, PyIndex, PyPrint,
202 )]
203 pub struct PyPlaceholder {
204 #[pyo3(get)]
205 pub name: String,
206 #[pyo3(get)]
207 pub ndim: usize,
208 #[pyo3(get)]
209 pub description: Option<String>,
210 pub custom_latex: Option<String>,
211 }
212 "#,
213 )?;
214 let out = PyClassInfo::try_from(input)?.to_token_stream();
215 insta::assert_snapshot!(format_as_value(out), @r###"
216 ::pyo3_stub_gen::type_info::PyClassInfo {
217 pyclass_name: "Placeholder",
218 struct_id: std::any::TypeId::of::<PyPlaceholder>,
219 getters: &[
220 ::pyo3_stub_gen::type_info::MemberInfo {
221 name: "name",
222 r#type: <String as ::pyo3_stub_gen::PyStubType>::type_output,
223 doc: "",
224 default: None,
225 deprecated: None,
226 },
227 ::pyo3_stub_gen::type_info::MemberInfo {
228 name: "ndim",
229 r#type: <usize as ::pyo3_stub_gen::PyStubType>::type_output,
230 doc: "",
231 default: None,
232 deprecated: None,
233 },
234 ::pyo3_stub_gen::type_info::MemberInfo {
235 name: "description",
236 r#type: <Option<String> as ::pyo3_stub_gen::PyStubType>::type_output,
237 doc: "",
238 default: None,
239 deprecated: None,
240 },
241 ],
242 setters: &[],
243 module: Some("my_module"),
244 doc: "",
245 bases: &[],
246 has_eq: false,
247 has_ord: false,
248 has_hash: false,
249 has_str: false,
250 subclass: false,
251 }
252 "###);
253 Ok(())
254 }
255
256 fn format_as_value(tt: TokenStream2) -> String {
257 let ttt = quote! { const _: () = #tt; };
258 let formatted = prettyplease::unparse(&syn::parse_file(&ttt.to_string()).unwrap());
259 formatted
260 .trim()
261 .strip_prefix("const _: () = ")
262 .unwrap()
263 .strip_suffix(';')
264 .unwrap()
265 .to_string()
266 }
267}