diff --git a/include/nanobind/nb_cast.h b/include/nanobind/nb_cast.h index 3e0dc894..e74ee489 100644 --- a/include/nanobind/nb_cast.h +++ b/include/nanobind/nb_cast.h @@ -126,9 +126,11 @@ template struct type_caster && !is_std_char_v>> { NB_INLINE bool from_python(handle src, uint8_t flags, cleanup_list *) noexcept { if constexpr (std::is_floating_point_v) { - if constexpr (sizeof(T) == 8) { + if constexpr (sizeof(T) == sizeof(double)) { // Assume T, double, and Python float are all IEEE 754 binary64 return detail::load_f64(src.ptr(), flags, (double *) &value); + } else if constexpr (std::is_same_v) { + return detail::load_f32(src.ptr(), flags, &value); } else { double d; if (!detail::load_f64(src.ptr(), flags, &d)) diff --git a/include/nanobind/nb_lib.h b/include/nanobind/nb_lib.h index 058273e2..3541bce8 100644 --- a/include/nanobind/nb_lib.h +++ b/include/nanobind/nb_lib.h @@ -512,6 +512,7 @@ NB_CORE bool load_i32(PyObject *o, uint8_t flags, int32_t *out) noexcept; NB_CORE bool load_u32(PyObject *o, uint8_t flags, uint32_t *out) noexcept; NB_CORE bool load_i64(PyObject *o, uint8_t flags, int64_t *out) noexcept; NB_CORE bool load_u64(PyObject *o, uint8_t flags, uint64_t *out) noexcept; +NB_CORE bool load_f32(PyObject *o, uint8_t flags, float *out) noexcept; NB_CORE bool load_f64(PyObject *o, uint8_t flags, double *out) noexcept; // ======================================================================== diff --git a/src/common.cpp b/src/common.cpp index f3565921..e513567c 100644 --- a/src/common.cpp +++ b/src/common.cpp @@ -8,6 +8,7 @@ */ #include +#include #include "nb_internals.h" NAMESPACE_BEGIN(NB_NAMESPACE) @@ -923,6 +924,40 @@ bool load_f64(PyObject *o, uint8_t flags, double *out) noexcept { return false; } +bool load_f32(PyObject *o, uint8_t flags, float *out) noexcept { + static_assert(std::numeric_limits::is_iec559); + static_assert(std::numeric_limits::is_iec559); + bool is_float = PyFloat_CheckExact(o); + bool convert = flags & (uint8_t) cast_flags::convert; + +#if !defined(Py_LIMITED_API) + if (NB_LIKELY(is_float)) { + double d = PyFloat_AS_DOUBLE(o); + float result = static_cast(d); + if (convert || static_cast(result) == d || d != d) { + *out = result; + return true; + } else { + return false; + } + } +#endif + + if (is_float || convert) { + double d = PyFloat_AsDouble(o); + if (d != -1.0 || !PyErr_Occurred()) { + float result = static_cast(d); + if (convert || static_cast(result) == d || d != d) { + *out = result; + return true; + } + } else { + PyErr_Clear(); + } + } + + return false; +} #if !defined(Py_LIMITED_API) && !defined(PYPY_VERSION) && PY_VERSION_HEX < 0x030c0000 // Direct access for compact integers. These functions are diff --git a/tests/test_functions.cpp b/tests/test_functions.cpp index 45d182c9..d22a3f64 100644 --- a/tests/test_functions.cpp +++ b/tests/test_functions.cpp @@ -258,6 +258,10 @@ NB_MODULE(test_functions_ext, m) { m.def("test_21_g", []() { return nb::int_(1.5); }); m.def("test_21_h", []() { return nb::int_(1e50); }); + // Test floating-point + m.def("test_21_dnc", [](double d) { return d + 1.0; }, nb::arg().noconvert()); + m.def("test_21_fnc", [](float f) { return f + 1.0f; }, nb::arg().noconvert()); + // Test capsule wrapper m.def("test_22", []() -> void * { return (void*) 1; }); m.def("test_23", []() -> void * { return nullptr; }); diff --git a/tests/test_functions.py b/tests/test_functions.py index cbd72235..026a4d1d 100644 --- a/tests/test_functions.py +++ b/tests/test_functions.py @@ -253,6 +253,18 @@ def test21_numpy_overloads(): assert t.test_11_sll(np.int32(5)) == 5 assert t.test_11_ull(np.int32(5)) == 5 + with pytest.raises(TypeError) as excinfo: + t.test_21_dnc(np.float64(21.0)) # Python type is not exactly float + assert "incompatible function arguments" in str(excinfo.value) + assert t.test_21_dnc(float(np.float64(21.0))) == 22.0 + assert t.test_21_dnc(float(np.float32(21.0))) == 22.0 + + assert t.test_21_fnc(float(np.float32(21.0))) == 22.0 + with pytest.raises(TypeError) as excinfo: + t.test_21_fnc(float(np.float64(21.1))) # Inexact narrowing to float32 + assert "incompatible function arguments" in str(excinfo.value) + assert t.test_21_fnc(float(np.float32(21.1))) == np.float32(22.1) + def test22_string_return(): assert t.test_12("hello") == "hello" diff --git a/tests/test_functions_ext.pyi.ref b/tests/test_functions_ext.pyi.ref index 3a025ba1..b8d3298b 100644 --- a/tests/test_functions_ext.pyi.ref +++ b/tests/test_functions_ext.pyi.ref @@ -120,8 +120,12 @@ def test_20(arg: str, /) -> object: ... def test_21(arg: int, /) -> int: ... +def test_21_dnc(arg: float) -> float: ... + def test_21_f(arg: float, /) -> int: ... +def test_21_fnc(arg: float) -> float: ... + def test_21_g() -> int: ... def test_21_h() -> int: ...