pyo3_stub_gen/stub_type/
either.rs

1use std::collections::{HashMap, HashSet};
2
3use super::{ImportKind, PyStubType, TypeIdentifierRef, TypeInfo};
4
5/// Extract type identifier from a pre-qualified type name
6fn extract_type_identifier(type_name: &str) -> Option<&str> {
7    if let Some(pos) = type_name.rfind('.') {
8        let bare_name = &type_name[pos + 1..];
9        if type_name.starts_with("typing.") || type_name.starts_with("collections.") {
10            return None;
11        }
12        Some(bare_name)
13    } else {
14        None
15    }
16}
17
18/// Build type_refs HashMap from inner TypeInfo for compound types
19fn build_type_refs_from_inner(inner: &TypeInfo) -> HashMap<String, TypeIdentifierRef> {
20    let mut type_refs = HashMap::new();
21    if let Some(ref source_module) = inner.source_module {
22        if let Some(_module_name) = source_module.get() {
23            if let Some(bare_name) = extract_type_identifier(&inner.name) {
24                type_refs.insert(
25                    bare_name.to_string(),
26                    TypeIdentifierRef {
27                        module: source_module.clone(),
28                        import_kind: ImportKind::Module,
29                    },
30                );
31            }
32        }
33    }
34    type_refs.extend(inner.type_refs.clone());
35    type_refs
36}
37
38impl<L: PyStubType, R: PyStubType> PyStubType for either::Either<L, R> {
39    fn type_input() -> TypeInfo {
40        let info_l = L::type_input();
41        let info_r = R::type_input();
42
43        let mut import: HashSet<_> = info_l
44            .import
45            .iter()
46            .cloned()
47            .chain(info_r.import.iter().cloned())
48            .collect();
49        import.insert("typing".into());
50
51        let mut type_refs = build_type_refs_from_inner(&info_l);
52        type_refs.extend(build_type_refs_from_inner(&info_r));
53
54        TypeInfo {
55            name: format!("typing.Union[{}, {}]", info_l.name, info_r.name),
56            source_module: None,
57            import,
58            type_refs,
59        }
60    }
61    fn type_output() -> TypeInfo {
62        let info_l = L::type_output();
63        let info_r = R::type_output();
64
65        let mut import: HashSet<_> = info_l
66            .import
67            .iter()
68            .cloned()
69            .chain(info_r.import.iter().cloned())
70            .collect();
71        import.insert("typing".into());
72
73        let mut type_refs = build_type_refs_from_inner(&info_l);
74        type_refs.extend(build_type_refs_from_inner(&info_r));
75
76        TypeInfo {
77            name: format!("typing.Union[{}, {}]", info_l.name, info_r.name),
78            source_module: None,
79            import,
80            type_refs,
81        }
82    }
83}