pyo3_stub_gen_derive/gen_stub/
pyclass.rs1use super::{extract_documents, parse_pyo3_attrs, util::quote_option, Attr, MemberInfo, StubType};
2use proc_macro2::TokenStream as TokenStream2;
3use quote::{quote, ToTokens, TokenStreamExt};
4use syn::{parse_quote, Error, ItemStruct, Result, Type};
5
6pub struct PyClassInfo {
7 pyclass_name: String,
8 struct_type: Type,
9 module: Option<String>,
10 getters: Vec<MemberInfo>,
11 setters: Vec<MemberInfo>,
12 doc: String,
13 bases: Vec<Type>,
14 has_eq: bool,
15 has_ord: bool,
16 has_hash: bool,
17 has_str: bool,
18}
19
20impl From<&PyClassInfo> for StubType {
21 fn from(info: &PyClassInfo) -> Self {
22 let PyClassInfo {
23 pyclass_name,
24 module,
25 struct_type,
26 ..
27 } = info;
28 Self {
29 ty: struct_type.clone(),
30 name: pyclass_name.clone(),
31 module: module.clone(),
32 }
33 }
34}
35
36impl TryFrom<ItemStruct> for PyClassInfo {
37 type Error = Error;
38 fn try_from(item: ItemStruct) -> Result<Self> {
39 let ItemStruct {
40 ident,
41 attrs,
42 fields,
43 ..
44 } = item;
45 let struct_type: Type = parse_quote!(#ident);
46 let mut pyclass_name = None;
47 let mut module = None;
48 let mut is_get_all = false;
49 let mut is_set_all = false;
50 let mut bases = Vec::new();
51 let mut has_eq = false;
52 let mut has_ord = false;
53 let mut has_hash = false;
54 let mut has_str = false;
55 for attr in parse_pyo3_attrs(&attrs)? {
56 match attr {
57 Attr::Name(name) => pyclass_name = Some(name),
58 Attr::Module(name) => {
59 module = Some(name);
60 }
61 Attr::GetAll => is_get_all = true,
62 Attr::SetAll => is_set_all = true,
63 Attr::Extends(typ) => bases.push(typ),
64 Attr::Eq => has_eq = true,
65 Attr::Ord => has_ord = true,
66 Attr::Hash => has_hash = true,
67 Attr::Str => has_str = true,
68 _ => {}
69 }
70 }
71 let pyclass_name = pyclass_name.unwrap_or_else(|| ident.to_string());
72 let mut getters = Vec::new();
73 let mut setters = Vec::new();
74 for field in fields {
75 if is_get_all || MemberInfo::is_get(&field)? {
76 getters.push(MemberInfo::try_from(field.clone())?)
77 }
78 if is_set_all || MemberInfo::is_set(&field)? {
79 setters.push(MemberInfo::try_from(field)?)
80 }
81 }
82 let doc = extract_documents(&attrs).join("\n");
83 Ok(Self {
84 struct_type,
85 pyclass_name,
86 getters,
87 setters,
88 module,
89 doc,
90 bases,
91 has_eq,
92 has_ord,
93 has_hash,
94 has_str,
95 })
96 }
97}
98
99impl ToTokens for PyClassInfo {
100 fn to_tokens(&self, tokens: &mut TokenStream2) {
101 let Self {
102 pyclass_name,
103 struct_type,
104 getters,
105 setters,
106 doc,
107 module,
108 bases,
109 has_eq,
110 has_ord,
111 has_hash,
112 has_str,
113 } = self;
114 let module = quote_option(module);
115 tokens.append_all(quote! {
116 ::pyo3_stub_gen::type_info::PyClassInfo {
117 pyclass_name: #pyclass_name,
118 struct_id: std::any::TypeId::of::<#struct_type>,
119 getters: &[ #( #getters),* ],
120 setters: &[ #( #setters),* ],
121 module: #module,
122 doc: #doc,
123 bases: &[ #( <#bases as ::pyo3_stub_gen::PyStubType>::type_output ),* ],
124 has_eq: #has_eq,
125 has_ord: #has_ord,
126 has_hash: #has_hash,
127 has_str: #has_str,
128 }
129 })
130 }
131}
132
133pub fn prune_attrs(item_struct: &mut ItemStruct) {
137 super::attr::prune_attrs(&mut item_struct.attrs);
138 for field in item_struct.fields.iter_mut() {
139 super::attr::prune_attrs(&mut field.attrs);
140 }
141}
142
143#[cfg(test)]
144mod test {
145 use super::*;
146 use syn::parse_str;
147
148 #[test]
149 fn test_pyclass() -> Result<()> {
150 let input: ItemStruct = parse_str(
151 r#"
152 #[pyclass(mapping, module = "my_module", name = "Placeholder")]
153 #[derive(
154 Debug, Clone, PyNeg, PyAdd, PySub, PyMul, PyDiv, PyMod, PyPow, PyCmp, PyIndex, PyPrint,
155 )]
156 pub struct PyPlaceholder {
157 #[pyo3(get)]
158 pub name: String,
159 #[pyo3(get)]
160 pub ndim: usize,
161 #[pyo3(get)]
162 pub description: Option<String>,
163 pub custom_latex: Option<String>,
164 }
165 "#,
166 )?;
167 let out = PyClassInfo::try_from(input)?.to_token_stream();
168 insta::assert_snapshot!(format_as_value(out), @r###"
169 ::pyo3_stub_gen::type_info::PyClassInfo {
170 pyclass_name: "Placeholder",
171 struct_id: std::any::TypeId::of::<PyPlaceholder>,
172 getters: &[
173 ::pyo3_stub_gen::type_info::MemberInfo {
174 name: "name",
175 r#type: <String as ::pyo3_stub_gen::PyStubType>::type_output,
176 doc: "",
177 default: None,
178 deprecated: None,
179 },
180 ::pyo3_stub_gen::type_info::MemberInfo {
181 name: "ndim",
182 r#type: <usize as ::pyo3_stub_gen::PyStubType>::type_output,
183 doc: "",
184 default: None,
185 deprecated: None,
186 },
187 ::pyo3_stub_gen::type_info::MemberInfo {
188 name: "description",
189 r#type: <Option<String> as ::pyo3_stub_gen::PyStubType>::type_output,
190 doc: "",
191 default: None,
192 deprecated: None,
193 },
194 ],
195 setters: &[],
196 module: Some("my_module"),
197 doc: "",
198 bases: &[],
199 has_eq: false,
200 has_ord: false,
201 has_hash: false,
202 has_str: false,
203 }
204 "###);
205 Ok(())
206 }
207
208 fn format_as_value(tt: TokenStream2) -> String {
209 let ttt = quote! { const _: () = #tt; };
210 let formatted = prettyplease::unparse(&syn::parse_file(&ttt.to_string()).unwrap());
211 formatted
212 .trim()
213 .strip_prefix("const _: () = ")
214 .unwrap()
215 .strip_suffix(';')
216 .unwrap()
217 .to_string()
218 }
219}