pyo3_stub_gen_derive/gen_stub/
pyclass_complex_enum.rs

1use super::{extract_documents, parse_pyo3_attrs, util::quote_option, Attr, PyClassAttr, StubType};
2use crate::gen_stub::variant::VariantInfo;
3use proc_macro2::TokenStream as TokenStream2;
4use quote::{quote, ToTokens, TokenStreamExt};
5use syn::{parse_quote, Error, ItemEnum, Result, Type};
6
7pub struct PyComplexEnumInfo {
8    pyclass_name: String,
9    enum_type: Type,
10    module: Option<String>,
11    variants: Vec<VariantInfo>,
12    doc: String,
13}
14
15impl From<&PyComplexEnumInfo> for StubType {
16    fn from(info: &PyComplexEnumInfo) -> Self {
17        let PyComplexEnumInfo {
18            pyclass_name,
19            module,
20            enum_type,
21            ..
22        } = info;
23        Self {
24            ty: enum_type.clone(),
25            name: pyclass_name.clone(),
26            module: module.clone(),
27        }
28    }
29}
30
31impl PyComplexEnumInfo {
32    /// Create PyComplexEnumInfo from ItemEnum with PyClassAttr for module override support
33    pub fn from_item_with_attr(item: ItemEnum, attr: &PyClassAttr) -> Result<Self> {
34        let ItemEnum {
35            variants,
36            attrs,
37            ident,
38            ..
39        } = item;
40
41        let doc = extract_documents(&attrs).join("\n");
42        let mut pyclass_name = None;
43        let mut pyo3_module = None;
44        let mut gen_stub_standalone_module = None;
45        let mut renaming_rule = None;
46        let mut bases = Vec::new();
47        for attr_item in parse_pyo3_attrs(&attrs)? {
48            match attr_item {
49                Attr::Name(name) => pyclass_name = Some(name),
50                Attr::Module(name) => pyo3_module = Some(name),
51                Attr::GenStubModule(name) => gen_stub_standalone_module = Some(name),
52                Attr::RenameAll(name) => renaming_rule = Some(name),
53                Attr::Extends(typ) => bases.push(typ),
54                _ => {}
55            }
56        }
57
58        // Validate: inline and standalone gen_stub modules must not conflict
59        if let (Some(inline_mod), Some(standalone_mod)) =
60            (&attr.module, &gen_stub_standalone_module)
61        {
62            if inline_mod != standalone_mod {
63                return Err(Error::new(
64                    ident.span(),
65                    format!(
66                        "Conflicting module specifications: #[gen_stub_pyclass_complex_enum(module = \"{}\")] \
67                         and #[gen_stub(module = \"{}\")]. Please use only one.",
68                        inline_mod, standalone_mod
69                    ),
70                ));
71            }
72        }
73
74        // Priority: inline > standalone > pyo3 > default
75        let module = if let Some(inline_mod) = &attr.module {
76            Some(inline_mod.clone()) // Priority 1: #[gen_stub_pyclass_complex_enum(module = "...")]
77        } else if let Some(standalone_mod) = gen_stub_standalone_module {
78            Some(standalone_mod) // Priority 2: #[gen_stub(module = "...")]
79        } else {
80            pyo3_module // Priority 3: #[pyo3(module = "...")]
81        };
82
83        let enum_type = parse_quote!(#ident);
84        let pyclass_name = pyclass_name.unwrap_or_else(|| ident.clone().to_string());
85
86        let mut items = Vec::new();
87        for variant in variants {
88            items.push(VariantInfo::from_variant(variant, &renaming_rule)?)
89        }
90
91        Ok(Self {
92            doc,
93            enum_type,
94            pyclass_name,
95            module,
96            variants: items,
97        })
98    }
99}
100
101impl TryFrom<ItemEnum> for PyComplexEnumInfo {
102    type Error = Error;
103
104    fn try_from(item: ItemEnum) -> Result<Self> {
105        // Use the new method with default PyClassAttr
106        Self::from_item_with_attr(item, &PyClassAttr::default())
107    }
108}
109
110impl ToTokens for PyComplexEnumInfo {
111    fn to_tokens(&self, tokens: &mut TokenStream2) {
112        let Self {
113            pyclass_name,
114            enum_type,
115            variants,
116            doc,
117            module,
118            ..
119        } = self;
120        let module = quote_option(module);
121
122        tokens.append_all(quote! {
123            ::pyo3_stub_gen::type_info::PyComplexEnumInfo {
124                pyclass_name: #pyclass_name,
125                enum_id: std::any::TypeId::of::<#enum_type>,
126                variants: &[ #( #variants ),* ],
127                module: #module,
128                doc: #doc,
129            }
130        })
131    }
132}
133
134#[cfg(test)]
135mod test {
136    use super::*;
137    use syn::parse_str;
138
139    #[test]
140    fn test_complex_enum() -> Result<()> {
141        let input: ItemEnum = parse_str(
142            r#"
143            #[pyclass(mapping, module = "my_module", name = "Placeholder")]
144            #[derive(
145                Debug, Clone, PyNeg, PyAdd, PySub, PyMul, PyDiv, PyMod, PyPow, PyCmp, PyIndex, PyPrint,
146            )]
147            pub enum PyPlaceholder {
148                #[pyo3(name="Name")]
149                name(String),
150                #[pyo3(constructor = (_0, _1=1.0))]
151                twonum(i32,f64),
152                ndim{count: usize},
153                description,
154            }
155            "#,
156        )?;
157        let out = PyComplexEnumInfo::try_from(input)?.to_token_stream();
158        insta::assert_snapshot!(format_as_value(out), @r#"
159        ::pyo3_stub_gen::type_info::PyComplexEnumInfo {
160            pyclass_name: "Placeholder",
161            enum_id: std::any::TypeId::of::<PyPlaceholder>,
162            variants: &[
163                ::pyo3_stub_gen::type_info::VariantInfo {
164                    pyclass_name: "Name",
165                    fields: &[
166                        ::pyo3_stub_gen::type_info::MemberInfo {
167                            name: "_0",
168                            r#type: <String as ::pyo3_stub_gen::PyStubType>::type_output,
169                            doc: "",
170                            default: None,
171                            deprecated: None,
172                        },
173                    ],
174                    module: None,
175                    doc: "",
176                    form: &pyo3_stub_gen::type_info::VariantForm::Tuple,
177                    constr_args: &[
178                        ::pyo3_stub_gen::type_info::ParameterInfo {
179                            name: "_0",
180                            kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
181                            type_info: <String as ::pyo3_stub_gen::PyStubType>::type_input,
182                            default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
183                        },
184                    ],
185                },
186                ::pyo3_stub_gen::type_info::VariantInfo {
187                    pyclass_name: "twonum",
188                    fields: &[
189                        ::pyo3_stub_gen::type_info::MemberInfo {
190                            name: "_0",
191                            r#type: <i32 as ::pyo3_stub_gen::PyStubType>::type_output,
192                            doc: "",
193                            default: None,
194                            deprecated: None,
195                        },
196                        ::pyo3_stub_gen::type_info::MemberInfo {
197                            name: "_1",
198                            r#type: <f64 as ::pyo3_stub_gen::PyStubType>::type_output,
199                            doc: "",
200                            default: None,
201                            deprecated: None,
202                        },
203                    ],
204                    module: None,
205                    doc: "",
206                    form: &pyo3_stub_gen::type_info::VariantForm::Tuple,
207                    constr_args: &[
208                        ::pyo3_stub_gen::type_info::ParameterInfo {
209                            name: "_0",
210                            kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
211                            type_info: <i32 as ::pyo3_stub_gen::PyStubType>::type_input,
212                            default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
213                        },
214                        ::pyo3_stub_gen::type_info::ParameterInfo {
215                            name: "_1",
216                            kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
217                            type_info: <f64 as ::pyo3_stub_gen::PyStubType>::type_input,
218                            default: ::pyo3_stub_gen::type_info::ParameterDefault::Expr({
219                                fn _fmt() -> String {
220                                    {
221                                        let v: f64 = 1.0;
222                                        let repr = ::pyo3_stub_gen::util::fmt_py_obj(v);
223                                        let type_info = <f64 as ::pyo3_stub_gen::PyStubType>::type_output();
224                                        let should_add_prefix = if let Some(dot_pos) = type_info
225                                            .name
226                                            .rfind('.')
227                                        {
228                                            let module_prefix = &type_info.name[..dot_pos];
229                                            type_info
230                                                .import
231                                                .iter()
232                                                .any(|imp| {
233                                                    if let ::pyo3_stub_gen::ImportRef::Module(module_ref) = imp {
234                                                        if let Some(module_name) = module_ref.get() {
235                                                            module_name.ends_with(&format!(".{}", module_prefix))
236                                                        } else {
237                                                            false
238                                                        }
239                                                    } else {
240                                                        false
241                                                    }
242                                                })
243                                        } else {
244                                            false
245                                        };
246                                        if should_add_prefix {
247                                            if let Some(dot_pos) = type_info.name.rfind('.') {
248                                                let module_prefix = &type_info.name[..dot_pos];
249                                                format!("{}.{}", module_prefix, repr)
250                                            } else {
251                                                repr
252                                            }
253                                        } else {
254                                            repr
255                                        }
256                                    }
257                                }
258                                _fmt
259                            }),
260                        },
261                    ],
262                },
263                ::pyo3_stub_gen::type_info::VariantInfo {
264                    pyclass_name: "ndim",
265                    fields: &[
266                        ::pyo3_stub_gen::type_info::MemberInfo {
267                            name: "count",
268                            r#type: <usize as ::pyo3_stub_gen::PyStubType>::type_output,
269                            doc: "",
270                            default: None,
271                            deprecated: None,
272                        },
273                    ],
274                    module: None,
275                    doc: "",
276                    form: &pyo3_stub_gen::type_info::VariantForm::Struct,
277                    constr_args: &[
278                        ::pyo3_stub_gen::type_info::ParameterInfo {
279                            name: "count",
280                            kind: ::pyo3_stub_gen::type_info::ParameterKind::PositionalOrKeyword,
281                            type_info: <usize as ::pyo3_stub_gen::PyStubType>::type_input,
282                            default: ::pyo3_stub_gen::type_info::ParameterDefault::None,
283                        },
284                    ],
285                },
286                ::pyo3_stub_gen::type_info::VariantInfo {
287                    pyclass_name: "description",
288                    fields: &[],
289                    module: None,
290                    doc: "",
291                    form: &pyo3_stub_gen::type_info::VariantForm::Unit,
292                    constr_args: &[],
293                },
294            ],
295            module: Some("my_module"),
296            doc: "",
297        }
298        "#);
299        Ok(())
300    }
301
302    fn format_as_value(tt: TokenStream2) -> String {
303        let ttt = quote! { const _: () = #tt; };
304        let formatted = prettyplease::unparse(&syn::parse_file(&ttt.to_string()).unwrap());
305        formatted
306            .trim()
307            .strip_prefix("const _: () = ")
308            .unwrap()
309            .strip_suffix(';')
310            .unwrap()
311            .to_string()
312    }
313}