pyo3_stub_gen/stub_type/
numpy.rs1use crate::runtime::PyRuntimeType;
2
3use super::{PyStubType, TypeInfo};
4use ::pyo3::prelude::*;
5use maplit::hashset;
6use numpy::{
7 ndarray::Dimension, Element, PyArray, PyArrayDescr, PyReadonlyArray, PyReadwriteArray,
8 PyUntypedArray,
9};
10use std::collections::HashMap;
11
12trait NumPyScalar {
13 fn type_() -> TypeInfo;
14}
15
16macro_rules! impl_numpy_scalar {
17 ($ty:ty, $name:expr) => {
18 impl NumPyScalar for $ty {
19 fn type_() -> TypeInfo {
20 TypeInfo {
21 name: format!("numpy.{}", $name),
22 source_module: None,
23 import: hashset!["numpy".into()],
24 type_refs: HashMap::new(),
25 }
26 }
27 }
28 };
29}
30
31impl_numpy_scalar!(i8, "int8");
32impl_numpy_scalar!(i16, "int16");
33impl_numpy_scalar!(i32, "int32");
34impl_numpy_scalar!(i64, "int64");
35impl_numpy_scalar!(u8, "uint8");
36impl_numpy_scalar!(u16, "uint16");
37impl_numpy_scalar!(u32, "uint32");
38impl_numpy_scalar!(u64, "uint64");
39impl_numpy_scalar!(f32, "float32");
40impl_numpy_scalar!(f64, "float64");
41impl_numpy_scalar!(num_complex::Complex32, "complex64");
42impl_numpy_scalar!(num_complex::Complex64, "complex128");
43
44impl<T: NumPyScalar, D> PyStubType for PyArray<T, D> {
45 fn type_output() -> TypeInfo {
46 let TypeInfo {
47 name, mut import, ..
48 } = T::type_();
49 import.insert("numpy.typing".into());
50 TypeInfo {
51 name: format!("numpy.typing.NDArray[{name}]"),
52 source_module: None,
53 import,
54 type_refs: HashMap::new(), }
56 }
57}
58impl<T, D> PyRuntimeType for PyArray<T, D> {
59 fn runtime_type_object(py: Python<'_>) -> PyResult<Bound<'_, PyAny>> {
60 let numpy = py.import("numpy")?;
61 numpy.getattr("ndarray")
62 }
63}
64
65impl PyStubType for PyUntypedArray {
66 fn type_output() -> TypeInfo {
67 TypeInfo {
68 name: "numpy.typing.NDArray[typing.Any]".into(),
69 source_module: None,
70 import: hashset!["numpy.typing".into(), "typing".into()],
71 type_refs: HashMap::new(),
72 }
73 }
74}
75impl PyRuntimeType for PyUntypedArray {
76 fn runtime_type_object(py: Python<'_>) -> PyResult<Bound<'_, PyAny>> {
77 let numpy = py.import("numpy")?;
78 numpy.getattr("ndarray")
79 }
80}
81
82impl<T, D> PyStubType for PyReadonlyArray<'_, T, D>
83where
84 T: NumPyScalar + Element,
85 D: Dimension,
86{
87 fn type_output() -> TypeInfo {
88 PyArray::<T, D>::type_output()
89 }
90}
91impl<T, D> PyRuntimeType for PyReadonlyArray<'_, T, D>
92where
93 T: Element,
94 D: Dimension,
95{
96 fn runtime_type_object(py: Python<'_>) -> PyResult<Bound<'_, PyAny>> {
97 let numpy = py.import("numpy")?;
98 numpy.getattr("ndarray")
99 }
100}
101
102impl<T, D> PyStubType for PyReadwriteArray<'_, T, D>
103where
104 T: NumPyScalar + Element,
105 D: Dimension,
106{
107 fn type_output() -> TypeInfo {
108 PyArray::<T, D>::type_output()
109 }
110}
111impl<T, D> PyRuntimeType for PyReadwriteArray<'_, T, D>
112where
113 T: Element,
114 D: Dimension,
115{
116 fn runtime_type_object(py: Python<'_>) -> PyResult<Bound<'_, PyAny>> {
117 let numpy = py.import("numpy")?;
118 numpy.getattr("ndarray")
119 }
120}
121
122impl PyStubType for PyArrayDescr {
123 fn type_output() -> TypeInfo {
124 TypeInfo {
125 name: "numpy.dtype".into(),
126 source_module: None,
127 import: hashset!["numpy".into()],
128 type_refs: HashMap::new(),
129 }
130 }
131}
132impl PyRuntimeType for PyArrayDescr {
133 fn runtime_type_object(py: Python<'_>) -> PyResult<Bound<'_, PyAny>> {
134 let numpy = py.import("numpy")?;
135 numpy.getattr("dtype")
136 }
137}