pyo3_stub_gen_derive/gen_stub/
signature.rs

1use std::collections::HashMap;
2
3use proc_macro2::TokenStream as TokenStream2;
4use quote::{quote, ToTokens, TokenStreamExt};
5use syn::{
6    parenthesized,
7    parse::{Parse, ParseStream},
8    punctuated::Punctuated,
9    token, Expr, Ident, Result, Token, Type,
10};
11
12use crate::gen_stub::remove_lifetime;
13
14use super::ArgInfo;
15
16#[derive(Debug, Clone, PartialEq)]
17enum SignatureArg {
18    Ident(Ident),
19    Assign(Ident, Token![=], Expr),
20    Star(Token![*]),
21    Args(Token![*], Ident),
22    Keywords(Token![*], Token![*], Ident),
23}
24
25impl Parse for SignatureArg {
26    fn parse(input: ParseStream) -> Result<Self> {
27        if input.peek(Token![*]) {
28            let star = input.parse()?;
29            if input.peek(Token![*]) {
30                Ok(SignatureArg::Keywords(star, input.parse()?, input.parse()?))
31            } else if input.peek(Ident) {
32                Ok(SignatureArg::Args(star, input.parse()?))
33            } else {
34                Ok(SignatureArg::Star(star))
35            }
36        } else if input.peek(Ident) {
37            let ident = Ident::parse(input)?;
38            if input.peek(Token![=]) {
39                Ok(SignatureArg::Assign(ident, input.parse()?, input.parse()?))
40            } else {
41                Ok(SignatureArg::Ident(ident))
42            }
43        } else {
44            dbg!(input);
45            todo!()
46        }
47    }
48}
49
50#[derive(Debug, Clone, PartialEq)]
51pub struct Signature {
52    paren: token::Paren,
53    args: Punctuated<SignatureArg, Token![,]>,
54}
55
56impl Parse for Signature {
57    fn parse(input: ParseStream) -> Result<Self> {
58        let content;
59        let paren = parenthesized!(content in input);
60        let args = content.parse_terminated(SignatureArg::parse, Token![,])?;
61        Ok(Self { paren, args })
62    }
63}
64
65pub struct ArgsWithSignature<'a> {
66    pub args: &'a Vec<ArgInfo>,
67    pub sig: &'a Option<Signature>,
68}
69
70impl ToTokens for ArgsWithSignature<'_> {
71    fn to_tokens(&self, tokens: &mut TokenStream2) {
72        let arg_infos_res: Result<Vec<TokenStream2>> = if let Some(sig) = self.sig {
73            // record all Type information from rust's args
74            let args_map: HashMap<String, Type> = self
75                .args
76                .iter()
77                .map(|arg| {
78                    let mut ty = arg.r#type.clone();
79                    remove_lifetime(&mut ty);
80                    (arg.name.clone(), ty)
81                })
82                .collect();
83            sig.args.iter().map(|sig_arg| match sig_arg {
84                SignatureArg::Ident(ident) => {
85                    let name = ident.to_string();
86                    if let Some(ty) = args_map.get(&name){
87                        Ok(quote! {
88                        ::pyo3_stub_gen::type_info::ArgInfo {
89                            name: #name,
90                            r#type: <#ty as ::pyo3_stub_gen::PyStubType>::type_input,
91                            signature: Some(pyo3_stub_gen::type_info::SignatureArg::Ident),
92                        }})
93                    } else {
94                        Err(syn::Error::new(ident.span(), format!("can not find argument: {ident}")))
95                    }
96                }
97                SignatureArg::Assign(ident, _eq, value) => {
98                    let name = ident.to_string();
99                    if let Some(ty) = args_map.get(&name){
100                        let default = if value.to_token_stream().to_string() == "None" {
101                            quote! {
102                            "None".to_string()
103                            }
104                        } else {
105                            quote! {
106                            ::pyo3::prepare_freethreaded_python();
107                            ::pyo3::Python::with_gil(|py| -> String {
108                                let v: #ty = #value;
109                                ::pyo3_stub_gen::util::fmt_py_obj(py, v)
110                            })
111                            }
112                        };
113                        Ok(quote! {
114                        ::pyo3_stub_gen::type_info::ArgInfo {
115                            name: #name,
116                            r#type: <#ty as ::pyo3_stub_gen::PyStubType>::type_input,
117                            signature: Some(pyo3_stub_gen::type_info::SignatureArg::Assign{
118                                default: {
119                                    static DEFAULT: std::sync::LazyLock<String> = std::sync::LazyLock::new(|| {
120                                        #default
121                                    });
122                                    &DEFAULT
123                                }
124                            }),
125                        }})
126                    } else {
127                        Err(syn::Error::new(ident.span(), format!("can not find argument: {ident}")))
128                    }
129                },
130                SignatureArg::Star(_) =>Ok(quote! {
131                    ::pyo3_stub_gen::type_info::ArgInfo {
132                        name: "",
133                        r#type: <() as ::pyo3_stub_gen::PyStubType>::type_input,
134                        signature: Some(pyo3_stub_gen::type_info::SignatureArg::Star),
135                }}),
136                SignatureArg::Args(_, ident) => {
137                    let name = ident.to_string();
138                    if let Some(ty) = args_map.get(&name){
139                        Ok(quote! {
140                        ::pyo3_stub_gen::type_info::ArgInfo {
141                            name: #name,
142                            r#type: <#ty as ::pyo3_stub_gen::PyStubType>::type_input,
143                            signature: Some(pyo3_stub_gen::type_info::SignatureArg::Args),
144                        }})
145                    } else {
146                        Err(syn::Error::new(ident.span(), format!("can not find argument: {ident}")))
147                    }
148                },
149                SignatureArg::Keywords(_, _, ident) => {
150                    let name = ident.to_string();
151                    if let Some(ty) = args_map.get(&name){
152                        Ok(quote! {
153                        ::pyo3_stub_gen::type_info::ArgInfo {
154                            name: #name,
155                            r#type: <#ty as ::pyo3_stub_gen::PyStubType>::type_input,
156                            signature: Some(pyo3_stub_gen::type_info::SignatureArg::Keywords),
157                        }})
158                    } else {
159                        Err(syn::Error::new(ident.span(), format!("can not find argument: {ident}")))
160                    }
161                }
162            }).collect()
163        } else {
164            self.args
165                .iter()
166                .map(|arg| {
167                    let mut ty = arg.r#type.clone();
168                    remove_lifetime(&mut ty);
169                    let name = &arg.name;
170                    Ok(quote! {
171                        ::pyo3_stub_gen::type_info::ArgInfo {
172                            name: #name,
173                            r#type: <#ty as ::pyo3_stub_gen::PyStubType>::type_input,
174                            signature: None,
175                        }
176                    })
177                })
178                .collect()
179        };
180        match arg_infos_res {
181            Ok(arg_infos) => tokens.append_all(quote! { &[ #(#arg_infos),* ] }),
182            Err(err) => tokens.extend(err.to_compile_error()),
183        }
184    }
185}
186
187impl Signature {
188    pub fn overriding_operator(sig: &syn::Signature) -> Option<Self> {
189        if sig.ident == "__pow__" {
190            return Some(syn::parse_str("(exponent, modulo=None)").unwrap());
191        }
192        if sig.ident == "__rpow__" {
193            return Some(syn::parse_str("(base, modulo=None)").unwrap());
194        }
195        None
196    }
197}