pyo3_stub_gen/
stub_type.rs

1mod builtins;
2mod collections;
3mod pyo3;
4
5#[cfg(feature = "numpy")]
6mod numpy;
7
8#[cfg(feature = "either")]
9mod either;
10
11#[cfg(feature = "rust_decimal")]
12mod rust_decimal;
13
14use maplit::hashset;
15use std::cmp::Ordering;
16use std::{
17    collections::{HashMap, HashSet},
18    fmt, ops,
19};
20
21/// Indicates what to import.
22/// Module: The purpose is to import the entire module(eg import builtins).
23/// Type: The purpose is to import the types in the module(eg from moduleX import typeX).
24#[derive(Debug, Clone, PartialEq, Eq, Hash)]
25pub enum ImportRef {
26    Module(ModuleRef),
27    Type(TypeRef),
28}
29
30impl From<&str> for ImportRef {
31    fn from(value: &str) -> Self {
32        ImportRef::Module(value.into())
33    }
34}
35
36impl PartialOrd for ImportRef {
37    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
38        Some(self.cmp(other))
39    }
40}
41
42impl Ord for ImportRef {
43    fn cmp(&self, other: &Self) -> Ordering {
44        match (self, other) {
45            (ImportRef::Module(a), ImportRef::Module(b)) => a.get().cmp(&b.get()),
46            (ImportRef::Type(a), ImportRef::Type(b)) => a.cmp(b),
47            (ImportRef::Module(_), ImportRef::Type(_)) => Ordering::Greater,
48            (ImportRef::Type(_), ImportRef::Module(_)) => Ordering::Less,
49        }
50    }
51}
52
53#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Default, Hash)]
54pub enum ModuleRef {
55    Named(String),
56
57    /// Default module that PyO3 creates.
58    ///
59    /// - For pure Rust project, the default module name is the crate name specified in `Cargo.toml`
60    ///   or `project.name` specified in `pyproject.toml`
61    /// - For mixed Rust/Python project, the default module name is `tool.maturin.module-name` specified in `pyproject.toml`
62    ///
63    /// Because the default module name cannot be known at compile time, it will be resolved at the time of the stub file generation.
64    /// This is a placeholder for the default module name.
65    #[default]
66    Default,
67}
68
69impl ModuleRef {
70    pub fn get(&self) -> Option<&str> {
71        match self {
72            Self::Named(name) => Some(name),
73            Self::Default => None,
74        }
75    }
76}
77
78impl From<&str> for ModuleRef {
79    fn from(s: &str) -> Self {
80        Self::Named(s.to_string())
81    }
82}
83
84/// Indicates the type of import(eg class enum).
85/// from module import type.
86/// name, type name. module, module name(which type defined).
87#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Default, Hash)]
88pub struct TypeRef {
89    pub module: ModuleRef,
90    pub name: String,
91}
92
93impl TypeRef {
94    pub fn new(module_ref: ModuleRef, name: String) -> Self {
95        Self {
96            module: module_ref,
97            name,
98        }
99    }
100}
101
102/// Represents how a type identifier should be qualified in stub files.
103#[derive(Debug, Clone, PartialEq, Eq)]
104pub enum ImportKind {
105    /// Type is imported by name (from module import Type).
106    /// It can be used unqualified in the target module.
107    ByName,
108    /// Type is from a module import (from package import module).
109    /// It must be qualified as module.Type.
110    Module,
111    /// Type is defined in the same module as the usage site.
112    /// It can be used unqualified.
113    SameModule,
114}
115
116/// Represents a reference to a type identifier within a compound type expression.
117/// Tracks which module the type comes from and how it should be qualified.
118#[derive(Debug, Clone, PartialEq, Eq)]
119pub struct TypeIdentifierRef {
120    /// The module where this type is defined.
121    pub module: ModuleRef,
122    /// How this type should be qualified in stub files.
123    pub import_kind: ImportKind,
124}
125
126/// Type information for creating Python stub files annotated by [PyStubType] trait.
127#[derive(Debug, Clone, PartialEq, Eq)]
128pub struct TypeInfo {
129    /// The Python type name.
130    pub name: String,
131
132    /// The module this type belongs to.
133    ///
134    /// - `None`: Type has no source module (e.g., `typing.Any`, primitives, generic container types)
135    /// - `Some(ModuleRef::Default)`: Type from current package's default module
136    /// - `Some(ModuleRef::Named(path))`: Type from specific module (e.g., `"package.sub_mod"`)
137    pub source_module: Option<ModuleRef>,
138
139    /// Python modules must be imported in the stub file.
140    ///
141    /// For example, when `name` is `typing.Sequence[int]`, `import` should contain `typing`.
142    /// This makes it possible to use user-defined types in the stub file.
143    pub import: HashSet<ImportRef>,
144
145    /// Track all type identifiers referenced in the name expression.
146    ///
147    /// This enables context-aware qualification of identifiers within compound type expressions.
148    /// For example, in `typing.Optional[ClassA]`, we need to track that `ClassA` is from a specific module
149    /// and qualify it appropriately based on the target module context.
150    ///
151    /// - Key: bare identifier (e.g., "ClassA")
152    /// - Value: TypeIdentifierRef containing module and import kind
153    pub type_refs: HashMap<String, TypeIdentifierRef>,
154}
155
156impl fmt::Display for TypeInfo {
157    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
158        write!(f, "{}", self.name)
159    }
160}
161
162impl TypeInfo {
163    /// A `None` type annotation.
164    pub fn none() -> Self {
165        // NOTE: since 3.10, NoneType is provided from types module,
166        // but there is no corresponding definitions prior to 3.10.
167        Self {
168            name: "None".to_string(),
169            source_module: None,
170            import: HashSet::new(),
171            type_refs: HashMap::new(),
172        }
173    }
174
175    /// A `typing.Any` type annotation.
176    pub fn any() -> Self {
177        Self {
178            name: "typing.Any".to_string(),
179            source_module: None,
180            import: hashset! { "typing".into() },
181            type_refs: HashMap::new(),
182        }
183    }
184
185    /// A `list[Type]` type annotation.
186    pub fn list_of<T: PyStubType>() -> Self {
187        let inner = T::type_output();
188        let mut import = inner.import.clone();
189        import.insert("builtins".into());
190
191        // Build type_refs from inner type
192        let mut type_refs = HashMap::new();
193        if let Some(ref source_module) = inner.source_module {
194            if let Some(_module_name) = source_module.get() {
195                // Extract bare type identifier from the (potentially qualified) name
196                let bare_name = inner
197                    .name
198                    .split('[')
199                    .next()
200                    .unwrap_or(&inner.name)
201                    .split('.')
202                    .next_back()
203                    .unwrap_or(&inner.name);
204                type_refs.insert(
205                    bare_name.to_string(),
206                    TypeIdentifierRef {
207                        module: source_module.clone(),
208                        import_kind: ImportKind::Module,
209                    },
210                );
211            }
212        }
213        type_refs.extend(inner.type_refs);
214
215        TypeInfo {
216            name: format!("builtins.list[{}]", inner.name),
217            source_module: None,
218            import,
219            type_refs,
220        }
221    }
222
223    /// A `set[Type]` type annotation.
224    pub fn set_of<T: PyStubType>() -> Self {
225        let inner = T::type_output();
226        let mut import = inner.import.clone();
227        import.insert("builtins".into());
228
229        // Build type_refs from inner type
230        let mut type_refs = HashMap::new();
231        if let Some(ref source_module) = inner.source_module {
232            if let Some(_module_name) = source_module.get() {
233                let bare_name = inner
234                    .name
235                    .split('[')
236                    .next()
237                    .unwrap_or(&inner.name)
238                    .split('.')
239                    .next_back()
240                    .unwrap_or(&inner.name);
241                type_refs.insert(
242                    bare_name.to_string(),
243                    TypeIdentifierRef {
244                        module: source_module.clone(),
245                        import_kind: ImportKind::Module,
246                    },
247                );
248            }
249        }
250        type_refs.extend(inner.type_refs);
251
252        TypeInfo {
253            name: format!("builtins.set[{}]", inner.name),
254            source_module: None,
255            import,
256            type_refs,
257        }
258    }
259
260    /// A `dict[Type]` type annotation.
261    pub fn dict_of<K: PyStubType, V: PyStubType>() -> Self {
262        let inner_k = K::type_output();
263        let inner_v = V::type_output();
264        let mut import = inner_k.import.clone();
265        import.extend(inner_v.import.clone());
266        import.insert("builtins".into());
267
268        // Build type_refs from both inner types
269        let mut type_refs = HashMap::new();
270        for inner in [&inner_k, &inner_v] {
271            if let Some(ref source_module) = inner.source_module {
272                if let Some(_module_name) = source_module.get() {
273                    let bare_name = inner
274                        .name
275                        .split('[')
276                        .next()
277                        .unwrap_or(&inner.name)
278                        .split('.')
279                        .next_back()
280                        .unwrap_or(&inner.name);
281                    type_refs.insert(
282                        bare_name.to_string(),
283                        TypeIdentifierRef {
284                            module: source_module.clone(),
285                            import_kind: ImportKind::Module,
286                        },
287                    );
288                }
289            }
290            type_refs.extend(inner.type_refs.clone());
291        }
292
293        TypeInfo {
294            name: format!("builtins.dict[{}, {}]", inner_k.name, inner_v.name),
295            source_module: None,
296            import,
297            type_refs,
298        }
299    }
300
301    /// A type annotation of a built-in type provided from `builtins` module, such as `int`, `str`, or `float`. Generic builtin types are also possible, such as `dict[str, str]`.
302    pub fn builtin(name: &str) -> Self {
303        Self {
304            name: format!("builtins.{name}"),
305            source_module: None,
306            import: hashset! { "builtins".into() },
307            type_refs: HashMap::new(),
308        }
309    }
310
311    /// Unqualified type.
312    pub fn unqualified(name: &str) -> Self {
313        Self {
314            name: name.to_string(),
315            source_module: None,
316            import: hashset! {},
317            type_refs: HashMap::new(),
318        }
319    }
320
321    /// A type annotation of a type that must be imported. The type name must be qualified with the module name:
322    ///
323    /// ```
324    /// pyo3_stub_gen::TypeInfo::with_module("pathlib.Path", "pathlib".into());
325    /// ```
326    pub fn with_module(name: &str, module: ModuleRef) -> Self {
327        let mut import = HashSet::new();
328        import.insert(ImportRef::Module(module.clone()));
329        Self {
330            name: name.to_string(),
331            source_module: Some(module),
332            import,
333            type_refs: HashMap::new(),
334        }
335    }
336
337    /// A type defined in the PyO3 module.
338    ///
339    /// - Types are referenced using fully qualified names to avoid symbol collision when used across modules.
340    /// - For example, if `A` is defined in `package.submod1`, it will be referenced as `submod1.A` when used in other modules.
341    /// - The module will be imported as `from package import submod1`.
342    /// - When used in the same module where it's defined, it will be automatically de-qualified during stub generation.
343    /// - The `source_module` field tracks which module the type belongs to for future use.
344    ///
345    /// ```
346    /// pyo3_stub_gen::TypeInfo::locally_defined("A", "package.submod1".into());
347    /// ```
348    pub fn locally_defined(type_name: &str, module: ModuleRef) -> Self {
349        let mut import = HashSet::new();
350        let mut type_refs = HashMap::new();
351
352        // Determine qualified name and import based on module
353        // We qualify all named modules; de-qualification for same-module usage happens during stub generation
354        let qualified_name = match module.get() {
355            Some(module_name) if !module_name.is_empty() => {
356                // Extract the last component of the module path for qualification
357                // e.g., "package.module.submodule" -> "submodule"
358                let module_component = module_name.rsplit('.').next().unwrap_or(module_name);
359                // Use Module import for cross-module references
360                import.insert(ImportRef::Module(module.clone()));
361
362                // Populate type_refs with the bare identifier for context-aware qualification
363                type_refs.insert(
364                    type_name.to_string(),
365                    TypeIdentifierRef {
366                        module: module.clone(),
367                        import_kind: ImportKind::Module,
368                    },
369                );
370
371                format!("{}.{}", module_component, type_name)
372            }
373            _ => {
374                // Default/empty module - treat like named modules but keep name unqualified
375                // Will be resolved to actual module name at runtime
376                import.insert(ImportRef::Module(module.clone()));
377                type_refs.insert(
378                    type_name.to_string(),
379                    TypeIdentifierRef {
380                        module: module.clone(),
381                        import_kind: ImportKind::Module,
382                    },
383                );
384                type_name.to_string()
385            }
386        };
387
388        Self {
389            name: qualified_name,
390            source_module: Some(module),
391            import,
392            type_refs,
393        }
394    }
395
396    /// Get the qualified name for use in a specific target module.
397    ///
398    /// - If the type has no source module, returns the name as-is
399    /// - If the type is from the same module as the target, returns unqualified name
400    /// - If the type is from a different module, returns qualified name with module component
401    ///
402    /// # Examples
403    ///
404    /// - Type A from "package.sub_mod" used in "package.sub_mod" -> "A"
405    /// - Type A from "package.sub_mod" used in "package.main_mod" -> "sub_mod.A"
406    pub fn qualified_name(&self, target_module: &str) -> String {
407        match &self.source_module {
408            None => self.name.clone(),
409            Some(module_ref) => {
410                let source = module_ref.get().unwrap_or(target_module);
411                if source == target_module {
412                    // Same module: unqualified
413                    // Strip module prefix if present (handles pre-qualified names from macros)
414                    let module_component = source.rsplit('.').next().unwrap_or(source);
415                    let prefix = format!("{}.", module_component);
416                    if let Some(stripped) = self.name.strip_prefix(&prefix) {
417                        stripped.to_string()
418                    } else {
419                        self.name.clone()
420                    }
421                } else {
422                    // Different module: qualify with last module component
423                    let module_component = source.rsplit('.').next().unwrap_or(source);
424                    // Strip existing module prefix if present (handles pre-qualified names from macros)
425                    let prefix = format!("{}.", module_component);
426                    let base_name = if let Some(stripped) = self.name.strip_prefix(&prefix) {
427                        stripped
428                    } else {
429                        &self.name
430                    };
431                    format!("{}.{}", module_component, base_name)
432                }
433            }
434        }
435    }
436
437    /// Check if this type is from the same module as the target module.
438    pub fn is_same_module(&self, target_module: &str) -> bool {
439        self.source_module.as_ref().and_then(|m| m.get()) == Some(target_module)
440    }
441
442    /// Check if this type is internal to the package (starts with package root).
443    pub fn is_internal_to_package(&self, package_root: &str) -> bool {
444        match &self.source_module {
445            Some(ModuleRef::Named(path)) => path.starts_with(package_root),
446            Some(ModuleRef::Default) => true,
447            None => false,
448        }
449    }
450
451    /// Get the qualified name for use in a specific target module with context-aware rewriting.
452    ///
453    /// This method handles compound type expressions by rewriting nested identifiers
454    /// based on the type_refs tracking information. For example:
455    /// - `typing.Optional[ClassA]` becomes `typing.Optional[sub_mod.ClassA]` when ClassA
456    ///   is from a different module.
457    ///
458    /// # Arguments
459    /// * `target_module` - The module where this type will be used
460    ///
461    /// # Returns
462    /// The qualified type name string with identifiers properly qualified
463    pub fn qualified_for_module(&self, target_module: &str) -> String {
464        // If no type_refs, use the simpler qualified_name method
465        if self.type_refs.is_empty() {
466            return self.qualified_name(target_module);
467        }
468
469        // Rewrite the expression with context-aware qualification
470        use crate::generate::qualifier::TypeExpressionQualifier;
471        TypeExpressionQualifier::qualify_expression(&self.name, &self.type_refs, target_module)
472    }
473
474    /// Resolve ModuleRef::Default to the actual module name.
475    /// Called at runtime when default module name is known.
476    pub fn resolve_default_module(&mut self, default_module_name: &str) {
477        // Resolve source_module
478        if let Some(ModuleRef::Default) = &self.source_module {
479            self.source_module = Some(ModuleRef::Named(default_module_name.to_string()));
480
481            // Update qualified name if needed
482            let module_component = default_module_name
483                .rsplit('.')
484                .next()
485                .unwrap_or(default_module_name);
486            if !self.name.contains('.') {
487                self.name = format!("{}.{}", module_component, self.name);
488            }
489        }
490
491        // Resolve import refs
492        let mut new_import = std::collections::HashSet::new();
493        for import_ref in &self.import {
494            match import_ref {
495                ImportRef::Module(ModuleRef::Default) => {
496                    new_import.insert(ImportRef::Module(ModuleRef::Named(
497                        default_module_name.to_string(),
498                    )));
499                }
500                other => {
501                    new_import.insert(other.clone());
502                }
503            }
504        }
505        self.import = new_import;
506
507        // Resolve type_refs
508        for type_ref in self.type_refs.values_mut() {
509            if let ModuleRef::Default = &type_ref.module {
510                type_ref.module = ModuleRef::Named(default_module_name.to_string());
511            }
512        }
513    }
514}
515
516impl ops::BitOr for TypeInfo {
517    type Output = Self;
518
519    fn bitor(mut self, rhs: Self) -> Self {
520        self.import.extend(rhs.import);
521        // Merge type_refs from both sides
522        let mut merged_type_refs = self.type_refs.clone();
523        merged_type_refs.extend(rhs.type_refs);
524        Self {
525            name: format!("{} | {}", self.name, rhs.name),
526            source_module: None, // Union types are synthetic, have no source module
527            import: self.import,
528            type_refs: merged_type_refs,
529        }
530    }
531}
532
533/// Implement [PyStubType]
534///
535/// ```rust
536/// use pyo3::*;
537/// use pyo3_stub_gen::{impl_stub_type, derive::*};
538///
539/// #[gen_stub_pyclass]
540/// #[pyclass]
541/// struct A;
542///
543/// #[gen_stub_pyclass]
544/// #[pyclass]
545/// struct B;
546///
547/// enum E {
548///     A(A),
549///     B(B),
550/// }
551/// impl_stub_type!(E = A | B);
552///
553/// struct X(A);
554/// impl_stub_type!(X = A);
555///
556/// struct Y {
557///    a: A,
558///    b: B,
559/// }
560/// impl_stub_type!(Y = (A, B));
561/// ```
562#[macro_export]
563macro_rules! impl_stub_type {
564    ($ty: ty = $($base:ty)|+) => {
565        impl ::pyo3_stub_gen::PyStubType for $ty {
566            fn type_output() -> ::pyo3_stub_gen::TypeInfo {
567                $(<$base>::type_output()) | *
568            }
569            fn type_input() -> ::pyo3_stub_gen::TypeInfo {
570                $(<$base>::type_input()) | *
571            }
572        }
573    };
574    ($ty:ty = $base:ty) => {
575        impl ::pyo3_stub_gen::PyStubType for $ty {
576            fn type_output() -> ::pyo3_stub_gen::TypeInfo {
577                <$base>::type_output()
578            }
579            fn type_input() -> ::pyo3_stub_gen::TypeInfo {
580                <$base>::type_input()
581            }
582        }
583    };
584}
585
586/// Annotate Rust types with Python type information.
587pub trait PyStubType {
588    /// The type to be used in the output signature, i.e. return type of the Python function or methods.
589    fn type_output() -> TypeInfo;
590
591    /// The type to be used in the input signature, i.e. the arguments of the Python function or methods.
592    ///
593    /// This defaults to the output type, but can be overridden for types that are not valid input types.
594    /// For example, `Vec::<T>::type_output` returns `list[T]` while `Vec::<T>::type_input` returns `typing.Sequence[T]`.
595    fn type_input() -> TypeInfo {
596        Self::type_output()
597    }
598}
599
600#[cfg(test)]
601mod test {
602    use super::*;
603    use maplit::hashset;
604    use std::collections::HashMap;
605    use test_case::test_case;
606
607    #[test_case(bool::type_input(), "builtins.bool", hashset! { "builtins".into() } ; "bool_input")]
608    #[test_case(<&str>::type_input(), "builtins.str", hashset! { "builtins".into() } ; "str_input")]
609    #[test_case(Vec::<u32>::type_input(), "typing.Sequence[builtins.int]", hashset! { "typing".into(), "builtins".into() } ; "Vec_u32_input")]
610    #[test_case(Vec::<u32>::type_output(), "builtins.list[builtins.int]", hashset! {  "builtins".into() } ; "Vec_u32_output")]
611    #[test_case(HashMap::<u32, String>::type_input(), "typing.Mapping[builtins.int, builtins.str]", hashset! { "typing".into(), "builtins".into() } ; "HashMap_u32_String_input")]
612    #[test_case(HashMap::<u32, String>::type_output(), "builtins.dict[builtins.int, builtins.str]", hashset! { "builtins".into() } ; "HashMap_u32_String_output")]
613    #[test_case(indexmap::IndexMap::<u32, String>::type_input(), "typing.Mapping[builtins.int, builtins.str]", hashset! { "typing".into(), "builtins".into() } ; "IndexMap_u32_String_input")]
614    #[test_case(indexmap::IndexMap::<u32, String>::type_output(), "builtins.dict[builtins.int, builtins.str]", hashset! { "builtins".into() } ; "IndexMap_u32_String_output")]
615    #[test_case(HashMap::<u32, Vec<u32>>::type_input(), "typing.Mapping[builtins.int, typing.Sequence[builtins.int]]", hashset! { "builtins".into(), "typing".into() } ; "HashMap_u32_Vec_u32_input")]
616    #[test_case(HashMap::<u32, Vec<u32>>::type_output(), "builtins.dict[builtins.int, builtins.list[builtins.int]]", hashset! { "builtins".into() } ; "HashMap_u32_Vec_u32_output")]
617    #[test_case(HashSet::<u32>::type_input(), "builtins.set[builtins.int]", hashset! { "builtins".into() } ; "HashSet_u32_input")]
618    #[test_case(indexmap::IndexSet::<u32>::type_input(), "builtins.set[builtins.int]", hashset! { "builtins".into() } ; "IndexSet_u32_input")]
619    #[test_case(TypeInfo::dict_of::<u32, String>(), "builtins.dict[builtins.int, builtins.str]", hashset! { "builtins".into() } ; "dict_of_u32_String")]
620    fn test(tinfo: TypeInfo, name: &str, import: HashSet<ImportRef>) {
621        assert_eq!(tinfo.name, name);
622        if import.is_empty() {
623            assert!(tinfo.import.is_empty());
624        } else {
625            assert_eq!(tinfo.import, import);
626        }
627    }
628}