pyo3_stub_gen/
stub_type.rs

1mod builtins;
2mod collections;
3mod pyo3;
4
5#[cfg(feature = "numpy")]
6mod numpy;
7
8use maplit::hashset;
9use std::{collections::HashSet, fmt, ops};
10
11#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Default, Hash)]
12pub enum ModuleRef {
13    Named(String),
14
15    /// Default module that PyO3 creates.
16    ///
17    /// - For pure Rust project, the default module name is the crate name specified in `Cargo.toml`
18    ///   or `project.name` specified in `pyproject.toml`
19    /// - For mixed Rust/Python project, the default module name is `tool.maturin.module-name` specified in `pyproject.toml`
20    ///
21    /// Because the default module name cannot be known at compile time, it will be resolved at the time of the stub file generation.
22    /// This is a placeholder for the default module name.
23    #[default]
24    Default,
25}
26
27impl ModuleRef {
28    pub fn get(&self) -> Option<&str> {
29        match self {
30            Self::Named(name) => Some(name),
31            Self::Default => None,
32        }
33    }
34}
35
36impl From<&str> for ModuleRef {
37    fn from(s: &str) -> Self {
38        Self::Named(s.to_string())
39    }
40}
41
42/// Type information for creating Python stub files annotated by [PyStubType] trait.
43#[derive(Debug, Clone, PartialEq, Eq)]
44pub struct TypeInfo {
45    /// The Python type name.
46    pub name: String,
47
48    /// Python modules must be imported in the stub file.
49    ///
50    /// For example, when `name` is `typing.Sequence[int]`, `import` should contain `typing`.
51    /// This makes it possible to use user-defined types in the stub file.
52    pub import: HashSet<ModuleRef>,
53}
54
55impl fmt::Display for TypeInfo {
56    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
57        write!(f, "{}", self.name)
58    }
59}
60
61impl TypeInfo {
62    /// A `None` type annotation.
63    pub fn none() -> Self {
64        // NOTE: since 3.10, NoneType is provided from types module,
65        // but there is no corresponding definitions prior to 3.10.
66        Self {
67            name: "None".to_string(),
68            import: HashSet::new(),
69        }
70    }
71
72    /// A `typing.Any` type annotation.
73    pub fn any() -> Self {
74        Self {
75            name: "typing.Any".to_string(),
76            import: hashset! { "typing".into() },
77        }
78    }
79
80    /// A `list[Type]` type annotation.
81    pub fn list_of<T: PyStubType>() -> Self {
82        let TypeInfo { name, mut import } = T::type_output();
83        import.insert("builtins".into());
84        TypeInfo {
85            name: format!("builtins.list[{}]", name),
86            import,
87        }
88    }
89
90    /// A `set[Type]` type annotation.
91    pub fn set_of<T: PyStubType>() -> Self {
92        let TypeInfo { name, mut import } = T::type_output();
93        import.insert("builtins".into());
94        TypeInfo {
95            name: format!("builtins.set[{}]", name),
96            import,
97        }
98    }
99
100    /// A `dict[Type]` type annotation.
101    pub fn dict_of<K: PyStubType, V: PyStubType>() -> Self {
102        let TypeInfo {
103            name: name_k,
104            mut import,
105        } = K::type_output();
106        let TypeInfo {
107            name: name_v,
108            import: import_v,
109        } = V::type_output();
110        import.extend(import_v);
111        import.insert("builtins".into());
112        TypeInfo {
113            name: format!("builtins.set[{}, {}]", name_k, name_v),
114            import,
115        }
116    }
117
118    /// A type annotation of a built-in type provided from `builtins` module, such as `int`, `str`, or `float`. Generic builtin types are also possible, such as `dict[str, str]`.
119    pub fn builtin(name: &str) -> Self {
120        Self {
121            name: format!("builtins.{name}"),
122            import: hashset! { "builtins".into() },
123        }
124    }
125
126    /// Unqualified type.
127    pub fn unqualified(name: &str) -> Self {
128        Self {
129            name: name.to_string(),
130            import: hashset! {},
131        }
132    }
133
134    /// A type annotation of a type that must be imported. The type name must be qualified with the module name:
135    ///
136    /// ```
137    /// pyo3_stub_gen::TypeInfo::with_module("pathlib.Path", "pathlib".into());
138    /// ```
139    pub fn with_module(name: &str, module: ModuleRef) -> Self {
140        let mut import = HashSet::new();
141        import.insert(module);
142        Self {
143            name: name.to_string(),
144            import,
145        }
146    }
147}
148
149impl ops::BitOr for TypeInfo {
150    type Output = Self;
151
152    fn bitor(mut self, rhs: Self) -> Self {
153        self.import.extend(rhs.import);
154        Self {
155            name: format!("{} | {}", self.name, rhs.name),
156            import: self.import,
157        }
158    }
159}
160
161/// Implement [PyStubType]
162///
163/// ```rust
164/// use pyo3::*;
165/// use pyo3_stub_gen::{impl_stub_type, derive::*};
166///
167/// #[gen_stub_pyclass]
168/// #[pyclass]
169/// struct A;
170///
171/// #[gen_stub_pyclass]
172/// #[pyclass]
173/// struct B;
174///
175/// enum E {
176///     A(A),
177///     B(B),
178/// }
179/// impl_stub_type!(E = A | B);
180///
181/// struct X(A);
182/// impl_stub_type!(X = A);
183///
184/// struct Y {
185///    a: A,
186///    b: B,
187/// }
188/// impl_stub_type!(Y = (A, B));
189/// ```
190#[macro_export]
191macro_rules! impl_stub_type {
192    ($ty: ty = $($base:ty)|+) => {
193        impl ::pyo3_stub_gen::PyStubType for $ty {
194            fn type_output() -> ::pyo3_stub_gen::TypeInfo {
195                $(<$base>::type_output()) | *
196            }
197            fn type_input() -> ::pyo3_stub_gen::TypeInfo {
198                $(<$base>::type_input()) | *
199            }
200        }
201    };
202    ($ty:ty = $base:ty) => {
203        impl ::pyo3_stub_gen::PyStubType for $ty {
204            fn type_output() -> ::pyo3_stub_gen::TypeInfo {
205                <$base>::type_output()
206            }
207            fn type_input() -> ::pyo3_stub_gen::TypeInfo {
208                <$base>::type_input()
209            }
210        }
211    };
212}
213
214/// Annotate Rust types with Python type information.
215pub trait PyStubType {
216    /// The type to be used in the output signature, i.e. return type of the Python function or methods.
217    fn type_output() -> TypeInfo;
218
219    /// The type to be used in the input signature, i.e. the arguments of the Python function or methods.
220    ///
221    /// This defaults to the output type, but can be overridden for types that are not valid input types.
222    /// For example, `Vec::<T>::type_output` returns `list[T]` while `Vec::<T>::type_input` returns `typing.Sequence[T]`.
223    fn type_input() -> TypeInfo {
224        Self::type_output()
225    }
226}
227
228#[cfg(test)]
229mod test {
230    use super::*;
231    use maplit::hashset;
232    use std::collections::HashMap;
233    use test_case::test_case;
234
235    #[test_case(bool::type_input(), "builtins.bool", hashset! { "builtins".into() } ; "bool_input")]
236    #[test_case(<&str>::type_input(), "builtins.str", hashset! { "builtins".into() } ; "str_input")]
237    #[test_case(Vec::<u32>::type_input(), "typing.Sequence[builtins.int]", hashset! { "typing".into(), "builtins".into() } ; "Vec_u32_input")]
238    #[test_case(Vec::<u32>::type_output(), "builtins.list[builtins.int]", hashset! {  "builtins".into() } ; "Vec_u32_output")]
239    #[test_case(HashMap::<u32, String>::type_input(), "typing.Mapping[builtins.int, builtins.str]", hashset! { "typing".into(), "builtins".into() } ; "HashMap_u32_String_input")]
240    #[test_case(HashMap::<u32, String>::type_output(), "builtins.dict[builtins.int, builtins.str]", hashset! { "builtins".into() } ; "HashMap_u32_String_output")]
241    #[test_case(indexmap::IndexMap::<u32, String>::type_input(), "typing.Mapping[builtins.int, builtins.str]", hashset! { "typing".into(), "builtins".into() } ; "IndexMap_u32_String_input")]
242    #[test_case(indexmap::IndexMap::<u32, String>::type_output(), "builtins.dict[builtins.int, builtins.str]", hashset! { "builtins".into() } ; "IndexMap_u32_String_output")]
243    #[test_case(HashMap::<u32, Vec<u32>>::type_input(), "typing.Mapping[builtins.int, typing.Sequence[builtins.int]]", hashset! { "builtins".into(), "typing".into() } ; "HashMap_u32_Vec_u32_input")]
244    #[test_case(HashMap::<u32, Vec<u32>>::type_output(), "builtins.dict[builtins.int, builtins.list[builtins.int]]", hashset! { "builtins".into() } ; "HashMap_u32_Vec_u32_output")]
245    #[test_case(HashSet::<u32>::type_input(), "builtins.set[builtins.int]", hashset! { "builtins".into() } ; "HashSet_u32_input")]
246    #[test_case(indexmap::IndexSet::<u32>::type_input(), "builtins.set[builtins.int]", hashset! { "builtins".into() } ; "IndexSet_u32_input")]
247    fn test(tinfo: TypeInfo, name: &str, import: HashSet<ModuleRef>) {
248        assert_eq!(tinfo.name, name);
249        if import.is_empty() {
250            assert!(tinfo.import.is_empty());
251        } else {
252            assert_eq!(tinfo.import, import);
253        }
254    }
255}