pyo3_stub_gen/stub_type/
either.rs

1use std::collections::{HashMap, HashSet};
2
3use ::pyo3::prelude::*;
4
5use super::{ImportKind, PyStubType, TypeIdentifierRef, TypeInfo};
6use crate::runtime::PyRuntimeType;
7
8/// Extract type identifier from a pre-qualified type name
9fn extract_type_identifier(type_name: &str) -> Option<&str> {
10    if let Some(pos) = type_name.rfind('.') {
11        let bare_name = &type_name[pos + 1..];
12        if type_name.starts_with("typing.") || type_name.starts_with("collections.") {
13            return None;
14        }
15        Some(bare_name)
16    } else {
17        None
18    }
19}
20
21/// Build type_refs HashMap from inner TypeInfo for compound types
22fn build_type_refs_from_inner(inner: &TypeInfo) -> HashMap<String, TypeIdentifierRef> {
23    let mut type_refs = HashMap::new();
24    if let Some(ref source_module) = inner.source_module {
25        if let Some(_module_name) = source_module.get() {
26            if let Some(bare_name) = extract_type_identifier(&inner.name) {
27                type_refs.insert(
28                    bare_name.to_string(),
29                    TypeIdentifierRef {
30                        module: source_module.clone(),
31                        import_kind: ImportKind::Module,
32                    },
33                );
34            }
35        }
36    }
37    type_refs.extend(inner.type_refs.clone());
38    type_refs
39}
40
41impl<L: PyStubType, R: PyStubType> PyStubType for either::Either<L, R> {
42    fn type_input() -> TypeInfo {
43        let info_l = L::type_input();
44        let info_r = R::type_input();
45
46        let mut import: HashSet<_> = info_l
47            .import
48            .iter()
49            .cloned()
50            .chain(info_r.import.iter().cloned())
51            .collect();
52        import.insert("typing".into());
53
54        let mut type_refs = build_type_refs_from_inner(&info_l);
55        type_refs.extend(build_type_refs_from_inner(&info_r));
56
57        TypeInfo {
58            name: format!("typing.Union[{}, {}]", info_l.name, info_r.name),
59            source_module: None,
60            import,
61            type_refs,
62        }
63    }
64    fn type_output() -> TypeInfo {
65        let info_l = L::type_output();
66        let info_r = R::type_output();
67
68        let mut import: HashSet<_> = info_l
69            .import
70            .iter()
71            .cloned()
72            .chain(info_r.import.iter().cloned())
73            .collect();
74        import.insert("typing".into());
75
76        let mut type_refs = build_type_refs_from_inner(&info_l);
77        type_refs.extend(build_type_refs_from_inner(&info_r));
78
79        TypeInfo {
80            name: format!("typing.Union[{}, {}]", info_l.name, info_r.name),
81            source_module: None,
82            import,
83            type_refs,
84        }
85    }
86}
87impl<L: PyRuntimeType, R: PyRuntimeType> PyRuntimeType for either::Either<L, R> {
88    fn runtime_type_object(py: Python<'_>) -> PyResult<Bound<'_, PyAny>> {
89        let l_type = L::runtime_type_object(py)?;
90        let r_type = R::runtime_type_object(py)?;
91        crate::runtime::union_type(py, &[l_type, r_type])
92    }
93}