pyo3_stub_gen_derive/gen_stub/
pyclass_complex_enum.rs

1use super::{extract_documents, parse_pyo3_attrs, util::quote_option, Attr, StubType};
2use crate::gen_stub::variant::VariantInfo;
3use proc_macro2::TokenStream as TokenStream2;
4use quote::{quote, ToTokens, TokenStreamExt};
5use syn::{parse_quote, Error, ItemEnum, Result, Type};
6
7pub struct PyComplexEnumInfo {
8    pyclass_name: String,
9    enum_type: Type,
10    module: Option<String>,
11    variants: Vec<VariantInfo>,
12    doc: String,
13}
14
15impl From<&PyComplexEnumInfo> for StubType {
16    fn from(info: &PyComplexEnumInfo) -> Self {
17        let PyComplexEnumInfo {
18            pyclass_name,
19            module,
20            enum_type,
21            ..
22        } = info;
23        Self {
24            ty: enum_type.clone(),
25            name: pyclass_name.clone(),
26            module: module.clone(),
27        }
28    }
29}
30
31impl TryFrom<ItemEnum> for PyComplexEnumInfo {
32    type Error = Error;
33
34    fn try_from(item: ItemEnum) -> Result<Self> {
35        let ItemEnum {
36            variants,
37            attrs,
38            ident,
39            ..
40        } = item;
41
42        let doc = extract_documents(&attrs).join("\n");
43        let mut pyclass_name = None;
44        let mut module = None;
45        let mut renaming_rule = None;
46        let mut bases = Vec::new();
47        for attr in parse_pyo3_attrs(&attrs)? {
48            match attr {
49                Attr::Name(name) => pyclass_name = Some(name),
50                Attr::Module(name) => module = Some(name),
51                Attr::RenameAll(name) => renaming_rule = Some(name),
52                Attr::Extends(typ) => bases.push(typ),
53                _ => {}
54            }
55        }
56
57        let enum_type = parse_quote!(#ident);
58        let pyclass_name = pyclass_name.unwrap_or_else(|| ident.clone().to_string());
59
60        let mut items = Vec::new();
61        for variant in variants {
62            items.push(VariantInfo::from_variant(variant, &renaming_rule)?)
63        }
64
65        Ok(Self {
66            doc,
67            enum_type,
68            pyclass_name,
69            module,
70            variants: items,
71        })
72    }
73}
74
75impl ToTokens for PyComplexEnumInfo {
76    fn to_tokens(&self, tokens: &mut TokenStream2) {
77        let Self {
78            pyclass_name,
79            enum_type,
80            variants,
81            doc,
82            module,
83            ..
84        } = self;
85        let module = quote_option(module);
86
87        tokens.append_all(quote! {
88            ::pyo3_stub_gen::type_info::PyComplexEnumInfo {
89                pyclass_name: #pyclass_name,
90                enum_id: std::any::TypeId::of::<#enum_type>,
91                variants: &[ #( #variants ),* ],
92                module: #module,
93                doc: #doc,
94            }
95        })
96    }
97}
98
99#[cfg(test)]
100mod test {
101    use super::*;
102    use syn::parse_str;
103
104    #[test]
105    fn test_complex_enum() -> Result<()> {
106        let input: ItemEnum = parse_str(
107            r#"
108            #[pyclass(mapping, module = "my_module", name = "Placeholder")]
109            #[derive(
110                Debug, Clone, PyNeg, PyAdd, PySub, PyMul, PyDiv, PyMod, PyPow, PyCmp, PyIndex, PyPrint,
111            )]
112            pub enum PyPlaceholder {
113                #[pyo3(name="Name")]
114                name(String),
115                #[pyo3(constructor = (_0, _1=1.0))]
116                twonum(i32,f64),
117                ndim{count: usize},
118                description,
119            }
120            "#,
121        )?;
122        let out = PyComplexEnumInfo::try_from(input)?.to_token_stream();
123        insta::assert_snapshot!(format_as_value(out), @r###"
124        ::pyo3_stub_gen::type_info::PyComplexEnumInfo {
125            pyclass_name: "Placeholder",
126            enum_id: std::any::TypeId::of::<PyPlaceholder>,
127            variants: &[
128                ::pyo3_stub_gen::type_info::VariantInfo {
129                    pyclass_name: "Name",
130                    fields: &[
131                        ::pyo3_stub_gen::type_info::MemberInfo {
132                            name: "_0",
133                            r#type: <String as ::pyo3_stub_gen::PyStubType>::type_output,
134                            doc: "",
135                            default: None,
136                            deprecated: None,
137                        },
138                    ],
139                    module: None,
140                    doc: "",
141                    form: &pyo3_stub_gen::type_info::VariantForm::Tuple,
142                    constr_args: &[
143                        ::pyo3_stub_gen::type_info::ParameterInfo {
144                            name: "_0",
145                            kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
146                            type_info: <String as ::pyo3_stub_gen::PyStubType>::type_input,
147                            default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
148                        },
149                    ],
150                },
151                ::pyo3_stub_gen::type_info::VariantInfo {
152                    pyclass_name: "twonum",
153                    fields: &[
154                        ::pyo3_stub_gen::type_info::MemberInfo {
155                            name: "_0",
156                            r#type: <i32 as ::pyo3_stub_gen::PyStubType>::type_output,
157                            doc: "",
158                            default: None,
159                            deprecated: None,
160                        },
161                        ::pyo3_stub_gen::type_info::MemberInfo {
162                            name: "_1",
163                            r#type: <f64 as ::pyo3_stub_gen::PyStubType>::type_output,
164                            doc: "",
165                            default: None,
166                            deprecated: None,
167                        },
168                    ],
169                    module: None,
170                    doc: "",
171                    form: &pyo3_stub_gen::type_info::VariantForm::Tuple,
172                    constr_args: &[
173                        ::pyo3_stub_gen::type_info::ParameterInfo {
174                            name: "_0",
175                            kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
176                            type_info: <i32 as ::pyo3_stub_gen::PyStubType>::type_input,
177                            default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
178                        },
179                        ::pyo3_stub_gen::type_info::ParameterInfo {
180                            name: "_1",
181                            kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
182                            type_info: <f64 as ::pyo3_stub_gen::PyStubType>::type_input,
183                            default: ::pyo3_stub_gen::type_info::ParameterDefault::Expr({
184                                fn _fmt() -> String {
185                                    let v: f64 = 1.0;
186                                    ::pyo3_stub_gen::util::fmt_py_obj(v)
187                                }
188                                _fmt
189                            }),
190                        },
191                    ],
192                },
193                ::pyo3_stub_gen::type_info::VariantInfo {
194                    pyclass_name: "ndim",
195                    fields: &[
196                        ::pyo3_stub_gen::type_info::MemberInfo {
197                            name: "count",
198                            r#type: <usize as ::pyo3_stub_gen::PyStubType>::type_output,
199                            doc: "",
200                            default: None,
201                            deprecated: None,
202                        },
203                    ],
204                    module: None,
205                    doc: "",
206                    form: &pyo3_stub_gen::type_info::VariantForm::Struct,
207                    constr_args: &[
208                        ::pyo3_stub_gen::type_info::ParameterInfo {
209                            name: "count",
210                            kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
211                            type_info: <usize as ::pyo3_stub_gen::PyStubType>::type_input,
212                            default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
213                        },
214                    ],
215                },
216                ::pyo3_stub_gen::type_info::VariantInfo {
217                    pyclass_name: "description",
218                    fields: &[],
219                    module: None,
220                    doc: "",
221                    form: &pyo3_stub_gen::type_info::VariantForm::Unit,
222                    constr_args: &[],
223                },
224            ],
225            module: Some("my_module"),
226            doc: "",
227        }
228        "###);
229        Ok(())
230    }
231
232    fn format_as_value(tt: TokenStream2) -> String {
233        let ttt = quote! { const _: () = #tt; };
234        let formatted = prettyplease::unparse(&syn::parse_file(&ttt.to_string()).unwrap());
235        formatted
236            .trim()
237            .strip_prefix("const _: () = ")
238            .unwrap()
239            .strip_suffix(';')
240            .unwrap()
241            .to_string()
242    }
243}