pyo3_stub_gen/
util.rs

1use pyo3::{prelude::*, types::*};
2use std::{borrow::Cow, ffi::CString};
3
4pub fn all_builtin_types(any: &Bound<'_, PyAny>) -> bool {
5    if any.is_instance_of::<PyString>()
6        || any.is_instance_of::<PyBool>()
7        || any.is_instance_of::<PyInt>()
8        || any.is_instance_of::<PyFloat>()
9        || any.is_instance_of::<PyComplex>()
10        || any.is_none()
11    {
12        return true;
13    }
14    if any.is_instance_of::<PyDict>() {
15        return any
16            .downcast::<PyDict>()
17            .map(|dict| {
18                dict.into_iter()
19                    .all(|(k, v)| all_builtin_types(&k) && all_builtin_types(&v))
20            })
21            .unwrap_or(false);
22    }
23    if any.is_instance_of::<PyList>() {
24        return any
25            .downcast::<PyList>()
26            .map(|list| list.into_iter().all(|v| all_builtin_types(&v)))
27            .unwrap_or(false);
28    }
29    if any.is_instance_of::<PyTuple>() {
30        return any
31            .downcast::<PyTuple>()
32            .map(|list| list.into_iter().all(|v| all_builtin_types(&v)))
33            .unwrap_or(false);
34    }
35    false
36}
37
38/// whether eval(repr(any)) == any
39pub fn valid_external_repr(any: &Bound<'_, PyAny>) -> Option<bool> {
40    let globals = get_globals(any).ok()?;
41    let fmt_str = any.repr().ok()?.to_string();
42    let fmt_cstr = CString::new(fmt_str.clone()).ok()?;
43    let new_any = any.py().eval(&fmt_cstr, Some(&globals), None).ok()?;
44    new_any.eq(any).ok()
45}
46
47fn get_globals<'py>(any: &Bound<'py, PyAny>) -> PyResult<Bound<'py, PyDict>> {
48    let type_object = any.get_type();
49    let type_name = type_object.getattr("__name__")?;
50    let type_name: Cow<str> = type_name.extract()?;
51    let globals = PyDict::new(any.py());
52    globals.set_item(type_name, type_object)?;
53    Ok(globals)
54}
55
56/// Check if a PyFloat is a special value (inf, -inf, nan) and return its Python repr.
57///
58/// Python's `repr(float('inf'))` returns `"inf"`, which is not a self-contained
59/// expression that can be evaluated without first defining or importing `inf`.
60/// This function returns `float('inf')` style which works without imports.
61///
62/// FIXME: This only handles top-level PyFloat. Containers like `[inf]` or complex
63/// numbers like `(inf+0j)` are not yet handled and will produce invalid stubs.
64#[cfg(feature = "infer_signature")]
65fn try_special_float_repr(any: &Bound<'_, PyAny>) -> Option<String> {
66    if !any.is_instance_of::<PyFloat>() {
67        return None;
68    }
69    let value: f64 = any.extract().ok()?;
70    if value.is_nan() {
71        Some("float('nan')".to_string())
72    } else if value.is_infinite() {
73        if value.is_sign_positive() {
74            Some("float('inf')".to_string())
75        } else {
76            Some("float('-inf')".to_string())
77        }
78    } else {
79        None
80    }
81}
82
83#[cfg_attr(not(feature = "infer_signature"), allow(unused_variables))]
84pub fn fmt_py_obj<T: for<'py> pyo3::IntoPyObjectExt<'py>>(obj: T) -> String {
85    #[cfg(feature = "infer_signature")]
86    {
87        pyo3::Python::initialize();
88        pyo3::Python::attach(|py| -> String {
89            if let Ok(any) = obj.into_bound_py_any(py) {
90                // Check for special float values first (inf, nan)
91                if let Some(special) = try_special_float_repr(&any) {
92                    return special;
93                }
94                if all_builtin_types(&any) || valid_external_repr(&any).is_some_and(|valid| valid) {
95                    if let Ok(py_str) = any.repr() {
96                        return py_str.to_string();
97                    }
98                }
99            }
100            "...".to_owned()
101        })
102    }
103    #[cfg(not(feature = "infer_signature"))]
104    {
105        "...".to_owned()
106    }
107}
108
109#[cfg(all(test, feature = "infer_signature"))]
110mod test {
111    use super::*;
112    #[pyclass]
113    #[derive(Debug)]
114    struct A {}
115    #[test]
116    fn test_fmt_dict() {
117        pyo3::Python::initialize();
118        pyo3::Python::attach(|py| {
119            let dict = PyDict::new(py);
120            _ = dict.set_item("k1", "v1");
121            _ = dict.set_item("k2", 2);
122            assert_eq!("{'k1': 'v1', 'k2': 2}", fmt_py_obj(dict.as_unbound()));
123            // class A variable can not be formatted
124            _ = dict.set_item("k3", A {});
125            assert_eq!("...", fmt_py_obj(dict.as_unbound()));
126        })
127    }
128    #[test]
129    fn test_fmt_list() {
130        pyo3::Python::initialize();
131        pyo3::Python::attach(|py| {
132            let list = PyList::new(py, [1, 2]).unwrap();
133            assert_eq!("[1, 2]", fmt_py_obj(list.as_unbound()));
134            // class A variable can not be formatted
135            let list = PyList::new(py, [A {}, A {}]).unwrap();
136            assert_eq!("...", fmt_py_obj(list.as_unbound()));
137        })
138    }
139    #[test]
140    fn test_fmt_tuple() {
141        pyo3::Python::initialize();
142        pyo3::Python::attach(|py| {
143            let tuple = PyTuple::new(py, [1, 2]).unwrap();
144            assert_eq!("(1, 2)", fmt_py_obj(tuple.as_unbound()));
145            let tuple = PyTuple::new(py, [1]).unwrap();
146            assert_eq!("(1,)", fmt_py_obj(tuple.as_unbound()));
147            // class A variable can not be formatted
148            let tuple = PyTuple::new(py, [A {}]).unwrap();
149            assert_eq!("...", fmt_py_obj(tuple.as_unbound()));
150        })
151    }
152    #[test]
153    fn test_fmt_other() {
154        // str
155        assert_eq!("'123'", fmt_py_obj("123"));
156        assert_eq!("\"don't\"", fmt_py_obj("don't"));
157        assert_eq!("'str\\\\'", fmt_py_obj("str\\"));
158        // bool
159        assert_eq!("True", fmt_py_obj(true));
160        assert_eq!("False", fmt_py_obj(false));
161        // int
162        assert_eq!("123", fmt_py_obj(123));
163        // float
164        assert_eq!("1.23", fmt_py_obj(1.23));
165        // None
166        let none: Option<usize> = None;
167        assert_eq!("None", fmt_py_obj(none));
168        // class A variable can not be formatted
169        assert_eq!("...", fmt_py_obj(A {}));
170    }
171    #[test]
172    fn test_fmt_special_float_values() {
173        // Special float values should be converted to valid Python syntax
174        assert_eq!("float('inf')", fmt_py_obj(f64::INFINITY));
175        assert_eq!("float('-inf')", fmt_py_obj(f64::NEG_INFINITY));
176        assert_eq!("float('nan')", fmt_py_obj(f64::NAN));
177        // f32 special values should also work
178        assert_eq!("float('inf')", fmt_py_obj(f32::INFINITY));
179        assert_eq!("float('-inf')", fmt_py_obj(f32::NEG_INFINITY));
180        assert_eq!("float('nan')", fmt_py_obj(f32::NAN));
181    }
182    #[test]
183    fn test_fmt_enum() {
184        #[pyclass(eq, eq_int)]
185        #[derive(Debug, Clone, PartialEq, Eq, Hash)]
186        pub enum Number {
187            Float,
188            Integer,
189        }
190        assert_eq!("Number.Float", fmt_py_obj(Number::Float));
191    }
192}