diff --git a/CMakeLists.txt b/CMakeLists.txt index 76db52436a..2eb0d6c34e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -235,6 +235,7 @@ list(APPEND HWY_CONTRIB_SOURCES hwy/contrib/sort/sorting_networks-inl.h hwy/contrib/sort/traits-inl.h hwy/contrib/sort/traits128-inl.h + hwy/contrib/sort/order-emulate-inl.h hwy/contrib/sort/vqsort-inl.h hwy/contrib/sort/vqsort.cc hwy/contrib/sort/vqsort.h diff --git a/hwy/contrib/sort/BUILD b/hwy/contrib/sort/BUILD index 6171029d4f..082587918f 100644 --- a/hwy/contrib/sort/BUILD +++ b/hwy/contrib/sort/BUILD @@ -115,6 +115,7 @@ VQSORT_TEXTUAL_HDRS = [ "sorting_networks-inl.h", "traits-inl.h", "traits128-inl.h", + "order-emulate-inl.h", "vqsort-inl.h", # Placeholder for internal instrumentation. Do not remove. ] diff --git a/hwy/contrib/sort/order-emulate-inl.h b/hwy/contrib/sort/order-emulate-inl.h new file mode 100644 index 0000000000..ade42803b6 --- /dev/null +++ b/hwy/contrib/sort/order-emulate-inl.h @@ -0,0 +1,225 @@ +// Emulated floating-point total order +// +// This implementation sorts floating-point values by reinterpreting them as +// unsigned integer bit patterns instead of using the FPU. It does not depend on +// the floating-point control register, so there is no flush-to-zero handling. +// +// NaNs are already replaced by ±Inf before calling this code, so no special +// handling is needed here. +// Because ordering is emulated, we guarantee a stable rule for zeros: -0.0 +// always comes before +0.0. +// +// SPDX-License-Identifier: BSD-3-Clause +#if defined(HIGHWAY_HWY_CONTRIB_SORT_ORDER_EMULATE_TOGGLE) == \ + defined(HWY_TARGET_TOGGLE) +#ifdef HIGHWAY_HWY_CONTRIB_SORT_ORDER_EMULATE_TOGGLE +#undef HIGHWAY_HWY_CONTRIB_SORT_ORDER_EMULATE_TOGGLE +#else +#define HIGHWAY_HWY_CONTRIB_SORT_ORDER_EMULATE_TOGGLE +#endif + +#include +#include + +#include "hwy/contrib/sort/order.h" // SortDescending +#include "hwy/highway.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { +namespace detail { + + +template , class DI = RebindToSigned> +HWY_API Vec LtBinKey(VU a) { + using TI = TFromD; + using VI = Vec; + const DI di; + const VI neg_flip = Set(di, TI(SignMask() - 1)); + return Xor(BitCast(di, a), IfNegativeThenElseZero(BitCast(di, a), neg_flip)); +} + +template , class M = MFromD, HWY_IF_UNSIGNED_D(D)> +HWY_API M LtBin(VU a, VU b) { + return RebindMask(D{}, Lt(LtBinKey(a), LtBinKey(b))); +} + +template +struct OrderEmulate : public Base { + using T = typename Base::LaneType; + using TF = typename Base::KeyType; + + HWY_INLINE bool Equal1(const T* a, const T* b) const { + return *a == *b; + } + + template + HWY_INLINE Mask EqualKeys(D, Vec a, Vec b) const { + return Eq(a, b); // Bitwise equality, -0 != +0, +-NaN is equal to itself + } + + template + HWY_INLINE Mask NotEqualKeys(D, Vec a, Vec b) const { + return Ne(a, b); // bitwise inequality, -0 != +0, +-NaN is equal to itself + } + + HWY_INLINE bool Compare1(const T* a_, const T* b_) const { + const T a = *a_; + const T b = *b_; + // specialized less than, -0.0 < +0.0, and NaNs are not ordered + using TI = MakeSigned; + constexpr int kMSB = 8 * sizeof(T) - 1; + constexpr T neg_flip = T((T(1) << kMSB) - 1); + const T a_neg = 0 - (a >> kMSB); + const T b_neg = 0 - (b >> kMSB); + // Signed-domain keys (xor 0x7FFF.. only for negatives) + const T sa = a ^ (a_neg & neg_flip); + const T sb = b ^ (b_neg & neg_flip); + return static_cast(sa) < static_cast(sb); + } + template + HWY_INLINE Mask Compare(D, Vec a, Vec b) const { + // specialized less than, -0.0 < +0.0, and NaNs are not ordered + return LtBin(a, b); + } + + // Two halves of Sort2, used in ScanMinMax. + template + HWY_INLINE Vec First(D /* tag */, const Vec a, const Vec b) const { + return IfThenElse(LtBin(a, b), a, b); + } + + template + HWY_INLINE Vec Last(D /* tag */, const Vec a, const Vec b) const { + return IfThenElse(LtBin(a, b), b, a); + } + + template + HWY_INLINE Vec FirstOfLanes(D d, Vec v, + T* HWY_RESTRICT /* buf */) const { + const RebindToSigned di; + using VI = Vec; + VI key = LtBinKey(v); + VI min = MinOfLanes(di, key); + Mask m = RebindMask(d, Eq(min, key)); + return MaxOfLanes(d, IfThenElseZero(m, v)); + } + + template + HWY_INLINE Vec LastOfLanes(D d, Vec v, + T* HWY_RESTRICT /* buf */) const { + const RebindToSigned di; + using VI = Vec; + VI key = LtBinKey(v); + VI max = MaxOfLanes(di, key); + Mask m = RebindMask(d, Eq(max, key)); + return MaxOfLanes(d, IfThenElseZero(m, v)); + } + + template + HWY_INLINE Vec FirstValue(D d) const { + return Set(d, BitCastScalar(NegativeInfOrLowestValue())); + } + + template + HWY_INLINE Vec LastValue(D d) const { + return Set(d, BitCastScalar(PositiveInfOrHighestValue())); + } + + // Returns the next distinct smaller value unless already -inf. + template > + HWY_INLINE V PrevValue(D, V v) const { + return NextSortValueBits(v); + } + + // Next representable value in total order by ±1 ULP, saturating at ±Inf. + // IsDown = false → next larger + // IsDown = true → next smaller + template + HWY_INLINE V NextSortValueBits(V u) const { + const DFromV d; + using M = Mask; + constexpr T kSignBit = SignMask(); + constexpr T kBoundaryUp = SignMask() - 1; + const V sign_bit = Set(d, kSignBit); + const V all1 = Set(d, T(~T(0))); + const V one = Set(d, T(1)); + // Detect saturation at ±Inf + const M is_target_inf = Eq(u, IsDown ? FirstValue(d) : LastValue(d)); + // Transform to monotonic space: flip sign for positives, invert for negatives + const M is_neg = TestBit(u, sign_bit); + const V key = Xor(u, IfThenElse(is_neg, all1, sign_bit)); + // Boundary detection: +0/-0 swap needs a step of 2 instead of 1 + const V boundary = Set(d, IsDown ? kSignBit : kBoundaryUp); + const M at_boundary = Eq(key, boundary); + // Step size: normally 1, but 2 at zero-boundary + const V step = Add(one, IfThenElseZero(at_boundary, one)); + // Apply increment/decrement unless already at ±Inf + const V key2 = IfThenElse(is_target_inf, key, + IsDown ? Sub(key, step) : Add(key, step)); + // Transform back from monotonic space + const M neg_out = Lt(key2, sign_bit); + return Xor(key2, IfThenElse(neg_out, all1, sign_bit)); + } +}; + +template +struct OrderEmulate : public OrderEmulate { + using T = typename Base::LaneType; + + HWY_INLINE const OrderEmulate& AscBase() const { + return *this; + } + + HWY_INLINE bool Compare1(const T* a, const T* b) const { + return AscBase().Compare1(b, a); + } + template + HWY_INLINE Mask Compare(D d, Vec a, Vec b) const { + return AscBase().Compare(d, b, a); + } + + template + HWY_INLINE Vec First(D d, const Vec a, const Vec b) const { + return AscBase().Last(d, a, b); + } + + template + HWY_INLINE Vec Last(D d, const Vec a, const Vec b) const { + return AscBase().First(d, a, b); + } + + template + HWY_INLINE Vec FirstOfLanes(D d, Vec v, + T* HWY_RESTRICT b) const { + return AscBase().LastOfLanes(d, v, b); + } + + template + HWY_INLINE Vec LastOfLanes(D d, Vec v, + T* HWY_RESTRICT b) const { + return AscBase().FirstOfLanes(d, v, b); + } + + template + HWY_INLINE Vec FirstValue(D d) const { + return AscBase().LastValue(d); + } + + template + HWY_INLINE Vec LastValue(D d) const { + return AscBase().FirstValue(d); + } + + template > + HWY_INLINE V PrevValue(D, V v) const { + return this->template NextSortValueBits(v); + } +}; + +} // namespace detail +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#endif // HIGHWAY_HWY_CONTRIB_SORT_ORDER_EMULATE_TOGGLE diff --git a/hwy/contrib/sort/sort_test.cc b/hwy/contrib/sort/sort_test.cc index 588b481d13..dd61af5739 100644 --- a/hwy/contrib/sort/sort_test.cc +++ b/hwy/contrib/sort/sort_test.cc @@ -61,17 +61,32 @@ using detail::OrderDescendingKV128; using detail::Traits128; #endif // !HAVE_INTEL && HWY_TARGET != HWY_SCALAR +template +inline void IotaWrapper(T *first, T *last, T val) { + std::iota(first, last, val); +} +// Emulate std::iota for hwy::float16_t, some compliers mostly on ARM & Longarch +// complain about operator++ on that type. +template <> +inline void IotaWrapper(hwy::float16_t *first, hwy::float16_t *last, + hwy::float16_t val) { + float v = ConvertScalarTo(val); + for (; first != last; ++first, ++v) { + *first = ConvertScalarTo(v); + } +} + template void TestSortIota(hwy::ThreadPool& pool) { pool.Run(128, 300, [](uint64_t task, size_t /*thread*/) { const size_t num = static_cast(task); Key keys[300]; - std::iota(keys, keys + num, Key{0}); + IotaWrapper(keys, keys + num, ConvertScalarTo(0)); VQSort(keys, num, hwy::SortAscending()); for (size_t i = 0; i < num; ++i) { - if (keys[i] != static_cast(i)) { + if (keys[i] != ConvertScalarTo(i)) { HWY_ABORT("num %zu i %zu: not iota, got %.0f\n", num, i, - static_cast(keys[i])); + ConvertScalarTo(keys[i])); } } }); @@ -86,10 +101,9 @@ void TestAllSortIota() { TestSortIota(pool); TestSortIota(pool); } + TestSortIota(pool); TestSortIota(pool); - if (hwy::HaveFloat64()) { - TestSortIota(pool); - } + TestSortIota(pool); #endif } diff --git a/hwy/contrib/sort/sort_unit_test.cc b/hwy/contrib/sort/sort_unit_test.cc index 75d2b1c130..1dd5a6da98 100644 --- a/hwy/contrib/sort/sort_unit_test.cc +++ b/hwy/contrib/sort/sort_unit_test.cc @@ -54,81 +54,124 @@ using detail::Traits128; #if VQSORT_ENABLED || HWY_IDE +template +void ForSortFloatTypesDynamic(const Func& func) { + func(hwy::float16_t()); + func(float()); + func(double()); +} + +template +inline T AddWrapper(T a, T b) { + return a + b; +} +template <> +inline hwy::float16_t AddWrapper(hwy::float16_t a, hwy::float16_t b) { + return ConvertScalarTo(ConvertScalarTo(a) + ConvertScalarTo(b)); +} + // Verify the corner cases of LargerSortValue/SmallerSortValue, used to // implement PrevValue/NextValue. struct TestFloatLargerSmaller { - template - HWY_NOINLINE void operator()(T, D d) { - const Vec p0 = Zero(d); - const Vec p1 = Set(d, ConvertScalarTo(1)); - const Vec pinf = Inf(d); - const Vec peps = Set(d, hwy::Epsilon()); - const Vec pmax = Set(d, hwy::HighestValue()); + template + HWY_NOINLINE void operator()(TF, DF) { + TraitsLane> st_smaller; + TraitsLane> st_larger; + + using T = typename decltype(st_smaller)::LaneType; + using D = Rebind; + const RebindToUnsigned du; + const D d; - const Vec n0 = Neg(p0); - const Vec n1 = Neg(p1); - const Vec ninf = Neg(pinf); - const Vec neps = Neg(peps); - const Vec nmax = Neg(pmax); + const Vec p0 = Zero(d); + const Vec n0 = BitCast(d, Set(du, SignMask())); + + const Vec p1 = Set(d, BitCastScalar(ConvertScalarTo(1))); + const Vec pinf = BitCast(d, Set(du, ExponentMask())); + const Vec peps = Set(d, BitCastScalar(hwy::Epsilon())); + const Vec pmax = Set(d, BitCastScalar(hwy::HighestValue())); + + const Vec n1 = Or(p1, n0); + const Vec ninf = Or(pinf, n0); + const Vec neps = Or(peps, n0); + const Vec nmax = Or(pmax, n0); // Larger(0) is the smallest subnormal, typically eps * FLT_MIN. - const RebindToUnsigned du; const Vec psub = BitCast(d, Set(du, 1)); - const Vec nsub = Neg(psub); - HWY_ASSERT(AllTrue(d, Lt(psub, peps))); - HWY_ASSERT(AllTrue(d, Gt(nsub, neps))); + const Vec nsub = BitCast(d, Set(du, 1 | SignMask())); + HWY_ASSERT(AllTrue(d, st_smaller.Compare(d, psub, peps))); + HWY_ASSERT(AllTrue(d, st_larger.Compare(d, nsub, neps))); // +/-0 moves to +/- smallest subnormal. - HWY_ASSERT_VEC_EQ(d, psub, detail::LargerSortValue(d, p0)); - HWY_ASSERT_VEC_EQ(d, nsub, detail::SmallerSortValue(d, p0)); - HWY_ASSERT_VEC_EQ(d, psub, detail::LargerSortValue(d, n0)); - HWY_ASSERT_VEC_EQ(d, nsub, detail::SmallerSortValue(d, n0)); + HWY_ASSERT_VEC_EQ(d, psub, st_larger.PrevValue(d, p0)); + HWY_ASSERT_VEC_EQ(d, nsub, st_smaller.PrevValue(d, p0)); + HWY_ASSERT_VEC_EQ(d, psub, st_larger.PrevValue(d, n0)); + HWY_ASSERT_VEC_EQ(d, nsub, st_smaller.PrevValue(d, n0)); // The next magnitude larger than 1 is (1 + eps) by definition. - HWY_ASSERT_VEC_EQ(d, Add(p1, peps), detail::LargerSortValue(d, p1)); - HWY_ASSERT_VEC_EQ(d, Add(n1, neps), detail::SmallerSortValue(d, n1)); + const Vec one_plus_eps = Set(d, BitCastScalar( + AddWrapper(ConvertScalarTo(1), hwy::Epsilon()) + )); + HWY_ASSERT_VEC_EQ(d, one_plus_eps, st_larger.PrevValue(d, p1)); + HWY_ASSERT_VEC_EQ(d, Or(one_plus_eps, n0), st_smaller.PrevValue(d, n1)); + // 1-eps and -1+eps are slightly different, but we can still ensure the // next values are less than 1 / greater than -1. - HWY_ASSERT(AllTrue(d, Gt(p1, detail::SmallerSortValue(d, p1)))); - HWY_ASSERT(AllTrue(d, Lt(n1, detail::LargerSortValue(d, n1)))); + HWY_ASSERT(AllTrue(d, st_larger.Compare(d, p1, st_smaller.PrevValue(d, p1)))); + HWY_ASSERT(AllTrue(d, st_smaller.Compare(d, n1, st_larger.PrevValue(d, n1)))); // Even for large (finite) values, we can move toward/away from infinity. - HWY_ASSERT_VEC_EQ(d, pinf, detail::LargerSortValue(d, pmax)); - HWY_ASSERT_VEC_EQ(d, ninf, detail::SmallerSortValue(d, nmax)); - HWY_ASSERT(AllTrue(d, Gt(pmax, detail::SmallerSortValue(d, pmax)))); - HWY_ASSERT(AllTrue(d, Lt(nmax, detail::LargerSortValue(d, nmax)))); + HWY_ASSERT_VEC_EQ(d, pinf, st_larger.PrevValue(d, pmax)); + HWY_ASSERT_VEC_EQ(d, ninf, st_smaller.PrevValue(d, nmax)); + HWY_ASSERT(AllTrue(d, st_larger.Compare(d, pmax, st_smaller.PrevValue(d, pmax)))); + HWY_ASSERT(AllTrue(d, st_smaller.Compare(d, nmax, st_larger.PrevValue(d, nmax)))); // For infinities, results are unchanged or the extremal finite value. - HWY_ASSERT_VEC_EQ(d, pinf, detail::LargerSortValue(d, pinf)); - HWY_ASSERT_VEC_EQ(d, pmax, detail::SmallerSortValue(d, pinf)); - HWY_ASSERT_VEC_EQ(d, nmax, detail::LargerSortValue(d, ninf)); - HWY_ASSERT_VEC_EQ(d, ninf, detail::SmallerSortValue(d, ninf)); + HWY_ASSERT_VEC_EQ(d, pinf, st_larger.PrevValue(d, pinf)); + HWY_ASSERT_VEC_EQ(d, pmax, st_smaller.PrevValue(d, pinf)); + HWY_ASSERT_VEC_EQ(d, nmax, st_larger.PrevValue(d, ninf)); + HWY_ASSERT_VEC_EQ(d, ninf, st_smaller.PrevValue(d, ninf)); } }; HWY_NOINLINE void TestAllFloatLargerSmaller() { - ForFloatTypesDynamic(ForPartialVectors()); + ForSortFloatTypesDynamic(ForPartialVectors()); } +template , HWY_IF_FLOAT_D(D)> +HWY_API Mask IsInfWrapper(V v) { + return IsInf(v); +} + +template , HWY_IF_UNSIGNED_D(D)> +HWY_API Mask IsInfWrapper(V v) { + const D d; + const V m_exp = Set(d, ExponentMask()); + const V m_mant = Set(d, MantissaMask()); + return And(Eq(And(v, m_mant), Zero(d)), Eq(And(v, m_exp), m_exp)); +} + // Previously, LastValue was the largest normal float, so we injected that // value into arrays containing only infinities. Ensure that does not happen. struct TestFloatInf { - template - HWY_NOINLINE void operator()(T, D d) { + template + HWY_NOINLINE void operator()(TF, DF) { + using T = typename TraitsLane>::LaneType; + const Rebind d; const size_t N = Lanes(d); const size_t num = N * 3; auto in = hwy::AllocateAligned(num); HWY_ASSERT(in); - Fill(d, GetLane(Inf(d)), num, in.get()); + Fill(d, BitCastScalar(PositiveInfOrHighestValue()), num, in.get()); VQSort(in.get(), num, SortAscending()); for (size_t i = 0; i < num; i += N) { - HWY_ASSERT(AllTrue(d, IsInf(LoadU(d, in.get() + i)))); + HWY_ASSERT(AllTrue(d, IsInfWrapper(LoadU(d, in.get() + i)))); } } }; HWY_NOINLINE void TestAllFloatInf() { // TODO(janwas): bfloat16_t not yet supported. - ForFloatTypesDynamic(ForPartialVectors()); + ForSortFloatTypesDynamic(ForPartialVectors()); } template diff --git a/hwy/contrib/sort/traits-inl.h b/hwy/contrib/sort/traits-inl.h index b02ed8b9ab..96cadb7f59 100644 --- a/hwy/contrib/sort/traits-inl.h +++ b/hwy/contrib/sort/traits-inl.h @@ -27,6 +27,7 @@ #include "hwy/contrib/sort/order.h" // SortDescending #include "hwy/contrib/sort/shared-inl.h" // SortConstants +#include "hwy/contrib/sort/order-emulate-inl.h" // Soft float #include "hwy/highway.h" HWY_BEFORE_NAMESPACE(); @@ -540,12 +541,57 @@ struct OrderDescendingKV64 : public KeyValue64 { } }; +#if !HWY_HAVE_FLOAT16 +template <> +struct OrderAscending : + public OrderEmulate, SortAscending> { + using Order = SortAscending; + using OrderForSortingNetwork = OrderAscending; + static constexpr bool IsKV() { return false; } +}; + +template <> +struct OrderDescending : + public OrderEmulate, SortDescending> { + using Order = SortDescending; + using OrderForSortingNetwork = OrderDescending; + static constexpr bool IsKV() { return false; } +}; +#endif + + +#if !HWY_HAVE_FLOAT64 +template <> +struct OrderAscending : + public OrderEmulate, SortAscending> { + using Order = SortAscending; + using OrderForSortingNetwork = OrderAscending; + static constexpr bool IsKV() { return false; } +}; +template <> +struct OrderDescending : + public OrderEmulate, SortDescending> { + using Order = SortDescending; + using OrderForSortingNetwork = OrderDescending; + static constexpr bool IsKV() { return false; } +}; +#endif + // Shared code that depends on Order. template struct TraitsLane : public Base { using TraitsForSortingNetwork = TraitsLane; + static constexpr bool IsEmulatedMinMax() { + using LaneType = typename Base::LaneType; + using KeyType = typename Base::KeyType; + return ( + (HWY_AVX3 < HWY_TARGET && HWY_TARGET <= HWY_SSSE3 && sizeof(LaneType) == 8) || + (!IsFloat() && IsFloat()) // emulated float + ); + } + // For each lane i: replaces a[i] with the first and b[i] with the second // according to Base. // Corresponds to a conditional swap, which is one "node" of a sorting @@ -553,45 +599,35 @@ struct TraitsLane : public Base { template HWY_INLINE void Sort2(D d, Vec& a, Vec& b) const { const Base* base = static_cast(this); - const Vec a_copy = a; // Prior to AVX3, there is no native 64-bit Min/Max, so they compile to 4 // instructions. We can reduce it to a compare + 2 IfThenElse. -#if HWY_AVX3 < HWY_TARGET && HWY_TARGET <= HWY_SSSE3 - if (sizeof(TFromD) == 8) { + HWY_IF_CONSTEXPR (IsEmulatedMinMax()) { const Mask cmp = base->Compare(d, a, b); a = IfThenElse(cmp, a, b); b = IfThenElse(cmp, b, a_copy); - return; } -#endif - a = base->First(d, a, b); - b = base->Last(d, a_copy, b); + else { + a = base->First(d, a, b); + b = base->Last(d, a_copy, b); + } } // Conditionally swaps even-numbered lanes with their odd-numbered neighbor. - template + template HWY_INLINE Vec SortPairsDistance1(D d, Vec v) const { const Base* base = static_cast(this); Vec swapped = base->ReverseKeys2(d, v); // Further to the above optimization, Sort2+OddEvenKeys compile to four // instructions; we can save one by combining two blends. -#if HWY_AVX3 < HWY_TARGET && HWY_TARGET <= HWY_SSSE3 - const Vec cmp = VecFromMask(d, base->Compare(d, v, swapped)); - return IfVecThenElse(DupOdd(cmp), swapped, v); -#else - Sort2(d, v, swapped); - return base->OddEvenKeys(swapped, v); -#endif - } - - // (See above - we use Sort2 for non-64-bit types.) - template - HWY_INLINE Vec SortPairsDistance1(D d, Vec v) const { - const Base* base = static_cast(this); - Vec swapped = base->ReverseKeys2(d, v); - Sort2(d, v, swapped); - return base->OddEvenKeys(swapped, v); + HWY_IF_CONSTEXPR (IsEmulatedMinMax()) { + const Vec cmp = VecFromMask(d, base->Compare(d, v, swapped)); + return IfVecThenElse(DupOdd(cmp), swapped, v); + } + else { + Sort2(d, v, swapped); + return base->OddEvenKeys(swapped, v); + } } // Swaps with the vector formed by reversing contiguous groups of 4 keys. diff --git a/hwy/contrib/sort/vqsort-inl.h b/hwy/contrib/sort/vqsort-inl.h index 1f98bcccf2..576ee38623 100644 --- a/hwy/contrib/sort/vqsort-inl.h +++ b/hwy/contrib/sort/vqsort-inl.h @@ -1921,9 +1921,24 @@ HWY_INLINE bool HandleSpecialCases(D d, Traits st, T* HWY_RESTRICT keys, #endif // VQSORT_ENABLED -template + +template , HWY_IF_FLOAT_D(D)> +HWY_API Mask IsNaNWrapper(V v) { + return IsNaN(v); +} + +template , HWY_IF_UNSIGNED_D(D)> +HWY_API Mask IsNaNWrapper(V v) { + const D d; + const Vec m_exp = Set(d, ExponentMask()); + const Vec m_mant = Set(d, MantissaMask()); + return AndNot(Eq(And(v, m_mant), Zero(d)), Eq(And(v, m_exp), m_exp)); +} + +template HWY_INLINE size_t CountAndReplaceNaN(D d, Traits st, T* HWY_RESTRICT keys, size_t num) { + using TF = typename Traits::KeyType; const size_t N = Lanes(d); // Will be sorted to the back of the array. const Vec sentinel = st.LastValue(d); @@ -1931,7 +1946,7 @@ HWY_INLINE size_t CountAndReplaceNaN(D d, Traits st, T* HWY_RESTRICT keys, size_t i = 0; if (num >= N) { for (; i <= num - N; i += N) { - const Mask is_nan = IsNaN(LoadU(d, keys + i)); + const Mask is_nan = IsNaNWrapper(LoadU(d, keys + i)); BlendedStore(sentinel, is_nan, d, keys + i); num_nan += CountTrue(d, is_nan); } @@ -1940,14 +1955,14 @@ HWY_INLINE size_t CountAndReplaceNaN(D d, Traits st, T* HWY_RESTRICT keys, const size_t remaining = num - i; HWY_DASSERT(remaining < N); const Vec v = LoadN(d, keys + i, remaining); - const Mask is_nan = IsNaN(v); + const Mask is_nan = IsNaNWrapper(v); StoreN(IfThenElse(is_nan, sentinel, v), d, keys + i, remaining); num_nan += CountTrue(d, is_nan); return num_nan; } // IsNaN is not implemented for non-float, so skip it. -template +template HWY_INLINE size_t CountAndReplaceNaN(D, Traits, T* HWY_RESTRICT, size_t) { return 0; } diff --git a/hwy/contrib/sort/vqsort.h b/hwy/contrib/sort/vqsort.h index 8f05415800..837274dc25 100644 --- a/hwy/contrib/sort/vqsort.h +++ b/hwy/contrib/sort/vqsort.h @@ -65,7 +65,6 @@ HWY_CONTRIB_DLLEXPORT void VQSort(int64_t* HWY_RESTRICT keys, size_t n, HWY_CONTRIB_DLLEXPORT void VQSort(int64_t* HWY_RESTRICT keys, size_t n, SortDescending); -// These two must only be called if hwy::HaveFloat16() is true. HWY_CONTRIB_DLLEXPORT void VQSort(float16_t* HWY_RESTRICT keys, size_t n, SortAscending); HWY_CONTRIB_DLLEXPORT void VQSort(float16_t* HWY_RESTRICT keys, size_t n, @@ -76,7 +75,6 @@ HWY_CONTRIB_DLLEXPORT void VQSort(float* HWY_RESTRICT keys, size_t n, HWY_CONTRIB_DLLEXPORT void VQSort(float* HWY_RESTRICT keys, size_t n, SortDescending); -// These two must only be called if hwy::HaveFloat64() is true. HWY_CONTRIB_DLLEXPORT void VQSort(double* HWY_RESTRICT keys, size_t n, SortAscending); HWY_CONTRIB_DLLEXPORT void VQSort(double* HWY_RESTRICT keys, size_t n, @@ -128,7 +126,6 @@ HWY_CONTRIB_DLLEXPORT void VQPartialSort(int64_t* HWY_RESTRICT keys, size_t n, HWY_CONTRIB_DLLEXPORT void VQPartialSort(int64_t* HWY_RESTRICT keys, size_t n, size_t k, SortDescending); -// These two must only be called if hwy::HaveFloat16() is true. HWY_CONTRIB_DLLEXPORT void VQPartialSort(float16_t* HWY_RESTRICT keys, size_t n, size_t k, SortAscending); HWY_CONTRIB_DLLEXPORT void VQPartialSort(float16_t* HWY_RESTRICT keys, size_t n, @@ -139,7 +136,6 @@ HWY_CONTRIB_DLLEXPORT void VQPartialSort(float* HWY_RESTRICT keys, size_t n, HWY_CONTRIB_DLLEXPORT void VQPartialSort(float* HWY_RESTRICT keys, size_t n, size_t k, SortDescending); -// These two must only be called if hwy::HaveFloat64() is true. HWY_CONTRIB_DLLEXPORT void VQPartialSort(double* HWY_RESTRICT keys, size_t n, size_t k, SortAscending); HWY_CONTRIB_DLLEXPORT void VQPartialSort(double* HWY_RESTRICT keys, size_t n, @@ -190,7 +186,6 @@ HWY_CONTRIB_DLLEXPORT void VQSelect(int64_t* HWY_RESTRICT keys, size_t n, HWY_CONTRIB_DLLEXPORT void VQSelect(int64_t* HWY_RESTRICT keys, size_t n, size_t k, SortDescending); -// These two must only be called if hwy::HaveFloat16() is true. HWY_CONTRIB_DLLEXPORT void VQSelect(float16_t* HWY_RESTRICT keys, size_t n, size_t k, SortAscending); HWY_CONTRIB_DLLEXPORT void VQSelect(float16_t* HWY_RESTRICT keys, size_t n, @@ -201,7 +196,6 @@ HWY_CONTRIB_DLLEXPORT void VQSelect(float* HWY_RESTRICT keys, size_t n, HWY_CONTRIB_DLLEXPORT void VQSelect(float* HWY_RESTRICT keys, size_t n, size_t k, SortDescending); -// These two must only be called if hwy::HaveFloat64() is true. HWY_CONTRIB_DLLEXPORT void VQSelect(double* HWY_RESTRICT keys, size_t n, size_t k, SortAscending); HWY_CONTRIB_DLLEXPORT void VQSelect(double* HWY_RESTRICT keys, size_t n, @@ -250,14 +244,12 @@ class HWY_CONTRIB_DLLEXPORT Sorter { void operator()(int64_t* HWY_RESTRICT keys, size_t n, SortAscending) const; void operator()(int64_t* HWY_RESTRICT keys, size_t n, SortDescending) const; - // These two must only be called if hwy::HaveFloat16() is true. void operator()(float16_t* HWY_RESTRICT keys, size_t n, SortAscending) const; void operator()(float16_t* HWY_RESTRICT keys, size_t n, SortDescending) const; void operator()(float* HWY_RESTRICT keys, size_t n, SortAscending) const; void operator()(float* HWY_RESTRICT keys, size_t n, SortDescending) const; - // These two must only be called if hwy::HaveFloat64() is true. void operator()(double* HWY_RESTRICT keys, size_t n, SortAscending) const; void operator()(double* HWY_RESTRICT keys, size_t n, SortDescending) const; diff --git a/hwy/contrib/sort/vqsort_f16a.cc b/hwy/contrib/sort/vqsort_f16a.cc index a5db4f76ee..8797c8806e 100644 --- a/hwy/contrib/sort/vqsort_f16a.cc +++ b/hwy/contrib/sort/vqsort_f16a.cc @@ -29,37 +29,17 @@ namespace HWY_NAMESPACE { namespace { void SortF16Asc(float16_t* HWY_RESTRICT keys, const size_t num) { -#if HWY_HAVE_FLOAT16 return VQSortStatic(keys, num, SortAscending()); -#else - (void)keys; - (void)num; - if (Unpredictable1()) HWY_ASSERT(0); -#endif } void PartialSortF16Asc(float16_t* HWY_RESTRICT keys, const size_t num, const size_t k) { -#if HWY_HAVE_FLOAT16 return VQPartialSortStatic(keys, num, k, SortAscending()); -#else - (void)keys; - (void)num; - (void)k; - if (Unpredictable1()) HWY_ASSERT(0); -#endif } void SelectF16Asc(float16_t* HWY_RESTRICT keys, const size_t num, const size_t k) { -#if HWY_HAVE_FLOAT16 return VQSelectStatic(keys, num, k, SortAscending()); -#else - (void)keys; - (void)num; - (void)k; - if (Unpredictable1()) HWY_ASSERT(0); -#endif } } // namespace diff --git a/hwy/contrib/sort/vqsort_f16d.cc b/hwy/contrib/sort/vqsort_f16d.cc index ccedd49fc1..c9945b4981 100644 --- a/hwy/contrib/sort/vqsort_f16d.cc +++ b/hwy/contrib/sort/vqsort_f16d.cc @@ -29,37 +29,17 @@ namespace HWY_NAMESPACE { namespace { void SortF16Desc(float16_t* HWY_RESTRICT keys, const size_t num) { -#if HWY_HAVE_FLOAT16 return VQSortStatic(keys, num, SortDescending()); -#else - (void)keys; - (void)num; - if (Unpredictable1()) HWY_ASSERT(0); -#endif } void PartialSortF16Desc(float16_t* HWY_RESTRICT keys, const size_t num, const size_t k) { -#if HWY_HAVE_FLOAT16 return VQPartialSortStatic(keys, num, k, SortDescending()); -#else - (void)keys; - (void)num; - (void)k; - if (Unpredictable1()) HWY_ASSERT(0); -#endif } void SelectF16Desc(float16_t* HWY_RESTRICT keys, const size_t num, const size_t k) { -#if HWY_HAVE_FLOAT16 return VQSelectStatic(keys, num, k, SortDescending()); -#else - (void)keys; - (void)num; - (void)k; - if (Unpredictable1()) HWY_ASSERT(0); -#endif } } // namespace diff --git a/hwy/contrib/sort/vqsort_f64a.cc b/hwy/contrib/sort/vqsort_f64a.cc index 39e1c1642a..e2a4651e3b 100644 --- a/hwy/contrib/sort/vqsort_f64a.cc +++ b/hwy/contrib/sort/vqsort_f64a.cc @@ -28,36 +28,16 @@ namespace HWY_NAMESPACE { namespace { void SortF64Asc(double* HWY_RESTRICT keys, const size_t num) { -#if HWY_HAVE_FLOAT64 return VQSortStatic(keys, num, SortAscending()); -#else - (void)keys; - (void)num; - HWY_ASSERT(0); -#endif } void PartialSortF64Asc(double* HWY_RESTRICT keys, const size_t num, const size_t k) { -#if HWY_HAVE_FLOAT64 return VQPartialSortStatic(keys, num, k, SortAscending()); -#else - (void)keys; - (void)num; - (void)k; - HWY_ASSERT(0); -#endif } void SelectF64Asc(double* HWY_RESTRICT keys, const size_t num, const size_t k) { -#if HWY_HAVE_FLOAT64 return VQSelectStatic(keys, num, k, SortAscending()); -#else - (void)keys; - (void)num; - (void)k; - HWY_ASSERT(0); -#endif } } // namespace diff --git a/hwy/contrib/sort/vqsort_f64d.cc b/hwy/contrib/sort/vqsort_f64d.cc index 585fd2b1de..7e0cb6dc7b 100644 --- a/hwy/contrib/sort/vqsort_f64d.cc +++ b/hwy/contrib/sort/vqsort_f64d.cc @@ -28,37 +28,17 @@ namespace HWY_NAMESPACE { namespace { void SortF64Desc(double* HWY_RESTRICT keys, const size_t num) { -#if HWY_HAVE_FLOAT64 return VQSortStatic(keys, num, SortDescending()); -#else - (void)keys; - (void)num; - HWY_ASSERT(0); -#endif } void PartialSortF64Desc(double* HWY_RESTRICT keys, const size_t num, const size_t k) { -#if HWY_HAVE_FLOAT64 return VQPartialSortStatic(keys, num, k, SortDescending()); -#else - (void)keys; - (void)num; - (void)k; - HWY_ASSERT(0); -#endif } void SelectF64Desc(double* HWY_RESTRICT keys, const size_t num, const size_t k) { -#if HWY_HAVE_FLOAT64 return VQSelectStatic(keys, num, k, SortDescending()); -#else - (void)keys; - (void)num; - (void)k; - HWY_ASSERT(0); -#endif } } // namespace diff --git a/meson.build b/meson.build index 864c407f6b..d2529f8fd8 100644 --- a/meson.build +++ b/meson.build @@ -132,6 +132,7 @@ hwy_contrib_headers = files( 'hwy/contrib/sort/sorting_networks-inl.h', 'hwy/contrib/sort/traits-inl.h', 'hwy/contrib/sort/traits128-inl.h', + 'hwy/contrib/sort/order-emulate-inl.h', 'hwy/contrib/sort/vqsort-inl.h', 'hwy/contrib/sort/vqsort.h', 'hwy/contrib/thread_pool/futex.h',