1mod builtins;
2mod collections;
3mod pyo3;
4
5#[cfg(feature = "numpy")]
6mod numpy;
7
8#[cfg(feature = "either")]
9mod either;
10
11#[cfg(feature = "rust_decimal")]
12mod rust_decimal;
13
14use maplit::hashset;
15use std::cmp::Ordering;
16use std::{
17 collections::{HashMap, HashSet},
18 fmt, ops,
19};
20
21#[derive(Debug, Clone, PartialEq, Eq, Hash)]
25pub enum ImportRef {
26 Module(ModuleRef),
27 Type(TypeRef),
28}
29
30impl From<&str> for ImportRef {
31 fn from(value: &str) -> Self {
32 ImportRef::Module(value.into())
33 }
34}
35
36impl PartialOrd for ImportRef {
37 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
38 Some(self.cmp(other))
39 }
40}
41
42impl Ord for ImportRef {
43 fn cmp(&self, other: &Self) -> Ordering {
44 match (self, other) {
45 (ImportRef::Module(a), ImportRef::Module(b)) => a.get().cmp(&b.get()),
46 (ImportRef::Type(a), ImportRef::Type(b)) => a.cmp(b),
47 (ImportRef::Module(_), ImportRef::Type(_)) => Ordering::Greater,
48 (ImportRef::Type(_), ImportRef::Module(_)) => Ordering::Less,
49 }
50 }
51}
52
53#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Default, Hash)]
54pub enum ModuleRef {
55 Named(String),
56
57 #[default]
66 Default,
67}
68
69impl ModuleRef {
70 pub fn get(&self) -> Option<&str> {
71 match self {
72 Self::Named(name) => Some(name),
73 Self::Default => None,
74 }
75 }
76}
77
78impl From<&str> for ModuleRef {
79 fn from(s: &str) -> Self {
80 Self::Named(s.to_string())
81 }
82}
83
84#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Default, Hash)]
88pub struct TypeRef {
89 pub module: ModuleRef,
90 pub name: String,
91}
92
93impl TypeRef {
94 pub fn new(module_ref: ModuleRef, name: String) -> Self {
95 Self {
96 module: module_ref,
97 name,
98 }
99 }
100}
101
102#[derive(Debug, Clone, PartialEq, Eq)]
104pub enum ImportKind {
105 ByName,
108 Module,
111 SameModule,
114}
115
116#[derive(Debug, Clone, PartialEq, Eq)]
119pub struct TypeIdentifierRef {
120 pub module: ModuleRef,
122 pub import_kind: ImportKind,
124}
125
126#[derive(Debug, Clone, PartialEq, Eq)]
128pub struct TypeInfo {
129 pub name: String,
131
132 pub source_module: Option<ModuleRef>,
138
139 pub import: HashSet<ImportRef>,
144
145 pub type_refs: HashMap<String, TypeIdentifierRef>,
154}
155
156impl fmt::Display for TypeInfo {
157 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
158 write!(f, "{}", self.name)
159 }
160}
161
162impl TypeInfo {
163 pub fn none() -> Self {
165 Self {
168 name: "None".to_string(),
169 source_module: None,
170 import: HashSet::new(),
171 type_refs: HashMap::new(),
172 }
173 }
174
175 pub fn any() -> Self {
177 Self {
178 name: "typing.Any".to_string(),
179 source_module: None,
180 import: hashset! { "typing".into() },
181 type_refs: HashMap::new(),
182 }
183 }
184
185 pub fn list_of<T: PyStubType>() -> Self {
187 let inner = T::type_output();
188 let mut import = inner.import.clone();
189 import.insert("builtins".into());
190
191 let mut type_refs = HashMap::new();
193 if let Some(ref source_module) = inner.source_module {
194 if let Some(_module_name) = source_module.get() {
195 let bare_name = inner
197 .name
198 .split('[')
199 .next()
200 .unwrap_or(&inner.name)
201 .split('.')
202 .next_back()
203 .unwrap_or(&inner.name);
204 type_refs.insert(
205 bare_name.to_string(),
206 TypeIdentifierRef {
207 module: source_module.clone(),
208 import_kind: ImportKind::Module,
209 },
210 );
211 }
212 }
213 type_refs.extend(inner.type_refs);
214
215 TypeInfo {
216 name: format!("builtins.list[{}]", inner.name),
217 source_module: None,
218 import,
219 type_refs,
220 }
221 }
222
223 pub fn set_of<T: PyStubType>() -> Self {
225 let inner = T::type_output();
226 let mut import = inner.import.clone();
227 import.insert("builtins".into());
228
229 let mut type_refs = HashMap::new();
231 if let Some(ref source_module) = inner.source_module {
232 if let Some(_module_name) = source_module.get() {
233 let bare_name = inner
234 .name
235 .split('[')
236 .next()
237 .unwrap_or(&inner.name)
238 .split('.')
239 .next_back()
240 .unwrap_or(&inner.name);
241 type_refs.insert(
242 bare_name.to_string(),
243 TypeIdentifierRef {
244 module: source_module.clone(),
245 import_kind: ImportKind::Module,
246 },
247 );
248 }
249 }
250 type_refs.extend(inner.type_refs);
251
252 TypeInfo {
253 name: format!("builtins.set[{}]", inner.name),
254 source_module: None,
255 import,
256 type_refs,
257 }
258 }
259
260 pub fn dict_of<K: PyStubType, V: PyStubType>() -> Self {
262 let inner_k = K::type_output();
263 let inner_v = V::type_output();
264 let mut import = inner_k.import.clone();
265 import.extend(inner_v.import.clone());
266 import.insert("builtins".into());
267
268 let mut type_refs = HashMap::new();
270 for inner in [&inner_k, &inner_v] {
271 if let Some(ref source_module) = inner.source_module {
272 if let Some(_module_name) = source_module.get() {
273 let bare_name = inner
274 .name
275 .split('[')
276 .next()
277 .unwrap_or(&inner.name)
278 .split('.')
279 .next_back()
280 .unwrap_or(&inner.name);
281 type_refs.insert(
282 bare_name.to_string(),
283 TypeIdentifierRef {
284 module: source_module.clone(),
285 import_kind: ImportKind::Module,
286 },
287 );
288 }
289 }
290 type_refs.extend(inner.type_refs.clone());
291 }
292
293 TypeInfo {
294 name: format!("builtins.dict[{}, {}]", inner_k.name, inner_v.name),
295 source_module: None,
296 import,
297 type_refs,
298 }
299 }
300
301 pub fn builtin(name: &str) -> Self {
303 Self {
304 name: format!("builtins.{name}"),
305 source_module: None,
306 import: hashset! { "builtins".into() },
307 type_refs: HashMap::new(),
308 }
309 }
310
311 pub fn unqualified(name: &str) -> Self {
313 Self {
314 name: name.to_string(),
315 source_module: None,
316 import: hashset! {},
317 type_refs: HashMap::new(),
318 }
319 }
320
321 pub fn with_module(name: &str, module: ModuleRef) -> Self {
327 let mut import = HashSet::new();
328 import.insert(ImportRef::Module(module.clone()));
329 Self {
330 name: name.to_string(),
331 source_module: Some(module),
332 import,
333 type_refs: HashMap::new(),
334 }
335 }
336
337 pub fn locally_defined(type_name: &str, module: ModuleRef) -> Self {
349 let mut import = HashSet::new();
350 let mut type_refs = HashMap::new();
351
352 let qualified_name = match module.get() {
355 Some(module_name) if !module_name.is_empty() => {
356 let module_component = module_name.rsplit('.').next().unwrap_or(module_name);
359 import.insert(ImportRef::Module(module.clone()));
361
362 type_refs.insert(
364 type_name.to_string(),
365 TypeIdentifierRef {
366 module: module.clone(),
367 import_kind: ImportKind::Module,
368 },
369 );
370
371 format!("{}.{}", module_component, type_name)
372 }
373 _ => {
374 import.insert(ImportRef::Module(module.clone()));
377 type_refs.insert(
378 type_name.to_string(),
379 TypeIdentifierRef {
380 module: module.clone(),
381 import_kind: ImportKind::Module,
382 },
383 );
384 type_name.to_string()
385 }
386 };
387
388 Self {
389 name: qualified_name,
390 source_module: Some(module),
391 import,
392 type_refs,
393 }
394 }
395
396 pub fn qualified_name(&self, target_module: &str) -> String {
407 match &self.source_module {
408 None => self.name.clone(),
409 Some(module_ref) => {
410 let source = module_ref.get().unwrap_or(target_module);
411 if source == target_module {
412 let module_component = source.rsplit('.').next().unwrap_or(source);
415 let prefix = format!("{}.", module_component);
416 if let Some(stripped) = self.name.strip_prefix(&prefix) {
417 stripped.to_string()
418 } else {
419 self.name.clone()
420 }
421 } else {
422 let module_component = source.rsplit('.').next().unwrap_or(source);
424 let prefix = format!("{}.", module_component);
426 let base_name = if let Some(stripped) = self.name.strip_prefix(&prefix) {
427 stripped
428 } else {
429 &self.name
430 };
431 format!("{}.{}", module_component, base_name)
432 }
433 }
434 }
435 }
436
437 pub fn is_same_module(&self, target_module: &str) -> bool {
439 self.source_module.as_ref().and_then(|m| m.get()) == Some(target_module)
440 }
441
442 pub fn is_internal_to_package(&self, package_root: &str) -> bool {
444 match &self.source_module {
445 Some(ModuleRef::Named(path)) => path.starts_with(package_root),
446 Some(ModuleRef::Default) => true,
447 None => false,
448 }
449 }
450
451 pub fn qualified_for_module(&self, target_module: &str) -> String {
464 if self.type_refs.is_empty() {
466 return self.qualified_name(target_module);
467 }
468
469 use crate::generate::qualifier::TypeExpressionQualifier;
471 TypeExpressionQualifier::qualify_expression(&self.name, &self.type_refs, target_module)
472 }
473
474 pub fn resolve_default_module(&mut self, default_module_name: &str) {
477 if let Some(ModuleRef::Default) = &self.source_module {
479 self.source_module = Some(ModuleRef::Named(default_module_name.to_string()));
480
481 let module_component = default_module_name
483 .rsplit('.')
484 .next()
485 .unwrap_or(default_module_name);
486 if !self.name.contains('.') {
487 self.name = format!("{}.{}", module_component, self.name);
488 }
489 }
490
491 let mut new_import = std::collections::HashSet::new();
493 for import_ref in &self.import {
494 match import_ref {
495 ImportRef::Module(ModuleRef::Default) => {
496 new_import.insert(ImportRef::Module(ModuleRef::Named(
497 default_module_name.to_string(),
498 )));
499 }
500 other => {
501 new_import.insert(other.clone());
502 }
503 }
504 }
505 self.import = new_import;
506
507 for type_ref in self.type_refs.values_mut() {
509 if let ModuleRef::Default = &type_ref.module {
510 type_ref.module = ModuleRef::Named(default_module_name.to_string());
511 }
512 }
513 }
514}
515
516impl ops::BitOr for TypeInfo {
517 type Output = Self;
518
519 fn bitor(mut self, rhs: Self) -> Self {
520 self.import.extend(rhs.import);
521 let mut merged_type_refs = self.type_refs.clone();
523 merged_type_refs.extend(rhs.type_refs);
524 Self {
525 name: format!("{} | {}", self.name, rhs.name),
526 source_module: None, import: self.import,
528 type_refs: merged_type_refs,
529 }
530 }
531}
532
533#[macro_export]
563macro_rules! impl_stub_type {
564 ($ty: ty = $($base:ty)|+) => {
565 impl ::pyo3_stub_gen::PyStubType for $ty {
566 fn type_output() -> ::pyo3_stub_gen::TypeInfo {
567 $(<$base>::type_output()) | *
568 }
569 fn type_input() -> ::pyo3_stub_gen::TypeInfo {
570 $(<$base>::type_input()) | *
571 }
572 }
573 };
574 ($ty:ty = $base:ty) => {
575 impl ::pyo3_stub_gen::PyStubType for $ty {
576 fn type_output() -> ::pyo3_stub_gen::TypeInfo {
577 <$base>::type_output()
578 }
579 fn type_input() -> ::pyo3_stub_gen::TypeInfo {
580 <$base>::type_input()
581 }
582 }
583 };
584}
585
586pub trait PyStubType {
588 fn type_output() -> TypeInfo;
590
591 fn type_input() -> TypeInfo {
596 Self::type_output()
597 }
598}
599
600#[cfg(test)]
601mod test {
602 use super::*;
603 use maplit::hashset;
604 use std::collections::HashMap;
605 use test_case::test_case;
606
607 #[test_case(bool::type_input(), "builtins.bool", hashset! { "builtins".into() } ; "bool_input")]
608 #[test_case(<&str>::type_input(), "builtins.str", hashset! { "builtins".into() } ; "str_input")]
609 #[test_case(Vec::<u32>::type_input(), "typing.Sequence[builtins.int]", hashset! { "typing".into(), "builtins".into() } ; "Vec_u32_input")]
610 #[test_case(Vec::<u32>::type_output(), "builtins.list[builtins.int]", hashset! { "builtins".into() } ; "Vec_u32_output")]
611 #[test_case(HashMap::<u32, String>::type_input(), "typing.Mapping[builtins.int, builtins.str]", hashset! { "typing".into(), "builtins".into() } ; "HashMap_u32_String_input")]
612 #[test_case(HashMap::<u32, String>::type_output(), "builtins.dict[builtins.int, builtins.str]", hashset! { "builtins".into() } ; "HashMap_u32_String_output")]
613 #[test_case(indexmap::IndexMap::<u32, String>::type_input(), "typing.Mapping[builtins.int, builtins.str]", hashset! { "typing".into(), "builtins".into() } ; "IndexMap_u32_String_input")]
614 #[test_case(indexmap::IndexMap::<u32, String>::type_output(), "builtins.dict[builtins.int, builtins.str]", hashset! { "builtins".into() } ; "IndexMap_u32_String_output")]
615 #[test_case(HashMap::<u32, Vec<u32>>::type_input(), "typing.Mapping[builtins.int, typing.Sequence[builtins.int]]", hashset! { "builtins".into(), "typing".into() } ; "HashMap_u32_Vec_u32_input")]
616 #[test_case(HashMap::<u32, Vec<u32>>::type_output(), "builtins.dict[builtins.int, builtins.list[builtins.int]]", hashset! { "builtins".into() } ; "HashMap_u32_Vec_u32_output")]
617 #[test_case(HashSet::<u32>::type_input(), "builtins.set[builtins.int]", hashset! { "builtins".into() } ; "HashSet_u32_input")]
618 #[test_case(indexmap::IndexSet::<u32>::type_input(), "builtins.set[builtins.int]", hashset! { "builtins".into() } ; "IndexSet_u32_input")]
619 #[test_case(TypeInfo::dict_of::<u32, String>(), "builtins.dict[builtins.int, builtins.str]", hashset! { "builtins".into() } ; "dict_of_u32_String")]
620 fn test(tinfo: TypeInfo, name: &str, import: HashSet<ImportRef>) {
621 assert_eq!(tinfo.name, name);
622 if import.is_empty() {
623 assert!(tinfo.import.is_empty());
624 } else {
625 assert_eq!(tinfo.import, import);
626 }
627 }
628}