pure/
lib.rs

1#![allow(deprecated)]
2
3mod chrono_types;
4mod custom_exceptions;
5mod manual_overloading;
6mod manual_submit;
7mod overloading;
8mod overriding;
9mod rust_type_marker;
10mod skip_stub_type_test;
11mod time_types;
12
13use chrono_types::*;
14use custom_exceptions::*;
15use manual_overloading::*;
16use manual_submit::*;
17use overloading::*;
18use overriding::*;
19use rust_type_marker::*;
20use skip_stub_type_test::*;
21use time_types::*;
22
23#[cfg_attr(target_os = "macos", doc = include_str!("../../../README.md"))]
24mod readme {}
25
26use ahash::RandomState;
27use pyo3::{prelude::*, types::*};
28use pyo3_stub_gen::{define_stub_info_gatherer, derive::*, module_doc, module_variable};
29use rust_decimal::Decimal;
30use std::{collections::HashMap, path::PathBuf};
31
32/// Returns the sum of two numbers as a string.
33#[gen_stub_pyfunction]
34#[pyfunction]
35fn sum(v: Vec<u32>) -> u32 {
36    v.iter().sum()
37}
38
39#[gen_stub_pyfunction]
40#[pyfunction]
41fn read_dict(dict: HashMap<usize, HashMap<usize, usize>>) {
42    for (k, v) in dict {
43        for (k2, v2) in v {
44            println!("{k} {k2} {v2}");
45        }
46    }
47}
48
49#[gen_stub_pyfunction]
50#[pyfunction]
51fn create_dict(n: usize) -> HashMap<usize, Vec<usize>> {
52    let mut dict = HashMap::new();
53    for i in 0..n {
54        dict.insert(i, (0..i).collect());
55    }
56    dict
57}
58
59/// Add two decimal numbers with high precision
60#[gen_stub_pyfunction]
61#[pyfunction]
62fn add_decimals(a: Decimal, b: Decimal) -> Decimal {
63    a + b
64}
65
66#[gen_stub_pyclass]
67#[pyclass(extends=PyDate)]
68struct MyDate;
69
70#[gen_stub_pyclass]
71#[pyclass(subclass)]
72#[derive(Debug)]
73struct A {
74    #[gen_stub(default = A::default().x)]
75    #[pyo3(get, set)]
76    x: usize,
77
78    #[pyo3(get)]
79    y: usize,
80}
81
82impl Default for A {
83    fn default() -> Self {
84        Self { x: 2, y: 10 }
85    }
86}
87
88#[gen_stub_pymethods]
89#[pymethods]
90impl A {
91    /// This is a constructor of :class:`A`.
92    #[new]
93    fn new(x: usize) -> Self {
94        Self { x, y: 10 }
95    }
96    /// class attribute NUM1
97    #[classattr]
98    #[pyo3(name = "NUM")]
99    const NUM1: usize = 2;
100
101    /// deprecated class attribute NUM3 (will show warning)
102    #[deprecated(since = "1.0.0", note = "This constant is deprecated")]
103    #[classattr]
104    const NUM3: usize = 3;
105    /// class attribute NUM2
106    #[expect(non_snake_case)]
107    #[classattr]
108    fn NUM2() -> usize {
109        2
110    }
111    #[classmethod]
112    fn classmethod_test1(cls: &Bound<'_, PyType>) {
113        _ = cls;
114    }
115
116    #[deprecated(since = "1.0.0", note = "This classmethod is deprecated")]
117    #[classmethod]
118    fn deprecated_classmethod(cls: &Bound<'_, PyType>) {
119        _ = cls;
120    }
121
122    #[classmethod]
123    fn classmethod_test2(_: &Bound<'_, PyType>) {}
124
125    fn show_x(&self) {
126        println!("x = {}", self.x);
127    }
128
129    fn ref_test<'a>(&self, x: Bound<'a, PyDict>) -> Bound<'a, PyDict> {
130        x
131    }
132
133    async fn async_get_x(&self) -> usize {
134        self.x
135    }
136
137    #[gen_stub(skip)]
138    fn need_skip(&self) {}
139
140    #[deprecated(since = "1.0.0", note = "This method is deprecated")]
141    fn deprecated_method(&self) {
142        println!("This method is deprecated");
143    }
144
145    #[deprecated(since = "1.0.0", note = "This method is deprecated")]
146    #[getter]
147    fn deprecated_getter(&self) -> usize {
148        self.x
149    }
150
151    #[deprecated(since = "1.0.0", note = "This setter is deprecated")]
152    #[setter]
153    fn set_y(&mut self, value: usize) {
154        self.y = value;
155    }
156
157    #[deprecated(since = "1.0.0", note = "This staticmethod is deprecated")]
158    #[staticmethod]
159    fn deprecated_staticmethod() -> usize {
160        42
161    }
162}
163
164#[gen_stub_pyfunction]
165#[pyfunction]
166#[pyo3(signature = (x = 2))]
167fn create_a(x: usize) -> A {
168    A { x, y: 10 }
169}
170
171#[gen_stub_pyclass]
172#[pyclass(extends=A)]
173#[derive(Debug)]
174struct B;
175
176/// `C` only impl `FromPyObject`
177#[derive(Debug)]
178struct C {
179    x: usize,
180}
181#[gen_stub_pyfunction]
182#[pyfunction(signature = (c=None))]
183fn print_c(c: Option<C>) {
184    if let Some(c) = c {
185        println!("{}", c.x);
186    } else {
187        println!("None");
188    }
189}
190impl FromPyObject<'_> for C {
191    fn extract_bound(ob: &Bound<'_, PyAny>) -> PyResult<Self> {
192        Ok(C { x: ob.extract()? })
193    }
194}
195impl pyo3_stub_gen::PyStubType for C {
196    fn type_output() -> pyo3_stub_gen::TypeInfo {
197        usize::type_output()
198    }
199}
200
201/// Returns the length of the string.
202#[gen_stub_pyfunction]
203#[pyfunction]
204fn str_len(x: &str) -> PyResult<usize> {
205    Ok(x.len())
206}
207
208#[gen_stub_pyfunction]
209#[pyfunction]
210fn echo_path(path: PathBuf) -> PyResult<PathBuf> {
211    Ok(path)
212}
213
214#[gen_stub_pyfunction]
215#[pyfunction]
216fn ahash_dict() -> HashMap<String, i32, RandomState> {
217    let mut map: HashMap<String, i32, RandomState> = HashMap::with_hasher(RandomState::new());
218    map.insert("apple".to_string(), 3);
219    map.insert("banana".to_string(), 2);
220    map.insert("orange".to_string(), 5);
221    map
222}
223
224#[gen_stub_pyclass_enum]
225#[pyclass(eq, eq_int)]
226#[derive(Debug, Clone, PartialEq, Eq, Hash)]
227pub enum Number {
228    #[pyo3(name = "FLOAT")]
229    Float,
230    #[pyo3(name = "INTEGER")]
231    Integer,
232}
233
234#[gen_stub_pyclass_enum]
235#[pyclass(eq, eq_int)]
236#[pyo3(rename_all = "UPPERCASE")]
237#[derive(Debug, Clone, PartialEq, Eq, Hash)]
238pub enum NumberRenameAll {
239    /// Float variant
240    Float,
241    Integer,
242}
243
244#[gen_stub_pyclass_complex_enum]
245#[pyclass]
246#[pyo3(rename_all = "UPPERCASE")]
247#[derive(Debug, Clone)]
248pub enum NumberComplex {
249    /// Float variant
250    Float(f64),
251    /// Integer variant
252    #[pyo3(constructor = (int=2))]
253    Integer {
254        /// The integer value
255        int: i32,
256    },
257}
258
259/// Example from PyO3 documentation for complex enum
260/// https://pyo3.rs/v0.25.1/class.html#complex-enums
261#[gen_stub_pyclass_complex_enum]
262#[pyclass]
263enum Shape1 {
264    Circle { radius: f64 },
265    Rectangle { width: f64, height: f64 },
266    RegularPolygon(u32, f64),
267    Nothing {},
268}
269
270/// Example from PyO3 documentation for complex enum
271/// https://pyo3.rs/v0.25.1/class.html#complex-enums
272#[gen_stub_pyclass_complex_enum]
273#[pyclass]
274enum Shape2 {
275    #[pyo3(constructor = (radius=1.0))]
276    Circle {
277        radius: f64,
278    },
279    #[pyo3(constructor = (*, width, height))]
280    Rectangle {
281        width: f64,
282        height: f64,
283    },
284    #[pyo3(constructor = (side_count, radius=1.0))]
285    RegularPolygon {
286        side_count: u32,
287        radius: f64,
288    },
289    Nothing {},
290}
291
292#[gen_stub_pymethods]
293#[pymethods]
294impl Number {
295    #[getter]
296    /// Whether the number is a float.
297    fn is_float(&self) -> bool {
298        matches!(self, Self::Float)
299    }
300
301    #[getter]
302    /// Whether the number is an integer.
303    fn is_integer(&self) -> bool {
304        matches!(self, Self::Integer)
305    }
306}
307
308#[gen_stub_pyclass]
309#[pyclass]
310pub struct DecimalHolder {
311    #[pyo3(get)]
312    value: Decimal,
313}
314
315#[gen_stub_pymethods]
316#[pymethods]
317impl DecimalHolder {
318    #[new]
319    fn new(value: Decimal) -> Self {
320        Self { value }
321    }
322}
323
324module_variable!("pure", "MY_CONSTANT1", usize);
325module_variable!("pure", "MY_CONSTANT2", usize, 123);
326
327#[gen_stub_pyfunction]
328#[pyfunction]
329async fn async_num() -> i32 {
330    123
331}
332
333#[gen_stub_pyfunction]
334#[pyfunction]
335#[deprecated(since = "1.0.0", note = "This function is deprecated")]
336fn deprecated_function() {
337    println!("This function is deprecated");
338}
339
340// Test if non-any PyObject Target can be a default value
341#[gen_stub_pyfunction]
342#[pyfunction]
343#[pyo3(signature = (num = Number::Float))]
344fn default_value(num: Number) -> Number {
345    num
346}
347
348// These are the tests to test the treatment of `*args` and `**kwargs` in functions
349
350/// Test struct for eq and ord comparison methods
351#[gen_stub_pyclass]
352#[pyclass(eq, ord)]
353#[derive(Debug, Clone, PartialEq, PartialOrd)]
354pub struct ComparableStruct {
355    #[pyo3(get)]
356    pub value: i32,
357}
358
359#[gen_stub_pymethods]
360#[pymethods]
361impl ComparableStruct {
362    #[new]
363    fn new(value: i32) -> Self {
364        Self { value }
365    }
366}
367
368/// Test struct for hash and str methods
369#[gen_stub_pyclass]
370#[pyclass(eq, hash, frozen, str)]
371#[derive(Debug, Clone, Hash, PartialEq)]
372pub struct HashableStruct {
373    #[pyo3(get)]
374    pub name: String,
375}
376
377impl std::fmt::Display for HashableStruct {
378    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
379        write!(f, "HashableStruct({})", self.name)
380    }
381}
382
383#[gen_stub_pymethods]
384#[pymethods]
385impl HashableStruct {
386    #[new]
387    fn new(name: String) -> Self {
388        Self { name }
389    }
390}
391
392/// Takes a variable number of arguments and returns their string representation.
393#[gen_stub_pyfunction]
394#[pyfunction]
395#[pyo3(signature = (*args))]
396fn func_with_star_arg_typed(
397    #[gen_stub(override_type(type_repr = "str"))] args: &Bound<PyTuple>,
398) -> String {
399    args.to_string()
400}
401
402/// Takes a variable number of arguments and returns their string representation.
403#[gen_stub_pyfunction]
404#[pyfunction]
405#[pyo3(signature = (*args))]
406fn func_with_star_arg(args: &Bound<PyTuple>) -> String {
407    args.to_string()
408}
409
410/// Takes a variable number of keyword arguments and does nothing
411#[gen_stub_pyfunction]
412#[pyfunction]
413#[pyo3(signature = (**kwargs))]
414fn func_with_kwargs(kwargs: Option<&Bound<PyDict>>) -> bool {
415    kwargs.is_some()
416}
417
418module_doc!("pure", "Document for {} ...", env!("CARGO_PKG_NAME"));
419
420/// Initializes the Python module
421#[pymodule]
422fn pure(m: &Bound<PyModule>) -> PyResult<()> {
423    m.add("MY_CONSTANT1", 19937)?;
424    m.add("MY_CONSTANT2", 123)?;
425    m.add_class::<A>()?;
426    m.add_class::<B>()?;
427    m.add_class::<MyDate>()?;
428    m.add_class::<Number>()?;
429    m.add_class::<NumberRenameAll>()?;
430    m.add_class::<NumberComplex>()?;
431    m.add_class::<Shape1>()?;
432    m.add_class::<Shape2>()?;
433    m.add_class::<ManualSubmit>()?;
434    m.add_class::<PartialManualSubmit>()?;
435    m.add_class::<OverrideType>()?;
436    m.add_class::<ComparableStruct>()?;
437    m.add_class::<HashableStruct>()?;
438    m.add_class::<DecimalHolder>()?;
439    m.add_class::<DataContainer>()?;
440    m.add_class::<Placeholder>()?;
441    m.add_class::<Calculator>()?;
442    m.add_class::<InstanceValue>()?;
443    m.add_class::<Problem>()?;
444    m.add_class::<CustomStubType>()?;
445    m.add_class::<NormalClass>()?;
446    m.add_class::<CustomEnum>()?;
447    m.add_class::<CustomComplexEnum>()?;
448    m.add_function(wrap_pyfunction!(sum, m)?)?;
449    m.add_function(wrap_pyfunction!(create_dict, m)?)?;
450    m.add_function(wrap_pyfunction!(read_dict, m)?)?;
451    m.add_function(wrap_pyfunction!(create_a, m)?)?;
452    m.add_function(wrap_pyfunction!(print_c, m)?)?;
453    m.add_function(wrap_pyfunction!(str_len, m)?)?;
454    m.add_function(wrap_pyfunction!(echo_path, m)?)?;
455    m.add_function(wrap_pyfunction!(ahash_dict, m)?)?;
456    m.add_function(wrap_pyfunction!(async_num, m)?)?;
457    m.add_function(wrap_pyfunction!(deprecated_function, m)?)?;
458    m.add_function(wrap_pyfunction!(default_value, m)?)?;
459    m.add_function(wrap_pyfunction!(fn_override_type, m)?)?;
460    m.add_function(wrap_pyfunction!(fn_with_python_param, m)?)?;
461    m.add_function(wrap_pyfunction!(fn_with_python_stub, m)?)?;
462    m.add_function(wrap_pyfunction!(overload_example_1, m)?)?;
463    m.add_function(wrap_pyfunction!(overload_example_2, m)?)?;
464    m.add_function(wrap_pyfunction!(as_tuple, m)?)?;
465    m.add_function(wrap_pyfunction!(manual_overload_example_1, m)?)?;
466    m.add_function(wrap_pyfunction!(manual_overload_example_2, m)?)?;
467    m.add_function(wrap_pyfunction!(manual_overload_as_tuple, m)?)?;
468    m.add_function(wrap_pyfunction!(add_decimals, m)?)?;
469    m.add_function(wrap_pyfunction!(process_container, m)?)?;
470    m.add_function(wrap_pyfunction!(sum_list, m)?)?;
471    m.add_function(wrap_pyfunction!(create_containers, m)?)?;
472    // Test-cases for `*args` and `**kwargs`
473    m.add_function(wrap_pyfunction!(func_with_star_arg, m)?)?;
474    m.add_function(wrap_pyfunction!(func_with_star_arg_typed, m)?)?;
475    m.add_function(wrap_pyfunction!(func_with_kwargs, m)?)?;
476
477    // Test cases for type: ignore functionality
478    m.add_function(wrap_pyfunction!(test_type_ignore_specific, m)?)?;
479    m.add_function(wrap_pyfunction!(test_type_ignore_all, m)?)?;
480    m.add_function(wrap_pyfunction!(test_type_ignore_pyright, m)?)?;
481    m.add_function(wrap_pyfunction!(test_type_ignore_custom, m)?)?;
482    m.add_function(wrap_pyfunction!(test_type_ignore_no_comment_all, m)?)?;
483    m.add_function(wrap_pyfunction!(test_type_ignore_no_comment_specific, m)?)?;
484
485    // Test case for custom exceptions
486    m.add("MyError", m.py().get_type::<MyError>())?;
487    m.add_class::<NotIntError>()?;
488
489    // Test class for type: ignore functionality
490    m.add_class::<TypeIgnoreTest>()?;
491
492    // Test cases for time crate types
493    m.add_function(wrap_pyfunction!(get_date, m)?)?;
494    m.add_function(wrap_pyfunction!(get_time, m)?)?;
495    m.add_function(wrap_pyfunction!(get_duration, m)?)?;
496    m.add_function(wrap_pyfunction!(get_primitive_datetime, m)?)?;
497    m.add_function(wrap_pyfunction!(get_offset_datetime, m)?)?;
498    m.add_function(wrap_pyfunction!(get_utc_offset, m)?)?;
499    m.add_function(wrap_pyfunction!(get_utc_datetime, m)?)?;
500    m.add_function(wrap_pyfunction!(add_duration_to_date, m)?)?;
501    m.add_function(wrap_pyfunction!(time_difference, m)?)?;
502
503    // Test cases for chrono crate types
504    m.add_function(wrap_pyfunction!(get_naive_date, m)?)?;
505    m.add_function(wrap_pyfunction!(get_naive_time, m)?)?;
506    m.add_function(wrap_pyfunction!(get_naive_datetime, m)?)?;
507    m.add_function(wrap_pyfunction!(get_datetime_utc, m)?)?;
508    m.add_function(wrap_pyfunction!(get_datetime_fixed_offset, m)?)?;
509    m.add_function(wrap_pyfunction!(get_chrono_duration, m)?)?;
510    m.add_function(wrap_pyfunction!(get_fixed_offset, m)?)?;
511    m.add_function(wrap_pyfunction!(get_utc, m)?)?;
512    m.add_function(wrap_pyfunction!(add_chrono_duration_to_date, m)?)?;
513    m.add_function(wrap_pyfunction!(naive_time_difference, m)?)?;
514    Ok(())
515}
516
517/// Test function with type: ignore for specific rules
518#[gen_stub_pyfunction]
519#[gen_stub(type_ignore = ["arg-type", "return-value"])]
520#[pyfunction]
521fn test_type_ignore_specific() -> i32 {
522    42
523}
524
525/// Test function with type: ignore (without equals for catch-all)
526#[gen_stub_pyfunction]
527#[gen_stub(type_ignore)]
528#[pyfunction]
529fn test_type_ignore_all() -> i32 {
530    42
531}
532
533/// Test function with Pyright diagnostic rules
534#[gen_stub_pyfunction]
535#[gen_stub(type_ignore = ["reportGeneralTypeIssues", "reportReturnType"])]
536#[pyfunction]
537fn test_type_ignore_pyright() -> i32 {
538    42
539}
540
541/// Test function with custom (unknown) rule
542#[gen_stub_pyfunction]
543#[gen_stub(type_ignore = ["custom-rule", "attr-defined"])]
544#[pyfunction]
545fn test_type_ignore_custom() -> i32 {
546    42
547}
548
549// NOTE: Doc-comment MUST NOT be added to the next function,
550// as it tests if `type_ignore` without no doccomment is handled correctly;
551// i.e. it emits comment after `...`, not before.
552
553#[gen_stub_pyfunction]
554#[gen_stub(type_ignore)]
555#[pyfunction]
556fn test_type_ignore_no_comment_all() -> i32 {
557    42
558}
559
560#[gen_stub_pyfunction]
561#[gen_stub(type_ignore=["arg-type", "reportIncompatibleMethodOverride"])]
562#[pyfunction]
563fn test_type_ignore_no_comment_specific() -> i32 {
564    42
565}
566
567/// Test class for method type: ignore functionality
568#[gen_stub_pyclass]
569#[pyclass]
570pub struct TypeIgnoreTest {}
571
572#[gen_stub_pymethods]
573#[pymethods]
574impl TypeIgnoreTest {
575    #[new]
576    fn new() -> Self {
577        Self {}
578    }
579
580    /// Test method with type: ignore for specific rules
581    #[gen_stub(type_ignore = ["union-attr", "return-value"])]
582    fn test_method_ignore(&self, value: i32) -> i32 {
583        value * 2
584    }
585
586    /// Test method with type: ignore (without equals for catch-all)
587    #[gen_stub(type_ignore)]
588    fn test_method_all_ignore(&self) -> i32 {
589        42
590    }
591}
592
593// Test type aliases WITHOUT docstrings (backward compatibility)
594pyo3_stub_gen::type_alias!("pure", SimpleAlias = Option<usize>);
595pyo3_stub_gen::type_alias!("pure", StrIntMap = HashMap<String, i32>);
596
597// Type alias referring to locally defined class
598pyo3_stub_gen::type_alias!(
599    "pure",
600    MaybeDecimal = Option<Bound<'static, DecimalHolder>>
601);
602
603// Direct union type syntax (no impl_stub_type! needed)
604pyo3_stub_gen::type_alias!("pure", NumberOrStringAlias = i32 | String);
605
606// Union of locally defined types using direct syntax
607pyo3_stub_gen::type_alias!("pure", StructUnion = Bound<'static, ComparableStruct> | Bound<'static, HashableStruct>);
608
609// Additional test cases for the new syntax
610pyo3_stub_gen::type_alias!("pure", TripleUnion = i32 | String | bool);
611pyo3_stub_gen::type_alias!("pure", GenericUnion = Option<i32> | Vec<String>);
612pyo3_stub_gen::type_alias!("pure", SingleTypeAlias = Option<usize>); // Backward compatibility test
613
614// Test type aliases WITH docstrings
615pyo3_stub_gen::type_alias!(
616    "pure",
617    DocumentedAlias = Option<usize>,
618    "This is a simple type alias with documentation"
619);
620
621pyo3_stub_gen::type_alias!(
622    "pure",
623    DocumentedUnion = i32 | String,
624    "A union type with documentation"
625);
626
627pyo3_stub_gen::type_alias!(
628    "pure",
629    DocumentedMap = HashMap<String, Vec<i32>>,
630    "A map type alias with detailed documentation.\n\nThis can have multiple lines of documentation."
631);
632
633// Test type aliases using Python syntax (without docstrings)
634pyo3_stub_gen::derive::gen_type_alias_from_python!(
635    "pure",
636    r#"
637    from typing import TypeAlias
638    import collections.abc
639
640    CallbackType: TypeAlias = collections.abc.Callable[[str], None]
641    OptionalCallback: TypeAlias = collections.abc.Callable[[str], None] | None
642    SequenceOfInts: TypeAlias = collections.abc.Sequence[int]
643    "#
644);
645
646// Test type aliases using Python syntax (with docstrings)
647pyo3_stub_gen::derive::gen_type_alias_from_python!(
648    "pure",
649    r#"
650    from typing import TypeAlias
651    import collections.abc
652
653    DocumentedCallback: TypeAlias = collections.abc.Callable[[str], None]
654    """A callback function that takes a string and returns nothing"""
655
656    UndocumentedCallback: TypeAlias = collections.abc.Callable[[int], bool]
657
658    MultiLineDocCallback: TypeAlias = collections.abc.Callable[[str, int], bool]
659    """
660    A callback with multi-line documentation.
661
662    This callback takes a string and an integer, and returns a boolean.
663    """
664    "#
665);
666
667// Test RustType markers in type aliases (TypeAlias syntax)
668pyo3_stub_gen::derive::gen_type_alias_from_python!(
669    "pure",
670    r#"
671    from typing import TypeAlias
672
673    SimpleContainer: TypeAlias = pyo3_stub_gen.RustType["DataContainer"]
674    ContainerList: TypeAlias = list[pyo3_stub_gen.RustType["DataContainer"]]
675    ContainerMap: TypeAlias = dict[str, pyo3_stub_gen.RustType["DataContainer"]]
676    OptionalContainer: TypeAlias = pyo3_stub_gen.RustType["DataContainer"] | None
677    "#
678);
679
680// Test RustType markers in type aliases (Python 3.12+ type statement syntax)
681pyo3_stub_gen::derive::gen_type_alias_from_python!(
682    "pure",
683    r#"
684    type ContainerTuple = tuple[pyo3_stub_gen.RustType["DataContainer"], pyo3_stub_gen.RustType["DataContainer"]]
685    type NestedContainer = list[list[pyo3_stub_gen.RustType["DataContainer"]]]
686    "#
687);
688
689define_stub_info_gatherer!(stub_info);
690
691/// Test of unit test for testing link problem
692#[cfg(test)]
693mod test {
694    #[test]
695    fn test() {
696        assert_eq!(2 + 2, 4);
697    }
698}