pyo3_stub_gen/stub_type/
numpy.rs1use super::{PyStubType, TypeInfo};
2use maplit::hashset;
3use numpy::{
4 ndarray::Dimension, Element, PyArray, PyArrayDescr, PyReadonlyArray, PyReadwriteArray,
5 PyUntypedArray,
6};
7use std::collections::HashMap;
8
9trait NumPyScalar {
10 fn type_() -> TypeInfo;
11}
12
13macro_rules! impl_numpy_scalar {
14 ($ty:ty, $name:expr) => {
15 impl NumPyScalar for $ty {
16 fn type_() -> TypeInfo {
17 TypeInfo {
18 name: format!("numpy.{}", $name),
19 source_module: None,
20 import: hashset!["numpy".into()],
21 type_refs: HashMap::new(),
22 }
23 }
24 }
25 };
26}
27
28impl_numpy_scalar!(i8, "int8");
29impl_numpy_scalar!(i16, "int16");
30impl_numpy_scalar!(i32, "int32");
31impl_numpy_scalar!(i64, "int64");
32impl_numpy_scalar!(u8, "uint8");
33impl_numpy_scalar!(u16, "uint16");
34impl_numpy_scalar!(u32, "uint32");
35impl_numpy_scalar!(u64, "uint64");
36impl_numpy_scalar!(f32, "float32");
37impl_numpy_scalar!(f64, "float64");
38impl_numpy_scalar!(num_complex::Complex32, "complex64");
39impl_numpy_scalar!(num_complex::Complex64, "complex128");
40
41impl<T: NumPyScalar, D> PyStubType for PyArray<T, D> {
42 fn type_output() -> TypeInfo {
43 let TypeInfo {
44 name, mut import, ..
45 } = T::type_();
46 import.insert("numpy.typing".into());
47 TypeInfo {
48 name: format!("numpy.typing.NDArray[{name}]"),
49 source_module: None,
50 import,
51 type_refs: HashMap::new(), }
53 }
54}
55
56impl PyStubType for PyUntypedArray {
57 fn type_output() -> TypeInfo {
58 TypeInfo {
59 name: "numpy.typing.NDArray[typing.Any]".into(),
60 source_module: None,
61 import: hashset!["numpy.typing".into(), "typing".into()],
62 type_refs: HashMap::new(),
63 }
64 }
65}
66
67impl<T, D> PyStubType for PyReadonlyArray<'_, T, D>
68where
69 T: NumPyScalar + Element,
70 D: Dimension,
71{
72 fn type_output() -> TypeInfo {
73 PyArray::<T, D>::type_output()
74 }
75}
76
77impl<T, D> PyStubType for PyReadwriteArray<'_, T, D>
78where
79 T: NumPyScalar + Element,
80 D: Dimension,
81{
82 fn type_output() -> TypeInfo {
83 PyArray::<T, D>::type_output()
84 }
85}
86
87impl PyStubType for PyArrayDescr {
88 fn type_output() -> TypeInfo {
89 TypeInfo {
90 name: "numpy.dtype".into(),
91 source_module: None,
92 import: hashset!["numpy".into()],
93 type_refs: HashMap::new(),
94 }
95 }
96}