pyo3_stub_gen/stub_type/
numpy.rs

1use super::{PyStubType, TypeInfo};
2use maplit::hashset;
3use numpy::{
4    ndarray::Dimension, Element, PyArray, PyArrayDescr, PyReadonlyArray, PyReadwriteArray,
5    PyUntypedArray,
6};
7
8trait NumPyScalar {
9    fn type_() -> TypeInfo;
10}
11
12macro_rules! impl_numpy_scalar {
13    ($ty:ty, $name:expr) => {
14        impl NumPyScalar for $ty {
15            fn type_() -> TypeInfo {
16                TypeInfo {
17                    name: format!("numpy.{}", $name),
18                    import: hashset!["numpy".into()],
19                }
20            }
21        }
22    };
23}
24
25impl_numpy_scalar!(i8, "int8");
26impl_numpy_scalar!(i16, "int16");
27impl_numpy_scalar!(i32, "int32");
28impl_numpy_scalar!(i64, "int64");
29impl_numpy_scalar!(u8, "uint8");
30impl_numpy_scalar!(u16, "uint16");
31impl_numpy_scalar!(u32, "uint32");
32impl_numpy_scalar!(u64, "uint64");
33impl_numpy_scalar!(f32, "float32");
34impl_numpy_scalar!(f64, "float64");
35impl_numpy_scalar!(num_complex::Complex32, "complex64");
36impl_numpy_scalar!(num_complex::Complex64, "complex128");
37
38impl<T: NumPyScalar, D> PyStubType for PyArray<T, D> {
39    fn type_output() -> TypeInfo {
40        let TypeInfo { name, mut import } = T::type_();
41        import.insert("numpy.typing".into());
42        TypeInfo {
43            name: format!("numpy.typing.NDArray[{name}]"),
44            import,
45        }
46    }
47}
48
49impl PyStubType for PyUntypedArray {
50    fn type_output() -> TypeInfo {
51        TypeInfo {
52            name: "numpy.typing.NDArray[typing.Any]".into(),
53            import: hashset!["numpy.typing".into(), "typing".into()],
54        }
55    }
56}
57
58impl<T, D> PyStubType for PyReadonlyArray<'_, T, D>
59where
60    T: NumPyScalar + Element,
61    D: Dimension,
62{
63    fn type_output() -> TypeInfo {
64        PyArray::<T, D>::type_output()
65    }
66}
67
68impl<T, D> PyStubType for PyReadwriteArray<'_, T, D>
69where
70    T: NumPyScalar + Element,
71    D: Dimension,
72{
73    fn type_output() -> TypeInfo {
74        PyArray::<T, D>::type_output()
75    }
76}
77
78impl PyStubType for PyArrayDescr {
79    fn type_output() -> TypeInfo {
80        TypeInfo {
81            name: "numpy.dtype".into(),
82            import: hashset!["numpy".into()],
83        }
84    }
85}