pyo3_stub_gen_derive/
gen_stub.rs1mod arg;
79mod attr;
80mod member;
81mod method;
82mod parameter;
83mod parse_python;
84mod pyclass;
85mod pyclass_complex_enum;
86mod pyclass_enum;
87mod pyfunction;
88mod pymethods;
89mod renaming;
90mod signature;
91mod stub_type;
92mod util;
93mod variant;
94
95use arg::*;
96use attr::*;
97use member::*;
98use method::*;
99use pyclass::*;
100use pyclass_complex_enum::*;
101use pyclass_enum::*;
102use pymethods::*;
103use renaming::*;
104use signature::*;
105use stub_type::*;
106use util::*;
107
108use proc_macro2::TokenStream as TokenStream2;
109use quote::quote;
110use syn::{parse2, ItemEnum, ItemFn, ItemImpl, ItemStruct, LitStr, Result};
111
112pub fn pyclass(item: TokenStream2) -> Result<TokenStream2> {
113 let mut item_struct = parse2::<ItemStruct>(item)?;
114 let inner = PyClassInfo::try_from(item_struct.clone())?;
115 let derive_stub_type = StubType::from(&inner);
116 pyclass::prune_attrs(&mut item_struct);
117 Ok(quote! {
118 #item_struct
119 #derive_stub_type
120 pyo3_stub_gen::inventory::submit! {
121 #inner
122 }
123 })
124}
125
126pub fn pyclass_enum(item: TokenStream2) -> Result<TokenStream2> {
127 let inner = PyEnumInfo::try_from(parse2::<ItemEnum>(item.clone())?)?;
128 let derive_stub_type = StubType::from(&inner);
129 Ok(quote! {
130 #item
131 #derive_stub_type
132 pyo3_stub_gen::inventory::submit! {
133 #inner
134 }
135 })
136}
137
138pub fn pyclass_complex_enum(item: TokenStream2) -> Result<TokenStream2> {
139 let inner = PyComplexEnumInfo::try_from(parse2::<ItemEnum>(item.clone())?)?;
140 let derive_stub_type = StubType::from(&inner);
141 Ok(quote! {
142 #item
143 #derive_stub_type
144 pyo3_stub_gen::inventory::submit! {
145 #inner
146 }
147 })
148}
149
150pub fn pymethods(item: TokenStream2) -> Result<TokenStream2> {
151 let mut item_impl = parse2::<ItemImpl>(item)?;
152 let inner = PyMethodsInfo::try_from(item_impl.clone())?;
153 pymethods::prune_attrs(&mut item_impl);
154 Ok(quote! {
155 #item_impl
156 #[automatically_derived]
157 pyo3_stub_gen::inventory::submit! {
158 #inner
159 }
160 })
161}
162
163pub fn pyfunction(attr: TokenStream2, item: TokenStream2) -> Result<TokenStream2> {
164 let item_fn = parse2::<ItemFn>(item)?;
166 let attr = parse2::<pyfunction::PyFunctionAttr>(attr)?;
167
168 let infos = pyfunction::PyFunctionInfos::from_parts(item_fn, attr)?;
170
171 Ok(quote! { #infos })
173}
174
175pub fn gen_function_from_python_impl(input: TokenStream2) -> Result<TokenStream2> {
176 let parsed: parse_python::GenFunctionFromPythonInput = parse2(input)?;
177 let inner = parse_python::parse_gen_function_from_python_input(parsed)?;
178 Ok(quote! { #inner })
179}
180
181pub fn gen_methods_from_python_impl(input: TokenStream2) -> Result<TokenStream2> {
182 let stub_str: LitStr = parse2(input)?;
183 let inner = parse_python::parse_python_methods_stub(&stub_str)?;
184 Ok(quote! { #inner })
185}
186
187pub fn prune_gen_stub(item: TokenStream2) -> Result<TokenStream2> {
188 fn prune_attrs<T: syn::parse::Parse + quote::ToTokens>(
189 item: &TokenStream2,
190 fn_prune_attrs: fn(&mut T),
191 ) -> Result<TokenStream2> {
192 parse2::<T>(item.clone()).map(|mut item| {
193 fn_prune_attrs(&mut item);
194 quote! { #item }
195 })
196 }
197 prune_attrs::<ItemStruct>(&item, pyclass::prune_attrs)
198 .or_else(|_| prune_attrs::<ItemImpl>(&item, pymethods::prune_attrs))
199 .or_else(|_| prune_attrs::<ItemFn>(&item, pyfunction::prune_attrs))
200}
201
202#[cfg(test)]
203mod tests {
204 use super::*;
205 use quote::quote;
206
207 fn format_tokens(tokens: TokenStream2) -> String {
208 let formatted = prettyplease::unparse(&syn::parse_file(&tokens.to_string()).unwrap());
209 formatted.trim().to_string()
210 }
211
212 #[test]
213 fn test_overload_example_1_expansion() {
214 let attr = quote! {
219 python_overload = r#"
220 @overload
221 def overload_example_1(x: int) -> int: ...
222 "#
223 };
224
225 let item = quote! {
226 #[pyfunction]
227 pub fn overload_example_1(x: f64) -> f64 {
228 x + 1.0
229 }
230 };
231
232 let result = pyfunction(attr, item).unwrap();
233 let formatted = format_tokens(result);
234
235 insta::assert_snapshot!(formatted);
236 }
237
238 #[test]
239 fn test_overload_example_2_expansion() {
240 let attr = quote! {
246 python_overload = r#"
247 @overload
248 def overload_example_2(ob: int) -> int:
249 """Increments integer by 1"""
250
251 @overload
252 def overload_example_2(ob: float) -> float:
253 """Increments float by 1"""
254 "#,
255 no_default_overload = true
256 };
257
258 let item = quote! {
259 #[pyfunction]
260 pub fn overload_example_2(ob: Bound<PyAny>) -> PyResult<PyObject> {
261 let py = ob.py();
262 Ok(ob.into_py_any(py)?)
263 }
264 };
265
266 let result = pyfunction(attr, item).unwrap();
267 let formatted = format_tokens(result);
268
269 insta::assert_snapshot!(formatted);
270 }
271
272 #[test]
273 fn test_regular_function_no_overload() {
274 let attr = quote! {};
277
278 let item = quote! {
279 #[pyfunction]
280 pub fn regular_function(x: i32) -> i32 {
281 x + 1
282 }
283 };
284
285 let result = pyfunction(attr, item).unwrap();
286 let formatted = format_tokens(result);
287
288 insta::assert_snapshot!(formatted);
289 }
290}