pyo3_stub_gen_derive/gen_stub/
pyclass_complex_enum.rs

1use super::{extract_documents, parse_pyo3_attrs, util::quote_option, Attr, PyClassAttr, StubType};
2use crate::gen_stub::variant::VariantInfo;
3use proc_macro2::TokenStream as TokenStream2;
4use quote::{quote, ToTokens, TokenStreamExt};
5use syn::{parse_quote, Error, ItemEnum, Result, Type};
6
7pub struct PyComplexEnumInfo {
8    pyclass_name: String,
9    enum_type: Type,
10    module: Option<String>,
11    variants: Vec<VariantInfo>,
12    doc: String,
13}
14
15impl From<&PyComplexEnumInfo> for StubType {
16    fn from(info: &PyComplexEnumInfo) -> Self {
17        let PyComplexEnumInfo {
18            pyclass_name,
19            module,
20            enum_type,
21            ..
22        } = info;
23        Self {
24            ty: enum_type.clone(),
25            name: pyclass_name.clone(),
26            module: module.clone(),
27        }
28    }
29}
30
31impl PyComplexEnumInfo {
32    /// Create PyComplexEnumInfo from ItemEnum with PyClassAttr for module override support
33    pub fn from_item_with_attr(item: ItemEnum, attr: &PyClassAttr) -> Result<Self> {
34        let ItemEnum {
35            variants,
36            attrs,
37            ident,
38            ..
39        } = item;
40
41        let doc = extract_documents(&attrs).join("\n");
42        let mut pyclass_name = None;
43        let mut pyo3_module = None;
44        let mut gen_stub_standalone_module = None;
45        let mut renaming_rule = None;
46        let mut bases = Vec::new();
47        for attr_item in parse_pyo3_attrs(&attrs)? {
48            match attr_item {
49                Attr::Name(name) => pyclass_name = Some(name),
50                Attr::Module(name) => pyo3_module = Some(name),
51                Attr::GenStubModule(name) => gen_stub_standalone_module = Some(name),
52                Attr::RenameAll(name) => renaming_rule = Some(name),
53                Attr::Extends(typ) => bases.push(typ),
54                _ => {}
55            }
56        }
57
58        // Validate: inline and standalone gen_stub modules must not conflict
59        if let (Some(inline_mod), Some(standalone_mod)) =
60            (&attr.module, &gen_stub_standalone_module)
61        {
62            if inline_mod != standalone_mod {
63                return Err(Error::new(
64                    ident.span(),
65                    format!(
66                        "Conflicting module specifications: #[gen_stub_pyclass_complex_enum(module = \"{}\")] \
67                         and #[gen_stub(module = \"{}\")]. Please use only one.",
68                        inline_mod, standalone_mod
69                    ),
70                ));
71            }
72        }
73
74        // Priority: inline > standalone > pyo3 > default
75        let module = if let Some(inline_mod) = &attr.module {
76            Some(inline_mod.clone()) // Priority 1: #[gen_stub_pyclass_complex_enum(module = "...")]
77        } else if let Some(standalone_mod) = gen_stub_standalone_module {
78            Some(standalone_mod) // Priority 2: #[gen_stub(module = "...")]
79        } else {
80            pyo3_module // Priority 3: #[pyo3(module = "...")]
81        };
82
83        let enum_type = parse_quote!(#ident);
84        let pyclass_name = pyclass_name.unwrap_or_else(|| ident.clone().to_string());
85
86        let mut items = Vec::new();
87        for variant in variants {
88            items.push(VariantInfo::from_variant(variant, &renaming_rule)?)
89        }
90
91        Ok(Self {
92            doc,
93            enum_type,
94            pyclass_name,
95            module,
96            variants: items,
97        })
98    }
99}
100
101impl TryFrom<ItemEnum> for PyComplexEnumInfo {
102    type Error = Error;
103
104    fn try_from(item: ItemEnum) -> Result<Self> {
105        // Use the new method with default PyClassAttr
106        Self::from_item_with_attr(item, &PyClassAttr::default())
107    }
108}
109
110impl ToTokens for PyComplexEnumInfo {
111    fn to_tokens(&self, tokens: &mut TokenStream2) {
112        let Self {
113            pyclass_name,
114            enum_type,
115            variants,
116            doc,
117            module,
118            ..
119        } = self;
120        let module = quote_option(module);
121
122        tokens.append_all(quote! {
123            ::pyo3_stub_gen::type_info::PyComplexEnumInfo {
124                pyclass_name: #pyclass_name,
125                enum_id: std::any::TypeId::of::<#enum_type>,
126                variants: &[ #( #variants ),* ],
127                module: #module,
128                doc: #doc,
129            }
130        })
131    }
132}
133
134#[cfg(test)]
135mod test {
136    use super::*;
137    use syn::parse_str;
138
139    #[test]
140    fn test_complex_enum() -> Result<()> {
141        let input: ItemEnum = parse_str(
142            r#"
143            #[pyclass(mapping, module = "my_module", name = "Placeholder")]
144            #[derive(
145                Debug, Clone, PyNeg, PyAdd, PySub, PyMul, PyDiv, PyMod, PyPow, PyCmp, PyIndex, PyPrint,
146            )]
147            pub enum PyPlaceholder {
148                #[pyo3(name="Name")]
149                name(String),
150                #[pyo3(constructor = (_0, _1=1.0))]
151                twonum(i32,f64),
152                ndim{count: usize},
153                description,
154            }
155            "#,
156        )?;
157        let out = PyComplexEnumInfo::try_from(input)?.to_token_stream();
158        insta::assert_snapshot!(format_as_value(out), @r###"
159        ::pyo3_stub_gen::type_info::PyComplexEnumInfo {
160            pyclass_name: "Placeholder",
161            enum_id: std::any::TypeId::of::<PyPlaceholder>,
162            variants: &[
163                ::pyo3_stub_gen::type_info::VariantInfo {
164                    pyclass_name: "Name",
165                    fields: &[
166                        ::pyo3_stub_gen::type_info::MemberInfo {
167                            name: "_0",
168                            r#type: <String as ::pyo3_stub_gen::PyStubType>::type_output,
169                            doc: "",
170                            default: None,
171                            deprecated: None,
172                        },
173                    ],
174                    module: None,
175                    doc: "",
176                    form: &pyo3_stub_gen::type_info::VariantForm::Tuple,
177                    constr_args: &[
178                        ::pyo3_stub_gen::type_info::ParameterInfo {
179                            name: "_0",
180                            kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
181                            type_info: <String as ::pyo3_stub_gen::PyStubType>::type_input,
182                            default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
183                        },
184                    ],
185                },
186                ::pyo3_stub_gen::type_info::VariantInfo {
187                    pyclass_name: "twonum",
188                    fields: &[
189                        ::pyo3_stub_gen::type_info::MemberInfo {
190                            name: "_0",
191                            r#type: <i32 as ::pyo3_stub_gen::PyStubType>::type_output,
192                            doc: "",
193                            default: None,
194                            deprecated: None,
195                        },
196                        ::pyo3_stub_gen::type_info::MemberInfo {
197                            name: "_1",
198                            r#type: <f64 as ::pyo3_stub_gen::PyStubType>::type_output,
199                            doc: "",
200                            default: None,
201                            deprecated: None,
202                        },
203                    ],
204                    module: None,
205                    doc: "",
206                    form: &pyo3_stub_gen::type_info::VariantForm::Tuple,
207                    constr_args: &[
208                        ::pyo3_stub_gen::type_info::ParameterInfo {
209                            name: "_0",
210                            kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
211                            type_info: <i32 as ::pyo3_stub_gen::PyStubType>::type_input,
212                            default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
213                        },
214                        ::pyo3_stub_gen::type_info::ParameterInfo {
215                            name: "_1",
216                            kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
217                            type_info: <f64 as ::pyo3_stub_gen::PyStubType>::type_input,
218                            default: ::pyo3_stub_gen::type_info::ParameterDefault::Expr {
219                                value: {
220                                    fn _fmt() -> String {
221                                        {
222                                            let v: f64 = 1.0;
223                                            ::pyo3_stub_gen::util::fmt_py_obj(v)
224                                        }
225                                    }
226                                    _fmt
227                                },
228                                source_module: Some({
229                                    fn _get_module() -> Option<::pyo3_stub_gen::ModuleRef> {
230                                        <f64 as ::pyo3_stub_gen::PyStubType>::type_output()
231                                            .source_module
232                                    }
233                                    _get_module
234                                }),
235                            },
236                        },
237                    ],
238                },
239                ::pyo3_stub_gen::type_info::VariantInfo {
240                    pyclass_name: "ndim",
241                    fields: &[
242                        ::pyo3_stub_gen::type_info::MemberInfo {
243                            name: "count",
244                            r#type: <usize as ::pyo3_stub_gen::PyStubType>::type_output,
245                            doc: "",
246                            default: None,
247                            deprecated: None,
248                        },
249                    ],
250                    module: None,
251                    doc: "",
252                    form: &pyo3_stub_gen::type_info::VariantForm::Struct,
253                    constr_args: &[
254                        ::pyo3_stub_gen::type_info::ParameterInfo {
255                            name: "count",
256                            kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
257                            type_info: <usize as ::pyo3_stub_gen::PyStubType>::type_input,
258                            default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
259                        },
260                    ],
261                },
262                ::pyo3_stub_gen::type_info::VariantInfo {
263                    pyclass_name: "description",
264                    fields: &[],
265                    module: None,
266                    doc: "",
267                    form: &pyo3_stub_gen::type_info::VariantForm::Unit,
268                    constr_args: &[],
269                },
270            ],
271            module: Some("my_module"),
272            doc: "",
273        }
274        "###);
275        Ok(())
276    }
277
278    fn format_as_value(tt: TokenStream2) -> String {
279        let ttt = quote! { const _: () = #tt; };
280        let formatted = prettyplease::unparse(&syn::parse_file(&ttt.to_string()).unwrap());
281        formatted
282            .trim()
283            .strip_prefix("const _: () = ")
284            .unwrap()
285            .strip_suffix(';')
286            .unwrap()
287            .to_string()
288    }
289}