pyo3_stub_gen_derive/gen_stub/
signature.rs1use std::collections::HashMap;
2
3use proc_macro2::TokenStream as TokenStream2;
4use quote::{quote, ToTokens, TokenStreamExt};
5use syn::{
6 parenthesized,
7 parse::{Parse, ParseStream},
8 punctuated::Punctuated,
9 token, Expr, Ident, Result, Token,
10};
11
12use crate::gen_stub::{remove_lifetime, util::TypeOrOverride};
13
14use super::ArgInfo;
15
16#[derive(Debug, Clone, PartialEq)]
17enum SignatureArg {
18 Ident(Ident),
19 Assign(Ident, Token![=], Expr),
20 Star(Token![*]),
21 Args(Token![*], Ident),
22 Keywords(Token![*], Token![*], Ident),
23}
24
25impl Parse for SignatureArg {
26 fn parse(input: ParseStream) -> Result<Self> {
27 if input.peek(Token![*]) {
28 let star = input.parse()?;
29 if input.peek(Token![*]) {
30 Ok(SignatureArg::Keywords(star, input.parse()?, input.parse()?))
31 } else if input.peek(Ident) {
32 Ok(SignatureArg::Args(star, input.parse()?))
33 } else {
34 Ok(SignatureArg::Star(star))
35 }
36 } else if input.peek(Ident) {
37 let ident = Ident::parse(input)?;
38 if input.peek(Token![=]) {
39 Ok(SignatureArg::Assign(ident, input.parse()?, input.parse()?))
40 } else {
41 Ok(SignatureArg::Ident(ident))
42 }
43 } else {
44 dbg!(input);
45 todo!()
46 }
47 }
48}
49
50#[derive(Debug, Clone, PartialEq)]
51pub struct Signature {
52 paren: token::Paren,
53 args: Punctuated<SignatureArg, Token![,]>,
54}
55
56impl Parse for Signature {
57 fn parse(input: ParseStream) -> Result<Self> {
58 let content;
59 let paren = parenthesized!(content in input);
60 let args = content.parse_terminated(SignatureArg::parse, Token![,])?;
61 Ok(Self { paren, args })
62 }
63}
64
65pub struct ArgsWithSignature<'a> {
66 pub args: &'a Vec<ArgInfo>,
67 pub sig: &'a Option<Signature>,
68}
69
70impl ToTokens for ArgsWithSignature<'_> {
71 fn to_tokens(&self, tokens: &mut TokenStream2) {
72 let arg_infos_res: Result<Vec<TokenStream2>> = if let Some(sig) = self.sig {
73 let args_map: HashMap<String, ArgInfo> = self
75 .args
76 .iter()
77 .map(|arg| match arg {
78 ArgInfo {
79 name,
80 r#type: TypeOrOverride::RustType { r#type },
81 } => {
82 let mut ty = r#type.clone();
83 remove_lifetime(&mut ty);
84 (
85 name.clone(),
86 ArgInfo {
87 name: name.clone(),
88 r#type: TypeOrOverride::RustType { r#type: ty },
89 },
90 )
91 }
92 arg @ ArgInfo { name, .. } => (name.clone(), arg.clone()),
93 })
94 .collect();
95 sig.args.iter().map(|sig_arg| match sig_arg {
96 SignatureArg::Ident(ident) => {
97 let name = ident.to_string();
98 match args_map.get(&name) {
99 Some(ArgInfo { name, r#type: TypeOrOverride::RustType { r#type } }) => Ok(quote! {
100 ::pyo3_stub_gen::type_info::ArgInfo {
101 name: #name,
102 r#type: <#r#type as ::pyo3_stub_gen::PyStubType>::type_input,
103 signature: Some(pyo3_stub_gen::type_info::SignatureArg::Ident),
104 }}),
105 Some(ArgInfo { name, r#type: TypeOrOverride::OverrideType{ type_repr, imports, .. }}) => {
106 let imports = imports.iter().collect::<Vec<&String>>();
107 Ok(quote! {
108 ::pyo3_stub_gen::type_info::ArgInfo {
109 name: #name,
110 r#type: || ::pyo3_stub_gen::TypeInfo { name: #type_repr.to_string(), import: ::std::collections::HashSet::from([#(#imports.into(),)*]) },
111 signature: Some(pyo3_stub_gen::type_info::SignatureArg::Ident),
112 }})
113 },
114 None => Err(syn::Error::new(ident.span(), format!("can not find argument: {ident}")))
115 }
116 }
117 SignatureArg::Assign(ident, _eq, value) => {
118 let name = ident.to_string();
119
120 match args_map.get(&name) {
121 Some(ArgInfo { name, r#type: TypeOrOverride::RustType { r#type } }) => {
122 let default = if value.to_token_stream().to_string() == "None" {
123 quote! {
124 "None".to_string()
125 }
126 } else {
127 quote! {
128 ::pyo3::prepare_freethreaded_python();
129 ::pyo3::Python::with_gil(|py| -> String {
130 let v: #r#type = #value;
131 ::pyo3_stub_gen::util::fmt_py_obj(py, v)
132 })
133 }
134 };
135 Ok(quote! {
136 ::pyo3_stub_gen::type_info::ArgInfo {
137 name: #name,
138 r#type: <#r#type as ::pyo3_stub_gen::PyStubType>::type_input,
139 signature: Some(pyo3_stub_gen::type_info::SignatureArg::Assign{
140 default: {
141 static DEFAULT: std::sync::LazyLock<String> = std::sync::LazyLock::new(|| {
142 #default
143 });
144 &DEFAULT
145 }
146 }),
147 }})
148 },
149 Some(ArgInfo { name, r#type: TypeOrOverride::OverrideType{ type_repr, imports, r#type }}) => {
150 let imports = imports.iter().collect::<Vec<&String>>();
151 let default = if value.to_token_stream().to_string() == "None" {
152 quote! {
153 "None".to_string()
154 }
155 } else {
156 quote! {
157 ::pyo3::prepare_freethreaded_python();
158 ::pyo3::Python::with_gil(|py| -> String {
159 let v: #r#type = #value;
160 ::pyo3_stub_gen::util::fmt_py_obj(py, v)
161 })
162 }
163 };
164 Ok(quote! {
165 ::pyo3_stub_gen::type_info::ArgInfo {
166 name: #name,
167 r#type: || ::pyo3_stub_gen::TypeInfo { name: #type_repr.to_string(), import: ::std::collections::HashSet::from([#(#imports.into(),)*]) },
168 signature: Some(pyo3_stub_gen::type_info::SignatureArg::Assign{
169 default: {
170 static DEFAULT: std::sync::LazyLock<String> = std::sync::LazyLock::new(|| {
171 #default
172 });
173 &DEFAULT
174 }
175 }),
176 }})
177 },
178 None => Err(syn::Error::new(ident.span(), format!("can not find argument: {ident}")))
179 }
180 },
181 SignatureArg::Star(_) =>Ok(quote! {
182 ::pyo3_stub_gen::type_info::ArgInfo {
183 name: "",
184 r#type: <() as ::pyo3_stub_gen::PyStubType>::type_input,
185 signature: Some(pyo3_stub_gen::type_info::SignatureArg::Star),
186 }}),
187 SignatureArg::Args(_, ident) => {
188 let name = ident.to_string();
189 match args_map.get(&name) {
190 Some(ArgInfo { name, r#type: TypeOrOverride::RustType { r#type } }) => Ok(quote! {
191 ::pyo3_stub_gen::type_info::ArgInfo {
192 name: #name,
193 r#type: <#r#type as ::pyo3_stub_gen::PyStubType>::type_input,
194 signature: Some(pyo3_stub_gen::type_info::SignatureArg::Ident),
195 }}),
196 Some(ArgInfo { name, r#type: TypeOrOverride::OverrideType{ type_repr, imports, .. }}) => {
197 let imports = imports.iter().collect::<Vec<&String>>();
198 Ok(quote! {
199 ::pyo3_stub_gen::type_info::ArgInfo {
200 name: #name,
201 r#type: || ::pyo3_stub_gen::TypeInfo { name: #type_repr.to_string(), import: ::std::collections::HashSet::from([#(#imports.into(),)*]) },
202 signature: Some(pyo3_stub_gen::type_info::SignatureArg::Ident),
203 }})
204 },
205 None => Err(syn::Error::new(ident.span(), format!("can not find argument: {ident}")))
206 }
207 },
208 SignatureArg::Keywords(_, _, ident) => {
209 let name = ident.to_string();
210 match args_map.get(&name) {
211 Some(ArgInfo { name, r#type: TypeOrOverride::RustType { r#type } }) => Ok(quote! {
212 ::pyo3_stub_gen::type_info::ArgInfo {
213 name: #name,
214 r#type: <#r#type as ::pyo3_stub_gen::PyStubType>::type_input,
215 signature: Some(pyo3_stub_gen::type_info::SignatureArg::Ident),
216 }}),
217 Some(ArgInfo { name, r#type: TypeOrOverride::OverrideType{ type_repr, imports, .. }}) => {
218 let imports = imports.iter().collect::<Vec<&String>>();
219 Ok(quote! {
220 ::pyo3_stub_gen::type_info::ArgInfo {
221 name: #name,
222 r#type: || ::pyo3_stub_gen::TypeInfo { name: #type_repr.to_string(), import: ::std::collections::HashSet::from([#(#imports.into(),)*]) },
223 signature: Some(pyo3_stub_gen::type_info::SignatureArg::Ident),
224 }})
225 },
226 None => Err(syn::Error::new(ident.span(), format!("can not find argument: {ident}")))
227 }
228 }
229 }).collect()
230 } else {
231 self.args
232 .iter()
233 .map(|arg| {
234 match arg {
235 ArgInfo { name, r#type: TypeOrOverride::RustType { r#type } } => {
236 let mut ty = r#type.clone();
237 remove_lifetime(&mut ty);
238 Ok(quote! {
239 ::pyo3_stub_gen::type_info::ArgInfo {
240 name: #name,
241 r#type: <#ty as ::pyo3_stub_gen::PyStubType>::type_input,
242 signature: None,
243 }
244 })
245 }
246 ArgInfo { name, r#type: TypeOrOverride::OverrideType{ type_repr, imports, .. }} => {
247 let imports = imports.iter().collect::<Vec<&String>>();
248 Ok(quote! {
249 ::pyo3_stub_gen::type_info::ArgInfo {
250 name: #name,
251 r#type: || ::pyo3_stub_gen::TypeInfo { name: #type_repr.to_string(), import: ::std::collections::HashSet::from([#(#imports.into(),)*]) },
252 signature: None,
253 }})
254 },
255 }
256 })
257 .collect()
258 };
259 match arg_infos_res {
260 Ok(arg_infos) => tokens.append_all(quote! { &[ #(#arg_infos),* ] }),
261 Err(err) => tokens.extend(err.to_compile_error()),
262 }
263 }
264}
265
266impl Signature {
267 pub fn overriding_operator(sig: &syn::Signature) -> Option<Self> {
268 if sig.ident == "__pow__" {
269 return Some(syn::parse_str("(exponent, modulo=None)").unwrap());
270 }
271 if sig.ident == "__rpow__" {
272 return Some(syn::parse_str("(base, modulo=None)").unwrap());
273 }
274 None
275 }
276}