1use rustpython_parser::{ast, Parse};
4use syn::{Error, LitStr, Result};
5
6use super::{
7 dedent, extract_args, extract_deprecated_from_decorators, extract_docstring,
8 extract_return_type,
9};
10use crate::gen_stub::pyfunction::PyFunctionInfo;
11
12pub struct PythonFunctionStub {
14 pub func_def: ast::StmtFunctionDef,
15 pub imports: Vec<String>,
16 pub is_async: bool,
17}
18
19impl TryFrom<PythonFunctionStub> for PyFunctionInfo {
20 type Error = syn::Error;
21
22 fn try_from(stub: PythonFunctionStub) -> Result<Self> {
23 let func_name = stub.func_def.name.to_string();
24
25 let doc = extract_docstring(&stub.func_def);
27
28 let args = extract_args(&stub.func_def.args, &stub.imports)?;
30
31 let return_type = extract_return_type(&stub.func_def.returns, &stub.imports)?;
33
34 let deprecated = extract_deprecated_from_decorators(&stub.func_def.decorator_list);
36
37 Ok(PyFunctionInfo {
42 name: func_name,
43 args,
44 r#return: return_type,
45 sig: None,
46 doc,
47 module: None,
48 is_async: stub.is_async,
49 deprecated,
50 type_ignored: None,
51 })
52 }
53}
54
55pub fn parse_python_function_stub(input: LitStr) -> Result<PyFunctionInfo> {
57 let stub_content = input.value();
58
59 let dedented_content = dedent(&stub_content);
61
62 let parsed = ast::Suite::parse(&dedented_content, "<stub>")
64 .map_err(|e| Error::new(input.span(), format!("Failed to parse Python stub: {}", e)))?;
65
66 let mut imports = Vec::new();
68 let mut function: Option<(ast::StmtFunctionDef, bool)> = None;
69
70 for stmt in parsed {
71 match stmt {
72 ast::Stmt::Import(import_stmt) => {
73 for alias in &import_stmt.names {
74 imports.push(alias.name.to_string());
75 }
76 }
77 ast::Stmt::ImportFrom(import_from_stmt) => {
78 if let Some(module) = &import_from_stmt.module {
79 imports.push(module.to_string());
80 }
81 }
82 ast::Stmt::FunctionDef(func_def) => {
83 if function.is_some() {
84 return Err(Error::new(
85 input.span(),
86 "Multiple function definitions found. Only one function is allowed per gen_function_from_python! call",
87 ));
88 }
89 function = Some((func_def, false));
90 }
91 ast::Stmt::AsyncFunctionDef(func_def) => {
92 if function.is_some() {
93 return Err(Error::new(
94 input.span(),
95 "Multiple function definitions found. Only one function is allowed per gen_function_from_python! call",
96 ));
97 }
98 let sync_func = ast::StmtFunctionDef {
100 range: func_def.range,
101 name: func_def.name,
102 type_params: func_def.type_params,
103 args: func_def.args,
104 body: func_def.body,
105 decorator_list: func_def.decorator_list,
106 returns: func_def.returns,
107 type_comment: func_def.type_comment,
108 };
109 function = Some((sync_func, true));
110 }
111 _ => {
112 }
114 }
115 }
116
117 let (func_def, is_async) = function
119 .ok_or_else(|| Error::new(input.span(), "No function definition found in Python stub"))?;
120
121 let stub = PythonFunctionStub {
123 func_def,
124 imports,
125 is_async,
126 };
127 PyFunctionInfo::try_from(stub)
128}
129
130#[cfg(test)]
131mod test {
132 use super::*;
133 use proc_macro2::TokenStream as TokenStream2;
134 use quote::{quote, ToTokens};
135
136 #[test]
137 fn test_basic_function() -> Result<()> {
138 let stub_str: LitStr = syn::parse2(quote! {
139 r#"
140 def foo(x: int) -> int:
141 """A simple function"""
142 "#
143 })?;
144 let info = parse_python_function_stub(stub_str)?;
145 let out = info.to_token_stream();
146 insta::assert_snapshot!(format_as_value(out), @r###"
147 ::pyo3_stub_gen::type_info::PyFunctionInfo {
148 name: "foo",
149 args: &[
150 ::pyo3_stub_gen::type_info::ArgInfo {
151 name: "x",
152 r#type: || ::pyo3_stub_gen::TypeInfo {
153 name: "int".to_string(),
154 import: ::std::collections::HashSet::from([]),
155 },
156 signature: None,
157 },
158 ],
159 r#return: || ::pyo3_stub_gen::TypeInfo {
160 name: "int".to_string(),
161 import: ::std::collections::HashSet::from([]),
162 },
163 doc: "A simple function",
164 module: None,
165 is_async: false,
166 deprecated: None,
167 type_ignored: None,
168 }
169 "###);
170 Ok(())
171 }
172
173 #[test]
174 fn test_function_with_imports() -> Result<()> {
175 let stub_str: LitStr = syn::parse2(quote! {
176 r#"
177 import typing
178 from collections.abc import Callable
179
180 def process(func: Callable[[str], int]) -> typing.Optional[int]:
181 """Process a callback function"""
182 "#
183 })?;
184 let info = parse_python_function_stub(stub_str)?;
185 let out = info.to_token_stream();
186 insta::assert_snapshot!(format_as_value(out), @r###"
187 ::pyo3_stub_gen::type_info::PyFunctionInfo {
188 name: "process",
189 args: &[
190 ::pyo3_stub_gen::type_info::ArgInfo {
191 name: "func",
192 r#type: || ::pyo3_stub_gen::TypeInfo {
193 name: "Callable[[str], int]".to_string(),
194 import: ::std::collections::HashSet::from([
195 "typing".into(),
196 "collections.abc".into(),
197 ]),
198 },
199 signature: None,
200 },
201 ],
202 r#return: || ::pyo3_stub_gen::TypeInfo {
203 name: "typing.Optional[int]".to_string(),
204 import: ::std::collections::HashSet::from([
205 "typing".into(),
206 "collections.abc".into(),
207 ]),
208 },
209 doc: "Process a callback function",
210 module: None,
211 is_async: false,
212 deprecated: None,
213 type_ignored: None,
214 }
215 "###);
216 Ok(())
217 }
218
219 #[test]
220 fn test_complex_types() -> Result<()> {
221 let stub_str: LitStr = syn::parse2(quote! {
222 r#"
223 import collections.abc
224 import typing
225
226 def fn_override_type(cb: collections.abc.Callable[[str], typing.Any]) -> collections.abc.Callable[[str], typing.Any]:
227 """Example function with complex types"""
228 "#
229 })?;
230 let info = parse_python_function_stub(stub_str)?;
231 let out = info.to_token_stream();
232 insta::assert_snapshot!(format_as_value(out), @r###"
233 ::pyo3_stub_gen::type_info::PyFunctionInfo {
234 name: "fn_override_type",
235 args: &[
236 ::pyo3_stub_gen::type_info::ArgInfo {
237 name: "cb",
238 r#type: || ::pyo3_stub_gen::TypeInfo {
239 name: "collections.abc.Callable[[str], typing.Any]".to_string(),
240 import: ::std::collections::HashSet::from([
241 "collections.abc".into(),
242 "typing".into(),
243 ]),
244 },
245 signature: None,
246 },
247 ],
248 r#return: || ::pyo3_stub_gen::TypeInfo {
249 name: "collections.abc.Callable[[str], typing.Any]".to_string(),
250 import: ::std::collections::HashSet::from([
251 "collections.abc".into(),
252 "typing".into(),
253 ]),
254 },
255 doc: "Example function with complex types",
256 module: None,
257 is_async: false,
258 deprecated: None,
259 type_ignored: None,
260 }
261 "###);
262 Ok(())
263 }
264
265 #[test]
266 fn test_multiple_args() -> Result<()> {
267 let stub_str: LitStr = syn::parse2(quote! {
268 r#"
269 import typing
270
271 def add(a: int, b: int, c: typing.Optional[int]) -> int: ...
272 "#
273 })?;
274 let info = parse_python_function_stub(stub_str)?;
275 let out = info.to_token_stream();
276 insta::assert_snapshot!(format_as_value(out), @r###"
277 ::pyo3_stub_gen::type_info::PyFunctionInfo {
278 name: "add",
279 args: &[
280 ::pyo3_stub_gen::type_info::ArgInfo {
281 name: "a",
282 r#type: || ::pyo3_stub_gen::TypeInfo {
283 name: "int".to_string(),
284 import: ::std::collections::HashSet::from(["typing".into()]),
285 },
286 signature: None,
287 },
288 ::pyo3_stub_gen::type_info::ArgInfo {
289 name: "b",
290 r#type: || ::pyo3_stub_gen::TypeInfo {
291 name: "int".to_string(),
292 import: ::std::collections::HashSet::from(["typing".into()]),
293 },
294 signature: None,
295 },
296 ::pyo3_stub_gen::type_info::ArgInfo {
297 name: "c",
298 r#type: || ::pyo3_stub_gen::TypeInfo {
299 name: "typing.Optional[int]".to_string(),
300 import: ::std::collections::HashSet::from(["typing".into()]),
301 },
302 signature: None,
303 },
304 ],
305 r#return: || ::pyo3_stub_gen::TypeInfo {
306 name: "int".to_string(),
307 import: ::std::collections::HashSet::from(["typing".into()]),
308 },
309 doc: "",
310 module: None,
311 is_async: false,
312 deprecated: None,
313 type_ignored: None,
314 }
315 "###);
316 Ok(())
317 }
318
319 #[test]
320 fn test_no_return_type() -> Result<()> {
321 let stub_str: LitStr = syn::parse2(quote! {
322 r#"
323 def print_hello(name: str):
324 """Print a greeting"""
325 "#
326 })?;
327 let info = parse_python_function_stub(stub_str)?;
328 let out = info.to_token_stream();
329 insta::assert_snapshot!(format_as_value(out), @r###"
330 ::pyo3_stub_gen::type_info::PyFunctionInfo {
331 name: "print_hello",
332 args: &[
333 ::pyo3_stub_gen::type_info::ArgInfo {
334 name: "name",
335 r#type: || ::pyo3_stub_gen::TypeInfo {
336 name: "str".to_string(),
337 import: ::std::collections::HashSet::from([]),
338 },
339 signature: None,
340 },
341 ],
342 r#return: ::pyo3_stub_gen::type_info::no_return_type_output,
343 doc: "Print a greeting",
344 module: None,
345 is_async: false,
346 deprecated: None,
347 type_ignored: None,
348 }
349 "###);
350 Ok(())
351 }
352
353 #[test]
354 fn test_async_function() -> Result<()> {
355 let stub_str: LitStr = syn::parse2(quote! {
356 r#"
357 async def fetch_data(url: str) -> str:
358 """Fetch data from URL"""
359 "#
360 })?;
361 let info = parse_python_function_stub(stub_str)?;
362 let out = info.to_token_stream();
363 insta::assert_snapshot!(format_as_value(out), @r###"
364 ::pyo3_stub_gen::type_info::PyFunctionInfo {
365 name: "fetch_data",
366 args: &[
367 ::pyo3_stub_gen::type_info::ArgInfo {
368 name: "url",
369 r#type: || ::pyo3_stub_gen::TypeInfo {
370 name: "str".to_string(),
371 import: ::std::collections::HashSet::from([]),
372 },
373 signature: None,
374 },
375 ],
376 r#return: || ::pyo3_stub_gen::TypeInfo {
377 name: "str".to_string(),
378 import: ::std::collections::HashSet::from([]),
379 },
380 doc: "Fetch data from URL",
381 module: None,
382 is_async: true,
383 deprecated: None,
384 type_ignored: None,
385 }
386 "###);
387 Ok(())
388 }
389
390 #[test]
391 fn test_deprecated_decorator() -> Result<()> {
392 let stub_str: LitStr = syn::parse2(quote! {
393 r#"
394 @deprecated
395 def old_function(x: int) -> int:
396 """This function is deprecated"""
397 "#
398 })?;
399 let info = parse_python_function_stub(stub_str)?;
400 let out = info.to_token_stream();
401 insta::assert_snapshot!(format_as_value(out), @r###"
402 ::pyo3_stub_gen::type_info::PyFunctionInfo {
403 name: "old_function",
404 args: &[
405 ::pyo3_stub_gen::type_info::ArgInfo {
406 name: "x",
407 r#type: || ::pyo3_stub_gen::TypeInfo {
408 name: "int".to_string(),
409 import: ::std::collections::HashSet::from([]),
410 },
411 signature: None,
412 },
413 ],
414 r#return: || ::pyo3_stub_gen::TypeInfo {
415 name: "int".to_string(),
416 import: ::std::collections::HashSet::from([]),
417 },
418 doc: "This function is deprecated",
419 module: None,
420 is_async: false,
421 deprecated: Some(::pyo3_stub_gen::type_info::DeprecatedInfo {
422 since: None,
423 note: None,
424 }),
425 type_ignored: None,
426 }
427 "###);
428 Ok(())
429 }
430
431 #[test]
432 fn test_deprecated_with_message() -> Result<()> {
433 let stub_str: LitStr = syn::parse2(quote! {
434 r#"
435 @deprecated("Use new_function instead")
436 def old_function(x: int) -> int:
437 """This function is deprecated"""
438 "#
439 })?;
440 let info = parse_python_function_stub(stub_str)?;
441 let out = info.to_token_stream();
442 insta::assert_snapshot!(format_as_value(out), @r###"
443 ::pyo3_stub_gen::type_info::PyFunctionInfo {
444 name: "old_function",
445 args: &[
446 ::pyo3_stub_gen::type_info::ArgInfo {
447 name: "x",
448 r#type: || ::pyo3_stub_gen::TypeInfo {
449 name: "int".to_string(),
450 import: ::std::collections::HashSet::from([]),
451 },
452 signature: None,
453 },
454 ],
455 r#return: || ::pyo3_stub_gen::TypeInfo {
456 name: "int".to_string(),
457 import: ::std::collections::HashSet::from([]),
458 },
459 doc: "This function is deprecated",
460 module: None,
461 is_async: false,
462 deprecated: Some(::pyo3_stub_gen::type_info::DeprecatedInfo {
463 since: None,
464 note: Some("Use new_function instead"),
465 }),
466 type_ignored: None,
467 }
468 "###);
469 Ok(())
470 }
471
472 fn format_as_value(tt: TokenStream2) -> String {
473 let ttt = quote! { const _: () = #tt; };
474 let formatted = prettyplease::unparse(&syn::parse_file(&ttt.to_string()).unwrap());
475 formatted
476 .trim()
477 .strip_prefix("const _: () = ")
478 .unwrap()
479 .strip_suffix(';')
480 .unwrap()
481 .to_string()
482 }
483}