pyo3_stub_gen_derive/gen_stub/
pyclass_complex_enum.rs1use super::{extract_documents, parse_pyo3_attrs, util::quote_option, Attr, StubType};
2use crate::gen_stub::variant::VariantInfo;
3use proc_macro2::TokenStream as TokenStream2;
4use quote::{quote, ToTokens, TokenStreamExt};
5use syn::{parse_quote, Error, ItemEnum, Result, Type};
6
7pub struct PyComplexEnumInfo {
8 pyclass_name: String,
9 enum_type: Type,
10 module: Option<String>,
11 variants: Vec<VariantInfo>,
12 doc: String,
13}
14
15impl From<&PyComplexEnumInfo> for StubType {
16 fn from(info: &PyComplexEnumInfo) -> Self {
17 let PyComplexEnumInfo {
18 pyclass_name,
19 module,
20 enum_type,
21 ..
22 } = info;
23 Self {
24 ty: enum_type.clone(),
25 name: pyclass_name.clone(),
26 module: module.clone(),
27 }
28 }
29}
30
31impl TryFrom<ItemEnum> for PyComplexEnumInfo {
32 type Error = Error;
33
34 fn try_from(item: ItemEnum) -> Result<Self> {
35 let ItemEnum {
36 variants,
37 attrs,
38 ident,
39 ..
40 } = item;
41
42 let doc = extract_documents(&attrs).join("\n");
43 let mut pyclass_name = None;
44 let mut module = None;
45 let mut renaming_rule = None;
46 let mut bases = Vec::new();
47 for attr in parse_pyo3_attrs(&attrs)? {
48 match attr {
49 Attr::Name(name) => pyclass_name = Some(name),
50 Attr::Module(name) => module = Some(name),
51 Attr::RenameAll(name) => renaming_rule = Some(name),
52 Attr::Extends(typ) => bases.push(typ),
53 _ => {}
54 }
55 }
56
57 let enum_type = parse_quote!(#ident);
58 let pyclass_name = pyclass_name.unwrap_or_else(|| ident.clone().to_string());
59
60 let mut items = Vec::new();
61 for variant in variants {
62 items.push(VariantInfo::from_variant(variant, &renaming_rule)?)
63 }
64
65 Ok(Self {
66 doc,
67 enum_type,
68 pyclass_name,
69 module,
70 variants: items,
71 })
72 }
73}
74
75impl ToTokens for PyComplexEnumInfo {
76 fn to_tokens(&self, tokens: &mut TokenStream2) {
77 let Self {
78 pyclass_name,
79 enum_type,
80 variants,
81 doc,
82 module,
83 ..
84 } = self;
85 let module = quote_option(module);
86
87 tokens.append_all(quote! {
88 ::pyo3_stub_gen::type_info::PyComplexEnumInfo {
89 pyclass_name: #pyclass_name,
90 enum_id: std::any::TypeId::of::<#enum_type>,
91 variants: &[ #( #variants ),* ],
92 module: #module,
93 doc: #doc,
94 }
95 })
96 }
97}
98
99#[cfg(test)]
100mod test {
101 use super::*;
102 use syn::parse_str;
103
104 #[test]
105 fn test_complex_enum() -> Result<()> {
106 let input: ItemEnum = parse_str(
107 r#"
108 #[pyclass(mapping, module = "my_module", name = "Placeholder")]
109 #[derive(
110 Debug, Clone, PyNeg, PyAdd, PySub, PyMul, PyDiv, PyMod, PyPow, PyCmp, PyIndex, PyPrint,
111 )]
112 pub enum PyPlaceholder {
113 #[pyo3(name="Name")]
114 name(String),
115 #[pyo3(constructor = (_0, _1=1.0))]
116 twonum(i32,f64),
117 ndim{count: usize},
118 description,
119 }
120 "#,
121 )?;
122 let out = PyComplexEnumInfo::try_from(input)?.to_token_stream();
123 insta::assert_snapshot!(format_as_value(out), @r###"
124 ::pyo3_stub_gen::type_info::PyComplexEnumInfo {
125 pyclass_name: "Placeholder",
126 enum_id: std::any::TypeId::of::<PyPlaceholder>,
127 variants: &[
128 ::pyo3_stub_gen::type_info::VariantInfo {
129 pyclass_name: "Name",
130 fields: &[
131 ::pyo3_stub_gen::type_info::MemberInfo {
132 name: "_0",
133 r#type: <String as ::pyo3_stub_gen::PyStubType>::type_output,
134 doc: "",
135 default: None,
136 deprecated: None,
137 },
138 ],
139 module: None,
140 doc: "",
141 form: &pyo3_stub_gen::type_info::VariantForm::Tuple,
142 constr_args: &[
143 ::pyo3_stub_gen::type_info::ArgInfo {
144 name: "_0",
145 r#type: <String as ::pyo3_stub_gen::PyStubType>::type_input,
146 signature: None,
147 },
148 ],
149 },
150 ::pyo3_stub_gen::type_info::VariantInfo {
151 pyclass_name: "twonum",
152 fields: &[
153 ::pyo3_stub_gen::type_info::MemberInfo {
154 name: "_0",
155 r#type: <i32 as ::pyo3_stub_gen::PyStubType>::type_output,
156 doc: "",
157 default: None,
158 deprecated: None,
159 },
160 ::pyo3_stub_gen::type_info::MemberInfo {
161 name: "_1",
162 r#type: <f64 as ::pyo3_stub_gen::PyStubType>::type_output,
163 doc: "",
164 default: None,
165 deprecated: None,
166 },
167 ],
168 module: None,
169 doc: "",
170 form: &pyo3_stub_gen::type_info::VariantForm::Tuple,
171 constr_args: &[
172 ::pyo3_stub_gen::type_info::ArgInfo {
173 name: "_0",
174 r#type: <i32 as ::pyo3_stub_gen::PyStubType>::type_input,
175 signature: Some(pyo3_stub_gen::type_info::SignatureArg::Ident),
176 },
177 ::pyo3_stub_gen::type_info::ArgInfo {
178 name: "_1",
179 r#type: <f64 as ::pyo3_stub_gen::PyStubType>::type_input,
180 signature: Some(pyo3_stub_gen::type_info::SignatureArg::Assign {
181 default: {
182 static DEFAULT: std::sync::LazyLock<String> = std::sync::LazyLock::new(||
183 {
184 ::pyo3::prepare_freethreaded_python();
185 ::pyo3::Python::with_gil(|py| -> String {
186 let v: f64 = 1.0;
187 ::pyo3_stub_gen::util::fmt_py_obj(py, v)
188 })
189 });
190 &DEFAULT
191 },
192 }),
193 },
194 ],
195 },
196 ::pyo3_stub_gen::type_info::VariantInfo {
197 pyclass_name: "ndim",
198 fields: &[
199 ::pyo3_stub_gen::type_info::MemberInfo {
200 name: "count",
201 r#type: <usize as ::pyo3_stub_gen::PyStubType>::type_output,
202 doc: "",
203 default: None,
204 deprecated: None,
205 },
206 ],
207 module: None,
208 doc: "",
209 form: &pyo3_stub_gen::type_info::VariantForm::Struct,
210 constr_args: &[
211 ::pyo3_stub_gen::type_info::ArgInfo {
212 name: "count",
213 r#type: <usize as ::pyo3_stub_gen::PyStubType>::type_input,
214 signature: None,
215 },
216 ],
217 },
218 ::pyo3_stub_gen::type_info::VariantInfo {
219 pyclass_name: "description",
220 fields: &[],
221 module: None,
222 doc: "",
223 form: &pyo3_stub_gen::type_info::VariantForm::Unit,
224 constr_args: &[],
225 },
226 ],
227 module: Some("my_module"),
228 doc: "",
229 }
230 "###);
231 Ok(())
232 }
233
234 fn format_as_value(tt: TokenStream2) -> String {
235 let ttt = quote! { const _: () = #tt; };
236 let formatted = prettyplease::unparse(&syn::parse_file(&ttt.to_string()).unwrap());
237 formatted
238 .trim()
239 .strip_prefix("const _: () = ")
240 .unwrap()
241 .strip_suffix(';')
242 .unwrap()
243 .to_string()
244 }
245}