1mod builtins;
2mod collections;
3mod pyo3;
4
5#[cfg(feature = "numpy")]
6mod numpy;
7
8use maplit::hashset;
9use std::{collections::HashSet, fmt, ops};
10
11#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Default, Hash)]
12pub enum ModuleRef {
13 Named(String),
14
15 #[default]
24 Default,
25}
26
27impl ModuleRef {
28 pub fn get(&self) -> Option<&str> {
29 match self {
30 Self::Named(name) => Some(name),
31 Self::Default => None,
32 }
33 }
34}
35
36impl From<&str> for ModuleRef {
37 fn from(s: &str) -> Self {
38 Self::Named(s.to_string())
39 }
40}
41
42#[derive(Debug, Clone, PartialEq, Eq)]
44pub struct TypeInfo {
45 pub name: String,
47
48 pub import: HashSet<ModuleRef>,
53}
54
55impl fmt::Display for TypeInfo {
56 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
57 write!(f, "{}", self.name)
58 }
59}
60
61impl TypeInfo {
62 pub fn none() -> Self {
64 Self {
67 name: "None".to_string(),
68 import: HashSet::new(),
69 }
70 }
71
72 pub fn any() -> Self {
74 Self {
75 name: "typing.Any".to_string(),
76 import: hashset! { "typing".into() },
77 }
78 }
79
80 pub fn list_of<T: PyStubType>() -> Self {
82 let TypeInfo { name, mut import } = T::type_output();
83 import.insert("builtins".into());
84 TypeInfo {
85 name: format!("builtins.list[{}]", name),
86 import,
87 }
88 }
89
90 pub fn set_of<T: PyStubType>() -> Self {
92 let TypeInfo { name, mut import } = T::type_output();
93 import.insert("builtins".into());
94 TypeInfo {
95 name: format!("builtins.set[{}]", name),
96 import,
97 }
98 }
99
100 pub fn dict_of<K: PyStubType, V: PyStubType>() -> Self {
102 let TypeInfo {
103 name: name_k,
104 mut import,
105 } = K::type_output();
106 let TypeInfo {
107 name: name_v,
108 import: import_v,
109 } = V::type_output();
110 import.extend(import_v);
111 import.insert("builtins".into());
112 TypeInfo {
113 name: format!("builtins.set[{}, {}]", name_k, name_v),
114 import,
115 }
116 }
117
118 pub fn builtin(name: &str) -> Self {
120 Self {
121 name: format!("builtins.{name}"),
122 import: hashset! { "builtins".into() },
123 }
124 }
125
126 pub fn unqualified(name: &str) -> Self {
128 Self {
129 name: name.to_string(),
130 import: hashset! {},
131 }
132 }
133
134 pub fn with_module(name: &str, module: ModuleRef) -> Self {
140 let mut import = HashSet::new();
141 import.insert(module);
142 Self {
143 name: name.to_string(),
144 import,
145 }
146 }
147}
148
149impl ops::BitOr for TypeInfo {
150 type Output = Self;
151
152 fn bitor(mut self, rhs: Self) -> Self {
153 self.import.extend(rhs.import);
154 Self {
155 name: format!("{} | {}", self.name, rhs.name),
156 import: self.import,
157 }
158 }
159}
160
161#[macro_export]
191macro_rules! impl_stub_type {
192 ($ty: ty = $($base:ty)|+) => {
193 impl ::pyo3_stub_gen::PyStubType for $ty {
194 fn type_output() -> ::pyo3_stub_gen::TypeInfo {
195 $(<$base>::type_output()) | *
196 }
197 fn type_input() -> ::pyo3_stub_gen::TypeInfo {
198 $(<$base>::type_input()) | *
199 }
200 }
201 };
202 ($ty:ty = $base:ty) => {
203 impl ::pyo3_stub_gen::PyStubType for $ty {
204 fn type_output() -> ::pyo3_stub_gen::TypeInfo {
205 <$base>::type_output()
206 }
207 fn type_input() -> ::pyo3_stub_gen::TypeInfo {
208 <$base>::type_input()
209 }
210 }
211 };
212}
213
214pub trait PyStubType {
216 fn type_output() -> TypeInfo;
218
219 fn type_input() -> TypeInfo {
224 Self::type_output()
225 }
226}
227
228#[cfg(test)]
229mod test {
230 use super::*;
231 use maplit::hashset;
232 use std::collections::HashMap;
233 use test_case::test_case;
234
235 #[test_case(bool::type_input(), "builtins.bool", hashset! { "builtins".into() } ; "bool_input")]
236 #[test_case(<&str>::type_input(), "builtins.str", hashset! { "builtins".into() } ; "str_input")]
237 #[test_case(Vec::<u32>::type_input(), "typing.Sequence[builtins.int]", hashset! { "typing".into(), "builtins".into() } ; "Vec_u32_input")]
238 #[test_case(Vec::<u32>::type_output(), "builtins.list[builtins.int]", hashset! { "builtins".into() } ; "Vec_u32_output")]
239 #[test_case(HashMap::<u32, String>::type_input(), "typing.Mapping[builtins.int, builtins.str]", hashset! { "typing".into(), "builtins".into() } ; "HashMap_u32_String_input")]
240 #[test_case(HashMap::<u32, String>::type_output(), "builtins.dict[builtins.int, builtins.str]", hashset! { "builtins".into() } ; "HashMap_u32_String_output")]
241 #[test_case(indexmap::IndexMap::<u32, String>::type_input(), "typing.Mapping[builtins.int, builtins.str]", hashset! { "typing".into(), "builtins".into() } ; "IndexMap_u32_String_input")]
242 #[test_case(indexmap::IndexMap::<u32, String>::type_output(), "builtins.dict[builtins.int, builtins.str]", hashset! { "builtins".into() } ; "IndexMap_u32_String_output")]
243 #[test_case(HashMap::<u32, Vec<u32>>::type_input(), "typing.Mapping[builtins.int, typing.Sequence[builtins.int]]", hashset! { "builtins".into(), "typing".into() } ; "HashMap_u32_Vec_u32_input")]
244 #[test_case(HashMap::<u32, Vec<u32>>::type_output(), "builtins.dict[builtins.int, builtins.list[builtins.int]]", hashset! { "builtins".into() } ; "HashMap_u32_Vec_u32_output")]
245 #[test_case(HashSet::<u32>::type_input(), "builtins.set[builtins.int]", hashset! { "builtins".into() } ; "HashSet_u32_input")]
246 #[test_case(indexmap::IndexSet::<u32>::type_input(), "builtins.set[builtins.int]", hashset! { "builtins".into() } ; "IndexSet_u32_input")]
247 fn test(tinfo: TypeInfo, name: &str, import: HashSet<ModuleRef>) {
248 assert_eq!(tinfo.name, name);
249 if import.is_empty() {
250 assert!(tinfo.import.is_empty());
251 } else {
252 assert_eq!(tinfo.import, import);
253 }
254 }
255}