pyo3_stub_gen/
stub_type.rs

1mod builtins;
2mod collections;
3mod pyo3;
4
5#[cfg(feature = "numpy")]
6mod numpy;
7
8#[cfg(feature = "either")]
9mod either;
10
11#[cfg(feature = "rust_decimal")]
12mod rust_decimal;
13
14use maplit::hashset;
15use std::cmp::Ordering;
16use std::{collections::HashSet, fmt, ops};
17
18/// Indicates what to import.
19/// Module: The purpose is to import the entire module(eg import builtins).
20/// Type: The purpose is to import the types in the module(eg from moduleX import typeX).
21#[derive(Debug, Clone, PartialEq, Eq, Hash)]
22pub enum ImportRef {
23    Module(ModuleRef),
24    Type(TypeRef),
25}
26
27impl From<&str> for ImportRef {
28    fn from(value: &str) -> Self {
29        ImportRef::Module(value.into())
30    }
31}
32
33impl PartialOrd for ImportRef {
34    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
35        Some(self.cmp(other))
36    }
37}
38
39impl Ord for ImportRef {
40    fn cmp(&self, other: &Self) -> Ordering {
41        match (self, other) {
42            (ImportRef::Module(a), ImportRef::Module(b)) => a.get().cmp(&b.get()),
43            (ImportRef::Type(a), ImportRef::Type(b)) => a.cmp(b),
44            (ImportRef::Module(_), ImportRef::Type(_)) => Ordering::Greater,
45            (ImportRef::Type(_), ImportRef::Module(_)) => Ordering::Less,
46        }
47    }
48}
49
50#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Default, Hash)]
51pub enum ModuleRef {
52    Named(String),
53
54    /// Default module that PyO3 creates.
55    ///
56    /// - For pure Rust project, the default module name is the crate name specified in `Cargo.toml`
57    ///   or `project.name` specified in `pyproject.toml`
58    /// - For mixed Rust/Python project, the default module name is `tool.maturin.module-name` specified in `pyproject.toml`
59    ///
60    /// Because the default module name cannot be known at compile time, it will be resolved at the time of the stub file generation.
61    /// This is a placeholder for the default module name.
62    #[default]
63    Default,
64}
65
66impl ModuleRef {
67    pub fn get(&self) -> Option<&str> {
68        match self {
69            Self::Named(name) => Some(name),
70            Self::Default => None,
71        }
72    }
73}
74
75impl From<&str> for ModuleRef {
76    fn from(s: &str) -> Self {
77        Self::Named(s.to_string())
78    }
79}
80
81/// Indicates the type of import(eg class enum).
82/// from module import type.
83/// name, type name. module, module name(which type defined).
84#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Default, Hash)]
85pub struct TypeRef {
86    pub module: ModuleRef,
87    pub name: String,
88}
89
90impl TypeRef {
91    pub fn new(module_ref: ModuleRef, name: String) -> Self {
92        Self {
93            module: module_ref,
94            name,
95        }
96    }
97}
98
99/// Type information for creating Python stub files annotated by [PyStubType] trait.
100#[derive(Debug, Clone, PartialEq, Eq)]
101pub struct TypeInfo {
102    /// The Python type name.
103    pub name: String,
104
105    /// Python modules must be imported in the stub file.
106    ///
107    /// For example, when `name` is `typing.Sequence[int]`, `import` should contain `typing`.
108    /// This makes it possible to use user-defined types in the stub file.
109    pub import: HashSet<ImportRef>,
110}
111
112impl fmt::Display for TypeInfo {
113    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
114        write!(f, "{}", self.name)
115    }
116}
117
118impl TypeInfo {
119    /// A `None` type annotation.
120    pub fn none() -> Self {
121        // NOTE: since 3.10, NoneType is provided from types module,
122        // but there is no corresponding definitions prior to 3.10.
123        Self {
124            name: "None".to_string(),
125            import: HashSet::new(),
126        }
127    }
128
129    /// A `typing.Any` type annotation.
130    pub fn any() -> Self {
131        Self {
132            name: "typing.Any".to_string(),
133            import: hashset! { "typing".into() },
134        }
135    }
136
137    /// A `list[Type]` type annotation.
138    pub fn list_of<T: PyStubType>() -> Self {
139        let TypeInfo { name, mut import } = T::type_output();
140        import.insert("builtins".into());
141        TypeInfo {
142            name: format!("builtins.list[{name}]"),
143            import,
144        }
145    }
146
147    /// A `set[Type]` type annotation.
148    pub fn set_of<T: PyStubType>() -> Self {
149        let TypeInfo { name, mut import } = T::type_output();
150        import.insert("builtins".into());
151        TypeInfo {
152            name: format!("builtins.set[{name}]"),
153            import,
154        }
155    }
156
157    /// A `dict[Type]` type annotation.
158    pub fn dict_of<K: PyStubType, V: PyStubType>() -> Self {
159        let TypeInfo {
160            name: name_k,
161            mut import,
162        } = K::type_output();
163        let TypeInfo {
164            name: name_v,
165            import: import_v,
166        } = V::type_output();
167        import.extend(import_v);
168        import.insert("builtins".into());
169        TypeInfo {
170            name: format!("builtins.set[{name_k}, {name_v}]"),
171            import,
172        }
173    }
174
175    /// A type annotation of a built-in type provided from `builtins` module, such as `int`, `str`, or `float`. Generic builtin types are also possible, such as `dict[str, str]`.
176    pub fn builtin(name: &str) -> Self {
177        Self {
178            name: format!("builtins.{name}"),
179            import: hashset! { "builtins".into() },
180        }
181    }
182
183    /// Unqualified type.
184    pub fn unqualified(name: &str) -> Self {
185        Self {
186            name: name.to_string(),
187            import: hashset! {},
188        }
189    }
190
191    /// A type annotation of a type that must be imported. The type name must be qualified with the module name:
192    ///
193    /// ```
194    /// pyo3_stub_gen::TypeInfo::with_module("pathlib.Path", "pathlib".into());
195    /// ```
196    pub fn with_module(name: &str, module: ModuleRef) -> Self {
197        let mut import = HashSet::new();
198        import.insert(ImportRef::Module(module));
199        Self {
200            name: name.to_string(),
201            import,
202        }
203    }
204
205    /// A type defined in the PyO3 module.
206    ///
207    /// - Types defined in the same module can be referenced without import.
208    ///   But when it is used in another submodule, it must be imported.
209    /// - For example, if `A` is defined in `submod1`, it can be used as `A` in `submod1`.
210    ///   In `submod2`, it must be imported as `from submod1 import A`.
211    ///
212    /// ```
213    /// pyo3_stub_gen::TypeInfo::locally_defined("A", "submod1".into());
214    /// ```
215    pub fn locally_defined(type_name: &str, module: ModuleRef) -> Self {
216        let mut import = HashSet::new();
217        let type_ref = TypeRef::new(module, type_name.to_string());
218        import.insert(ImportRef::Type(type_ref));
219
220        Self {
221            name: type_name.to_string(),
222            import,
223        }
224    }
225}
226
227impl ops::BitOr for TypeInfo {
228    type Output = Self;
229
230    fn bitor(mut self, rhs: Self) -> Self {
231        self.import.extend(rhs.import);
232        Self {
233            name: format!("{} | {}", self.name, rhs.name),
234            import: self.import,
235        }
236    }
237}
238
239/// Implement [PyStubType]
240///
241/// ```rust
242/// use pyo3::*;
243/// use pyo3_stub_gen::{impl_stub_type, derive::*};
244///
245/// #[gen_stub_pyclass]
246/// #[pyclass]
247/// struct A;
248///
249/// #[gen_stub_pyclass]
250/// #[pyclass]
251/// struct B;
252///
253/// enum E {
254///     A(A),
255///     B(B),
256/// }
257/// impl_stub_type!(E = A | B);
258///
259/// struct X(A);
260/// impl_stub_type!(X = A);
261///
262/// struct Y {
263///    a: A,
264///    b: B,
265/// }
266/// impl_stub_type!(Y = (A, B));
267/// ```
268#[macro_export]
269macro_rules! impl_stub_type {
270    ($ty: ty = $($base:ty)|+) => {
271        impl ::pyo3_stub_gen::PyStubType for $ty {
272            fn type_output() -> ::pyo3_stub_gen::TypeInfo {
273                $(<$base>::type_output()) | *
274            }
275            fn type_input() -> ::pyo3_stub_gen::TypeInfo {
276                $(<$base>::type_input()) | *
277            }
278        }
279    };
280    ($ty:ty = $base:ty) => {
281        impl ::pyo3_stub_gen::PyStubType for $ty {
282            fn type_output() -> ::pyo3_stub_gen::TypeInfo {
283                <$base>::type_output()
284            }
285            fn type_input() -> ::pyo3_stub_gen::TypeInfo {
286                <$base>::type_input()
287            }
288        }
289    };
290}
291
292/// Annotate Rust types with Python type information.
293pub trait PyStubType {
294    /// The type to be used in the output signature, i.e. return type of the Python function or methods.
295    fn type_output() -> TypeInfo;
296
297    /// The type to be used in the input signature, i.e. the arguments of the Python function or methods.
298    ///
299    /// This defaults to the output type, but can be overridden for types that are not valid input types.
300    /// For example, `Vec::<T>::type_output` returns `list[T]` while `Vec::<T>::type_input` returns `typing.Sequence[T]`.
301    fn type_input() -> TypeInfo {
302        Self::type_output()
303    }
304}
305
306#[cfg(test)]
307mod test {
308    use super::*;
309    use maplit::hashset;
310    use std::collections::HashMap;
311    use test_case::test_case;
312
313    #[test_case(bool::type_input(), "builtins.bool", hashset! { "builtins".into() } ; "bool_input")]
314    #[test_case(<&str>::type_input(), "builtins.str", hashset! { "builtins".into() } ; "str_input")]
315    #[test_case(Vec::<u32>::type_input(), "typing.Sequence[builtins.int]", hashset! { "typing".into(), "builtins".into() } ; "Vec_u32_input")]
316    #[test_case(Vec::<u32>::type_output(), "builtins.list[builtins.int]", hashset! {  "builtins".into() } ; "Vec_u32_output")]
317    #[test_case(HashMap::<u32, String>::type_input(), "typing.Mapping[builtins.int, builtins.str]", hashset! { "typing".into(), "builtins".into() } ; "HashMap_u32_String_input")]
318    #[test_case(HashMap::<u32, String>::type_output(), "builtins.dict[builtins.int, builtins.str]", hashset! { "builtins".into() } ; "HashMap_u32_String_output")]
319    #[test_case(indexmap::IndexMap::<u32, String>::type_input(), "typing.Mapping[builtins.int, builtins.str]", hashset! { "typing".into(), "builtins".into() } ; "IndexMap_u32_String_input")]
320    #[test_case(indexmap::IndexMap::<u32, String>::type_output(), "builtins.dict[builtins.int, builtins.str]", hashset! { "builtins".into() } ; "IndexMap_u32_String_output")]
321    #[test_case(HashMap::<u32, Vec<u32>>::type_input(), "typing.Mapping[builtins.int, typing.Sequence[builtins.int]]", hashset! { "builtins".into(), "typing".into() } ; "HashMap_u32_Vec_u32_input")]
322    #[test_case(HashMap::<u32, Vec<u32>>::type_output(), "builtins.dict[builtins.int, builtins.list[builtins.int]]", hashset! { "builtins".into() } ; "HashMap_u32_Vec_u32_output")]
323    #[test_case(HashSet::<u32>::type_input(), "builtins.set[builtins.int]", hashset! { "builtins".into() } ; "HashSet_u32_input")]
324    #[test_case(indexmap::IndexSet::<u32>::type_input(), "builtins.set[builtins.int]", hashset! { "builtins".into() } ; "IndexSet_u32_input")]
325    fn test(tinfo: TypeInfo, name: &str, import: HashSet<ImportRef>) {
326        assert_eq!(tinfo.name, name);
327        if import.is_empty() {
328            assert!(tinfo.import.is_empty());
329        } else {
330            assert_eq!(tinfo.import, import);
331        }
332    }
333}