pyo3_stub_gen_derive/gen_stub/
signature.rs

1use std::collections::HashMap;
2
3use proc_macro2::TokenStream as TokenStream2;
4use quote::{quote, ToTokens, TokenStreamExt};
5use syn::{
6    parenthesized,
7    parse::{Parse, ParseStream},
8    punctuated::Punctuated,
9    token, Expr, Ident, Result, Token, Type,
10};
11
12use crate::gen_stub::remove_lifetime;
13
14use super::ArgInfo;
15
16#[derive(Debug, Clone, PartialEq)]
17enum SignatureArg {
18    Ident(Ident),
19    Assign(Ident, Token![=], Expr),
20    Star(Token![*]),
21    Args(Token![*], Ident),
22    Keywords(Token![*], Token![*], Ident),
23}
24
25impl Parse for SignatureArg {
26    fn parse(input: ParseStream) -> Result<Self> {
27        if input.peek(Token![*]) {
28            let star = input.parse()?;
29            if input.peek(Token![*]) {
30                Ok(SignatureArg::Keywords(star, input.parse()?, input.parse()?))
31            } else if input.peek(Ident) {
32                Ok(SignatureArg::Args(star, input.parse()?))
33            } else {
34                Ok(SignatureArg::Star(star))
35            }
36        } else if input.peek(Ident) {
37            let ident = Ident::parse(input)?;
38            if input.peek(Token![=]) {
39                Ok(SignatureArg::Assign(ident, input.parse()?, input.parse()?))
40            } else {
41                Ok(SignatureArg::Ident(ident))
42            }
43        } else {
44            dbg!(input);
45            todo!()
46        }
47    }
48}
49
50#[derive(Debug, Clone, PartialEq)]
51pub struct Signature {
52    paren: token::Paren,
53    args: Punctuated<SignatureArg, Token![,]>,
54}
55
56impl Parse for Signature {
57    fn parse(input: ParseStream) -> Result<Self> {
58        let content;
59        let paren = parenthesized!(content in input);
60        let args = content.parse_terminated(SignatureArg::parse, Token![,])?;
61        Ok(Self { paren, args })
62    }
63}
64
65pub struct ArgsWithSignature<'a> {
66    pub args: &'a Vec<ArgInfo>,
67    pub sig: &'a Option<Signature>,
68}
69
70impl ToTokens for ArgsWithSignature<'_> {
71    fn to_tokens(&self, tokens: &mut TokenStream2) {
72        let arg_infos: Vec<TokenStream2> = if let Some(sig) = self.sig {
73            // record all Type information from rust's args
74            let args_map: HashMap<String, Type> = self
75                .args
76                .iter()
77                .map(|arg| {
78                    let mut ty = arg.r#type.clone();
79                    remove_lifetime(&mut ty);
80                    (arg.name.clone(), ty)
81                })
82                .collect();
83            sig.args.iter().map(|sig_arg| match sig_arg {
84                SignatureArg::Ident(ident) => {
85                    let name = ident.to_string();
86                    let ty = args_map.get(&name).unwrap();
87                    quote! {
88                        ::pyo3_stub_gen::type_info::ArgInfo {
89                            name: #name,
90                            r#type: <#ty as ::pyo3_stub_gen::PyStubType>::type_input,
91                            signature: Some(pyo3_stub_gen::type_info::SignatureArg::Ident),
92                        }
93                    }
94                }
95                SignatureArg::Assign(ident, _eq, value) => {
96                    let name = ident.to_string();
97                    let ty = args_map.get(&name).unwrap();
98                    let default = if value.to_token_stream().to_string() == "None" {
99                        quote! {
100                            "None".to_string()
101                        }
102                    } else {
103                        quote! {
104                            ::pyo3::prepare_freethreaded_python();
105                            ::pyo3::Python::with_gil(|py| -> String {
106                                let v: #ty = #value;
107                                ::pyo3_stub_gen::util::fmt_py_obj(py, v)
108                            })
109                        }
110                    };
111                    quote! {
112                        ::pyo3_stub_gen::type_info::ArgInfo {
113                            name: #name,
114                            r#type: <#ty as ::pyo3_stub_gen::PyStubType>::type_input,
115                            signature: Some(pyo3_stub_gen::type_info::SignatureArg::Assign{
116                                default: {
117                                    static DEFAULT: std::sync::LazyLock<String> = std::sync::LazyLock::new(|| {
118                                        #default
119                                    });
120                                    &DEFAULT
121                                }
122                            }),
123                        }
124                    }
125                },
126                SignatureArg::Star(_) => quote! {
127                    ::pyo3_stub_gen::type_info::ArgInfo {
128                        name: "",
129                        r#type: <() as ::pyo3_stub_gen::PyStubType>::type_input,
130                        signature: Some(pyo3_stub_gen::type_info::SignatureArg::Star),
131                    }
132                },
133                SignatureArg::Args(_, ident) => {
134                    let name = ident.to_string();
135                    let ty = args_map.get(&name).unwrap();
136                    quote! {
137                        ::pyo3_stub_gen::type_info::ArgInfo {
138                            name: #name,
139                            r#type: <#ty as ::pyo3_stub_gen::PyStubType>::type_input,
140                            signature: Some(pyo3_stub_gen::type_info::SignatureArg::Args),
141                        }
142                    }
143                },
144                SignatureArg::Keywords(_, _, ident) => {
145                    let name = ident.to_string();
146                    let ty = args_map.get(&name).unwrap();
147                    quote! {
148                        ::pyo3_stub_gen::type_info::ArgInfo {
149                            name: #name,
150                            r#type: <#ty as ::pyo3_stub_gen::PyStubType>::type_input,
151                            signature: Some(pyo3_stub_gen::type_info::SignatureArg::Keywords),
152                        }
153                    }
154                }
155            }).collect()
156        } else {
157            self.args
158                .iter()
159                .map(|arg| {
160                    let mut ty = arg.r#type.clone();
161                    remove_lifetime(&mut ty);
162                    let name = &arg.name;
163                    quote! {
164                        ::pyo3_stub_gen::type_info::ArgInfo {
165                            name: #name,
166                            r#type: <#ty as ::pyo3_stub_gen::PyStubType>::type_input,
167                            signature: None,
168                        }
169                    }
170                })
171                .collect()
172        };
173        tokens.append_all(quote! { &[ #(#arg_infos),* ] });
174    }
175}
176
177impl Signature {
178    pub fn overriding_operator(sig: &syn::Signature) -> Option<Self> {
179        if sig.ident == "__pow__" {
180            return Some(syn::parse_str("(exponent, modulo=None)").unwrap());
181        }
182        if sig.ident == "__rpow__" {
183            return Some(syn::parse_str("(base, modulo=None)").unwrap());
184        }
185        None
186    }
187}