pyo3_stub_gen/stub_type/
numpy.rs

1use crate::runtime::PyRuntimeType;
2
3use super::{PyStubType, TypeInfo};
4use ::pyo3::prelude::*;
5use maplit::hashset;
6use numpy::{
7    ndarray::Dimension, Element, PyArray, PyArrayDescr, PyReadonlyArray, PyReadwriteArray,
8    PyUntypedArray,
9};
10use std::collections::HashMap;
11
12trait NumPyScalar {
13    fn type_() -> TypeInfo;
14}
15
16macro_rules! impl_numpy_scalar {
17    ($ty:ty, $name:expr) => {
18        impl NumPyScalar for $ty {
19            fn type_() -> TypeInfo {
20                TypeInfo {
21                    name: format!("numpy.{}", $name),
22                    source_module: None,
23                    import: hashset!["numpy".into()],
24                    type_refs: HashMap::new(),
25                }
26            }
27        }
28    };
29}
30
31impl_numpy_scalar!(i8, "int8");
32impl_numpy_scalar!(i16, "int16");
33impl_numpy_scalar!(i32, "int32");
34impl_numpy_scalar!(i64, "int64");
35impl_numpy_scalar!(u8, "uint8");
36impl_numpy_scalar!(u16, "uint16");
37impl_numpy_scalar!(u32, "uint32");
38impl_numpy_scalar!(u64, "uint64");
39impl_numpy_scalar!(f32, "float32");
40impl_numpy_scalar!(f64, "float64");
41impl_numpy_scalar!(num_complex::Complex32, "complex64");
42impl_numpy_scalar!(num_complex::Complex64, "complex128");
43
44impl<T: NumPyScalar, D> PyStubType for PyArray<T, D> {
45    fn type_output() -> TypeInfo {
46        let TypeInfo {
47            name, mut import, ..
48        } = T::type_();
49        import.insert("numpy.typing".into());
50        TypeInfo {
51            name: format!("numpy.typing.NDArray[{name}]"),
52            source_module: None,
53            import,
54            type_refs: HashMap::new(), // TODO: Track type refs for compound types
55        }
56    }
57}
58impl<T, D> PyRuntimeType for PyArray<T, D> {
59    fn runtime_type_object(py: Python<'_>) -> PyResult<Bound<'_, PyAny>> {
60        let numpy = py.import("numpy")?;
61        numpy.getattr("ndarray")
62    }
63}
64
65impl PyStubType for PyUntypedArray {
66    fn type_output() -> TypeInfo {
67        TypeInfo {
68            name: "numpy.typing.NDArray[typing.Any]".into(),
69            source_module: None,
70            import: hashset!["numpy.typing".into(), "typing".into()],
71            type_refs: HashMap::new(),
72        }
73    }
74}
75impl PyRuntimeType for PyUntypedArray {
76    fn runtime_type_object(py: Python<'_>) -> PyResult<Bound<'_, PyAny>> {
77        let numpy = py.import("numpy")?;
78        numpy.getattr("ndarray")
79    }
80}
81
82impl<T, D> PyStubType for PyReadonlyArray<'_, T, D>
83where
84    T: NumPyScalar + Element,
85    D: Dimension,
86{
87    fn type_output() -> TypeInfo {
88        PyArray::<T, D>::type_output()
89    }
90}
91impl<T, D> PyRuntimeType for PyReadonlyArray<'_, T, D>
92where
93    T: Element,
94    D: Dimension,
95{
96    fn runtime_type_object(py: Python<'_>) -> PyResult<Bound<'_, PyAny>> {
97        let numpy = py.import("numpy")?;
98        numpy.getattr("ndarray")
99    }
100}
101
102impl<T, D> PyStubType for PyReadwriteArray<'_, T, D>
103where
104    T: NumPyScalar + Element,
105    D: Dimension,
106{
107    fn type_output() -> TypeInfo {
108        PyArray::<T, D>::type_output()
109    }
110}
111impl<T, D> PyRuntimeType for PyReadwriteArray<'_, T, D>
112where
113    T: Element,
114    D: Dimension,
115{
116    fn runtime_type_object(py: Python<'_>) -> PyResult<Bound<'_, PyAny>> {
117        let numpy = py.import("numpy")?;
118        numpy.getattr("ndarray")
119    }
120}
121
122impl PyStubType for PyArrayDescr {
123    fn type_output() -> TypeInfo {
124        TypeInfo {
125            name: "numpy.dtype".into(),
126            source_module: None,
127            import: hashset!["numpy".into()],
128            type_refs: HashMap::new(),
129        }
130    }
131}
132impl PyRuntimeType for PyArrayDescr {
133    fn runtime_type_object(py: Python<'_>) -> PyResult<Bound<'_, PyAny>> {
134        let numpy = py.import("numpy")?;
135        numpy.getattr("dtype")
136    }
137}