pyo3_stub_gen_derive/gen_stub/
pyclass_complex_enum.rs

1use super::{extract_documents, parse_pyo3_attrs, util::quote_option, Attr, StubType};
2use crate::gen_stub::variant::VariantInfo;
3use proc_macro2::TokenStream as TokenStream2;
4use quote::{quote, ToTokens, TokenStreamExt};
5use syn::{parse_quote, Error, ItemEnum, Result, Type};
6
7pub struct PyComplexEnumInfo {
8    pyclass_name: String,
9    enum_type: Type,
10    module: Option<String>,
11    variants: Vec<VariantInfo>,
12    doc: String,
13}
14
15impl From<&PyComplexEnumInfo> for StubType {
16    fn from(info: &PyComplexEnumInfo) -> Self {
17        let PyComplexEnumInfo {
18            pyclass_name,
19            module,
20            enum_type,
21            ..
22        } = info;
23        Self {
24            ty: enum_type.clone(),
25            name: pyclass_name.clone(),
26            module: module.clone(),
27        }
28    }
29}
30
31impl TryFrom<ItemEnum> for PyComplexEnumInfo {
32    type Error = Error;
33
34    fn try_from(item: ItemEnum) -> Result<Self> {
35        let ItemEnum {
36            variants,
37            attrs,
38            ident,
39            ..
40        } = item;
41
42        let doc = extract_documents(&attrs).join("\n");
43        let mut pyclass_name = None;
44        let mut module = None;
45        let mut renaming_rule = None;
46        let mut bases = Vec::new();
47        for attr in parse_pyo3_attrs(&attrs)? {
48            match attr {
49                Attr::Name(name) => pyclass_name = Some(name),
50                Attr::Module(name) => module = Some(name),
51                Attr::RenameAll(name) => renaming_rule = Some(name),
52                Attr::Extends(typ) => bases.push(typ),
53                _ => {}
54            }
55        }
56
57        let enum_type = parse_quote!(#ident);
58        let pyclass_name = pyclass_name.unwrap_or_else(|| ident.clone().to_string());
59
60        let mut items = Vec::new();
61        for variant in variants {
62            items.push(VariantInfo::from_variant(variant, &renaming_rule)?)
63        }
64
65        Ok(Self {
66            doc,
67            enum_type,
68            pyclass_name,
69            module,
70            variants: items,
71        })
72    }
73}
74
75impl ToTokens for PyComplexEnumInfo {
76    fn to_tokens(&self, tokens: &mut TokenStream2) {
77        let Self {
78            pyclass_name,
79            enum_type,
80            variants,
81            doc,
82            module,
83            ..
84        } = self;
85        let module = quote_option(module);
86
87        tokens.append_all(quote! {
88            ::pyo3_stub_gen::type_info::PyComplexEnumInfo {
89                pyclass_name: #pyclass_name,
90                enum_id: std::any::TypeId::of::<#enum_type>,
91                variants: &[ #( #variants ),* ],
92                module: #module,
93                doc: #doc,
94            }
95        })
96    }
97}
98
99#[cfg(test)]
100mod test {
101    use super::*;
102    use syn::parse_str;
103
104    #[test]
105    fn test_complex_enum() -> Result<()> {
106        let input: ItemEnum = parse_str(
107            r#"
108            #[pyclass(mapping, module = "my_module", name = "Placeholder")]
109            #[derive(
110                Debug, Clone, PyNeg, PyAdd, PySub, PyMul, PyDiv, PyMod, PyPow, PyCmp, PyIndex, PyPrint,
111            )]
112            pub enum PyPlaceholder {
113                #[pyo3(name="Name")]
114                name(String),
115                #[pyo3(constructor = (_0, _1=1.0))]
116                twonum(i32,f64),
117                ndim{count: usize},
118                description,
119            }
120            "#,
121        )?;
122        let out = PyComplexEnumInfo::try_from(input)?.to_token_stream();
123        insta::assert_snapshot!(format_as_value(out), @r###"
124        ::pyo3_stub_gen::type_info::PyComplexEnumInfo {
125            pyclass_name: "Placeholder",
126            enum_id: std::any::TypeId::of::<PyPlaceholder>,
127            variants: &[
128                ::pyo3_stub_gen::type_info::VariantInfo {
129                    pyclass_name: "Name",
130                    fields: &[
131                        ::pyo3_stub_gen::type_info::MemberInfo {
132                            name: "_0",
133                            r#type: <String as ::pyo3_stub_gen::PyStubType>::type_output,
134                            doc: "",
135                            default: None,
136                            deprecated: None,
137                        },
138                    ],
139                    module: None,
140                    doc: "",
141                    form: &pyo3_stub_gen::type_info::VariantForm::Tuple,
142                    constr_args: &[
143                        ::pyo3_stub_gen::type_info::ArgInfo {
144                            name: "_0",
145                            r#type: <String as ::pyo3_stub_gen::PyStubType>::type_input,
146                            signature: None,
147                        },
148                    ],
149                },
150                ::pyo3_stub_gen::type_info::VariantInfo {
151                    pyclass_name: "twonum",
152                    fields: &[
153                        ::pyo3_stub_gen::type_info::MemberInfo {
154                            name: "_0",
155                            r#type: <i32 as ::pyo3_stub_gen::PyStubType>::type_output,
156                            doc: "",
157                            default: None,
158                            deprecated: None,
159                        },
160                        ::pyo3_stub_gen::type_info::MemberInfo {
161                            name: "_1",
162                            r#type: <f64 as ::pyo3_stub_gen::PyStubType>::type_output,
163                            doc: "",
164                            default: None,
165                            deprecated: None,
166                        },
167                    ],
168                    module: None,
169                    doc: "",
170                    form: &pyo3_stub_gen::type_info::VariantForm::Tuple,
171                    constr_args: &[
172                        ::pyo3_stub_gen::type_info::ArgInfo {
173                            name: "_0",
174                            r#type: <i32 as ::pyo3_stub_gen::PyStubType>::type_input,
175                            signature: Some(pyo3_stub_gen::type_info::SignatureArg::Ident),
176                        },
177                        ::pyo3_stub_gen::type_info::ArgInfo {
178                            name: "_1",
179                            r#type: <f64 as ::pyo3_stub_gen::PyStubType>::type_input,
180                            signature: Some(pyo3_stub_gen::type_info::SignatureArg::Assign {
181                                default: {
182                                    static DEFAULT: std::sync::LazyLock<String> = std::sync::LazyLock::new(||
183                                    {
184                                        ::pyo3::prepare_freethreaded_python();
185                                        ::pyo3::Python::with_gil(|py| -> String {
186                                            let v: f64 = 1.0;
187                                            ::pyo3_stub_gen::util::fmt_py_obj(py, v)
188                                        })
189                                    });
190                                    &DEFAULT
191                                },
192                            }),
193                        },
194                    ],
195                },
196                ::pyo3_stub_gen::type_info::VariantInfo {
197                    pyclass_name: "ndim",
198                    fields: &[
199                        ::pyo3_stub_gen::type_info::MemberInfo {
200                            name: "count",
201                            r#type: <usize as ::pyo3_stub_gen::PyStubType>::type_output,
202                            doc: "",
203                            default: None,
204                            deprecated: None,
205                        },
206                    ],
207                    module: None,
208                    doc: "",
209                    form: &pyo3_stub_gen::type_info::VariantForm::Struct,
210                    constr_args: &[
211                        ::pyo3_stub_gen::type_info::ArgInfo {
212                            name: "count",
213                            r#type: <usize as ::pyo3_stub_gen::PyStubType>::type_input,
214                            signature: None,
215                        },
216                    ],
217                },
218                ::pyo3_stub_gen::type_info::VariantInfo {
219                    pyclass_name: "description",
220                    fields: &[],
221                    module: None,
222                    doc: "",
223                    form: &pyo3_stub_gen::type_info::VariantForm::Unit,
224                    constr_args: &[],
225                },
226            ],
227            module: Some("my_module"),
228            doc: "",
229        }
230        "###);
231        Ok(())
232    }
233
234    fn format_as_value(tt: TokenStream2) -> String {
235        let ttt = quote! { const _: () = #tt; };
236        let formatted = prettyplease::unparse(&syn::parse_file(&ttt.to_string()).unwrap());
237        formatted
238            .trim()
239            .strip_prefix("const _: () = ")
240            .unwrap()
241            .strip_suffix(';')
242            .unwrap()
243            .to_string()
244    }
245}