1mod builtins;
2mod collections;
3mod pyo3;
4
5#[cfg(feature = "numpy")]
6mod numpy;
7
8#[cfg(feature = "either")]
9mod either;
10
11use maplit::hashset;
12use std::{collections::HashSet, fmt, ops};
13
14#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Default, Hash)]
15pub enum ModuleRef {
16 Named(String),
17
18 #[default]
27 Default,
28}
29
30impl ModuleRef {
31 pub fn get(&self) -> Option<&str> {
32 match self {
33 Self::Named(name) => Some(name),
34 Self::Default => None,
35 }
36 }
37}
38
39impl From<&str> for ModuleRef {
40 fn from(s: &str) -> Self {
41 Self::Named(s.to_string())
42 }
43}
44
45#[derive(Debug, Clone, PartialEq, Eq)]
47pub struct TypeInfo {
48 pub name: String,
50
51 pub import: HashSet<ModuleRef>,
56}
57
58impl fmt::Display for TypeInfo {
59 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
60 write!(f, "{}", self.name)
61 }
62}
63
64impl TypeInfo {
65 pub fn none() -> Self {
67 Self {
70 name: "None".to_string(),
71 import: HashSet::new(),
72 }
73 }
74
75 pub fn any() -> Self {
77 Self {
78 name: "typing.Any".to_string(),
79 import: hashset! { "typing".into() },
80 }
81 }
82
83 pub fn list_of<T: PyStubType>() -> Self {
85 let TypeInfo { name, mut import } = T::type_output();
86 import.insert("builtins".into());
87 TypeInfo {
88 name: format!("builtins.list[{}]", name),
89 import,
90 }
91 }
92
93 pub fn set_of<T: PyStubType>() -> Self {
95 let TypeInfo { name, mut import } = T::type_output();
96 import.insert("builtins".into());
97 TypeInfo {
98 name: format!("builtins.set[{}]", name),
99 import,
100 }
101 }
102
103 pub fn dict_of<K: PyStubType, V: PyStubType>() -> Self {
105 let TypeInfo {
106 name: name_k,
107 mut import,
108 } = K::type_output();
109 let TypeInfo {
110 name: name_v,
111 import: import_v,
112 } = V::type_output();
113 import.extend(import_v);
114 import.insert("builtins".into());
115 TypeInfo {
116 name: format!("builtins.set[{}, {}]", name_k, name_v),
117 import,
118 }
119 }
120
121 pub fn builtin(name: &str) -> Self {
123 Self {
124 name: format!("builtins.{name}"),
125 import: hashset! { "builtins".into() },
126 }
127 }
128
129 pub fn unqualified(name: &str) -> Self {
131 Self {
132 name: name.to_string(),
133 import: hashset! {},
134 }
135 }
136
137 pub fn with_module(name: &str, module: ModuleRef) -> Self {
143 let mut import = HashSet::new();
144 import.insert(module);
145 Self {
146 name: name.to_string(),
147 import,
148 }
149 }
150}
151
152impl ops::BitOr for TypeInfo {
153 type Output = Self;
154
155 fn bitor(mut self, rhs: Self) -> Self {
156 self.import.extend(rhs.import);
157 Self {
158 name: format!("{} | {}", self.name, rhs.name),
159 import: self.import,
160 }
161 }
162}
163
164#[macro_export]
194macro_rules! impl_stub_type {
195 ($ty: ty = $($base:ty)|+) => {
196 impl ::pyo3_stub_gen::PyStubType for $ty {
197 fn type_output() -> ::pyo3_stub_gen::TypeInfo {
198 $(<$base>::type_output()) | *
199 }
200 fn type_input() -> ::pyo3_stub_gen::TypeInfo {
201 $(<$base>::type_input()) | *
202 }
203 }
204 };
205 ($ty:ty = $base:ty) => {
206 impl ::pyo3_stub_gen::PyStubType for $ty {
207 fn type_output() -> ::pyo3_stub_gen::TypeInfo {
208 <$base>::type_output()
209 }
210 fn type_input() -> ::pyo3_stub_gen::TypeInfo {
211 <$base>::type_input()
212 }
213 }
214 };
215}
216
217pub trait PyStubType {
219 fn type_output() -> TypeInfo;
221
222 fn type_input() -> TypeInfo {
227 Self::type_output()
228 }
229}
230
231#[cfg(test)]
232mod test {
233 use super::*;
234 use maplit::hashset;
235 use std::collections::HashMap;
236 use test_case::test_case;
237
238 #[test_case(bool::type_input(), "builtins.bool", hashset! { "builtins".into() } ; "bool_input")]
239 #[test_case(<&str>::type_input(), "builtins.str", hashset! { "builtins".into() } ; "str_input")]
240 #[test_case(Vec::<u32>::type_input(), "typing.Sequence[builtins.int]", hashset! { "typing".into(), "builtins".into() } ; "Vec_u32_input")]
241 #[test_case(Vec::<u32>::type_output(), "builtins.list[builtins.int]", hashset! { "builtins".into() } ; "Vec_u32_output")]
242 #[test_case(HashMap::<u32, String>::type_input(), "typing.Mapping[builtins.int, builtins.str]", hashset! { "typing".into(), "builtins".into() } ; "HashMap_u32_String_input")]
243 #[test_case(HashMap::<u32, String>::type_output(), "builtins.dict[builtins.int, builtins.str]", hashset! { "builtins".into() } ; "HashMap_u32_String_output")]
244 #[test_case(indexmap::IndexMap::<u32, String>::type_input(), "typing.Mapping[builtins.int, builtins.str]", hashset! { "typing".into(), "builtins".into() } ; "IndexMap_u32_String_input")]
245 #[test_case(indexmap::IndexMap::<u32, String>::type_output(), "builtins.dict[builtins.int, builtins.str]", hashset! { "builtins".into() } ; "IndexMap_u32_String_output")]
246 #[test_case(HashMap::<u32, Vec<u32>>::type_input(), "typing.Mapping[builtins.int, typing.Sequence[builtins.int]]", hashset! { "builtins".into(), "typing".into() } ; "HashMap_u32_Vec_u32_input")]
247 #[test_case(HashMap::<u32, Vec<u32>>::type_output(), "builtins.dict[builtins.int, builtins.list[builtins.int]]", hashset! { "builtins".into() } ; "HashMap_u32_Vec_u32_output")]
248 #[test_case(HashSet::<u32>::type_input(), "builtins.set[builtins.int]", hashset! { "builtins".into() } ; "HashSet_u32_input")]
249 #[test_case(indexmap::IndexSet::<u32>::type_input(), "builtins.set[builtins.int]", hashset! { "builtins".into() } ; "IndexSet_u32_input")]
250 fn test(tinfo: TypeInfo, name: &str, import: HashSet<ModuleRef>) {
251 assert_eq!(tinfo.name, name);
252 if import.is_empty() {
253 assert!(tinfo.import.is_empty());
254 } else {
255 assert_eq!(tinfo.import, import);
256 }
257 }
258}