pyo3_stub_gen_derive/gen_stub/
pyclass_complex_enum.rs1use super::{extract_documents, parse_pyo3_attrs, util::quote_option, Attr, StubType};
2use crate::gen_stub::variant::VariantInfo;
3use proc_macro2::TokenStream as TokenStream2;
4use quote::{quote, ToTokens, TokenStreamExt};
5use syn::{parse_quote, Error, ItemEnum, Result, Type};
6
7pub struct PyComplexEnumInfo {
8 pyclass_name: String,
9 enum_type: Type,
10 module: Option<String>,
11 variants: Vec<VariantInfo>,
12 doc: String,
13}
14
15impl From<&PyComplexEnumInfo> for StubType {
16 fn from(info: &PyComplexEnumInfo) -> Self {
17 let PyComplexEnumInfo {
18 pyclass_name,
19 module,
20 enum_type,
21 ..
22 } = info;
23 Self {
24 ty: enum_type.clone(),
25 name: pyclass_name.clone(),
26 module: module.clone(),
27 }
28 }
29}
30
31impl TryFrom<ItemEnum> for PyComplexEnumInfo {
32 type Error = Error;
33
34 fn try_from(item: ItemEnum) -> Result<Self> {
35 let ItemEnum {
36 variants,
37 attrs,
38 ident,
39 ..
40 } = item;
41
42 let doc = extract_documents(&attrs).join("\n");
43 let mut pyclass_name = None;
44 let mut module = None;
45 let mut renaming_rule = None;
46 let mut bases = Vec::new();
47 for attr in parse_pyo3_attrs(&attrs)? {
48 match attr {
49 Attr::Name(name) => pyclass_name = Some(name),
50 Attr::Module(name) => module = Some(name),
51 Attr::RenameAll(name) => renaming_rule = Some(name),
52 Attr::Extends(typ) => bases.push(typ),
53 _ => {}
54 }
55 }
56
57 let enum_type = parse_quote!(#ident);
58 let pyclass_name = pyclass_name.unwrap_or_else(|| ident.clone().to_string());
59
60 let mut items = Vec::new();
61 for variant in variants {
62 items.push(VariantInfo::from_variant(variant, &renaming_rule)?)
63 }
64
65 Ok(Self {
66 doc,
67 enum_type,
68 pyclass_name,
69 module,
70 variants: items,
71 })
72 }
73}
74
75impl ToTokens for PyComplexEnumInfo {
76 fn to_tokens(&self, tokens: &mut TokenStream2) {
77 let Self {
78 pyclass_name,
79 enum_type,
80 variants,
81 doc,
82 module,
83 ..
84 } = self;
85 let module = quote_option(module);
86
87 tokens.append_all(quote! {
88 ::pyo3_stub_gen::type_info::PyComplexEnumInfo {
89 pyclass_name: #pyclass_name,
90 enum_id: std::any::TypeId::of::<#enum_type>,
91 variants: &[ #( #variants ),* ],
92 module: #module,
93 doc: #doc,
94 }
95 })
96 }
97}
98
99#[cfg(test)]
100mod test {
101 use super::*;
102 use syn::parse_str;
103
104 #[test]
105 fn test_complex_enum() -> Result<()> {
106 let input: ItemEnum = parse_str(
107 r#"
108 #[pyclass(mapping, module = "my_module", name = "Placeholder")]
109 #[derive(
110 Debug, Clone, PyNeg, PyAdd, PySub, PyMul, PyDiv, PyMod, PyPow, PyCmp, PyIndex, PyPrint,
111 )]
112 pub enum PyPlaceholder {
113 #[pyo3(name="Name")]
114 name(String),
115 #[pyo3(constructor = (_0, _1=1.0))]
116 twonum(i32,f64),
117 ndim{count: usize},
118 description,
119 }
120 "#,
121 )?;
122 let out = PyComplexEnumInfo::try_from(input)?.to_token_stream();
123 insta::assert_snapshot!(format_as_value(out), @r###"
124 ::pyo3_stub_gen::type_info::PyComplexEnumInfo {
125 pyclass_name: "Placeholder",
126 enum_id: std::any::TypeId::of::<PyPlaceholder>,
127 variants: &[
128 ::pyo3_stub_gen::type_info::VariantInfo {
129 pyclass_name: "Name",
130 fields: &[
131 ::pyo3_stub_gen::type_info::MemberInfo {
132 name: "_0",
133 r#type: <String as ::pyo3_stub_gen::PyStubType>::type_output,
134 doc: "",
135 default: None,
136 deprecated: None,
137 },
138 ],
139 module: None,
140 doc: "",
141 form: &pyo3_stub_gen::type_info::VariantForm::Tuple,
142 constr_args: &[
143 ::pyo3_stub_gen::type_info::ParameterInfo {
144 name: "_0",
145 kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
146 type_info: <String as ::pyo3_stub_gen::PyStubType>::type_input,
147 default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
148 },
149 ],
150 },
151 ::pyo3_stub_gen::type_info::VariantInfo {
152 pyclass_name: "twonum",
153 fields: &[
154 ::pyo3_stub_gen::type_info::MemberInfo {
155 name: "_0",
156 r#type: <i32 as ::pyo3_stub_gen::PyStubType>::type_output,
157 doc: "",
158 default: None,
159 deprecated: None,
160 },
161 ::pyo3_stub_gen::type_info::MemberInfo {
162 name: "_1",
163 r#type: <f64 as ::pyo3_stub_gen::PyStubType>::type_output,
164 doc: "",
165 default: None,
166 deprecated: None,
167 },
168 ],
169 module: None,
170 doc: "",
171 form: &pyo3_stub_gen::type_info::VariantForm::Tuple,
172 constr_args: &[
173 ::pyo3_stub_gen::type_info::ParameterInfo {
174 name: "_0",
175 kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
176 type_info: <i32 as ::pyo3_stub_gen::PyStubType>::type_input,
177 default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
178 },
179 ::pyo3_stub_gen::type_info::ParameterInfo {
180 name: "_1",
181 kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
182 type_info: <f64 as ::pyo3_stub_gen::PyStubType>::type_input,
183 default: ::pyo3_stub_gen::type_info::ParameterDefault::Expr({
184 fn _fmt() -> String {
185 let v: f64 = 1.0;
186 ::pyo3_stub_gen::util::fmt_py_obj(v)
187 }
188 _fmt
189 }),
190 },
191 ],
192 },
193 ::pyo3_stub_gen::type_info::VariantInfo {
194 pyclass_name: "ndim",
195 fields: &[
196 ::pyo3_stub_gen::type_info::MemberInfo {
197 name: "count",
198 r#type: <usize as ::pyo3_stub_gen::PyStubType>::type_output,
199 doc: "",
200 default: None,
201 deprecated: None,
202 },
203 ],
204 module: None,
205 doc: "",
206 form: &pyo3_stub_gen::type_info::VariantForm::Struct,
207 constr_args: &[
208 ::pyo3_stub_gen::type_info::ParameterInfo {
209 name: "count",
210 kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
211 type_info: <usize as ::pyo3_stub_gen::PyStubType>::type_input,
212 default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
213 },
214 ],
215 },
216 ::pyo3_stub_gen::type_info::VariantInfo {
217 pyclass_name: "description",
218 fields: &[],
219 module: None,
220 doc: "",
221 form: &pyo3_stub_gen::type_info::VariantForm::Unit,
222 constr_args: &[],
223 },
224 ],
225 module: Some("my_module"),
226 doc: "",
227 }
228 "###);
229 Ok(())
230 }
231
232 fn format_as_value(tt: TokenStream2) -> String {
233 let ttt = quote! { const _: () = #tt; };
234 let formatted = prettyplease::unparse(&syn::parse_file(&ttt.to_string()).unwrap());
235 formatted
236 .trim()
237 .strip_prefix("const _: () = ")
238 .unwrap()
239 .strip_suffix(';')
240 .unwrap()
241 .to_string()
242 }
243}