pyo3_stub_gen_derive/gen_stub/
pyclass_complex_enum.rs1use super::{extract_documents, parse_pyo3_attrs, util::quote_option, Attr, PyClassAttr, StubType};
2use crate::gen_stub::variant::VariantInfo;
3use proc_macro2::TokenStream as TokenStream2;
4use quote::{quote, ToTokens, TokenStreamExt};
5use syn::{parse_quote, Error, ItemEnum, Result, Type};
6
7pub struct PyComplexEnumInfo {
8 pyclass_name: String,
9 enum_type: Type,
10 module: Option<String>,
11 variants: Vec<VariantInfo>,
12 doc: String,
13}
14
15impl From<&PyComplexEnumInfo> for StubType {
16 fn from(info: &PyComplexEnumInfo) -> Self {
17 let PyComplexEnumInfo {
18 pyclass_name,
19 module,
20 enum_type,
21 ..
22 } = info;
23 Self {
24 ty: enum_type.clone(),
25 name: pyclass_name.clone(),
26 module: module.clone(),
27 }
28 }
29}
30
31impl PyComplexEnumInfo {
32 pub fn from_item_with_attr(item: ItemEnum, attr: &PyClassAttr) -> Result<Self> {
34 let ItemEnum {
35 variants,
36 attrs,
37 ident,
38 ..
39 } = item;
40
41 let doc = extract_documents(&attrs).join("\n");
42 let mut pyclass_name = None;
43 let mut pyo3_module = None;
44 let mut gen_stub_standalone_module = None;
45 let mut renaming_rule = None;
46 let mut bases = Vec::new();
47 for attr_item in parse_pyo3_attrs(&attrs)? {
48 match attr_item {
49 Attr::Name(name) => pyclass_name = Some(name),
50 Attr::Module(name) => pyo3_module = Some(name),
51 Attr::GenStubModule(name) => gen_stub_standalone_module = Some(name),
52 Attr::RenameAll(name) => renaming_rule = Some(name),
53 Attr::Extends(typ) => bases.push(typ),
54 _ => {}
55 }
56 }
57
58 if let (Some(inline_mod), Some(standalone_mod)) =
60 (&attr.module, &gen_stub_standalone_module)
61 {
62 if inline_mod != standalone_mod {
63 return Err(Error::new(
64 ident.span(),
65 format!(
66 "Conflicting module specifications: #[gen_stub_pyclass_complex_enum(module = \"{}\")] \
67 and #[gen_stub(module = \"{}\")]. Please use only one.",
68 inline_mod, standalone_mod
69 ),
70 ));
71 }
72 }
73
74 let module = if let Some(inline_mod) = &attr.module {
76 Some(inline_mod.clone()) } else if let Some(standalone_mod) = gen_stub_standalone_module {
78 Some(standalone_mod) } else {
80 pyo3_module };
82
83 let enum_type = parse_quote!(#ident);
84 let pyclass_name = pyclass_name.unwrap_or_else(|| ident.clone().to_string());
85
86 let mut items = Vec::new();
87 for variant in variants {
88 items.push(VariantInfo::from_variant(variant, &renaming_rule)?)
89 }
90
91 Ok(Self {
92 doc,
93 enum_type,
94 pyclass_name,
95 module,
96 variants: items,
97 })
98 }
99}
100
101impl TryFrom<ItemEnum> for PyComplexEnumInfo {
102 type Error = Error;
103
104 fn try_from(item: ItemEnum) -> Result<Self> {
105 Self::from_item_with_attr(item, &PyClassAttr::default())
107 }
108}
109
110impl ToTokens for PyComplexEnumInfo {
111 fn to_tokens(&self, tokens: &mut TokenStream2) {
112 let Self {
113 pyclass_name,
114 enum_type,
115 variants,
116 doc,
117 module,
118 ..
119 } = self;
120 let module = quote_option(module);
121
122 tokens.append_all(quote! {
123 ::pyo3_stub_gen::type_info::PyComplexEnumInfo {
124 pyclass_name: #pyclass_name,
125 enum_id: std::any::TypeId::of::<#enum_type>,
126 variants: &[ #( #variants ),* ],
127 module: #module,
128 doc: #doc,
129 }
130 })
131 }
132}
133
134#[cfg(test)]
135mod test {
136 use super::*;
137 use syn::parse_str;
138
139 #[test]
140 fn test_complex_enum() -> Result<()> {
141 let input: ItemEnum = parse_str(
142 r#"
143 #[pyclass(mapping, module = "my_module", name = "Placeholder")]
144 #[derive(
145 Debug, Clone, PyNeg, PyAdd, PySub, PyMul, PyDiv, PyMod, PyPow, PyCmp, PyIndex, PyPrint,
146 )]
147 pub enum PyPlaceholder {
148 #[pyo3(name="Name")]
149 name(String),
150 #[pyo3(constructor = (_0, _1=1.0))]
151 twonum(i32,f64),
152 ndim{count: usize},
153 description,
154 }
155 "#,
156 )?;
157 let out = PyComplexEnumInfo::try_from(input)?.to_token_stream();
158 insta::assert_snapshot!(format_as_value(out), @r###"
159 ::pyo3_stub_gen::type_info::PyComplexEnumInfo {
160 pyclass_name: "Placeholder",
161 enum_id: std::any::TypeId::of::<PyPlaceholder>,
162 variants: &[
163 ::pyo3_stub_gen::type_info::VariantInfo {
164 pyclass_name: "Name",
165 fields: &[
166 ::pyo3_stub_gen::type_info::MemberInfo {
167 name: "_0",
168 r#type: <String as ::pyo3_stub_gen::PyStubType>::type_output,
169 doc: "",
170 default: None,
171 deprecated: None,
172 },
173 ],
174 module: None,
175 doc: "",
176 form: &pyo3_stub_gen::type_info::VariantForm::Tuple,
177 constr_args: &[
178 ::pyo3_stub_gen::type_info::ParameterInfo {
179 name: "_0",
180 kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
181 type_info: <String as ::pyo3_stub_gen::PyStubType>::type_input,
182 default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
183 },
184 ],
185 },
186 ::pyo3_stub_gen::type_info::VariantInfo {
187 pyclass_name: "twonum",
188 fields: &[
189 ::pyo3_stub_gen::type_info::MemberInfo {
190 name: "_0",
191 r#type: <i32 as ::pyo3_stub_gen::PyStubType>::type_output,
192 doc: "",
193 default: None,
194 deprecated: None,
195 },
196 ::pyo3_stub_gen::type_info::MemberInfo {
197 name: "_1",
198 r#type: <f64 as ::pyo3_stub_gen::PyStubType>::type_output,
199 doc: "",
200 default: None,
201 deprecated: None,
202 },
203 ],
204 module: None,
205 doc: "",
206 form: &pyo3_stub_gen::type_info::VariantForm::Tuple,
207 constr_args: &[
208 ::pyo3_stub_gen::type_info::ParameterInfo {
209 name: "_0",
210 kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
211 type_info: <i32 as ::pyo3_stub_gen::PyStubType>::type_input,
212 default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
213 },
214 ::pyo3_stub_gen::type_info::ParameterInfo {
215 name: "_1",
216 kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
217 type_info: <f64 as ::pyo3_stub_gen::PyStubType>::type_input,
218 default: ::pyo3_stub_gen::type_info::ParameterDefault::Expr {
219 value: {
220 fn _fmt() -> String {
221 {
222 let v: f64 = 1.0;
223 ::pyo3_stub_gen::util::fmt_py_obj(v)
224 }
225 }
226 _fmt
227 },
228 source_module: Some({
229 fn _get_module() -> Option<::pyo3_stub_gen::ModuleRef> {
230 <f64 as ::pyo3_stub_gen::PyStubType>::type_output()
231 .source_module
232 }
233 _get_module
234 }),
235 },
236 },
237 ],
238 },
239 ::pyo3_stub_gen::type_info::VariantInfo {
240 pyclass_name: "ndim",
241 fields: &[
242 ::pyo3_stub_gen::type_info::MemberInfo {
243 name: "count",
244 r#type: <usize as ::pyo3_stub_gen::PyStubType>::type_output,
245 doc: "",
246 default: None,
247 deprecated: None,
248 },
249 ],
250 module: None,
251 doc: "",
252 form: &pyo3_stub_gen::type_info::VariantForm::Struct,
253 constr_args: &[
254 ::pyo3_stub_gen::type_info::ParameterInfo {
255 name: "count",
256 kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
257 type_info: <usize as ::pyo3_stub_gen::PyStubType>::type_input,
258 default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
259 },
260 ],
261 },
262 ::pyo3_stub_gen::type_info::VariantInfo {
263 pyclass_name: "description",
264 fields: &[],
265 module: None,
266 doc: "",
267 form: &pyo3_stub_gen::type_info::VariantForm::Unit,
268 constr_args: &[],
269 },
270 ],
271 module: Some("my_module"),
272 doc: "",
273 }
274 "###);
275 Ok(())
276 }
277
278 fn format_as_value(tt: TokenStream2) -> String {
279 let ttt = quote! { const _: () = #tt; };
280 let formatted = prettyplease::unparse(&syn::parse_file(&ttt.to_string()).unwrap());
281 formatted
282 .trim()
283 .strip_prefix("const _: () = ")
284 .unwrap()
285 .strip_suffix(';')
286 .unwrap()
287 .to_string()
288 }
289}