pyo3_stub_gen_derive/gen_stub/
signature.rs

1use std::collections::HashMap;
2
3use proc_macro2::TokenStream as TokenStream2;
4use quote::{quote, ToTokens, TokenStreamExt};
5use syn::{
6    parenthesized,
7    parse::{Parse, ParseStream},
8    punctuated::Punctuated,
9    token, Expr, Ident, Result, Token,
10};
11
12use crate::gen_stub::{remove_lifetime, util::TypeOrOverride};
13
14use super::ArgInfo;
15
16#[derive(Debug, Clone, PartialEq)]
17enum SignatureArg {
18    Ident(Ident),
19    Assign(Ident, Token![=], Expr),
20    Star(Token![*]),
21    Args(Token![*], Ident),
22    Keywords(Token![*], Token![*], Ident),
23}
24
25impl Parse for SignatureArg {
26    fn parse(input: ParseStream) -> Result<Self> {
27        if input.peek(Token![*]) {
28            let star = input.parse()?;
29            if input.peek(Token![*]) {
30                Ok(SignatureArg::Keywords(star, input.parse()?, input.parse()?))
31            } else if input.peek(Ident) {
32                Ok(SignatureArg::Args(star, input.parse()?))
33            } else {
34                Ok(SignatureArg::Star(star))
35            }
36        } else if input.peek(Ident) {
37            let ident = Ident::parse(input)?;
38            if input.peek(Token![=]) {
39                Ok(SignatureArg::Assign(ident, input.parse()?, input.parse()?))
40            } else {
41                Ok(SignatureArg::Ident(ident))
42            }
43        } else {
44            dbg!(input);
45            todo!()
46        }
47    }
48}
49
50#[derive(Debug, Clone, PartialEq)]
51pub struct Signature {
52    paren: token::Paren,
53    args: Punctuated<SignatureArg, Token![,]>,
54}
55
56impl Parse for Signature {
57    fn parse(input: ParseStream) -> Result<Self> {
58        let content;
59        let paren = parenthesized!(content in input);
60        let args = content.parse_terminated(SignatureArg::parse, Token![,])?;
61        Ok(Self { paren, args })
62    }
63}
64
65pub struct ArgsWithSignature<'a> {
66    pub args: &'a Vec<ArgInfo>,
67    pub sig: &'a Option<Signature>,
68}
69
70impl ToTokens for ArgsWithSignature<'_> {
71    fn to_tokens(&self, tokens: &mut TokenStream2) {
72        let arg_infos_res: Result<Vec<TokenStream2>> = if let Some(sig) = self.sig {
73            // record all Type information from rust's args
74            let args_map: HashMap<String, ArgInfo> = self
75                .args
76                .iter()
77                .map(|arg| match arg {
78                    ArgInfo {
79                        name,
80                        r#type: TypeOrOverride::RustType { r#type },
81                    } => {
82                        let mut ty = r#type.clone();
83                        remove_lifetime(&mut ty);
84                        (
85                            name.clone(),
86                            ArgInfo {
87                                name: name.clone(),
88                                r#type: TypeOrOverride::RustType { r#type: ty },
89                            },
90                        )
91                    }
92                    arg @ ArgInfo { name, .. } => (name.clone(), arg.clone()),
93                })
94                .collect();
95            sig.args.iter().map(|sig_arg| match sig_arg {
96                SignatureArg::Ident(ident) => {
97                    let name = ident.to_string();
98                    match args_map.get(&name) {
99                        Some(ArgInfo { name, r#type: TypeOrOverride::RustType { r#type } }) => Ok(quote! {
100                        ::pyo3_stub_gen::type_info::ArgInfo {
101                            name: #name,
102                            r#type: <#r#type as ::pyo3_stub_gen::PyStubType>::type_input,
103                            signature: Some(pyo3_stub_gen::type_info::SignatureArg::Ident),
104                        }}),
105                        Some(ArgInfo { name, r#type: TypeOrOverride::OverrideType{ type_repr, imports, .. }}) => {
106                            let imports = imports.iter().collect::<Vec<&String>>();
107                            Ok(quote! {
108                            ::pyo3_stub_gen::type_info::ArgInfo {
109                                name: #name,
110                                r#type: || ::pyo3_stub_gen::TypeInfo { name: #type_repr.to_string(), import: ::std::collections::HashSet::from([#(#imports.into(),)*]) },
111                                signature: Some(pyo3_stub_gen::type_info::SignatureArg::Ident),
112                            }})
113                        },
114                        None => Err(syn::Error::new(ident.span(), format!("can not find argument: {ident}")))
115                    }
116                }
117                SignatureArg::Assign(ident, _eq, value) => {
118                    let name = ident.to_string();
119
120                    match args_map.get(&name) {
121                        Some(ArgInfo { name, r#type: TypeOrOverride::RustType { r#type } }) => {
122                            let default = if value.to_token_stream().to_string() == "None" {
123                                quote! {
124                                "None".to_string()
125                                }
126                            } else {
127                                quote! {
128                                ::pyo3::prepare_freethreaded_python();
129                                ::pyo3::Python::with_gil(|py| -> String {
130                                    let v: #r#type = #value;
131                                    ::pyo3_stub_gen::util::fmt_py_obj(py, v)
132                                })
133                                }
134                            };
135                            Ok(quote! {
136                            ::pyo3_stub_gen::type_info::ArgInfo {
137                                name: #name,
138                                r#type: <#r#type as ::pyo3_stub_gen::PyStubType>::type_input,
139                                signature: Some(pyo3_stub_gen::type_info::SignatureArg::Assign{
140                                    default: {
141                                        static DEFAULT: std::sync::LazyLock<String> = std::sync::LazyLock::new(|| {
142                                            #default
143                                        });
144                                        &DEFAULT
145                                    }
146                                }),
147                            }})
148                        },
149                        Some(ArgInfo { name, r#type: TypeOrOverride::OverrideType{ type_repr, imports, r#type }}) => {
150                            let imports = imports.iter().collect::<Vec<&String>>();
151                            let default = if value.to_token_stream().to_string() == "None" {
152                                quote! {
153                                "None".to_string()
154                                }
155                            } else {
156                                quote! {
157                                ::pyo3::prepare_freethreaded_python();
158                                ::pyo3::Python::with_gil(|py| -> String {
159                                    let v: #r#type = #value;
160                                    ::pyo3_stub_gen::util::fmt_py_obj(py, v)
161                                })
162                                }
163                            };
164                            Ok(quote! {
165                            ::pyo3_stub_gen::type_info::ArgInfo {
166                                name: #name,
167                                r#type: || ::pyo3_stub_gen::TypeInfo { name: #type_repr.to_string(), import: ::std::collections::HashSet::from([#(#imports.into(),)*]) },
168                                signature: Some(pyo3_stub_gen::type_info::SignatureArg::Assign{
169                                    default: {
170                                        static DEFAULT: std::sync::LazyLock<String> = std::sync::LazyLock::new(|| {
171                                            #default
172                                        });
173                                        &DEFAULT
174                                    }
175                                }),
176                            }})
177                        },
178                        None => Err(syn::Error::new(ident.span(), format!("can not find argument: {ident}")))
179                    }
180                },
181                SignatureArg::Star(_) =>Ok(quote! {
182                    ::pyo3_stub_gen::type_info::ArgInfo {
183                        name: "",
184                        r#type: <() as ::pyo3_stub_gen::PyStubType>::type_input,
185                        signature: Some(pyo3_stub_gen::type_info::SignatureArg::Star),
186                }}),
187                SignatureArg::Args(_, ident) => {
188                    let name = ident.to_string();
189                    match args_map.get(&name) {
190                        Some(ArgInfo { name, r#type: TypeOrOverride::RustType { r#type } }) => Ok(quote! {
191                        ::pyo3_stub_gen::type_info::ArgInfo {
192                            name: #name,
193                            r#type: <#r#type as ::pyo3_stub_gen::PyStubType>::type_input,
194                            signature: Some(pyo3_stub_gen::type_info::SignatureArg::Ident),
195                        }}),
196                        Some(ArgInfo { name, r#type: TypeOrOverride::OverrideType{ type_repr, imports, .. }}) => {
197                            let imports = imports.iter().collect::<Vec<&String>>();
198                            Ok(quote! {
199                            ::pyo3_stub_gen::type_info::ArgInfo {
200                                name: #name,
201                                r#type: || ::pyo3_stub_gen::TypeInfo { name: #type_repr.to_string(), import: ::std::collections::HashSet::from([#(#imports.into(),)*]) },
202                                signature: Some(pyo3_stub_gen::type_info::SignatureArg::Ident),
203                            }})
204                        },
205                        None => Err(syn::Error::new(ident.span(), format!("can not find argument: {ident}")))
206                    }
207                },
208                SignatureArg::Keywords(_, _, ident) => {
209                    let name = ident.to_string();
210                    match args_map.get(&name) {
211                        Some(ArgInfo { name, r#type: TypeOrOverride::RustType { r#type } }) => Ok(quote! {
212                        ::pyo3_stub_gen::type_info::ArgInfo {
213                            name: #name,
214                            r#type: <#r#type as ::pyo3_stub_gen::PyStubType>::type_input,
215                            signature: Some(pyo3_stub_gen::type_info::SignatureArg::Ident),
216                        }}),
217                        Some(ArgInfo { name, r#type: TypeOrOverride::OverrideType{ type_repr, imports, .. }}) => {
218                            let imports = imports.iter().collect::<Vec<&String>>();
219                            Ok(quote! {
220                            ::pyo3_stub_gen::type_info::ArgInfo {
221                                name: #name,
222                                r#type: || ::pyo3_stub_gen::TypeInfo { name: #type_repr.to_string(), import: ::std::collections::HashSet::from([#(#imports.into(),)*]) },
223                                signature: Some(pyo3_stub_gen::type_info::SignatureArg::Ident),
224                            }})
225                        },
226                        None => Err(syn::Error::new(ident.span(), format!("can not find argument: {ident}")))
227                    }
228                }
229            }).collect()
230        } else {
231            self.args
232                .iter()
233                .map(|arg| {
234                    match arg {
235                        ArgInfo { name, r#type: TypeOrOverride::RustType { r#type } } => {
236                            let mut ty = r#type.clone();
237                            remove_lifetime(&mut ty);
238                            Ok(quote! {
239                                ::pyo3_stub_gen::type_info::ArgInfo {
240                                    name: #name,
241                                    r#type: <#ty as ::pyo3_stub_gen::PyStubType>::type_input,
242                                    signature: None,
243                                }
244                            })
245                        }
246                        ArgInfo { name, r#type: TypeOrOverride::OverrideType{ type_repr, imports, .. }} => {
247                            let imports = imports.iter().collect::<Vec<&String>>();
248                            Ok(quote! {
249                            ::pyo3_stub_gen::type_info::ArgInfo {
250                                name: #name,
251                                r#type: || ::pyo3_stub_gen::TypeInfo { name: #type_repr.to_string(), import: ::std::collections::HashSet::from([#(#imports.into(),)*]) },
252                                signature: None,
253                            }})
254                        },
255                    }
256                })
257                .collect()
258        };
259        match arg_infos_res {
260            Ok(arg_infos) => tokens.append_all(quote! { &[ #(#arg_infos),* ] }),
261            Err(err) => tokens.extend(err.to_compile_error()),
262        }
263    }
264}
265
266impl Signature {
267    pub fn overriding_operator(sig: &syn::Signature) -> Option<Self> {
268        if sig.ident == "__pow__" {
269            return Some(syn::parse_str("(exponent, modulo=None)").unwrap());
270        }
271        if sig.ident == "__rpow__" {
272            return Some(syn::parse_str("(base, modulo=None)").unwrap());
273        }
274        None
275    }
276}