Skip to content

Commit

Permalink
Add dedicated detail::load_f32() function.
Browse files Browse the repository at this point in the history
  • Loading branch information
hpkfft committed Jan 7, 2025
1 parent 1e008be commit b025460
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 1 deletion.
4 changes: 3 additions & 1 deletion include/nanobind/nb_cast.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,11 @@ template <typename T>
struct type_caster<T, enable_if_t<std::is_arithmetic_v<T> && !is_std_char_v<T>>> {
NB_INLINE bool from_python(handle src, uint8_t flags, cleanup_list *) noexcept {
if constexpr (std::is_floating_point_v<T>) {
if constexpr (sizeof(T) == 8) {
if constexpr (sizeof(T) == sizeof(double)) {
// Assume T, double, and Python float are all IEEE 754 binary64
return detail::load_f64(src.ptr(), flags, (double *) &value);
} else if constexpr (std::is_same_v<T, float>) {
return detail::load_f32(src.ptr(), flags, &value);
} else {
double d;
if (!detail::load_f64(src.ptr(), flags, &d))
Expand Down
1 change: 1 addition & 0 deletions include/nanobind/nb_lib.h
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,7 @@ NB_CORE bool load_i32(PyObject *o, uint8_t flags, int32_t *out) noexcept;
NB_CORE bool load_u32(PyObject *o, uint8_t flags, uint32_t *out) noexcept;
NB_CORE bool load_i64(PyObject *o, uint8_t flags, int64_t *out) noexcept;
NB_CORE bool load_u64(PyObject *o, uint8_t flags, uint64_t *out) noexcept;
NB_CORE bool load_f32(PyObject *o, uint8_t flags, float *out) noexcept;
NB_CORE bool load_f64(PyObject *o, uint8_t flags, double *out) noexcept;

// ========================================================================
Expand Down
35 changes: 35 additions & 0 deletions src/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
*/

#include <nanobind/nanobind.h>
#include <limits>
#include "nb_internals.h"

NAMESPACE_BEGIN(NB_NAMESPACE)
Expand Down Expand Up @@ -923,6 +924,40 @@ bool load_f64(PyObject *o, uint8_t flags, double *out) noexcept {
return false;
}

bool load_f32(PyObject *o, uint8_t flags, float *out) noexcept {
static_assert(std::numeric_limits<float>::is_iec559);
static_assert(std::numeric_limits<double>::is_iec559);
bool is_float = PyFloat_CheckExact(o);
bool convert = flags & (uint8_t) cast_flags::convert;

#if !defined(Py_LIMITED_API)
if (NB_LIKELY(is_float)) {
double d = PyFloat_AS_DOUBLE(o);
float result = static_cast<float>(d);
if (convert || static_cast<double>(result) == d || d != d) {
*out = result;
return true;
} else {
return false;
}
}
#endif

if (is_float || convert) {
double d = PyFloat_AsDouble(o);
if (d != -1.0 || !PyErr_Occurred()) {
float result = static_cast<float>(d);
if (convert || static_cast<double>(result) == d || d != d) {
*out = result;
return true;
}
} else {
PyErr_Clear();
}
}

return false;
}

#if !defined(Py_LIMITED_API) && !defined(PYPY_VERSION) && PY_VERSION_HEX < 0x030c0000
// Direct access for compact integers. These functions are
Expand Down
4 changes: 4 additions & 0 deletions tests/test_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,10 @@ NB_MODULE(test_functions_ext, m) {
m.def("test_21_g", []() { return nb::int_(1.5); });
m.def("test_21_h", []() { return nb::int_(1e50); });

// Test floating-point
m.def("test_21_dnc", [](double d) { return d + 1.0; }, nb::arg().noconvert());
m.def("test_21_fnc", [](float f) { return f + 1.0f; }, nb::arg().noconvert());

// Test capsule wrapper
m.def("test_22", []() -> void * { return (void*) 1; });
m.def("test_23", []() -> void * { return nullptr; });
Expand Down
12 changes: 12 additions & 0 deletions tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,18 @@ def test21_numpy_overloads():
assert t.test_11_sll(np.int32(5)) == 5
assert t.test_11_ull(np.int32(5)) == 5

with pytest.raises(TypeError) as excinfo:
t.test_21_dnc(np.float64(21.0)) # Python type is not exactly float
assert "incompatible function arguments" in str(excinfo.value)
assert t.test_21_dnc(float(np.float64(21.0))) == 22.0
assert t.test_21_dnc(float(np.float32(21.0))) == 22.0

assert t.test_21_fnc(float(np.float32(21.0))) == 22.0
with pytest.raises(TypeError) as excinfo:
t.test_21_fnc(float(np.float64(21.1))) # Inexact narrowing to float32
assert "incompatible function arguments" in str(excinfo.value)
assert t.test_21_fnc(float(np.float32(21.1))) == np.float32(22.1)


def test22_string_return():
assert t.test_12("hello") == "hello"
Expand Down
4 changes: 4 additions & 0 deletions tests/test_functions_ext.pyi.ref
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,12 @@ def test_20(arg: str, /) -> object: ...

def test_21(arg: int, /) -> int: ...

def test_21_dnc(arg: float) -> float: ...

def test_21_f(arg: float, /) -> int: ...

def test_21_fnc(arg: float) -> float: ...

def test_21_g() -> int: ...

def test_21_h() -> int: ...
Expand Down

0 comments on commit b025460

Please sign in to comment.