pyo3_stub_gen_derive/gen_stub/
pyclass_complex_enum.rs

1use super::{extract_documents, parse_pyo3_attrs, util::quote_option, Attr, StubType};
2use crate::gen_stub::variant::VariantInfo;
3use proc_macro2::TokenStream as TokenStream2;
4use quote::{quote, ToTokens, TokenStreamExt};
5use syn::{parse_quote, Error, ItemEnum, Result, Type};
6
7pub struct PyComplexEnumInfo {
8    pyclass_name: String,
9    enum_type: Type,
10    module: Option<String>,
11    variants: Vec<VariantInfo>,
12    doc: String,
13}
14
15impl From<&PyComplexEnumInfo> for StubType {
16    fn from(info: &PyComplexEnumInfo) -> Self {
17        let PyComplexEnumInfo {
18            pyclass_name,
19            module,
20            enum_type,
21            ..
22        } = info;
23        Self {
24            ty: enum_type.clone(),
25            name: pyclass_name.clone(),
26            module: module.clone(),
27        }
28    }
29}
30
31impl TryFrom<ItemEnum> for PyComplexEnumInfo {
32    type Error = Error;
33
34    fn try_from(item: ItemEnum) -> Result<Self> {
35        let ItemEnum {
36            variants,
37            attrs,
38            ident,
39            ..
40        } = item;
41
42        let doc = extract_documents(&attrs).join("\n");
43        let mut pyclass_name = None;
44        let mut module = None;
45        let mut renaming_rule = None;
46        let mut bases = Vec::new();
47        for attr in parse_pyo3_attrs(&attrs)? {
48            match attr {
49                Attr::Name(name) => pyclass_name = Some(name),
50                Attr::Module(name) => module = Some(name),
51                Attr::RenameAll(name) => renaming_rule = Some(name),
52                Attr::Extends(typ) => bases.push(typ),
53                _ => {}
54            }
55        }
56
57        let enum_type = parse_quote!(#ident);
58        let pyclass_name = pyclass_name.unwrap_or_else(|| ident.clone().to_string());
59
60        let mut items = Vec::new();
61        for variant in variants {
62            items.push(VariantInfo::from_variant(variant, &renaming_rule)?)
63        }
64
65        Ok(Self {
66            doc,
67            enum_type,
68            pyclass_name,
69            module,
70            variants: items,
71        })
72    }
73}
74
75impl ToTokens for PyComplexEnumInfo {
76    fn to_tokens(&self, tokens: &mut TokenStream2) {
77        let Self {
78            pyclass_name,
79            enum_type,
80            variants,
81            doc,
82            module,
83            ..
84        } = self;
85        let module = quote_option(module);
86
87        tokens.append_all(quote! {
88            ::pyo3_stub_gen::type_info::PyComplexEnumInfo {
89                pyclass_name: #pyclass_name,
90                enum_id: std::any::TypeId::of::<#enum_type>,
91                variants: &[ #( #variants ),* ],
92                module: #module,
93                doc: #doc,
94            }
95        })
96    }
97}
98
99#[cfg(test)]
100mod test {
101    use super::*;
102    use syn::parse_str;
103
104    #[test]
105    fn test_complex_enum() -> Result<()> {
106        let input: ItemEnum = parse_str(
107            r#"
108            #[pyclass(mapping, module = "my_module", name = "Placeholder")]
109            #[derive(
110                Debug, Clone, PyNeg, PyAdd, PySub, PyMul, PyDiv, PyMod, PyPow, PyCmp, PyIndex, PyPrint,
111            )]
112            pub enum PyPlaceholder {
113                #[pyo3(name="Name")]
114                name(String),
115                #[pyo3(constructor = (_0, _1=1.0))]
116                twonum(i32,f64),
117                ndim{count: usize},
118                description,
119            }
120            "#,
121        )?;
122        let out = PyComplexEnumInfo::try_from(input)?.to_token_stream();
123        insta::assert_snapshot!(format_as_value(out), @r###"
124        ::pyo3_stub_gen::type_info::PyComplexEnumInfo {
125            pyclass_name: "Placeholder",
126            enum_id: std::any::TypeId::of::<PyPlaceholder>,
127            variants: &[
128                ::pyo3_stub_gen::type_info::VariantInfo {
129                    pyclass_name: "Name",
130                    fields: &[
131                        ::pyo3_stub_gen::type_info::MemberInfo {
132                            name: "_0",
133                            r#type: <String as ::pyo3_stub_gen::PyStubType>::type_output,
134                            doc: "",
135                            default: None,
136                            deprecated: None,
137                        },
138                    ],
139                    module: None,
140                    doc: "",
141                    form: &pyo3_stub_gen::type_info::VariantForm::Tuple,
142                    constr_args: &[
143                        ::pyo3_stub_gen::type_info::ArgInfo {
144                            name: "_0",
145                            r#type: <String as ::pyo3_stub_gen::PyStubType>::type_input,
146                            signature: None,
147                        },
148                    ],
149                },
150                ::pyo3_stub_gen::type_info::VariantInfo {
151                    pyclass_name: "twonum",
152                    fields: &[
153                        ::pyo3_stub_gen::type_info::MemberInfo {
154                            name: "_0",
155                            r#type: <i32 as ::pyo3_stub_gen::PyStubType>::type_output,
156                            doc: "",
157                            default: None,
158                            deprecated: None,
159                        },
160                        ::pyo3_stub_gen::type_info::MemberInfo {
161                            name: "_1",
162                            r#type: <f64 as ::pyo3_stub_gen::PyStubType>::type_output,
163                            doc: "",
164                            default: None,
165                            deprecated: None,
166                        },
167                    ],
168                    module: None,
169                    doc: "",
170                    form: &pyo3_stub_gen::type_info::VariantForm::Tuple,
171                    constr_args: &[
172                        ::pyo3_stub_gen::type_info::ArgInfo {
173                            name: "_0",
174                            r#type: <i32 as ::pyo3_stub_gen::PyStubType>::type_input,
175                            signature: Some(pyo3_stub_gen::type_info::SignatureArg::Ident),
176                        },
177                        ::pyo3_stub_gen::type_info::ArgInfo {
178                            name: "_1",
179                            r#type: <f64 as ::pyo3_stub_gen::PyStubType>::type_input,
180                            signature: Some(pyo3_stub_gen::type_info::SignatureArg::Assign {
181                                default: {
182                                    fn _fmt() -> String {
183                                        let v: f64 = 1.0;
184                                        ::pyo3_stub_gen::util::fmt_py_obj(v)
185                                    }
186                                    _fmt
187                                },
188                            }),
189                        },
190                    ],
191                },
192                ::pyo3_stub_gen::type_info::VariantInfo {
193                    pyclass_name: "ndim",
194                    fields: &[
195                        ::pyo3_stub_gen::type_info::MemberInfo {
196                            name: "count",
197                            r#type: <usize as ::pyo3_stub_gen::PyStubType>::type_output,
198                            doc: "",
199                            default: None,
200                            deprecated: None,
201                        },
202                    ],
203                    module: None,
204                    doc: "",
205                    form: &pyo3_stub_gen::type_info::VariantForm::Struct,
206                    constr_args: &[
207                        ::pyo3_stub_gen::type_info::ArgInfo {
208                            name: "count",
209                            r#type: <usize as ::pyo3_stub_gen::PyStubType>::type_input,
210                            signature: None,
211                        },
212                    ],
213                },
214                ::pyo3_stub_gen::type_info::VariantInfo {
215                    pyclass_name: "description",
216                    fields: &[],
217                    module: None,
218                    doc: "",
219                    form: &pyo3_stub_gen::type_info::VariantForm::Unit,
220                    constr_args: &[],
221                },
222            ],
223            module: Some("my_module"),
224            doc: "",
225        }
226        "###);
227        Ok(())
228    }
229
230    fn format_as_value(tt: TokenStream2) -> String {
231        let ttt = quote! { const _: () = #tt; };
232        let formatted = prettyplease::unparse(&syn::parse_file(&ttt.to_string()).unwrap());
233        formatted
234            .trim()
235            .strip_prefix("const _: () = ")
236            .unwrap()
237            .strip_suffix(';')
238            .unwrap()
239            .to_string()
240    }
241}