pyo3_stub_gen_derive/
gen_stub.rs

1//! Code generation for embedding metadata for generating Python stub file.
2//!
3//! These metadata are embedded as `inventory::submit!` block like:
4//!
5//! ```rust
6//! # use pyo3::*;
7//! # use pyo3_stub_gen::type_info::*;
8//! # struct PyPlaceholder;
9//! inventory::submit!{
10//!     PyClassInfo {
11//!         pyclass_name: "Placeholder",
12//!         module: Some("my_module"),
13//!         struct_id: std::any::TypeId::of::<PyPlaceholder>,
14//!         getters: &[
15//!             MemberInfo {
16//!                 name: "name",
17//!                 r#type: <String as ::pyo3_stub_gen::PyStubType>::type_output,
18//!                 doc: "",
19//!                 default: None,
20//!                 deprecated: None,
21//!             },
22//!             MemberInfo {
23//!                 name: "ndim",
24//!                 r#type: <usize as ::pyo3_stub_gen::PyStubType>::type_output,
25//!                 doc: "",
26//!                 default: None,
27//!                 deprecated: None,
28//!             },
29//!             MemberInfo {
30//!                 name: "description",
31//!                 r#type: <Option<String> as ::pyo3_stub_gen::PyStubType>::type_output,
32//!                 doc: "",
33//!                 default: None,
34//!                 deprecated: None,
35//!             },
36//!         ],
37//!         setters: &[],
38//!         doc: "",
39//!         bases: &[],
40//!         has_eq: false,
41//!         has_ord: false,
42//!         has_hash: false,
43//!         has_str: false,
44//!         subclass: false,
45//!     }
46//! }
47//! ```
48//!
49//! and this submodule responsible for generating such codes from Rust code like
50//!
51//! ```rust
52//! # use pyo3::*;
53//! #[pyclass(mapping, module = "my_module", name = "Placeholder")]
54//! #[derive(Debug, Clone)]
55//! pub struct PyPlaceholder {
56//!     #[pyo3(get)]
57//!     pub name: String,
58//!     #[pyo3(get)]
59//!     pub ndim: usize,
60//!     #[pyo3(get)]
61//!     pub description: Option<String>,
62//!     pub custom_latex: Option<String>,
63//! }
64//! ```
65//!
66//! Mechanism
67//! ----------
68//! Code generation will take three steps:
69//!
70//! 1. Parse input [proc_macro2::TokenStream] into corresponding syntax tree component in [syn],
71//!    - e.g. [ItemStruct] for `#[pyclass]`, [ItemImpl] for `#[pymethods]`, and so on.
72//! 2. Convert syntax tree components into `*Info` struct using [TryInto].
73//!    - e.g. [PyClassInfo] is converted from [ItemStruct], [PyMethodsInfo] is converted from [ItemImpl], and so on.
74//! 3. Generate token streams using implementation of [quote::ToTokens] trait for `*Info` structs.
75//!    - [quote::quote!] macro uses this trait.
76//!
77
78mod arg;
79mod attr;
80mod member;
81mod method;
82mod parameter;
83mod parse_python;
84mod pyclass;
85mod pyclass_complex_enum;
86mod pyclass_enum;
87mod pyfunction;
88mod pymethods;
89mod renaming;
90mod signature;
91mod stub_type;
92mod util;
93mod variant;
94
95use arg::*;
96use attr::*;
97use member::*;
98use method::*;
99use pyclass::*;
100use pyclass_complex_enum::*;
101use pyclass_enum::*;
102use pymethods::*;
103use renaming::*;
104use signature::*;
105use stub_type::*;
106use util::*;
107
108use proc_macro2::TokenStream as TokenStream2;
109use quote::quote;
110use syn::{parse2, ItemEnum, ItemFn, ItemImpl, ItemStruct, LitStr, Result};
111
112pub fn pyclass(attr: TokenStream2, item: TokenStream2) -> Result<TokenStream2> {
113    let attr = parse2::<attr::PyClassAttr>(attr)?;
114    let mut item_struct = parse2::<ItemStruct>(item)?;
115    let inner = PyClassInfo::try_from(item_struct.clone())?;
116    pyclass::prune_attrs(&mut item_struct);
117
118    if attr.skip_stub_type {
119        Ok(quote! {
120            #item_struct
121            pyo3_stub_gen::inventory::submit! {
122                #inner
123            }
124        })
125    } else {
126        let derive_stub_type = StubType::from(&inner);
127        Ok(quote! {
128            #item_struct
129            #derive_stub_type
130            pyo3_stub_gen::inventory::submit! {
131                #inner
132            }
133        })
134    }
135}
136
137pub fn pyclass_enum(attr: TokenStream2, item: TokenStream2) -> Result<TokenStream2> {
138    let attr = parse2::<attr::PyClassAttr>(attr)?;
139    let inner = PyEnumInfo::try_from(parse2::<ItemEnum>(item.clone())?)?;
140
141    if attr.skip_stub_type {
142        Ok(quote! {
143            #item
144            pyo3_stub_gen::inventory::submit! {
145                #inner
146            }
147        })
148    } else {
149        let derive_stub_type = StubType::from(&inner);
150        Ok(quote! {
151            #item
152            #derive_stub_type
153            pyo3_stub_gen::inventory::submit! {
154                #inner
155            }
156        })
157    }
158}
159
160pub fn pyclass_complex_enum(attr: TokenStream2, item: TokenStream2) -> Result<TokenStream2> {
161    let attr = parse2::<attr::PyClassAttr>(attr)?;
162    let inner = PyComplexEnumInfo::try_from(parse2::<ItemEnum>(item.clone())?)?;
163
164    if attr.skip_stub_type {
165        Ok(quote! {
166            #item
167            pyo3_stub_gen::inventory::submit! {
168                #inner
169            }
170        })
171    } else {
172        let derive_stub_type = StubType::from(&inner);
173        Ok(quote! {
174            #item
175            #derive_stub_type
176            pyo3_stub_gen::inventory::submit! {
177                #inner
178            }
179        })
180    }
181}
182
183pub fn pymethods(item: TokenStream2) -> Result<TokenStream2> {
184    let mut item_impl = parse2::<ItemImpl>(item)?;
185    let inner = PyMethodsInfo::try_from(item_impl.clone())?;
186    pymethods::prune_attrs(&mut item_impl);
187    Ok(quote! {
188        #item_impl
189        #[automatically_derived]
190        pyo3_stub_gen::inventory::submit! {
191            #inner
192        }
193    })
194}
195
196pub fn pyfunction(attr: TokenStream2, item: TokenStream2) -> Result<TokenStream2> {
197    // Step 1: Parse TokenStream to syn types
198    let item_fn = parse2::<ItemFn>(item)?;
199    let attr = parse2::<pyfunction::PyFunctionAttr>(attr)?;
200
201    // Step 2: Convert to intermediate representation
202    let infos = pyfunction::PyFunctionInfos::from_parts(item_fn, attr)?;
203
204    // Step 3: Generate output TokenStream via ToTokens
205    Ok(quote! { #infos })
206}
207
208pub fn gen_function_from_python_impl(input: TokenStream2) -> Result<TokenStream2> {
209    let parsed: parse_python::GenFunctionFromPythonInput = parse2(input)?;
210    let inner = parse_python::parse_gen_function_from_python_input(parsed)?;
211    Ok(quote! { #inner })
212}
213
214pub fn gen_methods_from_python_impl(input: TokenStream2) -> Result<TokenStream2> {
215    let stub_str: LitStr = parse2(input)?;
216    let inner = parse_python::parse_python_methods_stub(&stub_str)?;
217    Ok(quote! { #inner })
218}
219
220pub fn prune_gen_stub(item: TokenStream2) -> Result<TokenStream2> {
221    fn prune_attrs<T: syn::parse::Parse + quote::ToTokens>(
222        item: &TokenStream2,
223        fn_prune_attrs: fn(&mut T),
224    ) -> Result<TokenStream2> {
225        parse2::<T>(item.clone()).map(|mut item| {
226            fn_prune_attrs(&mut item);
227            quote! { #item }
228        })
229    }
230    prune_attrs::<ItemStruct>(&item, pyclass::prune_attrs)
231        .or_else(|_| prune_attrs::<ItemImpl>(&item, pymethods::prune_attrs))
232        .or_else(|_| prune_attrs::<ItemFn>(&item, pyfunction::prune_attrs))
233}
234
235#[cfg(test)]
236mod tests {
237    use super::*;
238    use quote::quote;
239
240    fn format_tokens(tokens: TokenStream2) -> String {
241        let formatted = prettyplease::unparse(&syn::parse_file(&tokens.to_string()).unwrap());
242        formatted.trim().to_string()
243    }
244
245    #[test]
246    fn test_overload_example_1_expansion() {
247        // Test the overload_example_1 case: python_overload + auto-generated
248        // This should generate TWO PyFunctionInfo:
249        // 1. From python_overload: int -> int with is_overload: true
250        // 2. From Rust signature: f64 -> f64 with is_overload: true
251        let attr = quote! {
252            python_overload = r#"
253            @overload
254            def overload_example_1(x: int) -> int: ...
255            "#
256        };
257
258        let item = quote! {
259            #[pyfunction]
260            pub fn overload_example_1(x: f64) -> f64 {
261                x + 1.0
262            }
263        };
264
265        let result = pyfunction(attr, item).unwrap();
266        let formatted = format_tokens(result);
267
268        insta::assert_snapshot!(formatted);
269    }
270
271    #[test]
272    fn test_overload_example_2_expansion() {
273        // Test the overload_example_2 case: python_overload with no_default_overload
274        // This should generate TWO PyFunctionInfo (both from python_overload):
275        // 1. int -> int with is_overload: true
276        // 2. float -> float with is_overload: true
277        // Should NOT generate overload from Rust signature (Bound<PyAny>)
278        let attr = quote! {
279            python_overload = r#"
280            @overload
281            def overload_example_2(ob: int) -> int:
282                """Increments integer by 1"""
283
284            @overload
285            def overload_example_2(ob: float) -> float:
286                """Increments float by 1"""
287            "#,
288            no_default_overload = true
289        };
290
291        let item = quote! {
292            #[pyfunction]
293            pub fn overload_example_2(ob: Bound<PyAny>) -> PyResult<PyObject> {
294                let py = ob.py();
295                Ok(ob.into_py_any(py)?)
296            }
297        };
298
299        let result = pyfunction(attr, item).unwrap();
300        let formatted = format_tokens(result);
301
302        insta::assert_snapshot!(formatted);
303    }
304
305    #[test]
306    fn test_regular_function_no_overload() {
307        // Test a regular function without python_overload
308        // This should generate ONE PyFunctionInfo with is_overload: false
309        let attr = quote! {};
310
311        let item = quote! {
312            #[pyfunction]
313            pub fn regular_function(x: i32) -> i32 {
314                x + 1
315            }
316        };
317
318        let result = pyfunction(attr, item).unwrap();
319        let formatted = format_tokens(result);
320
321        insta::assert_snapshot!(formatted);
322    }
323}