pyo3_stub_gen_derive/gen_stub/
pyclass_enum.rs

1use proc_macro2::TokenStream as TokenStream2;
2use quote::{quote, ToTokens, TokenStreamExt};
3use syn::{parse_quote, Error, ItemEnum, Result, Type};
4
5use super::{extract_documents, parse_pyo3_attrs, util::quote_option, Attr, PyClassAttr, StubType};
6
7pub struct PyEnumInfo {
8    pyclass_name: String,
9    enum_type: Type,
10    module: Option<String>,
11    variants: Vec<(String, String)>,
12    doc: String,
13}
14
15impl From<&PyEnumInfo> for StubType {
16    fn from(info: &PyEnumInfo) -> Self {
17        let PyEnumInfo {
18            pyclass_name,
19            module,
20            enum_type,
21            ..
22        } = info;
23        Self {
24            ty: enum_type.clone(),
25            name: pyclass_name.clone(),
26            module: module.clone(),
27        }
28    }
29}
30
31impl PyEnumInfo {
32    /// Create PyEnumInfo from ItemEnum with PyClassAttr for module override support
33    pub fn from_item_with_attr(
34        ItemEnum {
35            variants,
36            attrs,
37            ident,
38            ..
39        }: ItemEnum,
40        attr: &PyClassAttr,
41    ) -> Result<Self> {
42        let doc = extract_documents(&attrs).join("\n");
43        let mut pyclass_name = None;
44        let mut pyo3_module = None;
45        let mut gen_stub_standalone_module = None;
46        let mut renaming_rule = None;
47        for attr_item in parse_pyo3_attrs(&attrs)? {
48            match attr_item {
49                Attr::Name(name) => pyclass_name = Some(name),
50                Attr::Module(name) => pyo3_module = Some(name),
51                Attr::GenStubModule(name) => gen_stub_standalone_module = Some(name),
52                Attr::RenameAll(name) => renaming_rule = Some(name),
53                _ => {}
54            }
55        }
56
57        // Validate: inline and standalone gen_stub modules must not conflict
58        if let (Some(inline_mod), Some(standalone_mod)) =
59            (&attr.module, &gen_stub_standalone_module)
60        {
61            if inline_mod != standalone_mod {
62                return Err(Error::new(
63                    ident.span(),
64                    format!(
65                        "Conflicting module specifications: #[gen_stub_pyclass_enum(module = \"{}\")] \
66                         and #[gen_stub(module = \"{}\")]. Please use only one.",
67                        inline_mod, standalone_mod
68                    ),
69                ));
70            }
71        }
72
73        // Priority: inline > standalone > pyo3 > default
74        let module = if let Some(inline_mod) = &attr.module {
75            Some(inline_mod.clone()) // Priority 1: #[gen_stub_pyclass_enum(module = "...")]
76        } else if let Some(standalone_mod) = gen_stub_standalone_module {
77            Some(standalone_mod) // Priority 2: #[gen_stub(module = "...")]
78        } else {
79            pyo3_module // Priority 3: #[pyo3(module = "...")]
80        };
81
82        let struct_type = parse_quote!(#ident);
83        let pyclass_name = pyclass_name.unwrap_or_else(|| ident.to_string());
84        let variants = variants
85            .into_iter()
86            .map(|var| -> Result<(String, String)> {
87                let mut var_name = None;
88                for attr in parse_pyo3_attrs(&var.attrs)? {
89                    if let Attr::Name(name) = attr {
90                        var_name = Some(name);
91                    }
92                }
93                let mut var_name = var_name.unwrap_or_else(|| var.ident.to_string());
94                if let Some(renaming_rule) = renaming_rule {
95                    var_name = renaming_rule.apply(&var_name);
96                }
97                let var_doc = extract_documents(&var.attrs).join("\n");
98                Ok((var_name, var_doc))
99            })
100            .collect::<Result<Vec<(String, String)>>>()?;
101        Ok(Self {
102            doc,
103            enum_type: struct_type,
104            pyclass_name,
105            module,
106            variants,
107        })
108    }
109}
110
111impl TryFrom<ItemEnum> for PyEnumInfo {
112    type Error = Error;
113    fn try_from(item: ItemEnum) -> Result<Self> {
114        // Use the new method with default PyClassAttr
115        Self::from_item_with_attr(item, &PyClassAttr::default())
116    }
117}
118
119impl ToTokens for PyEnumInfo {
120    fn to_tokens(&self, tokens: &mut TokenStream2) {
121        let Self {
122            pyclass_name,
123            enum_type,
124            variants,
125            doc,
126            module,
127        } = self;
128        let module = quote_option(module);
129        let variants: Vec<_> = variants
130            .iter()
131            .map(|(name, doc)| quote! {(#name,#doc)})
132            .collect();
133        tokens.append_all(quote! {
134            ::pyo3_stub_gen::type_info::PyEnumInfo {
135                pyclass_name: #pyclass_name,
136                enum_id: std::any::TypeId::of::<#enum_type>,
137                variants: &[ #(#variants),* ],
138                module: #module,
139                doc: #doc,
140            }
141        })
142    }
143}