Skip to content

Commit

Permalink
Improve type_caster for floating-point types.
Browse files Browse the repository at this point in the history
  • Loading branch information
hpkfft committed Dec 20, 2024
1 parent b7c4f1a commit 1e008be
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 34 deletions.
20 changes: 16 additions & 4 deletions include/nanobind/nb_cast.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,22 @@ template <typename T>
struct type_caster<T, enable_if_t<std::is_arithmetic_v<T> && !is_std_char_v<T>>> {
NB_INLINE bool from_python(handle src, uint8_t flags, cleanup_list *) noexcept {
if constexpr (std::is_floating_point_v<T>) {
if constexpr (sizeof(T) == 8)
return detail::load_f64(src.ptr(), flags, &value);
else
return detail::load_f32(src.ptr(), flags, &value);
if constexpr (sizeof(T) == 8) {
// Assume T, double, and Python float are all IEEE 754 binary64
return detail::load_f64(src.ptr(), flags, (double *) &value);
} else {
double d;
if (!detail::load_f64(src.ptr(), flags, &d))
return false;
T result = static_cast<T>(d);
if ((flags & (uint8_t) cast_flags::convert)
|| static_cast<double>(result) == d
|| (result != result && d != d)) {
value = result;
return true;
}
return false;
}
} else {
if constexpr (std::is_signed_v<T>) {
if constexpr (sizeof(T) == 8)
Expand Down
1 change: 0 additions & 1 deletion include/nanobind/nb_lib.h
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,6 @@ NB_CORE bool load_i32(PyObject *o, uint8_t flags, int32_t *out) noexcept;
NB_CORE bool load_u32(PyObject *o, uint8_t flags, uint32_t *out) noexcept;
NB_CORE bool load_i64(PyObject *o, uint8_t flags, int64_t *out) noexcept;
NB_CORE bool load_u64(PyObject *o, uint8_t flags, uint64_t *out) noexcept;
NB_CORE bool load_f32(PyObject *o, uint8_t flags, float *out) noexcept;
NB_CORE bool load_f64(PyObject *o, uint8_t flags, double *out) noexcept;

// ========================================================================
Expand Down
31 changes: 2 additions & 29 deletions src/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -904,18 +904,16 @@ bool load_f64(PyObject *o, uint8_t flags, double *out) noexcept {

#if !defined(Py_LIMITED_API)
if (NB_LIKELY(is_float)) {
*out = (double) PyFloat_AS_DOUBLE(o);
*out = PyFloat_AS_DOUBLE(o);
return true;
}

is_float = false;
#endif

if (is_float || (flags & (uint8_t) cast_flags::convert)) {
double result = PyFloat_AsDouble(o);

if (result != -1.0 || !PyErr_Occurred()) {
*out = (double) result;
*out = result;
return true;
} else {
PyErr_Clear();
Expand All @@ -925,31 +923,6 @@ bool load_f64(PyObject *o, uint8_t flags, double *out) noexcept {
return false;
}

bool load_f32(PyObject *o, uint8_t flags, float *out) noexcept {
bool is_float = PyFloat_CheckExact(o);

#if !defined(Py_LIMITED_API)
if (NB_LIKELY(is_float)) {
*out = (float) PyFloat_AS_DOUBLE(o);
return true;
}

is_float = false;
#endif

if (is_float || (flags & (uint8_t) cast_flags::convert)) {
double result = PyFloat_AsDouble(o);

if (result != -1.0 || !PyErr_Occurred()) {
*out = (float) result;
return true;
} else {
PyErr_Clear();
}
}

return false;
}

#if !defined(Py_LIMITED_API) && !defined(PYPY_VERSION) && PY_VERSION_HEX < 0x030c0000
// Direct access for compact integers. These functions are
Expand Down

0 comments on commit 1e008be

Please sign in to comment.