pyo3_stub_gen_derive/gen_stub/
variant.rs

1use crate::gen_stub::arg::ArgInfo;
2use crate::gen_stub::attr::{extract_documents, parse_pyo3_attrs, Attr};
3use crate::gen_stub::member::MemberInfo;
4use crate::gen_stub::renaming::RenamingRule;
5use crate::gen_stub::signature::{ArgsWithSignature, Signature};
6use crate::gen_stub::util::quote_option;
7use proc_macro2::{Ident, Span, TokenStream as TokenStream2};
8use quote::{quote, ToTokens, TokenStreamExt};
9use syn::spanned::Spanned;
10use syn::{Fields, Result, Variant};
11
12#[derive(Debug, Clone, Copy)]
13pub enum VariantForm {
14    Struct,
15    Tuple,
16    Unit,
17}
18
19impl ToTokens for VariantForm {
20    fn to_tokens(&self, tokens: &mut TokenStream2) {
21        let token = match self {
22            VariantForm::Struct => Ident::new("Struct", Span::call_site()),
23            VariantForm::Tuple => Ident::new("Tuple", Span::call_site()),
24            VariantForm::Unit => Ident::new("Unit", Span::call_site()),
25        };
26
27        tokens.append(token);
28    }
29}
30
31pub struct VariantInfo {
32    pyclass_name: String,
33    module: Option<String>,
34    fields: Vec<MemberInfo>,
35    doc: String,
36    form: VariantForm,
37    constr_args: Vec<ArgInfo>,
38    constr_sig: Option<Signature>,
39}
40
41impl VariantInfo {
42    pub fn from_variant(variant: Variant, renaming_rule: &Option<RenamingRule>) -> Result<Self> {
43        let Variant {
44            ident,
45            fields,
46            attrs,
47            ..
48        } = variant;
49
50        let mut pyclass_name = None;
51        let mut module = None;
52        let mut constr_sig = None;
53        for attr in parse_pyo3_attrs(&attrs)? {
54            match attr {
55                Attr::Name(name) => pyclass_name = Some(name),
56                Attr::Module(name) => {
57                    module = Some(name);
58                }
59                Attr::Constructor(sig) => {
60                    constr_sig = Some(sig);
61                }
62                _ => {}
63            }
64        }
65
66        let mut pyclass_name = pyclass_name.unwrap_or_else(|| ident.to_string());
67        if let Some(renaming_rule) = renaming_rule {
68            pyclass_name = renaming_rule.apply(&pyclass_name);
69        }
70
71        let mut members = Vec::new();
72
73        let form = match fields {
74            Fields::Unit => VariantForm::Unit,
75            Fields::Named(fields) => {
76                for field in fields.named {
77                    members.push(MemberInfo::try_from(field)?)
78                }
79                VariantForm::Struct
80            }
81            Fields::Unnamed(fields) => {
82                for (i, field) in fields.unnamed.iter().enumerate() {
83                    let mut named_field = field.clone();
84                    named_field.ident = Some(Ident::new(&format!("_{i}"), field.ident.span()));
85                    members.push(MemberInfo::try_from(named_field)?)
86                }
87                VariantForm::Tuple
88            }
89        };
90
91        let constr_args = members.iter().map(|f| f.clone().into()).collect();
92
93        let doc = extract_documents(&attrs).join("\n");
94        Ok(Self {
95            pyclass_name,
96            fields: members,
97            module,
98            doc,
99            form,
100            constr_args,
101            constr_sig,
102        })
103    }
104}
105
106impl ToTokens for VariantInfo {
107    fn to_tokens(&self, tokens: &mut TokenStream2) {
108        let Self {
109            pyclass_name,
110            fields,
111            doc,
112            module,
113            form,
114            constr_args,
115            constr_sig,
116        } = self;
117
118        let args_with_sig = ArgsWithSignature {
119            args: constr_args,
120            sig: constr_sig,
121        };
122
123        let module = quote_option(module);
124        tokens.append_all(quote! {
125            ::pyo3_stub_gen::type_info::VariantInfo {
126                pyclass_name: #pyclass_name,
127                fields: &[ #( #fields),* ],
128                module: #module,
129                doc: #doc,
130                form: &pyo3_stub_gen::type_info::VariantForm::#form,
131                constr_args: #args_with_sig,
132            }
133        })
134    }
135}