pyo3_stub_gen_derive/gen_stub/
variant.rs

1use crate::gen_stub::arg::ArgInfo;
2use crate::gen_stub::attr::{extract_documents, parse_pyo3_attrs, Attr};
3use crate::gen_stub::member::MemberInfo;
4use crate::gen_stub::parameter::Parameters;
5use crate::gen_stub::renaming::RenamingRule;
6use crate::gen_stub::signature::Signature;
7use crate::gen_stub::util::quote_option;
8use proc_macro2::{Ident, Span, TokenStream as TokenStream2};
9use quote::{quote, ToTokens, TokenStreamExt};
10use syn::spanned::Spanned;
11use syn::{Fields, Result, Variant};
12
13#[derive(Debug, Clone, Copy)]
14pub enum VariantForm {
15    Struct,
16    Tuple,
17    Unit,
18}
19
20impl ToTokens for VariantForm {
21    fn to_tokens(&self, tokens: &mut TokenStream2) {
22        let token = match self {
23            VariantForm::Struct => Ident::new("Struct", Span::call_site()),
24            VariantForm::Tuple => Ident::new("Tuple", Span::call_site()),
25            VariantForm::Unit => Ident::new("Unit", Span::call_site()),
26        };
27
28        tokens.append(token);
29    }
30}
31
32pub struct VariantInfo {
33    pyclass_name: String,
34    module: Option<String>,
35    fields: Vec<MemberInfo>,
36    doc: String,
37    form: VariantForm,
38    constr_args: Vec<ArgInfo>,
39    constr_sig: Option<Signature>,
40}
41
42impl VariantInfo {
43    pub fn from_variant(variant: Variant, renaming_rule: &Option<RenamingRule>) -> Result<Self> {
44        let Variant {
45            ident,
46            fields,
47            attrs,
48            ..
49        } = variant;
50
51        let mut pyclass_name = None;
52        let mut module = None;
53        let mut constr_sig = None;
54        for attr in parse_pyo3_attrs(&attrs)? {
55            match attr {
56                Attr::Name(name) => pyclass_name = Some(name),
57                Attr::Module(name) => {
58                    module = Some(name);
59                }
60                Attr::Constructor(sig) => {
61                    constr_sig = Some(sig);
62                }
63                _ => {}
64            }
65        }
66
67        let mut pyclass_name = pyclass_name.unwrap_or_else(|| ident.to_string());
68        if let Some(renaming_rule) = renaming_rule {
69            pyclass_name = renaming_rule.apply(&pyclass_name);
70        }
71
72        let mut members = Vec::new();
73
74        let form = match fields {
75            Fields::Unit => VariantForm::Unit,
76            Fields::Named(fields) => {
77                for field in fields.named {
78                    members.push(MemberInfo::try_from(field)?)
79                }
80                VariantForm::Struct
81            }
82            Fields::Unnamed(fields) => {
83                for (i, field) in fields.unnamed.iter().enumerate() {
84                    let mut named_field = field.clone();
85                    named_field.ident = Some(Ident::new(&format!("_{i}"), field.ident.span()));
86                    members.push(MemberInfo::try_from(named_field)?)
87                }
88                VariantForm::Tuple
89            }
90        };
91
92        let constr_args = members.iter().map(|f| f.clone().into()).collect();
93
94        let doc = extract_documents(&attrs).join("\n");
95        Ok(Self {
96            pyclass_name,
97            fields: members,
98            module,
99            doc,
100            form,
101            constr_args,
102            constr_sig,
103        })
104    }
105}
106
107impl ToTokens for VariantInfo {
108    fn to_tokens(&self, tokens: &mut TokenStream2) {
109        let Self {
110            pyclass_name,
111            fields,
112            doc,
113            module,
114            form,
115            constr_args,
116            constr_sig,
117        } = self;
118
119        let parameters = if let Some(sig) = constr_sig {
120            match Parameters::new_with_sig(constr_args, sig) {
121                Ok(params) => params,
122                Err(err) => {
123                    tokens.extend(err.to_compile_error());
124                    return;
125                }
126            }
127        } else {
128            Parameters::new(constr_args)
129        };
130
131        let module = quote_option(module);
132        tokens.append_all(quote! {
133            ::pyo3_stub_gen::type_info::VariantInfo {
134                pyclass_name: #pyclass_name,
135                fields: &[ #( #fields),* ],
136                module: #module,
137                doc: #doc,
138                form: &pyo3_stub_gen::type_info::VariantForm::#form,
139                constr_args: #parameters,
140            }
141        })
142    }
143}