pyo3_stub_gen/stub_type/
either.rs1use std::collections::{HashMap, HashSet};
2
3use super::{ImportKind, PyStubType, TypeIdentifierRef, TypeInfo};
4
5fn extract_type_identifier(type_name: &str) -> Option<&str> {
7 if let Some(pos) = type_name.rfind('.') {
8 let bare_name = &type_name[pos + 1..];
9 if type_name.starts_with("typing.") || type_name.starts_with("collections.") {
10 return None;
11 }
12 Some(bare_name)
13 } else {
14 None
15 }
16}
17
18fn build_type_refs_from_inner(inner: &TypeInfo) -> HashMap<String, TypeIdentifierRef> {
20 let mut type_refs = HashMap::new();
21 if let Some(ref source_module) = inner.source_module {
22 if let Some(_module_name) = source_module.get() {
23 if let Some(bare_name) = extract_type_identifier(&inner.name) {
24 type_refs.insert(
25 bare_name.to_string(),
26 TypeIdentifierRef {
27 module: source_module.clone(),
28 import_kind: ImportKind::Module,
29 },
30 );
31 }
32 }
33 }
34 type_refs.extend(inner.type_refs.clone());
35 type_refs
36}
37
38impl<L: PyStubType, R: PyStubType> PyStubType for either::Either<L, R> {
39 fn type_input() -> TypeInfo {
40 let info_l = L::type_input();
41 let info_r = R::type_input();
42
43 let mut import: HashSet<_> = info_l
44 .import
45 .iter()
46 .cloned()
47 .chain(info_r.import.iter().cloned())
48 .collect();
49 import.insert("typing".into());
50
51 let mut type_refs = build_type_refs_from_inner(&info_l);
52 type_refs.extend(build_type_refs_from_inner(&info_r));
53
54 TypeInfo {
55 name: format!("typing.Union[{}, {}]", info_l.name, info_r.name),
56 source_module: None,
57 import,
58 type_refs,
59 }
60 }
61 fn type_output() -> TypeInfo {
62 let info_l = L::type_output();
63 let info_r = R::type_output();
64
65 let mut import: HashSet<_> = info_l
66 .import
67 .iter()
68 .cloned()
69 .chain(info_r.import.iter().cloned())
70 .collect();
71 import.insert("typing".into());
72
73 let mut type_refs = build_type_refs_from_inner(&info_l);
74 type_refs.extend(build_type_refs_from_inner(&info_r));
75
76 TypeInfo {
77 name: format!("typing.Union[{}, {}]", info_l.name, info_r.name),
78 source_module: None,
79 import,
80 type_refs,
81 }
82 }
83}