pyo3_stub_gen/stub_type/
numpy.rs1use super::{PyStubType, TypeInfo};
2use maplit::hashset;
3use numpy::{
4 ndarray::Dimension, Element, PyArray, PyArrayDescr, PyReadonlyArray, PyReadwriteArray,
5 PyUntypedArray,
6};
7
8trait NumPyScalar {
9 fn type_() -> TypeInfo;
10}
11
12macro_rules! impl_numpy_scalar {
13 ($ty:ty, $name:expr) => {
14 impl NumPyScalar for $ty {
15 fn type_() -> TypeInfo {
16 TypeInfo {
17 name: format!("numpy.{}", $name),
18 import: hashset!["numpy".into()],
19 }
20 }
21 }
22 };
23}
24
25impl_numpy_scalar!(i8, "int8");
26impl_numpy_scalar!(i16, "int16");
27impl_numpy_scalar!(i32, "int32");
28impl_numpy_scalar!(i64, "int64");
29impl_numpy_scalar!(u8, "uint8");
30impl_numpy_scalar!(u16, "uint16");
31impl_numpy_scalar!(u32, "uint32");
32impl_numpy_scalar!(u64, "uint64");
33impl_numpy_scalar!(f32, "float32");
34impl_numpy_scalar!(f64, "float64");
35impl_numpy_scalar!(num_complex::Complex32, "complex64");
36impl_numpy_scalar!(num_complex::Complex64, "complex128");
37
38impl<T: NumPyScalar, D> PyStubType for PyArray<T, D> {
39 fn type_output() -> TypeInfo {
40 let TypeInfo { name, mut import } = T::type_();
41 import.insert("numpy.typing".into());
42 TypeInfo {
43 name: format!("numpy.typing.NDArray[{name}]"),
44 import,
45 }
46 }
47}
48
49impl PyStubType for PyUntypedArray {
50 fn type_output() -> TypeInfo {
51 TypeInfo {
52 name: "numpy.typing.NDArray[typing.Any]".into(),
53 import: hashset!["numpy.typing".into(), "typing".into()],
54 }
55 }
56}
57
58impl<T, D> PyStubType for PyReadonlyArray<'_, T, D>
59where
60 T: NumPyScalar + Element,
61 D: Dimension,
62{
63 fn type_output() -> TypeInfo {
64 PyArray::<T, D>::type_output()
65 }
66}
67
68impl<T, D> PyStubType for PyReadwriteArray<'_, T, D>
69where
70 T: NumPyScalar + Element,
71 D: Dimension,
72{
73 fn type_output() -> TypeInfo {
74 PyArray::<T, D>::type_output()
75 }
76}
77
78impl PyStubType for PyArrayDescr {
79 fn type_output() -> TypeInfo {
80 TypeInfo {
81 name: "numpy.dtype".into(),
82 import: hashset!["numpy".into()],
83 }
84 }
85}