pyo3_stub_gen_derive/gen_stub/
pyclass_complex_enum.rs1use super::{extract_documents, parse_pyo3_attrs, util::quote_option, Attr, PyClassAttr, StubType};
2use crate::gen_stub::variant::VariantInfo;
3use proc_macro2::TokenStream as TokenStream2;
4use quote::{quote, ToTokens, TokenStreamExt};
5use syn::{parse_quote, Error, ItemEnum, Result, Type};
6
7pub struct PyComplexEnumInfo {
8 pyclass_name: String,
9 enum_type: Type,
10 module: Option<String>,
11 variants: Vec<VariantInfo>,
12 doc: String,
13}
14
15impl From<&PyComplexEnumInfo> for StubType {
16 fn from(info: &PyComplexEnumInfo) -> Self {
17 let PyComplexEnumInfo {
18 pyclass_name,
19 module,
20 enum_type,
21 ..
22 } = info;
23 Self {
24 ty: enum_type.clone(),
25 name: pyclass_name.clone(),
26 module: module.clone(),
27 }
28 }
29}
30
31impl PyComplexEnumInfo {
32 pub fn from_item_with_attr(item: ItemEnum, attr: &PyClassAttr) -> Result<Self> {
34 let ItemEnum {
35 variants,
36 attrs,
37 ident,
38 ..
39 } = item;
40
41 let doc = extract_documents(&attrs).join("\n");
42 let mut pyclass_name = None;
43 let mut pyo3_module = None;
44 let mut gen_stub_standalone_module = None;
45 let mut renaming_rule = None;
46 let mut bases = Vec::new();
47 for attr_item in parse_pyo3_attrs(&attrs)? {
48 match attr_item {
49 Attr::Name(name) => pyclass_name = Some(name),
50 Attr::Module(name) => pyo3_module = Some(name),
51 Attr::GenStubModule(name) => gen_stub_standalone_module = Some(name),
52 Attr::RenameAll(name) => renaming_rule = Some(name),
53 Attr::Extends(typ) => bases.push(typ),
54 _ => {}
55 }
56 }
57
58 if let (Some(inline_mod), Some(standalone_mod)) =
60 (&attr.module, &gen_stub_standalone_module)
61 {
62 if inline_mod != standalone_mod {
63 return Err(Error::new(
64 ident.span(),
65 format!(
66 "Conflicting module specifications: #[gen_stub_pyclass_complex_enum(module = \"{}\")] \
67 and #[gen_stub(module = \"{}\")]. Please use only one.",
68 inline_mod, standalone_mod
69 ),
70 ));
71 }
72 }
73
74 let module = if let Some(inline_mod) = &attr.module {
76 Some(inline_mod.clone()) } else if let Some(standalone_mod) = gen_stub_standalone_module {
78 Some(standalone_mod) } else {
80 pyo3_module };
82
83 let enum_type = parse_quote!(#ident);
84 let pyclass_name = pyclass_name.unwrap_or_else(|| ident.clone().to_string());
85
86 let mut items = Vec::new();
87 for variant in variants {
88 items.push(VariantInfo::from_variant(variant, &renaming_rule)?)
89 }
90
91 Ok(Self {
92 doc,
93 enum_type,
94 pyclass_name,
95 module,
96 variants: items,
97 })
98 }
99}
100
101impl TryFrom<ItemEnum> for PyComplexEnumInfo {
102 type Error = Error;
103
104 fn try_from(item: ItemEnum) -> Result<Self> {
105 Self::from_item_with_attr(item, &PyClassAttr::default())
107 }
108}
109
110impl ToTokens for PyComplexEnumInfo {
111 fn to_tokens(&self, tokens: &mut TokenStream2) {
112 let Self {
113 pyclass_name,
114 enum_type,
115 variants,
116 doc,
117 module,
118 ..
119 } = self;
120 let module = quote_option(module);
121
122 tokens.append_all(quote! {
123 ::pyo3_stub_gen::type_info::PyComplexEnumInfo {
124 pyclass_name: #pyclass_name,
125 enum_id: std::any::TypeId::of::<#enum_type>,
126 variants: &[ #( #variants ),* ],
127 module: #module,
128 doc: #doc,
129 }
130 })
131 }
132}
133
134#[cfg(test)]
135mod test {
136 use super::*;
137 use syn::parse_str;
138
139 #[test]
140 fn test_complex_enum() -> Result<()> {
141 let input: ItemEnum = parse_str(
142 r#"
143 #[pyclass(mapping, module = "my_module", name = "Placeholder")]
144 #[derive(
145 Debug, Clone, PyNeg, PyAdd, PySub, PyMul, PyDiv, PyMod, PyPow, PyCmp, PyIndex, PyPrint,
146 )]
147 pub enum PyPlaceholder {
148 #[pyo3(name="Name")]
149 name(String),
150 #[pyo3(constructor = (_0, _1=1.0))]
151 twonum(i32,f64),
152 ndim{count: usize},
153 description,
154 }
155 "#,
156 )?;
157 let out = PyComplexEnumInfo::try_from(input)?.to_token_stream();
158 insta::assert_snapshot!(format_as_value(out), @r#"
159 ::pyo3_stub_gen::type_info::PyComplexEnumInfo {
160 pyclass_name: "Placeholder",
161 enum_id: std::any::TypeId::of::<PyPlaceholder>,
162 variants: &[
163 ::pyo3_stub_gen::type_info::VariantInfo {
164 pyclass_name: "Name",
165 fields: &[
166 ::pyo3_stub_gen::type_info::MemberInfo {
167 name: "_0",
168 r#type: <String as ::pyo3_stub_gen::PyStubType>::type_output,
169 doc: "",
170 default: None,
171 deprecated: None,
172 },
173 ],
174 module: None,
175 doc: "",
176 form: &pyo3_stub_gen::type_info::VariantForm::Tuple,
177 constr_args: &[
178 ::pyo3_stub_gen::type_info::ParameterInfo {
179 name: "_0",
180 kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
181 type_info: <String as ::pyo3_stub_gen::PyStubType>::type_input,
182 default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
183 },
184 ],
185 },
186 ::pyo3_stub_gen::type_info::VariantInfo {
187 pyclass_name: "twonum",
188 fields: &[
189 ::pyo3_stub_gen::type_info::MemberInfo {
190 name: "_0",
191 r#type: <i32 as ::pyo3_stub_gen::PyStubType>::type_output,
192 doc: "",
193 default: None,
194 deprecated: None,
195 },
196 ::pyo3_stub_gen::type_info::MemberInfo {
197 name: "_1",
198 r#type: <f64 as ::pyo3_stub_gen::PyStubType>::type_output,
199 doc: "",
200 default: None,
201 deprecated: None,
202 },
203 ],
204 module: None,
205 doc: "",
206 form: &pyo3_stub_gen::type_info::VariantForm::Tuple,
207 constr_args: &[
208 ::pyo3_stub_gen::type_info::ParameterInfo {
209 name: "_0",
210 kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
211 type_info: <i32 as ::pyo3_stub_gen::PyStubType>::type_input,
212 default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
213 },
214 ::pyo3_stub_gen::type_info::ParameterInfo {
215 name: "_1",
216 kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
217 type_info: <f64 as ::pyo3_stub_gen::PyStubType>::type_input,
218 default: ::pyo3_stub_gen::type_info::ParameterDefault::Expr({
219 fn _fmt() -> String {
220 {
221 let v: f64 = 1.0;
222 let repr = ::pyo3_stub_gen::util::fmt_py_obj(v);
223 let type_info = <f64 as ::pyo3_stub_gen::PyStubType>::type_output();
224 let should_add_prefix = if let Some(dot_pos) = type_info
225 .name
226 .rfind('.')
227 {
228 let module_prefix = &type_info.name[..dot_pos];
229 type_info
230 .import
231 .iter()
232 .any(|imp| {
233 if let ::pyo3_stub_gen::ImportRef::Module(module_ref) = imp {
234 if let Some(module_name) = module_ref.get() {
235 module_name.ends_with(&format!(".{}", module_prefix))
236 } else {
237 false
238 }
239 } else {
240 false
241 }
242 })
243 } else {
244 false
245 };
246 if should_add_prefix {
247 if let Some(dot_pos) = type_info.name.rfind('.') {
248 let module_prefix = &type_info.name[..dot_pos];
249 format!("{}.{}", module_prefix, repr)
250 } else {
251 repr
252 }
253 } else {
254 repr
255 }
256 }
257 }
258 _fmt
259 }),
260 },
261 ],
262 },
263 ::pyo3_stub_gen::type_info::VariantInfo {
264 pyclass_name: "ndim",
265 fields: &[
266 ::pyo3_stub_gen::type_info::MemberInfo {
267 name: "count",
268 r#type: <usize as ::pyo3_stub_gen::PyStubType>::type_output,
269 doc: "",
270 default: None,
271 deprecated: None,
272 },
273 ],
274 module: None,
275 doc: "",
276 form: &pyo3_stub_gen::type_info::VariantForm::Struct,
277 constr_args: &[
278 ::pyo3_stub_gen::type_info::ParameterInfo {
279 name: "count",
280 kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
281 type_info: <usize as ::pyo3_stub_gen::PyStubType>::type_input,
282 default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
283 },
284 ],
285 },
286 ::pyo3_stub_gen::type_info::VariantInfo {
287 pyclass_name: "description",
288 fields: &[],
289 module: None,
290 doc: "",
291 form: &pyo3_stub_gen::type_info::VariantForm::Unit,
292 constr_args: &[],
293 },
294 ],
295 module: Some("my_module"),
296 doc: "",
297 }
298 "#);
299 Ok(())
300 }
301
302 fn format_as_value(tt: TokenStream2) -> String {
303 let ttt = quote! { const _: () = #tt; };
304 let formatted = prettyplease::unparse(&syn::parse_file(&ttt.to_string()).unwrap());
305 formatted
306 .trim()
307 .strip_prefix("const _: () = ")
308 .unwrap()
309 .strip_suffix(';')
310 .unwrap()
311 .to_string()
312 }
313}