Skip to content

Commit e9910e0

Browse files
authored
Merge pull request jax-ml#2642 from jakevdp/trunc
Add implementation of np.trunc
2 parents 52e779c + 46a8922 commit e9910e0

2 files changed

Lines changed: 7 additions & 0 deletions

File tree

jax/numpy/lax_numpy.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -653,6 +653,12 @@ def signbit(x):
653653
x = lax.bitcast_convert_type(x, int_type)
654654
return lax.convert_element_type(x >> (info.nexp + info.nmant), onp.bool)
655655

656+
657+
@_wraps(onp.trunc)
658+
def trunc(x):
659+
return where(lax.lt(x, lax._const(x, 0)), lax.ceil(x), lax.floor(x))
660+
661+
656662
def _normalize_float(x):
657663
info = finfo(_dtype(x))
658664
cond = lax.abs(x) < info.tiny

tests/lax_numpy_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ def op_record(name, nargs, dtypes, shapes, rng_factory, diff_modes,
132132
op_record("subtract", 2, number_dtypes, all_shapes, jtu.rand_default, ["rev"]),
133133
op_record("signbit", 1, default_dtypes + bool_dtypes, all_shapes,
134134
jtu.rand_some_inf_and_nan, ["rev"]),
135+
op_record("trunc", 1, float_dtypes, all_shapes, jtu.rand_some_inf_and_nan, []),
135136
op_record("sin", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"],
136137
inexact=True),
137138
op_record("cos", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"],

0 commit comments

Comments
 (0)