pyo3_stub_gen/stub_type/
either.rs1use std::collections::{HashMap, HashSet};
2
3use ::pyo3::prelude::*;
4
5use super::{ImportKind, PyStubType, TypeIdentifierRef, TypeInfo};
6use crate::runtime::PyRuntimeType;
7
8fn extract_type_identifier(type_name: &str) -> Option<&str> {
10 if let Some(pos) = type_name.rfind('.') {
11 let bare_name = &type_name[pos + 1..];
12 if type_name.starts_with("typing.") || type_name.starts_with("collections.") {
13 return None;
14 }
15 Some(bare_name)
16 } else {
17 None
18 }
19}
20
21fn build_type_refs_from_inner(inner: &TypeInfo) -> HashMap<String, TypeIdentifierRef> {
23 let mut type_refs = HashMap::new();
24 if let Some(ref source_module) = inner.source_module {
25 if let Some(_module_name) = source_module.get() {
26 if let Some(bare_name) = extract_type_identifier(&inner.name) {
27 type_refs.insert(
28 bare_name.to_string(),
29 TypeIdentifierRef {
30 module: source_module.clone(),
31 import_kind: ImportKind::Module,
32 },
33 );
34 }
35 }
36 }
37 type_refs.extend(inner.type_refs.clone());
38 type_refs
39}
40
41impl<L: PyStubType, R: PyStubType> PyStubType for either::Either<L, R> {
42 fn type_input() -> TypeInfo {
43 let info_l = L::type_input();
44 let info_r = R::type_input();
45
46 let mut import: HashSet<_> = info_l
47 .import
48 .iter()
49 .cloned()
50 .chain(info_r.import.iter().cloned())
51 .collect();
52 import.insert("typing".into());
53
54 let mut type_refs = build_type_refs_from_inner(&info_l);
55 type_refs.extend(build_type_refs_from_inner(&info_r));
56
57 TypeInfo {
58 name: format!("typing.Union[{}, {}]", info_l.name, info_r.name),
59 source_module: None,
60 import,
61 type_refs,
62 }
63 }
64 fn type_output() -> TypeInfo {
65 let info_l = L::type_output();
66 let info_r = R::type_output();
67
68 let mut import: HashSet<_> = info_l
69 .import
70 .iter()
71 .cloned()
72 .chain(info_r.import.iter().cloned())
73 .collect();
74 import.insert("typing".into());
75
76 let mut type_refs = build_type_refs_from_inner(&info_l);
77 type_refs.extend(build_type_refs_from_inner(&info_r));
78
79 TypeInfo {
80 name: format!("typing.Union[{}, {}]", info_l.name, info_r.name),
81 source_module: None,
82 import,
83 type_refs,
84 }
85 }
86}
87impl<L: PyRuntimeType, R: PyRuntimeType> PyRuntimeType for either::Either<L, R> {
88 fn runtime_type_object(py: Python<'_>) -> PyResult<Bound<'_, PyAny>> {
89 let l_type = L::runtime_type_object(py)?;
90 let r_type = R::runtime_type_object(py)?;
91 crate::runtime::union_type(py, &[l_type, r_type])
92 }
93}