pyo3_stub_gen_derive/gen_stub/
signature.rs1use std::collections::HashMap;
2
3use proc_macro2::TokenStream as TokenStream2;
4use quote::{quote, ToTokens, TokenStreamExt};
5use syn::{
6 parenthesized,
7 parse::{Parse, ParseStream},
8 punctuated::Punctuated,
9 token, Expr, Ident, Result, Token,
10};
11
12use crate::gen_stub::{remove_lifetime, util::TypeOrOverride};
13
14use super::ArgInfo;
15
16#[derive(Debug, Clone, PartialEq)]
17enum SignatureArg {
18 Ident(Ident),
19 Assign(Ident, Token![=], Expr),
20 Star(Token![*]),
21 Args(Token![*], Ident),
22 Keywords(Token![*], Token![*], Ident),
23}
24
25impl Parse for SignatureArg {
26 fn parse(input: ParseStream) -> Result<Self> {
27 if input.peek(Token![*]) {
28 let star = input.parse()?;
29 if input.peek(Token![*]) {
30 Ok(SignatureArg::Keywords(star, input.parse()?, input.parse()?))
31 } else if input.peek(Ident) {
32 Ok(SignatureArg::Args(star, input.parse()?))
33 } else {
34 Ok(SignatureArg::Star(star))
35 }
36 } else if input.peek(Ident) {
37 let ident = Ident::parse(input)?;
38 if input.peek(Token![=]) {
39 Ok(SignatureArg::Assign(ident, input.parse()?, input.parse()?))
40 } else {
41 Ok(SignatureArg::Ident(ident))
42 }
43 } else {
44 dbg!(input);
45 todo!()
46 }
47 }
48}
49
50#[derive(Debug, Clone, PartialEq)]
51pub struct Signature {
52 paren: token::Paren,
53 args: Punctuated<SignatureArg, Token![,]>,
54}
55
56impl Parse for Signature {
57 fn parse(input: ParseStream) -> Result<Self> {
58 let content;
59 let paren = parenthesized!(content in input);
60 let args = content.parse_terminated(SignatureArg::parse, Token![,])?;
61 Ok(Self { paren, args })
62 }
63}
64
65pub struct ArgsWithSignature<'a> {
66 pub args: &'a Vec<ArgInfo>,
67 pub sig: &'a Option<Signature>,
68}
69
70impl ToTokens for ArgsWithSignature<'_> {
71 fn to_tokens(&self, tokens: &mut TokenStream2) {
72 let arg_infos_res: Result<Vec<TokenStream2>> = if let Some(sig) = self.sig {
73 let args_map: HashMap<String, ArgInfo> = self
75 .args
76 .iter()
77 .map(|arg| match arg {
78 ArgInfo {
79 name,
80 r#type: TypeOrOverride::RustType { r#type },
81 } => {
82 let mut ty = r#type.clone();
83 remove_lifetime(&mut ty);
84 (
85 name.clone(),
86 ArgInfo {
87 name: name.clone(),
88 r#type: TypeOrOverride::RustType { r#type: ty },
89 },
90 )
91 }
92 arg @ ArgInfo { name, .. } => (name.clone(), arg.clone()),
93 })
94 .collect();
95 sig.args.iter().map(|sig_arg| match sig_arg {
96 SignatureArg::Ident(ident) => {
97 let name = ident.to_string();
98 match args_map.get(&name) {
99 Some(ArgInfo { name, r#type: TypeOrOverride::RustType { r#type } }) => Ok(quote! {
100 ::pyo3_stub_gen::type_info::ArgInfo {
101 name: #name,
102 r#type: <#r#type as ::pyo3_stub_gen::PyStubType>::type_input,
103 signature: Some(pyo3_stub_gen::type_info::SignatureArg::Ident),
104 }}),
105 Some(ArgInfo { name, r#type: TypeOrOverride::OverrideType{ type_repr, imports, .. }}) => {
106 let imports = imports.iter().collect::<Vec<&String>>();
107 Ok(quote! {
108 ::pyo3_stub_gen::type_info::ArgInfo {
109 name: #name,
110 r#type: || ::pyo3_stub_gen::TypeInfo { name: #type_repr.to_string(), import: ::std::collections::HashSet::from([#(#imports.into(),)*]) },
111 signature: Some(pyo3_stub_gen::type_info::SignatureArg::Ident),
112 }})
113 },
114 None => Err(syn::Error::new(ident.span(), format!("can not find argument: {ident}")))
115 }
116 }
117 SignatureArg::Assign(ident, _eq, value) => {
118 let name = ident.to_string();
119
120 match args_map.get(&name) {
121 Some(ArgInfo { name, r#type: TypeOrOverride::RustType { r#type } }) => {
122 let default = if value.to_token_stream().to_string() == "None" {
123 quote! {
124 "None".to_string()
125 }
126 } else {
127 quote! {
128 let v: #r#type = #value;
129 ::pyo3_stub_gen::util::fmt_py_obj(v)
130 }
131 };
132 Ok(quote! {
133 ::pyo3_stub_gen::type_info::ArgInfo {
134 name: #name,
135 r#type: <#r#type as ::pyo3_stub_gen::PyStubType>::type_input,
136 signature: Some(pyo3_stub_gen::type_info::SignatureArg::Assign{
137 default: {
138 fn _fmt() -> String {
139 #default
140 }
141 _fmt
142 }
143 }),
144 }})
145 },
146 Some(ArgInfo { name, r#type: TypeOrOverride::OverrideType{ type_repr, imports, r#type }}) => {
147 let imports = imports.iter().collect::<Vec<&String>>();
148 let default = if value.to_token_stream().to_string() == "None" {
149 quote! {
150 "None".to_string()
151 }
152 } else {
153 quote! {
154 let v: #r#type = #value;
155 ::pyo3_stub_gen::util::fmt_py_obj(v)
156 }
157 };
158 Ok(quote! {
159 ::pyo3_stub_gen::type_info::ArgInfo {
160 name: #name,
161 r#type: || ::pyo3_stub_gen::TypeInfo { name: #type_repr.to_string(), import: ::std::collections::HashSet::from([#(#imports.into(),)*]) },
162 signature: Some(pyo3_stub_gen::type_info::SignatureArg::Assign{
163 default: {
164 fn _fmt() -> String {
165 #default
166 }
167 _fmt
168 }
169 }),
170 }})
171 },
172 None => Err(syn::Error::new(ident.span(), format!("can not find argument: {ident}")))
173 }
174 },
175 SignatureArg::Star(_) =>Ok(quote! {
176 ::pyo3_stub_gen::type_info::ArgInfo {
177 name: "",
178 r#type: <() as ::pyo3_stub_gen::PyStubType>::type_input,
179 signature: Some(pyo3_stub_gen::type_info::SignatureArg::Star),
180 }}),
181 SignatureArg::Args(_, ident) => {
182 let name = ident.to_string();
183 match args_map.get(&name) {
184 Some(ArgInfo { name, r#type: TypeOrOverride::RustType { r#type } }) => Ok(quote! {
185 ::pyo3_stub_gen::type_info::ArgInfo {
186 name: #name,
187 r#type: <#r#type as ::pyo3_stub_gen::PyStubType>::type_input,
188 signature: Some(pyo3_stub_gen::type_info::SignatureArg::Args),
189 }}),
190 Some(ArgInfo { name, r#type: TypeOrOverride::OverrideType{ type_repr, imports, .. }}) => {
191 let imports = imports.iter().collect::<Vec<&String>>();
192 Ok(quote! {
193 ::pyo3_stub_gen::type_info::ArgInfo {
194 name: #name,
195 r#type: || ::pyo3_stub_gen::TypeInfo { name: #type_repr.to_string(), import: ::std::collections::HashSet::from([#(#imports.into(),)*]) },
196 signature: Some(pyo3_stub_gen::type_info::SignatureArg::Args),
197 }})
198 },
199 None => Err(syn::Error::new(ident.span(), format!("can not find argument: {ident}")))
200 }
201 },
202 SignatureArg::Keywords(_, _, ident) => {
203 let name = ident.to_string();
204 match args_map.get(&name) {
205 Some(ArgInfo { name, r#type: TypeOrOverride::RustType { r#type } }) => Ok(quote! {
206 ::pyo3_stub_gen::type_info::ArgInfo {
207 name: #name,
208 r#type: <#r#type as ::pyo3_stub_gen::PyStubType>::type_input,
209 signature: Some(pyo3_stub_gen::type_info::SignatureArg::Keywords),
210 }}),
211 Some(ArgInfo { name, r#type: TypeOrOverride::OverrideType{ type_repr, imports, .. }}) => {
212 let imports = imports.iter().collect::<Vec<&String>>();
213 Ok(quote! {
214 ::pyo3_stub_gen::type_info::ArgInfo {
215 name: #name,
216 r#type: || ::pyo3_stub_gen::TypeInfo { name: #type_repr.to_string(), import: ::std::collections::HashSet::from([#(#imports.into(),)*]) },
217 signature: Some(pyo3_stub_gen::type_info::SignatureArg::Keywords),
218 }})
219 },
220 None => Err(syn::Error::new(ident.span(), format!("can not find argument: {ident}")))
221 }
222 }
223 }).collect()
224 } else {
225 self.args
226 .iter()
227 .map(|arg| {
228 match arg {
229 ArgInfo { name, r#type: TypeOrOverride::RustType { r#type } } => {
230 let mut ty = r#type.clone();
231 remove_lifetime(&mut ty);
232 Ok(quote! {
233 ::pyo3_stub_gen::type_info::ArgInfo {
234 name: #name,
235 r#type: <#ty as ::pyo3_stub_gen::PyStubType>::type_input,
236 signature: None,
237 }
238 })
239 }
240 ArgInfo { name, r#type: TypeOrOverride::OverrideType{ type_repr, imports, .. }} => {
241 let imports = imports.iter().collect::<Vec<&String>>();
242 Ok(quote! {
243 ::pyo3_stub_gen::type_info::ArgInfo {
244 name: #name,
245 r#type: || ::pyo3_stub_gen::TypeInfo { name: #type_repr.to_string(), import: ::std::collections::HashSet::from([#(#imports.into(),)*]) },
246 signature: None,
247 }})
248 },
249 }
250 })
251 .collect()
252 };
253 match arg_infos_res {
254 Ok(arg_infos) => tokens.append_all(quote! { &[ #(#arg_infos),* ] }),
255 Err(err) => tokens.extend(err.to_compile_error()),
256 }
257 }
258}
259
260impl Signature {
261 pub fn overriding_operator(sig: &syn::Signature) -> Option<Self> {
262 if sig.ident == "__pow__" {
263 return Some(syn::parse_str("(exponent, modulo=None)").unwrap());
264 }
265 if sig.ident == "__rpow__" {
266 return Some(syn::parse_str("(base, modulo=None)").unwrap());
267 }
268 None
269 }
270}