pyo3_stub_gen/
util.rs

1use pyo3::{prelude::*, types::*};
2use std::{borrow::Cow, ffi::CString};
3
4pub fn all_builtin_types(any: &Bound<'_, PyAny>) -> bool {
5    if any.is_instance_of::<PyString>()
6        || any.is_instance_of::<PyBool>()
7        || any.is_instance_of::<PyInt>()
8        || any.is_instance_of::<PyFloat>()
9        || any.is_instance_of::<PyComplex>()
10        || any.is_none()
11    {
12        return true;
13    }
14    if any.is_instance_of::<PyDict>() {
15        return any
16            .downcast::<PyDict>()
17            .map(|dict| {
18                dict.into_iter()
19                    .all(|(k, v)| all_builtin_types(&k) && all_builtin_types(&v))
20            })
21            .unwrap_or(false);
22    }
23    if any.is_instance_of::<PyList>() {
24        return any
25            .downcast::<PyList>()
26            .map(|list| list.into_iter().all(|v| all_builtin_types(&v)))
27            .unwrap_or(false);
28    }
29    if any.is_instance_of::<PyTuple>() {
30        return any
31            .downcast::<PyTuple>()
32            .map(|list| list.into_iter().all(|v| all_builtin_types(&v)))
33            .unwrap_or(false);
34    }
35    false
36}
37
38/// whether eval(repr(any)) == any
39pub fn valid_external_repr(any: &Bound<'_, PyAny>) -> Option<bool> {
40    let globals = get_globals(any).ok()?;
41    let fmt_str = any.repr().ok()?.to_string();
42    let fmt_cstr = CString::new(fmt_str.clone()).ok()?;
43    let new_any = any.py().eval(&fmt_cstr, Some(&globals), None).ok()?;
44    new_any.eq(any).ok()
45}
46
47fn get_globals<'py>(any: &Bound<'py, PyAny>) -> PyResult<Bound<'py, PyDict>> {
48    let type_object = any.get_type();
49    let type_name = type_object.getattr("__name__")?;
50    let type_name: Cow<str> = type_name.extract()?;
51    let globals = PyDict::new(any.py());
52    globals.set_item(type_name, type_object)?;
53    Ok(globals)
54}
55
56#[cfg_attr(not(feature = "infer_signature"), allow(unused_variables))]
57pub fn fmt_py_obj<T: for<'py> pyo3::IntoPyObjectExt<'py>>(obj: T) -> String {
58    #[cfg(feature = "infer_signature")]
59    {
60        pyo3::Python::initialize();
61        pyo3::Python::attach(|py| -> String {
62            if let Ok(any) = obj.into_bound_py_any(py) {
63                if all_builtin_types(&any) || valid_external_repr(&any).is_some_and(|valid| valid) {
64                    if let Ok(py_str) = any.repr() {
65                        return py_str.to_string();
66                    }
67                }
68            }
69            "...".to_owned()
70        })
71    }
72    #[cfg(not(feature = "infer_signature"))]
73    {
74        "...".to_owned()
75    }
76}
77
78#[cfg(all(test, feature = "infer_signature"))]
79mod test {
80    use super::*;
81    #[pyclass]
82    #[derive(Debug)]
83    struct A {}
84    #[test]
85    fn test_fmt_dict() {
86        pyo3::Python::initialize();
87        pyo3::Python::attach(|py| {
88            let dict = PyDict::new(py);
89            _ = dict.set_item("k1", "v1");
90            _ = dict.set_item("k2", 2);
91            assert_eq!("{'k1': 'v1', 'k2': 2}", fmt_py_obj(dict.as_unbound()));
92            // class A variable can not be formatted
93            _ = dict.set_item("k3", A {});
94            assert_eq!("...", fmt_py_obj(dict.as_unbound()));
95        })
96    }
97    #[test]
98    fn test_fmt_list() {
99        pyo3::Python::initialize();
100        pyo3::Python::attach(|py| {
101            let list = PyList::new(py, [1, 2]).unwrap();
102            assert_eq!("[1, 2]", fmt_py_obj(list.as_unbound()));
103            // class A variable can not be formatted
104            let list = PyList::new(py, [A {}, A {}]).unwrap();
105            assert_eq!("...", fmt_py_obj(list.as_unbound()));
106        })
107    }
108    #[test]
109    fn test_fmt_tuple() {
110        pyo3::Python::initialize();
111        pyo3::Python::attach(|py| {
112            let tuple = PyTuple::new(py, [1, 2]).unwrap();
113            assert_eq!("(1, 2)", fmt_py_obj(tuple.as_unbound()));
114            let tuple = PyTuple::new(py, [1]).unwrap();
115            assert_eq!("(1,)", fmt_py_obj(tuple.as_unbound()));
116            // class A variable can not be formatted
117            let tuple = PyTuple::new(py, [A {}]).unwrap();
118            assert_eq!("...", fmt_py_obj(tuple.as_unbound()));
119        })
120    }
121    #[test]
122    fn test_fmt_other() {
123        // str
124        assert_eq!("'123'", fmt_py_obj("123"));
125        assert_eq!("\"don't\"", fmt_py_obj("don't"));
126        assert_eq!("'str\\\\'", fmt_py_obj("str\\"));
127        // bool
128        assert_eq!("True", fmt_py_obj(true));
129        assert_eq!("False", fmt_py_obj(false));
130        // int
131        assert_eq!("123", fmt_py_obj(123));
132        // float
133        assert_eq!("1.23", fmt_py_obj(1.23));
134        // None
135        let none: Option<usize> = None;
136        assert_eq!("None", fmt_py_obj(none));
137        // class A variable can not be formatted
138        assert_eq!("...", fmt_py_obj(A {}));
139    }
140    #[test]
141    fn test_fmt_enum() {
142        #[pyclass(eq, eq_int)]
143        #[derive(Debug, Clone, PartialEq, Eq, Hash)]
144        pub enum Number {
145            Float,
146            Integer,
147        }
148        assert_eq!("Number.Float", fmt_py_obj(Number::Float));
149    }
150}