pyo3_stub_gen_derive/
gen_stub.rs1mod arg;
79mod attr;
80mod member;
81mod method;
82mod parameter;
83mod parse_python;
84mod pyclass;
85mod pyclass_complex_enum;
86mod pyclass_enum;
87mod pyfunction;
88mod pymethods;
89mod renaming;
90mod signature;
91mod stub_type;
92mod util;
93mod variant;
94
95use arg::*;
96use attr::*;
97use member::*;
98use method::*;
99use pyclass::*;
100use pyclass_complex_enum::*;
101use pyclass_enum::*;
102use pymethods::*;
103use renaming::*;
104use signature::*;
105use stub_type::*;
106use util::*;
107
108use proc_macro2::TokenStream as TokenStream2;
109use quote::quote;
110use syn::{parse2, ItemEnum, ItemFn, ItemImpl, ItemStruct, LitStr, Result};
111
112pub fn pyclass(attr: TokenStream2, item: TokenStream2) -> Result<TokenStream2> {
113 let attr = parse2::<attr::PyClassAttr>(attr)?;
114 let mut item_struct = parse2::<ItemStruct>(item)?;
115 let inner = PyClassInfo::try_from(item_struct.clone())?;
116 pyclass::prune_attrs(&mut item_struct);
117
118 if attr.skip_stub_type {
119 Ok(quote! {
120 #item_struct
121 pyo3_stub_gen::inventory::submit! {
122 #inner
123 }
124 })
125 } else {
126 let derive_stub_type = StubType::from(&inner);
127 Ok(quote! {
128 #item_struct
129 #derive_stub_type
130 pyo3_stub_gen::inventory::submit! {
131 #inner
132 }
133 })
134 }
135}
136
137pub fn pyclass_enum(attr: TokenStream2, item: TokenStream2) -> Result<TokenStream2> {
138 let attr = parse2::<attr::PyClassAttr>(attr)?;
139 let inner = PyEnumInfo::try_from(parse2::<ItemEnum>(item.clone())?)?;
140
141 if attr.skip_stub_type {
142 Ok(quote! {
143 #item
144 pyo3_stub_gen::inventory::submit! {
145 #inner
146 }
147 })
148 } else {
149 let derive_stub_type = StubType::from(&inner);
150 Ok(quote! {
151 #item
152 #derive_stub_type
153 pyo3_stub_gen::inventory::submit! {
154 #inner
155 }
156 })
157 }
158}
159
160pub fn pyclass_complex_enum(attr: TokenStream2, item: TokenStream2) -> Result<TokenStream2> {
161 let attr = parse2::<attr::PyClassAttr>(attr)?;
162 let inner = PyComplexEnumInfo::try_from(parse2::<ItemEnum>(item.clone())?)?;
163
164 if attr.skip_stub_type {
165 Ok(quote! {
166 #item
167 pyo3_stub_gen::inventory::submit! {
168 #inner
169 }
170 })
171 } else {
172 let derive_stub_type = StubType::from(&inner);
173 Ok(quote! {
174 #item
175 #derive_stub_type
176 pyo3_stub_gen::inventory::submit! {
177 #inner
178 }
179 })
180 }
181}
182
183pub fn pymethods(item: TokenStream2) -> Result<TokenStream2> {
184 let mut item_impl = parse2::<ItemImpl>(item)?;
185 let inner = PyMethodsInfo::try_from(item_impl.clone())?;
186 pymethods::prune_attrs(&mut item_impl);
187 Ok(quote! {
188 #item_impl
189 #[automatically_derived]
190 pyo3_stub_gen::inventory::submit! {
191 #inner
192 }
193 })
194}
195
196pub fn pyfunction(attr: TokenStream2, item: TokenStream2) -> Result<TokenStream2> {
197 let item_fn = parse2::<ItemFn>(item)?;
199 let attr = parse2::<pyfunction::PyFunctionAttr>(attr)?;
200
201 let infos = pyfunction::PyFunctionInfos::from_parts(item_fn, attr)?;
203
204 Ok(quote! { #infos })
206}
207
208pub fn gen_function_from_python_impl(input: TokenStream2) -> Result<TokenStream2> {
209 let parsed: parse_python::GenFunctionFromPythonInput = parse2(input)?;
210 let inner = parse_python::parse_gen_function_from_python_input(parsed)?;
211 Ok(quote! { #inner })
212}
213
214pub fn gen_methods_from_python_impl(input: TokenStream2) -> Result<TokenStream2> {
215 let stub_str: LitStr = parse2(input)?;
216 let inner = parse_python::parse_python_methods_stub(&stub_str)?;
217 Ok(quote! { #inner })
218}
219
220pub fn prune_gen_stub(item: TokenStream2) -> Result<TokenStream2> {
221 fn prune_attrs<T: syn::parse::Parse + quote::ToTokens>(
222 item: &TokenStream2,
223 fn_prune_attrs: fn(&mut T),
224 ) -> Result<TokenStream2> {
225 parse2::<T>(item.clone()).map(|mut item| {
226 fn_prune_attrs(&mut item);
227 quote! { #item }
228 })
229 }
230 prune_attrs::<ItemStruct>(&item, pyclass::prune_attrs)
231 .or_else(|_| prune_attrs::<ItemImpl>(&item, pymethods::prune_attrs))
232 .or_else(|_| prune_attrs::<ItemFn>(&item, pyfunction::prune_attrs))
233}
234
235#[cfg(test)]
236mod tests {
237 use super::*;
238 use quote::quote;
239
240 fn format_tokens(tokens: TokenStream2) -> String {
241 let formatted = prettyplease::unparse(&syn::parse_file(&tokens.to_string()).unwrap());
242 formatted.trim().to_string()
243 }
244
245 #[test]
246 fn test_overload_example_1_expansion() {
247 let attr = quote! {
252 python_overload = r#"
253 @overload
254 def overload_example_1(x: int) -> int: ...
255 "#
256 };
257
258 let item = quote! {
259 #[pyfunction]
260 pub fn overload_example_1(x: f64) -> f64 {
261 x + 1.0
262 }
263 };
264
265 let result = pyfunction(attr, item).unwrap();
266 let formatted = format_tokens(result);
267
268 insta::assert_snapshot!(formatted);
269 }
270
271 #[test]
272 fn test_overload_example_2_expansion() {
273 let attr = quote! {
279 python_overload = r#"
280 @overload
281 def overload_example_2(ob: int) -> int:
282 """Increments integer by 1"""
283
284 @overload
285 def overload_example_2(ob: float) -> float:
286 """Increments float by 1"""
287 "#,
288 no_default_overload = true
289 };
290
291 let item = quote! {
292 #[pyfunction]
293 pub fn overload_example_2(ob: Bound<PyAny>) -> PyResult<PyObject> {
294 let py = ob.py();
295 Ok(ob.into_py_any(py)?)
296 }
297 };
298
299 let result = pyfunction(attr, item).unwrap();
300 let formatted = format_tokens(result);
301
302 insta::assert_snapshot!(formatted);
303 }
304
305 #[test]
306 fn test_regular_function_no_overload() {
307 let attr = quote! {};
310
311 let item = quote! {
312 #[pyfunction]
313 pub fn regular_function(x: i32) -> i32 {
314 x + 1
315 }
316 };
317
318 let result = pyfunction(attr, item).unwrap();
319 let formatted = format_tokens(result);
320
321 insta::assert_snapshot!(formatted);
322 }
323}