pyo3_stub_gen/stub_type/
collections.rs

1use crate::runtime::PyRuntimeType;
2use crate::stub_type::*;
3use ::pyo3::types::{PyList, PyNone};
4use ::pyo3::{Bound, PyAny, PyResult, Python};
5use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet};
6
7/// Extract type identifier from a pre-qualified type name
8///
9/// If the type name is qualified (e.g., "sub_mod.ClassA"), extract the bare identifier.
10/// Returns None if the type is unqualified or is a known builtin/typing type.
11fn extract_type_identifier(type_name: &str) -> Option<&str> {
12    // Check if it contains a dot (qualified name)
13    if let Some(pos) = type_name.rfind('.') {
14        let bare_name = &type_name[pos + 1..];
15        // Skip known typing/builtin modules
16        if type_name.starts_with("typing.") || type_name.starts_with("collections.") {
17            return None;
18        }
19        Some(bare_name)
20    } else {
21        None
22    }
23}
24
25/// Build type_refs HashMap from inner TypeInfo for compound types
26///
27/// If the inner type is locally-defined and qualified, track it for context-aware rendering.
28fn build_type_refs_from_inner(inner: &TypeInfo) -> HashMap<String, TypeIdentifierRef> {
29    let mut type_refs = HashMap::new();
30
31    // If inner type is locally defined with a module, track it
32    if let Some(ref source_module) = inner.source_module {
33        if let Some(_module_name) = source_module.get() {
34            // Extract bare type identifier from the (potentially qualified) name
35            if let Some(bare_name) = extract_type_identifier(&inner.name) {
36                type_refs.insert(
37                    bare_name.to_string(),
38                    TypeIdentifierRef {
39                        module: source_module.clone(),
40                        import_kind: ImportKind::Module,
41                    },
42                );
43            }
44        }
45    }
46
47    // Also inherit any existing type_refs from inner type (for nested compounds)
48    type_refs.extend(inner.type_refs.clone());
49
50    type_refs
51}
52
53impl<T: PyStubType> PyStubType for Option<T> {
54    fn type_input() -> TypeInfo {
55        let inner = T::type_input();
56        let name = inner.name.clone();
57        let mut import = inner.import.clone();
58        import.insert("typing".into());
59
60        let type_refs = build_type_refs_from_inner(&inner);
61
62        TypeInfo {
63            name: format!("typing.Optional[{name}]"),
64            source_module: None,
65            import,
66            type_refs,
67        }
68    }
69    fn type_output() -> TypeInfo {
70        let inner = T::type_output();
71        let name = inner.name.clone();
72        let mut import = inner.import.clone();
73        import.insert("typing".into());
74
75        let type_refs = build_type_refs_from_inner(&inner);
76
77        TypeInfo {
78            name: format!("typing.Optional[{name}]"),
79            source_module: None,
80            import,
81            type_refs,
82        }
83    }
84}
85impl<T: PyRuntimeType> PyRuntimeType for Option<T> {
86    fn runtime_type_object(py: Python<'_>) -> PyResult<Bound<'_, PyAny>> {
87        // Option<T> maps to T | None at runtime
88        let inner_type = T::runtime_type_object(py)?;
89        let none_type = py.get_type::<PyNone>().into_any();
90        crate::runtime::union_type(py, &[inner_type, none_type])
91    }
92}
93
94impl<T: PyStubType> PyStubType for Box<T> {
95    fn type_input() -> TypeInfo {
96        T::type_input()
97    }
98    fn type_output() -> TypeInfo {
99        T::type_output()
100    }
101}
102impl<T: PyRuntimeType> PyRuntimeType for Box<T> {
103    fn runtime_type_object(py: Python<'_>) -> PyResult<Bound<'_, PyAny>> {
104        T::runtime_type_object(py)
105    }
106}
107
108impl<T: PyStubType, E> PyStubType for Result<T, E> {
109    fn type_input() -> TypeInfo {
110        T::type_input()
111    }
112    fn type_output() -> TypeInfo {
113        T::type_output()
114    }
115}
116impl<T: PyRuntimeType, E> PyRuntimeType for Result<T, E> {
117    fn runtime_type_object(py: Python<'_>) -> PyResult<Bound<'_, PyAny>> {
118        T::runtime_type_object(py)
119    }
120}
121
122impl<T: PyStubType> PyStubType for Vec<T> {
123    fn type_input() -> TypeInfo {
124        let inner = T::type_input();
125        let name = inner.name.clone();
126        let mut import = inner.import.clone();
127        import.insert("typing".into());
128
129        let type_refs = build_type_refs_from_inner(&inner);
130
131        TypeInfo {
132            name: format!("typing.Sequence[{name}]"),
133            source_module: None,
134            import,
135            type_refs,
136        }
137    }
138    fn type_output() -> TypeInfo {
139        TypeInfo::list_of::<T>()
140    }
141}
142impl<T> PyRuntimeType for Vec<T> {
143    fn runtime_type_object(py: Python<'_>) -> PyResult<Bound<'_, PyAny>> {
144        // Vec<T> maps to list at runtime (without generic parameter)
145        Ok(py.get_type::<PyList>().into_any())
146    }
147}
148
149impl<T: PyStubType, const N: usize> PyStubType for [T; N] {
150    fn type_input() -> TypeInfo {
151        let inner = T::type_input();
152        let name = inner.name.clone();
153        let mut import = inner.import.clone();
154        import.insert("typing".into());
155
156        let type_refs = build_type_refs_from_inner(&inner);
157
158        TypeInfo {
159            name: format!("typing.Sequence[{name}]"),
160            source_module: None,
161            import,
162            type_refs,
163        }
164    }
165    fn type_output() -> TypeInfo {
166        TypeInfo::list_of::<T>()
167    }
168}
169impl<T, const N: usize> PyRuntimeType for [T; N] {
170    fn runtime_type_object(py: Python<'_>) -> PyResult<Bound<'_, PyAny>> {
171        Ok(py.get_type::<PyList>().into_any())
172    }
173}
174
175impl<T: PyStubType, State> PyStubType for HashSet<T, State> {
176    fn type_output() -> TypeInfo {
177        TypeInfo::set_of::<T>()
178    }
179}
180impl<T, State> PyRuntimeType for HashSet<T, State> {
181    fn runtime_type_object(py: Python<'_>) -> PyResult<Bound<'_, PyAny>> {
182        Ok(py.get_type::<::pyo3::types::PySet>().into_any())
183    }
184}
185
186impl<T: PyStubType> PyStubType for BTreeSet<T> {
187    fn type_output() -> TypeInfo {
188        TypeInfo::set_of::<T>()
189    }
190}
191impl<T> PyRuntimeType for BTreeSet<T> {
192    fn runtime_type_object(py: Python<'_>) -> PyResult<Bound<'_, PyAny>> {
193        Ok(py.get_type::<::pyo3::types::PySet>().into_any())
194    }
195}
196
197impl<T: PyStubType> PyStubType for indexmap::IndexSet<T> {
198    fn type_output() -> TypeInfo {
199        TypeInfo::set_of::<T>()
200    }
201}
202impl<T> PyRuntimeType for indexmap::IndexSet<T> {
203    fn runtime_type_object(py: Python<'_>) -> PyResult<Bound<'_, PyAny>> {
204        Ok(py.get_type::<::pyo3::types::PySet>().into_any())
205    }
206}
207
208macro_rules! impl_map_stub_type {
209    () => {
210        fn type_input() -> TypeInfo {
211            let key_info = Key::type_input();
212            let value_info = Value::type_input();
213
214            let mut import = key_info.import.clone();
215            import.extend(value_info.import.clone());
216            import.insert("typing".into());
217
218            let mut type_refs = build_type_refs_from_inner(&key_info);
219            type_refs.extend(build_type_refs_from_inner(&value_info));
220
221            TypeInfo {
222                name: format!("typing.Mapping[{}, {}]", key_info.name, value_info.name),
223                source_module: None,
224                import,
225                type_refs,
226            }
227        }
228        fn type_output() -> TypeInfo {
229            let key_info = Key::type_output();
230            let value_info = Value::type_output();
231
232            let mut import = key_info.import.clone();
233            import.extend(value_info.import.clone());
234            import.insert("builtins".into());
235
236            let mut type_refs = build_type_refs_from_inner(&key_info);
237            type_refs.extend(build_type_refs_from_inner(&value_info));
238
239            TypeInfo {
240                name: format!("builtins.dict[{}, {}]", key_info.name, value_info.name),
241                source_module: None,
242                import,
243                type_refs,
244            }
245        }
246    };
247}
248
249impl<Key: PyStubType, Value: PyStubType> PyStubType for BTreeMap<Key, Value> {
250    impl_map_stub_type!();
251}
252impl<Key, Value> PyRuntimeType for BTreeMap<Key, Value> {
253    fn runtime_type_object(py: Python<'_>) -> PyResult<Bound<'_, PyAny>> {
254        Ok(py.get_type::<::pyo3::types::PyDict>().into_any())
255    }
256}
257
258impl<Key: PyStubType, Value: PyStubType, State> PyStubType for HashMap<Key, Value, State> {
259    impl_map_stub_type!();
260}
261impl<Key, Value, State> PyRuntimeType for HashMap<Key, Value, State> {
262    fn runtime_type_object(py: Python<'_>) -> PyResult<Bound<'_, PyAny>> {
263        Ok(py.get_type::<::pyo3::types::PyDict>().into_any())
264    }
265}
266
267impl<Key: PyStubType, Value: PyStubType, State> PyStubType
268    for indexmap::IndexMap<Key, Value, State>
269{
270    impl_map_stub_type!();
271}
272impl<Key, Value, State> PyRuntimeType for indexmap::IndexMap<Key, Value, State> {
273    fn runtime_type_object(py: Python<'_>) -> PyResult<Bound<'_, PyAny>> {
274        Ok(py.get_type::<::pyo3::types::PyDict>().into_any())
275    }
276}
277
278macro_rules! impl_tuple_stub_type {
279    ($($T:ident),*) => {
280        impl<$($T: PyStubType),*> PyStubType for ($($T),* ,) {
281            fn type_output() -> TypeInfo {
282                let mut merged = HashSet::new();
283                let mut names = Vec::new();
284                let mut type_refs = HashMap::new();
285                $(
286                let info = $T::type_output();
287                type_refs.extend(build_type_refs_from_inner(&info));
288                names.push(info.name);
289                merged.extend(info.import);
290                )*
291                TypeInfo {
292                    name: format!("tuple[{}]", names.join(", ")),
293                    source_module: None,
294                    import: merged,
295                    type_refs,
296                }
297            }
298            fn type_input() -> TypeInfo {
299                let mut merged = HashSet::new();
300                let mut names = Vec::new();
301                let mut type_refs = HashMap::new();
302                $(
303                let info = $T::type_input();
304                type_refs.extend(build_type_refs_from_inner(&info));
305                names.push(info.name);
306                merged.extend(info.import);
307                )*
308                TypeInfo {
309                    name: format!("tuple[{}]", names.join(", ")),
310                    source_module: None,
311                    import: merged,
312                    type_refs,
313                }
314            }
315        }
316        impl<$($T),*> PyRuntimeType for ($($T),* ,) {
317            fn runtime_type_object(py: Python<'_>) -> PyResult<Bound<'_, PyAny>> {
318                Ok(py.get_type::<::pyo3::types::PyTuple>().into_any())
319            }
320        }
321    };
322}
323
324impl_tuple_stub_type!(T1);
325impl_tuple_stub_type!(T1, T2);
326impl_tuple_stub_type!(T1, T2, T3);
327impl_tuple_stub_type!(T1, T2, T3, T4);
328impl_tuple_stub_type!(T1, T2, T3, T4, T5);
329impl_tuple_stub_type!(T1, T2, T3, T4, T5, T6);
330impl_tuple_stub_type!(T1, T2, T3, T4, T5, T6, T7);
331impl_tuple_stub_type!(T1, T2, T3, T4, T5, T6, T7, T8);
332impl_tuple_stub_type!(T1, T2, T3, T4, T5, T6, T7, T8, T9);