pyo3_stub_gen/stub_type/
numpy.rs

1use super::{PyStubType, TypeInfo};
2use maplit::hashset;
3use numpy::{
4    ndarray::Dimension, Element, PyArray, PyArrayDescr, PyReadonlyArray, PyReadwriteArray,
5    PyUntypedArray,
6};
7use std::collections::HashMap;
8
9trait NumPyScalar {
10    fn type_() -> TypeInfo;
11}
12
13macro_rules! impl_numpy_scalar {
14    ($ty:ty, $name:expr) => {
15        impl NumPyScalar for $ty {
16            fn type_() -> TypeInfo {
17                TypeInfo {
18                    name: format!("numpy.{}", $name),
19                    source_module: None,
20                    import: hashset!["numpy".into()],
21                    type_refs: HashMap::new(),
22                }
23            }
24        }
25    };
26}
27
28impl_numpy_scalar!(i8, "int8");
29impl_numpy_scalar!(i16, "int16");
30impl_numpy_scalar!(i32, "int32");
31impl_numpy_scalar!(i64, "int64");
32impl_numpy_scalar!(u8, "uint8");
33impl_numpy_scalar!(u16, "uint16");
34impl_numpy_scalar!(u32, "uint32");
35impl_numpy_scalar!(u64, "uint64");
36impl_numpy_scalar!(f32, "float32");
37impl_numpy_scalar!(f64, "float64");
38impl_numpy_scalar!(num_complex::Complex32, "complex64");
39impl_numpy_scalar!(num_complex::Complex64, "complex128");
40
41impl<T: NumPyScalar, D> PyStubType for PyArray<T, D> {
42    fn type_output() -> TypeInfo {
43        let TypeInfo {
44            name, mut import, ..
45        } = T::type_();
46        import.insert("numpy.typing".into());
47        TypeInfo {
48            name: format!("numpy.typing.NDArray[{name}]"),
49            source_module: None,
50            import,
51            type_refs: HashMap::new(), // TODO: Track type refs for compound types
52        }
53    }
54}
55
56impl PyStubType for PyUntypedArray {
57    fn type_output() -> TypeInfo {
58        TypeInfo {
59            name: "numpy.typing.NDArray[typing.Any]".into(),
60            source_module: None,
61            import: hashset!["numpy.typing".into(), "typing".into()],
62            type_refs: HashMap::new(),
63        }
64    }
65}
66
67impl<T, D> PyStubType for PyReadonlyArray<'_, T, D>
68where
69    T: NumPyScalar + Element,
70    D: Dimension,
71{
72    fn type_output() -> TypeInfo {
73        PyArray::<T, D>::type_output()
74    }
75}
76
77impl<T, D> PyStubType for PyReadwriteArray<'_, T, D>
78where
79    T: NumPyScalar + Element,
80    D: Dimension,
81{
82    fn type_output() -> TypeInfo {
83        PyArray::<T, D>::type_output()
84    }
85}
86
87impl PyStubType for PyArrayDescr {
88    fn type_output() -> TypeInfo {
89        TypeInfo {
90            name: "numpy.dtype".into(),
91            source_module: None,
92            import: hashset!["numpy".into()],
93            type_refs: HashMap::new(),
94        }
95    }
96}