pyo3_stub_gen/stub_type/
pyo3.rs

1use crate::runtime::PyRuntimeType;
2use crate::stub_type::*;
3use ::pyo3::{
4    basic::CompareOp,
5    pybacked::{PyBackedBytes, PyBackedStr},
6    pyclass::boolean_struct::False,
7    types::*,
8    Bound, Py, PyClass, PyRef, PyRefMut, PyResult, Python,
9};
10use maplit::hashset;
11use std::collections::HashMap;
12
13impl PyStubType for PyAny {
14    fn type_output() -> TypeInfo {
15        TypeInfo {
16            name: "typing.Any".to_string(),
17            source_module: None,
18            import: hashset! { "typing".into() },
19            type_refs: HashMap::new(),
20        }
21    }
22}
23impl PyRuntimeType for PyAny {
24    fn runtime_type_object(py: Python<'_>) -> PyResult<Bound<'_, ::pyo3::PyAny>> {
25        // PyAny maps to `object` at runtime
26        Ok(py.get_type::<::pyo3::types::PyAny>().into_any())
27    }
28}
29
30impl<T: PyStubType> PyStubType for Py<T> {
31    fn type_input() -> TypeInfo {
32        T::type_input()
33    }
34    fn type_output() -> TypeInfo {
35        T::type_output()
36    }
37}
38impl<T: PyRuntimeType> PyRuntimeType for Py<T> {
39    fn runtime_type_object(py: Python<'_>) -> PyResult<Bound<'_, ::pyo3::PyAny>> {
40        T::runtime_type_object(py)
41    }
42}
43
44impl<T: PyStubType + PyClass> PyStubType for PyRef<'_, T> {
45    fn type_input() -> TypeInfo {
46        T::type_input()
47    }
48    fn type_output() -> TypeInfo {
49        T::type_output()
50    }
51}
52impl<T: PyRuntimeType + PyClass> PyRuntimeType for PyRef<'_, T> {
53    fn runtime_type_object(py: Python<'_>) -> PyResult<Bound<'_, ::pyo3::PyAny>> {
54        T::runtime_type_object(py)
55    }
56}
57
58impl<T: PyStubType + PyClass<Frozen = False>> PyStubType for PyRefMut<'_, T> {
59    fn type_input() -> TypeInfo {
60        T::type_input()
61    }
62    fn type_output() -> TypeInfo {
63        T::type_output()
64    }
65}
66impl<T: PyRuntimeType + PyClass<Frozen = False>> PyRuntimeType for PyRefMut<'_, T> {
67    fn runtime_type_object(py: Python<'_>) -> PyResult<Bound<'_, ::pyo3::PyAny>> {
68        T::runtime_type_object(py)
69    }
70}
71
72impl<T: PyStubType> PyStubType for Bound<'_, T> {
73    fn type_input() -> TypeInfo {
74        T::type_input()
75    }
76    fn type_output() -> TypeInfo {
77        T::type_output()
78    }
79}
80impl<T: PyRuntimeType> PyRuntimeType for Bound<'_, T> {
81    fn runtime_type_object(py: Python<'_>) -> PyResult<Bound<'_, ::pyo3::PyAny>> {
82        T::runtime_type_object(py)
83    }
84}
85
86macro_rules! impl_builtin {
87    ($ty:ty, $pytype:expr) => {
88        impl PyStubType for $ty {
89            fn type_output() -> TypeInfo {
90                TypeInfo {
91                    name: $pytype.to_string(),
92                    source_module: None,
93                    import: HashSet::new(),
94                    type_refs: HashMap::new(),
95                }
96            }
97        }
98        impl PyRuntimeType for $ty {
99            fn runtime_type_object(py: Python<'_>) -> PyResult<Bound<'_, ::pyo3::PyAny>> {
100                Ok(py.get_type::<$ty>().into_any())
101            }
102        }
103    };
104}
105
106impl_builtin!(PyBool, "bool");
107impl_builtin!(PyInt, "int");
108impl_builtin!(PyFloat, "float");
109impl_builtin!(PyComplex, "complex");
110impl_builtin!(PyList, "list");
111impl_builtin!(PyTuple, "tuple");
112impl_builtin!(PySlice, "slice");
113impl_builtin!(PyDict, "dict");
114impl_builtin!(PySet, "set");
115impl_builtin!(PyString, "str");
116impl_builtin!(PyByteArray, "bytearray");
117impl_builtin!(PyBytes, "bytes");
118impl_builtin!(PyType, "type");
119
120// PyBackedStr and PyBackedBytes don't have PyTypeInfo, use underlying types
121impl PyStubType for PyBackedStr {
122    fn type_output() -> TypeInfo {
123        TypeInfo {
124            name: "str".to_string(),
125            source_module: None,
126            import: HashSet::new(),
127            type_refs: HashMap::new(),
128        }
129    }
130}
131impl PyRuntimeType for PyBackedStr {
132    fn runtime_type_object(py: Python<'_>) -> PyResult<Bound<'_, ::pyo3::PyAny>> {
133        Ok(py.get_type::<PyString>().into_any())
134    }
135}
136
137impl PyStubType for PyBackedBytes {
138    fn type_output() -> TypeInfo {
139        TypeInfo {
140            name: "bytes".to_string(),
141            source_module: None,
142            import: HashSet::new(),
143            type_refs: HashMap::new(),
144        }
145    }
146}
147impl PyRuntimeType for PyBackedBytes {
148    fn runtime_type_object(py: Python<'_>) -> PyResult<Bound<'_, ::pyo3::PyAny>> {
149        Ok(py.get_type::<PyBytes>().into_any())
150    }
151}
152
153// CompareOp maps to int at stub level but is not a Python type
154impl PyStubType for CompareOp {
155    fn type_output() -> TypeInfo {
156        TypeInfo {
157            name: "int".to_string(),
158            source_module: None,
159            import: HashSet::new(),
160            type_refs: HashMap::new(),
161        }
162    }
163}
164impl PyRuntimeType for CompareOp {
165    fn runtime_type_object(py: Python<'_>) -> PyResult<Bound<'_, ::pyo3::PyAny>> {
166        Ok(py.get_type::<PyInt>().into_any())
167    }
168}
169
170macro_rules! impl_simple {
171    ($ty:ty, $mod:expr, $pytype:expr) => {
172        impl PyStubType for $ty {
173            fn type_output() -> TypeInfo {
174                TypeInfo {
175                    name: concat!($mod, ".", $pytype).to_string(),
176                    source_module: None,
177                    import: hashset! { $mod.into() },
178                    type_refs: HashMap::new(),
179                }
180            }
181        }
182        impl PyRuntimeType for $ty {
183            fn runtime_type_object(py: Python<'_>) -> PyResult<Bound<'_, ::pyo3::PyAny>> {
184                Ok(py.get_type::<$ty>().into_any())
185            }
186        }
187    };
188}
189
190impl_simple!(PyDate, "datetime", "date");
191impl_simple!(PyDateTime, "datetime", "datetime");
192impl_simple!(PyDelta, "datetime", "timedelta");
193impl_simple!(PyTime, "datetime", "time");
194impl_simple!(PyTzInfo, "datetime", "tzinfo");