Skip to content

Commit

Permalink
Added SSE2 implementation of Vec3::sSelect, Vec4::sSelect and UVec4::…
Browse files Browse the repository at this point in the history
…sSelect (#1314)

* This works around an issue where FireFox has problems with the _mm_blendv_ps intrinsic when compiling to WASM. See: https://x.com/fforw/status/1848540672481214765.
* Also made DVec3::sSelect more consistent.
  • Loading branch information
jrouwe authored Nov 1, 2024
1 parent da2d6cb commit e4debe8
Show file tree
Hide file tree
Showing 12 changed files with 51 additions and 30 deletions.
4 changes: 2 additions & 2 deletions Jolt/Math/DVec3.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,8 @@ class [[nodiscard]] alignas(JPH_DVECTOR_ALIGNMENT) DVec3
/// Calculates inMul1 * inMul2 + inAdd
static JPH_INLINE DVec3 sFusedMultiplyAdd(DVec3Arg inMul1, DVec3Arg inMul2, DVec3Arg inAdd);

/// Component wise select, returns inV1 when highest bit of inControl = 0 and inV2 when highest bit of inControl = 1
static JPH_INLINE DVec3 sSelect(DVec3Arg inV1, DVec3Arg inV2, DVec3Arg inControl);
/// Component wise select, returns inNotSet when highest bit of inControl = 0 and inSet when highest bit of inControl = 1
static JPH_INLINE DVec3 sSelect(DVec3Arg inNotSet, DVec3Arg inSet, DVec3Arg inControl);

/// Logical or (component wise)
static JPH_INLINE DVec3 sOr(DVec3Arg inV1, DVec3Arg inV2);
Expand Down
12 changes: 6 additions & 6 deletions Jolt/Math/DVec3.inl
Original file line number Diff line number Diff line change
Expand Up @@ -315,21 +315,21 @@ DVec3 DVec3::sFusedMultiplyAdd(DVec3Arg inMul1, DVec3Arg inMul2, DVec3Arg inAdd)
#endif
}

DVec3 DVec3::sSelect(DVec3Arg inV1, DVec3Arg inV2, DVec3Arg inControl)
DVec3 DVec3::sSelect(DVec3Arg inNotSet, DVec3Arg inSet, DVec3Arg inControl)
{
#if defined(JPH_USE_AVX)
return _mm256_blendv_pd(inV1.mValue, inV2.mValue, inControl.mValue);
return _mm256_blendv_pd(inNotSet.mValue, inSet.mValue, inControl.mValue);
#elif defined(JPH_USE_SSE4_1)
Type v = { _mm_blendv_pd(inV1.mValue.mLow, inV2.mValue.mLow, inControl.mValue.mLow), _mm_blendv_pd(inV1.mValue.mHigh, inV2.mValue.mHigh, inControl.mValue.mHigh) };
Type v = { _mm_blendv_pd(inNotSet.mValue.mLow, inSet.mValue.mLow, inControl.mValue.mLow), _mm_blendv_pd(inNotSet.mValue.mHigh, inSet.mValue.mHigh, inControl.mValue.mHigh) };
return sFixW(v);
#elif defined(JPH_USE_NEON)
Type v = { vbslq_f64(vreinterpretq_u64_s64(vshrq_n_s64(vreinterpretq_s64_f64(inControl.mValue.val[0]), 63)), inV2.mValue.val[0], inV1.mValue.val[0]),
vbslq_f64(vreinterpretq_u64_s64(vshrq_n_s64(vreinterpretq_s64_f64(inControl.mValue.val[1]), 63)), inV2.mValue.val[1], inV1.mValue.val[1]) };
Type v = { vbslq_f64(vreinterpretq_u64_s64(vshrq_n_s64(vreinterpretq_s64_f64(inControl.mValue.val[0]), 63)), inSet.mValue.val[0], inNotSet.mValue.val[0]),
vbslq_f64(vreinterpretq_u64_s64(vshrq_n_s64(vreinterpretq_s64_f64(inControl.mValue.val[1]), 63)), inSet.mValue.val[1], inNotSet.mValue.val[1]) };
return sFixW(v);
#else
DVec3 result;
for (int i = 0; i < 3; i++)
result.mF64[i] = BitCast<uint64>(inControl.mF64[i])? inV2.mF64[i] : inV1.mF64[i];
result.mF64[i] = (BitCast<uint64>(inControl.mF64[i]) & (uint64(1) << 63))? inSet.mF64[i] : inNotSet.mF64[i];
#ifdef JPH_FLOATING_POINT_EXCEPTIONS_ENABLED
result.mF64[3] = result.mF64[2];
#endif // JPH_FLOATING_POINT_EXCEPTIONS_ENABLED
Expand Down
4 changes: 2 additions & 2 deletions Jolt/Math/UVec4.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ class [[nodiscard]] alignas(JPH_VECTOR_ALIGNMENT) UVec4
/// Equals (component wise)
static JPH_INLINE UVec4 sEquals(UVec4Arg inV1, UVec4Arg inV2);

/// Component wise select, returns inV1 when highest bit of inControl = 0 and inV2 when highest bit of inControl = 1
static JPH_INLINE UVec4 sSelect(UVec4Arg inV1, UVec4Arg inV2, UVec4Arg inControl);
/// Component wise select, returns inNotSet when highest bit of inControl = 0 and inSet when highest bit of inControl = 1
static JPH_INLINE UVec4 sSelect(UVec4Arg inNotSet, UVec4Arg inSet, UVec4Arg inControl);

/// Logical or (component wise)
static JPH_INLINE UVec4 sOr(UVec4Arg inV1, UVec4Arg inV2);
Expand Down
13 changes: 8 additions & 5 deletions Jolt/Math/UVec4.inl
Original file line number Diff line number Diff line change
Expand Up @@ -154,16 +154,19 @@ UVec4 UVec4::sEquals(UVec4Arg inV1, UVec4Arg inV2)
#endif
}

UVec4 UVec4::sSelect(UVec4Arg inV1, UVec4Arg inV2, UVec4Arg inControl)
UVec4 UVec4::sSelect(UVec4Arg inNotSet, UVec4Arg inSet, UVec4Arg inControl)
{
#if defined(JPH_USE_SSE4_1)
return _mm_castps_si128(_mm_blendv_ps(_mm_castsi128_ps(inV1.mValue), _mm_castsi128_ps(inV2.mValue), _mm_castsi128_ps(inControl.mValue)));
#if defined(JPH_USE_SSE4_1) && !defined(JPH_PLATFORM_WASM) // _mm_blendv_ps has problems on FireFox
return _mm_castps_si128(_mm_blendv_ps(_mm_castsi128_ps(inNotSet.mValue), _mm_castsi128_ps(inSet.mValue), _mm_castsi128_ps(inControl.mValue)));
#elif defined(JPH_USE_SSE)
__m128 is_set = _mm_castsi128_ps(_mm_srai_epi32(inControl.mValue, 31));
return _mm_castps_si128(_mm_or_ps(_mm_and_ps(is_set, _mm_castsi128_ps(inSet.mValue)), _mm_andnot_ps(is_set, _mm_castsi128_ps(inNotSet.mValue))));
#elif defined(JPH_USE_NEON)
return vbslq_u32(vreinterpretq_u32_s32(vshrq_n_s32(vreinterpretq_s32_u32(inControl.mValue), 31)), inV2.mValue, inV1.mValue);
return vbslq_u32(vreinterpretq_u32_s32(vshrq_n_s32(vreinterpretq_s32_u32(inControl.mValue), 31)), inSet.mValue, inNotSet.mValue);
#else
UVec4 result;
for (int i = 0; i < 4; i++)
result.mU32[i] = inControl.mU32[i] ? inV2.mU32[i] : inV1.mU32[i];
result.mU32[i] = (inControl.mU32[i] & 0x80000000u) ? inSet.mU32[i] : inNotSet.mU32[i];
return result;
#endif
}
Expand Down
4 changes: 2 additions & 2 deletions Jolt/Math/Vec3.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ class [[nodiscard]] alignas(JPH_VECTOR_ALIGNMENT) Vec3
/// Calculates inMul1 * inMul2 + inAdd
static JPH_INLINE Vec3 sFusedMultiplyAdd(Vec3Arg inMul1, Vec3Arg inMul2, Vec3Arg inAdd);

/// Component wise select, returns inV1 when highest bit of inControl = 0 and inV2 when highest bit of inControl = 1
static JPH_INLINE Vec3 sSelect(Vec3Arg inV1, Vec3Arg inV2, UVec4Arg inControl);
/// Component wise select, returns inNotSet when highest bit of inControl = 0 and inSet when highest bit of inControl = 1
static JPH_INLINE Vec3 sSelect(Vec3Arg inNotSet, Vec3Arg inSet, UVec4Arg inControl);

/// Logical or (component wise)
static JPH_INLINE Vec3 sOr(Vec3Arg inV1, Vec3Arg inV2);
Expand Down
16 changes: 10 additions & 6 deletions Jolt/Math/Vec3.inl
Original file line number Diff line number Diff line change
Expand Up @@ -266,18 +266,22 @@ Vec3 Vec3::sFusedMultiplyAdd(Vec3Arg inMul1, Vec3Arg inMul2, Vec3Arg inAdd)
#endif
}

Vec3 Vec3::sSelect(Vec3Arg inV1, Vec3Arg inV2, UVec4Arg inControl)
Vec3 Vec3::sSelect(Vec3Arg inNotSet, Vec3Arg inSet, UVec4Arg inControl)
{
#if defined(JPH_USE_SSE4_1)
Type v = _mm_blendv_ps(inV1.mValue, inV2.mValue, _mm_castsi128_ps(inControl.mValue));
#if defined(JPH_USE_SSE4_1) && !defined(JPH_PLATFORM_WASM) // _mm_blendv_ps has problems on FireFox
Type v = _mm_blendv_ps(inNotSet.mValue, inSet.mValue, _mm_castsi128_ps(inControl.mValue));
return sFixW(v);
#elif defined(JPH_USE_SSE)
__m128 is_set = _mm_castsi128_ps(_mm_srai_epi32(inControl.mValue, 31));
Type v = _mm_or_ps(_mm_and_ps(is_set, inSet.mValue), _mm_andnot_ps(is_set, inNotSet.mValue));
return sFixW(v);
#elif defined(JPH_USE_NEON)
Type v = vbslq_f32(vreinterpretq_u32_s32(vshrq_n_s32(vreinterpretq_s32_u32(inControl.mValue), 31)), inV2.mValue, inV1.mValue);
Type v = vbslq_f32(vreinterpretq_u32_s32(vshrq_n_s32(vreinterpretq_s32_u32(inControl.mValue), 31)), inSet.mValue, inNotSet.mValue);
return sFixW(v);
#else
Vec3 result;
for (int i = 0; i < 3; i++)
result.mF32[i] = inControl.mU32[i] ? inV2.mF32[i] : inV1.mF32[i];
result.mF32[i] = (inControl.mU32[i] & 0x80000000u) ? inSet.mF32[i] : inNotSet.mF32[i];
#ifdef JPH_FLOATING_POINT_EXCEPTIONS_ENABLED
result.mF32[3] = result.mF32[2];
#endif // JPH_FLOATING_POINT_EXCEPTIONS_ENABLED
Expand Down Expand Up @@ -715,7 +719,7 @@ Vec3 Vec3::Normalized() const

Vec3 Vec3::NormalizedOr(Vec3Arg inZeroValue) const
{
#if defined(JPH_USE_SSE4_1)
#if defined(JPH_USE_SSE4_1) && !defined(JPH_PLATFORM_WASM) // _mm_blendv_ps has problems on FireFox
Type len_sq = _mm_dp_ps(mValue, mValue, 0x7f);
Type is_zero = _mm_cmpeq_ps(len_sq, _mm_setzero_ps());
#ifdef JPH_FLOATING_POINT_EXCEPTIONS_ENABLED
Expand Down
4 changes: 2 additions & 2 deletions Jolt/Math/Vec4.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ class [[nodiscard]] alignas(JPH_VECTOR_ALIGNMENT) Vec4
/// Calculates inMul1 * inMul2 + inAdd
static JPH_INLINE Vec4 sFusedMultiplyAdd(Vec4Arg inMul1, Vec4Arg inMul2, Vec4Arg inAdd);

/// Component wise select, returns inV1 when highest bit of inControl = 0 and inV2 when highest bit of inControl = 1
static JPH_INLINE Vec4 sSelect(Vec4Arg inV1, Vec4Arg inV2, UVec4Arg inControl);
/// Component wise select, returns inNotSet when highest bit of inControl = 0 and inSet when highest bit of inControl = 1
static JPH_INLINE Vec4 sSelect(Vec4Arg inNotSet, Vec4Arg inSet, UVec4Arg inControl);

/// Logical or (component wise)
static JPH_INLINE Vec4 sOr(Vec4Arg inV1, Vec4Arg inV2);
Expand Down
13 changes: 8 additions & 5 deletions Jolt/Math/Vec4.inl
Original file line number Diff line number Diff line change
Expand Up @@ -251,16 +251,19 @@ Vec4 Vec4::sFusedMultiplyAdd(Vec4Arg inMul1, Vec4Arg inMul2, Vec4Arg inAdd)
#endif
}

Vec4 Vec4::sSelect(Vec4Arg inV1, Vec4Arg inV2, UVec4Arg inControl)
Vec4 Vec4::sSelect(Vec4Arg inNotSet, Vec4Arg inSet, UVec4Arg inControl)
{
#if defined(JPH_USE_SSE4_1)
return _mm_blendv_ps(inV1.mValue, inV2.mValue, _mm_castsi128_ps(inControl.mValue));
#if defined(JPH_USE_SSE4_1) && !defined(JPH_PLATFORM_WASM) // _mm_blendv_ps has problems on FireFox
return _mm_blendv_ps(inNotSet.mValue, inSet.mValue, _mm_castsi128_ps(inControl.mValue));
#elif defined(JPH_USE_SSE)
__m128 is_set = _mm_castsi128_ps(_mm_srai_epi32(inControl.mValue, 31));
return _mm_or_ps(_mm_and_ps(is_set, inSet.mValue), _mm_andnot_ps(is_set, inNotSet.mValue));
#elif defined(JPH_USE_NEON)
return vbslq_f32(vreinterpretq_u32_s32(vshrq_n_s32(vreinterpretq_s32_u32(inControl.mValue), 31)), inV2.mValue, inV1.mValue);
return vbslq_f32(vreinterpretq_u32_s32(vshrq_n_s32(vreinterpretq_s32_u32(inControl.mValue), 31)), inSet.mValue, inNotSet.mValue);
#else
Vec4 result;
for (int i = 0; i < 4; i++)
result.mF32[i] = inControl.mU32[i] ? inV2.mF32[i] : inV1.mF32[i];
result.mF32[i] = (inControl.mU32[i] & 0x80000000u) ? inSet.mF32[i] : inNotSet.mF32[i];
return result;
#endif
}
Expand Down
5 changes: 5 additions & 0 deletions UnitTests/Math/DVec3Tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,13 @@ TEST_SUITE("DVec3Tests")

TEST_CASE("TestDVec3Select")
{
const double cTrue2 = BitCast<double>(uint64(1) << 63);
const double cFalse2 = BitCast<double>(~uint64(0) >> 1);

CHECK(DVec3::sSelect(DVec3(1, 2, 3), DVec3(4, 5, 6), DVec3(DVec3::cTrue, DVec3::cFalse, DVec3::cTrue)) == DVec3(4, 2, 6));
CHECK(DVec3::sSelect(DVec3(1, 2, 3), DVec3(4, 5, 6), DVec3(DVec3::cFalse, DVec3::cTrue, DVec3::cFalse)) == DVec3(1, 5, 3));
CHECK(DVec3::sSelect(DVec3(1, 2, 3), DVec3(4, 5, 6), DVec3(cTrue2, cFalse2, cTrue2)) == DVec3(4, 2, 6));
CHECK(DVec3::sSelect(DVec3(1, 2, 3), DVec3(4, 5, 6), DVec3(cFalse2, cTrue2, cFalse2)) == DVec3(1, 5, 3));
}

TEST_CASE("TestDVec3BitOps")
Expand Down
2 changes: 2 additions & 0 deletions UnitTests/Math/UVec4Tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,8 @@ TEST_SUITE("UVec4Tests")
{
CHECK(UVec4::sSelect(UVec4(1, 2, 3, 4), UVec4(5, 6, 7, 8), UVec4(0x80000000U, 0, 0x80000000U, 0)) == UVec4(5, 2, 7, 4));
CHECK(UVec4::sSelect(UVec4(1, 2, 3, 4), UVec4(5, 6, 7, 8), UVec4(0, 0x80000000U, 0, 0x80000000U)) == UVec4(1, 6, 3, 8));
CHECK(UVec4::sSelect(UVec4(1, 2, 3, 4), UVec4(5, 6, 7, 8), UVec4(0xffffffffU, 0x7fffffffU, 0xffffffffU, 0x7fffffffU)) == UVec4(5, 2, 7, 4));
CHECK(UVec4::sSelect(UVec4(1, 2, 3, 4), UVec4(5, 6, 7, 8), UVec4(0x7fffffffU, 0xffffffffU, 0x7fffffffU, 0xffffffffU)) == UVec4(1, 6, 3, 8));
}

TEST_CASE("TestUVec4BitOps")
Expand Down
2 changes: 2 additions & 0 deletions UnitTests/Math/Vec3Tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,8 @@ TEST_SUITE("Vec3Tests")
{
CHECK(Vec3::sSelect(Vec3(1, 2, 3), Vec3(4, 5, 6), UVec4(0x80000000U, 0, 0x80000000U, 0)) == Vec3(4, 2, 6));
CHECK(Vec3::sSelect(Vec3(1, 2, 3), Vec3(4, 5, 6), UVec4(0, 0x80000000U, 0, 0x80000000U)) == Vec3(1, 5, 3));
CHECK(Vec3::sSelect(Vec3(1, 2, 3), Vec3(4, 5, 6), UVec4(0xffffffffU, 0x7fffffffU, 0xffffffffU, 0x7fffffffU)) == Vec3(4, 2, 6));
CHECK(Vec3::sSelect(Vec3(1, 2, 3), Vec3(4, 5, 6), UVec4(0x7fffffffU, 0xffffffffU, 0x7fffffffU, 0xffffffffU)) == Vec3(1, 5, 3));
}

TEST_CASE("TestVec3BitOps")
Expand Down
2 changes: 2 additions & 0 deletions UnitTests/Math/Vec4Tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ TEST_SUITE("Vec4Tests")
{
CHECK(Vec4::sSelect(Vec4(1, 2, 3, 4), Vec4(5, 6, 7, 8), UVec4(0x80000000U, 0, 0x80000000U, 0)) == Vec4(5, 2, 7, 4));
CHECK(Vec4::sSelect(Vec4(1, 2, 3, 4), Vec4(5, 6, 7, 8), UVec4(0, 0x80000000U, 0, 0x80000000U)) == Vec4(1, 6, 3, 8));
CHECK(Vec4::sSelect(Vec4(1, 2, 3, 4), Vec4(5, 6, 7, 8), UVec4(0xffffffffU, 0x7fffffffU, 0xffffffffU, 0x7fffffffU)) == Vec4(5, 2, 7, 4));
CHECK(Vec4::sSelect(Vec4(1, 2, 3, 4), Vec4(5, 6, 7, 8), UVec4(0x7fffffffU, 0xffffffffU, 0x7fffffffU, 0xffffffffU)) == Vec4(1, 6, 3, 8));
}

TEST_CASE("TestVec4BitOps")
Expand Down

0 comments on commit e4debe8

Please sign in to comment.