pyo3_stub_gen/stub_type/
collections.rs

1use crate::stub_type::*;
2use std::collections::{BTreeMap, BTreeSet, HashMap};
3
4/// Extract type identifier from a pre-qualified type name
5///
6/// If the type name is qualified (e.g., "sub_mod.ClassA"), extract the bare identifier.
7/// Returns None if the type is unqualified or is a known builtin/typing type.
8fn extract_type_identifier(type_name: &str) -> Option<&str> {
9    // Check if it contains a dot (qualified name)
10    if let Some(pos) = type_name.rfind('.') {
11        let bare_name = &type_name[pos + 1..];
12        // Skip known typing/builtin modules
13        if type_name.starts_with("typing.") || type_name.starts_with("collections.") {
14            return None;
15        }
16        Some(bare_name)
17    } else {
18        None
19    }
20}
21
22/// Build type_refs HashMap from inner TypeInfo for compound types
23///
24/// If the inner type is locally-defined and qualified, track it for context-aware rendering.
25fn build_type_refs_from_inner(inner: &TypeInfo) -> HashMap<String, TypeIdentifierRef> {
26    let mut type_refs = HashMap::new();
27
28    // If inner type is locally defined with a module, track it
29    if let Some(ref source_module) = inner.source_module {
30        if let Some(_module_name) = source_module.get() {
31            // Extract bare type identifier from the (potentially qualified) name
32            if let Some(bare_name) = extract_type_identifier(&inner.name) {
33                type_refs.insert(
34                    bare_name.to_string(),
35                    TypeIdentifierRef {
36                        module: source_module.clone(),
37                        import_kind: ImportKind::Module,
38                    },
39                );
40            }
41        }
42    }
43
44    // Also inherit any existing type_refs from inner type (for nested compounds)
45    type_refs.extend(inner.type_refs.clone());
46
47    type_refs
48}
49
50impl<T: PyStubType> PyStubType for Option<T> {
51    fn type_input() -> TypeInfo {
52        let inner = T::type_input();
53        let name = inner.name.clone();
54        let mut import = inner.import.clone();
55        import.insert("typing".into());
56
57        let type_refs = build_type_refs_from_inner(&inner);
58
59        TypeInfo {
60            name: format!("typing.Optional[{name}]"),
61            source_module: None,
62            import,
63            type_refs,
64        }
65    }
66    fn type_output() -> TypeInfo {
67        let inner = T::type_output();
68        let name = inner.name.clone();
69        let mut import = inner.import.clone();
70        import.insert("typing".into());
71
72        let type_refs = build_type_refs_from_inner(&inner);
73
74        TypeInfo {
75            name: format!("typing.Optional[{name}]"),
76            source_module: None,
77            import,
78            type_refs,
79        }
80    }
81}
82
83impl<T: PyStubType> PyStubType for Box<T> {
84    fn type_input() -> TypeInfo {
85        T::type_input()
86    }
87    fn type_output() -> TypeInfo {
88        T::type_output()
89    }
90}
91
92impl<T: PyStubType, E> PyStubType for Result<T, E> {
93    fn type_input() -> TypeInfo {
94        T::type_input()
95    }
96    fn type_output() -> TypeInfo {
97        T::type_output()
98    }
99}
100
101impl<T: PyStubType> PyStubType for Vec<T> {
102    fn type_input() -> TypeInfo {
103        let inner = T::type_input();
104        let name = inner.name.clone();
105        let mut import = inner.import.clone();
106        import.insert("typing".into());
107
108        let type_refs = build_type_refs_from_inner(&inner);
109
110        TypeInfo {
111            name: format!("typing.Sequence[{name}]"),
112            source_module: None,
113            import,
114            type_refs,
115        }
116    }
117    fn type_output() -> TypeInfo {
118        TypeInfo::list_of::<T>()
119    }
120}
121
122impl<T: PyStubType, const N: usize> PyStubType for [T; N] {
123    fn type_input() -> TypeInfo {
124        let inner = T::type_input();
125        let name = inner.name.clone();
126        let mut import = inner.import.clone();
127        import.insert("typing".into());
128
129        let type_refs = build_type_refs_from_inner(&inner);
130
131        TypeInfo {
132            name: format!("typing.Sequence[{name}]"),
133            source_module: None,
134            import,
135            type_refs,
136        }
137    }
138    fn type_output() -> TypeInfo {
139        TypeInfo::list_of::<T>()
140    }
141}
142
143impl<T: PyStubType, State> PyStubType for HashSet<T, State> {
144    fn type_output() -> TypeInfo {
145        TypeInfo::set_of::<T>()
146    }
147}
148
149impl<T: PyStubType> PyStubType for BTreeSet<T> {
150    fn type_output() -> TypeInfo {
151        TypeInfo::set_of::<T>()
152    }
153}
154
155impl<T: PyStubType> PyStubType for indexmap::IndexSet<T> {
156    fn type_output() -> TypeInfo {
157        TypeInfo::set_of::<T>()
158    }
159}
160
161macro_rules! impl_map_inner {
162    () => {
163        fn type_input() -> TypeInfo {
164            let key_info = Key::type_input();
165            let value_info = Value::type_input();
166
167            let mut import = key_info.import.clone();
168            import.extend(value_info.import.clone());
169            import.insert("typing".into());
170
171            let mut type_refs = build_type_refs_from_inner(&key_info);
172            type_refs.extend(build_type_refs_from_inner(&value_info));
173
174            TypeInfo {
175                name: format!("typing.Mapping[{}, {}]", key_info.name, value_info.name),
176                source_module: None,
177                import,
178                type_refs,
179            }
180        }
181        fn type_output() -> TypeInfo {
182            let key_info = Key::type_output();
183            let value_info = Value::type_output();
184
185            let mut import = key_info.import.clone();
186            import.extend(value_info.import.clone());
187            import.insert("builtins".into());
188
189            let mut type_refs = build_type_refs_from_inner(&key_info);
190            type_refs.extend(build_type_refs_from_inner(&value_info));
191
192            TypeInfo {
193                name: format!("builtins.dict[{}, {}]", key_info.name, value_info.name),
194                source_module: None,
195                import,
196                type_refs,
197            }
198        }
199    };
200}
201
202impl<Key: PyStubType, Value: PyStubType> PyStubType for BTreeMap<Key, Value> {
203    impl_map_inner!();
204}
205
206impl<Key: PyStubType, Value: PyStubType, State> PyStubType for HashMap<Key, Value, State> {
207    impl_map_inner!();
208}
209
210impl<Key: PyStubType, Value: PyStubType, State> PyStubType
211    for indexmap::IndexMap<Key, Value, State>
212{
213    impl_map_inner!();
214}
215
216macro_rules! impl_tuple {
217    ($($T:ident),*) => {
218        impl<$($T: PyStubType),*> PyStubType for ($($T),* ,) {
219            fn type_output() -> TypeInfo {
220                let mut merged = HashSet::new();
221                let mut names = Vec::new();
222                let mut type_refs = HashMap::new();
223                $(
224                let info = $T::type_output();
225                type_refs.extend(build_type_refs_from_inner(&info));
226                names.push(info.name);
227                merged.extend(info.import);
228                )*
229                TypeInfo {
230                    name: format!("tuple[{}]", names.join(", ")),
231                    source_module: None,
232                    import: merged,
233                    type_refs,
234                }
235            }
236            fn type_input() -> TypeInfo {
237                let mut merged = HashSet::new();
238                let mut names = Vec::new();
239                let mut type_refs = HashMap::new();
240                $(
241                let info = $T::type_input();
242                type_refs.extend(build_type_refs_from_inner(&info));
243                names.push(info.name);
244                merged.extend(info.import);
245                )*
246                TypeInfo {
247                    name: format!("tuple[{}]", names.join(", ")),
248                    source_module: None,
249                    import: merged,
250                    type_refs,
251                }
252            }
253        }
254    };
255}
256
257impl_tuple!(T1);
258impl_tuple!(T1, T2);
259impl_tuple!(T1, T2, T3);
260impl_tuple!(T1, T2, T3, T4);
261impl_tuple!(T1, T2, T3, T4, T5);
262impl_tuple!(T1, T2, T3, T4, T5, T6);
263impl_tuple!(T1, T2, T3, T4, T5, T6, T7);
264impl_tuple!(T1, T2, T3, T4, T5, T6, T7, T8);
265impl_tuple!(T1, T2, T3, T4, T5, T6, T7, T8, T9);