Skip to content

Commit d8ac6f4

Browse files
authored
Optimize AVX intrinsics for .NET8 (#597)
* pot * fix projs * test * opt * 512 * cleanup * params ---------
1 parent 1c2f221 commit d8ac6f4

File tree

3 files changed

+21
-47
lines changed

3 files changed

+21
-47
lines changed

BitFaster.Caching.Benchmarks/Lfu/SketchFrequency.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77

88
namespace BitFaster.Caching.Benchmarks.Lfu
99
{
10+
#if Windows
11+
[DisassemblyDiagnoser(printSource: true, maxDepth: 4)]
12+
#endif
1013
[SimpleJob(RuntimeMoniker.Net60)]
1114
[SimpleJob(RuntimeMoniker.Net80)]
1215
[SimpleJob(RuntimeMoniker.Net90)]

BitFaster.Caching.Benchmarks/Lfu/SketchIncrement.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77

88
namespace BitFaster.Caching.Benchmarks.Lfu
99
{
10+
#if Windows
11+
[DisassemblyDiagnoser(printSource: true, maxDepth: 4)]
12+
#endif
1013
[SimpleJob(RuntimeMoniker.Net60)]
1114
[SimpleJob(RuntimeMoniker.Net80)]
1215
[SimpleJob(RuntimeMoniker.Net90)]

BitFaster.Caching/Lfu/CmSketchCore.cs

Lines changed: 15 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -255,39 +255,26 @@ private void Reset()
255255
}
256256

257257
#if !NETSTANDARD2_0
258+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
258259
private unsafe int EstimateFrequencyAvx(T value)
259260
{
260261
int blockHash = Spread(comparer.GetHashCode(value));
261262
int counterHash = Rehash(blockHash);
262263
int block = (blockHash & blockMask) << 3;
263264

264-
Vector128<int> h = Vector128.Create(counterHash);
265-
h = Avx2.ShiftRightLogicalVariable(h.AsUInt32(), Vector128.Create(0U, 8U, 16U, 24U)).AsInt32();
265+
Vector128<int> h = Avx2.ShiftRightLogicalVariable(Vector128.Create(counterHash).AsUInt32(), Vector128.Create(0U, 8U, 16U, 24U)).AsInt32();
266+
Vector128<int> index = Avx2.ShiftLeftLogical(Avx2.And(Avx2.ShiftRightLogical(h, 1), Vector128.Create(15)), 2);
267+
Vector128<int> blockOffset = Avx2.Add(Avx2.Add(Vector128.Create(block), Avx2.And(h, Vector128.Create(1))), Vector128.Create(0, 2, 4, 6));
266268

267-
var index = Avx2.ShiftRightLogical(h, 1);
268-
index = Avx2.And(index, Vector128.Create(15)); // j - counter index
269-
Vector128<int> offset = Avx2.And(h, Vector128.Create(1));
270-
Vector128<int> blockOffset = Avx2.Add(Vector128.Create(block), offset); // i - table index
271-
blockOffset = Avx2.Add(blockOffset, Vector128.Create(0, 2, 4, 6)); // + (i << 1)
269+
Vector256<ulong> indexLong = Avx2.PermuteVar8x32(Vector256.Create(index, Vector128<int>.Zero), Vector256.Create(0, 4, 1, 5, 2, 5, 3, 7)).AsUInt64();
272270

273271
#if NET6_0_OR_GREATER
274272
long* tablePtr = tableAddr;
275273
#else
276274
fixed (long* tablePtr = table)
277275
#endif
278276
{
279-
Vector256<long> tableVector = Avx2.GatherVector256(tablePtr, blockOffset, 8);
280-
index = Avx2.ShiftLeftLogical(index, 2);
281-
282-
// convert index from int to long via permute
283-
Vector256<long> indexLong = Vector256.Create(index, Vector128<int>.Zero).AsInt64();
284-
Vector256<int> permuteMask2 = Vector256.Create(0, 4, 1, 5, 2, 5, 3, 7);
285-
indexLong = Avx2.PermuteVar8x32(indexLong.AsInt32(), permuteMask2).AsInt64();
286-
tableVector = Avx2.ShiftRightLogicalVariable(tableVector, indexLong.AsUInt64());
287-
tableVector = Avx2.And(tableVector, Vector256.Create(0xfL));
288-
289-
Vector256<int> permuteMask = Vector256.Create(0, 2, 4, 6, 1, 3, 5, 7);
290-
Vector128<ushort> count = Avx2.PermuteVar8x32(tableVector.AsInt32(), permuteMask)
277+
Vector128<ushort> count = Avx2.PermuteVar8x32(Avx2.And(Avx2.ShiftRightLogicalVariable(Avx2.GatherVector256(tablePtr, blockOffset, 8), indexLong), Vector256.Create(0xfL)).AsInt32(), Vector256.Create(0, 2, 4, 6, 1, 3, 5, 7))
291278
.GetLower()
292279
.AsUInt16();
293280

@@ -302,52 +289,33 @@ private unsafe int EstimateFrequencyAvx(T value)
302289
}
303290
}
304291

292+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
305293
private unsafe void IncrementAvx(T value)
306294
{
307295
int blockHash = Spread(comparer.GetHashCode(value));
308296
int counterHash = Rehash(blockHash);
309297
int block = (blockHash & blockMask) << 3;
310298

311-
Vector128<int> h = Vector128.Create(counterHash);
312-
h = Avx2.ShiftRightLogicalVariable(h.AsUInt32(), Vector128.Create(0U, 8U, 16U, 24U)).AsInt32();
299+
Vector128<int> h = Avx2.ShiftRightLogicalVariable(Vector128.Create(counterHash).AsUInt32(), Vector128.Create(0U, 8U, 16U, 24U)).AsInt32();
300+
Vector128<int> index = Avx2.ShiftLeftLogical(Avx2.And(Avx2.ShiftRightLogical(h, 1), Vector128.Create(15)), 2);
301+
Vector128<int> blockOffset = Avx2.Add(Avx2.Add(Vector128.Create(block), Avx2.And(h, Vector128.Create(1))), Vector128.Create(0, 2, 4, 6));
313302

314-
Vector128<int> index = Avx2.ShiftRightLogical(h, 1);
315-
index = Avx2.And(index, Vector128.Create(15)); // j - counter index
316-
Vector128<int> offset = Avx2.And(h, Vector128.Create(1));
317-
Vector128<int> blockOffset = Avx2.Add(Vector128.Create(block), offset); // i - table index
318-
blockOffset = Avx2.Add(blockOffset, Vector128.Create(0, 2, 4, 6)); // + (i << 1)
303+
Vector256<ulong> offsetLong = Avx2.PermuteVar8x32(Vector256.Create(index, Vector128<int>.Zero), Vector256.Create(0, 4, 1, 5, 2, 5, 3, 7)).AsUInt64();
304+
Vector256<long> mask = Avx2.ShiftLeftLogicalVariable(Vector256.Create(0xfL), offsetLong);
319305

320306
#if NET6_0_OR_GREATER
321307
long* tablePtr = tableAddr;
322308
#else
323309
fixed (long* tablePtr = table)
324310
#endif
325311
{
326-
Vector256<long> tableVector = Avx2.GatherVector256(tablePtr, blockOffset, 8);
327-
328-
// j == index
329-
index = Avx2.ShiftLeftLogical(index, 2);
330-
Vector256<long> offsetLong = Vector256.Create(index, Vector128<int>.Zero).AsInt64();
331-
332-
Vector256<int> permuteMask = Vector256.Create(0, 4, 1, 5, 2, 5, 3, 7);
333-
offsetLong = Avx2.PermuteVar8x32(offsetLong.AsInt32(), permuteMask).AsInt64();
334-
335-
// mask = (0xfL << offset)
336-
Vector256<long> fifteen = Vector256.Create(0xfL);
337-
Vector256<long> mask = Avx2.ShiftLeftLogicalVariable(fifteen, offsetLong.AsUInt64());
338-
339-
// (table[i] & mask) != mask)
340312
// Note masked is 'equal' - therefore use AndNot below
341-
Vector256<long> masked = Avx2.CompareEqual(Avx2.And(tableVector, mask), mask);
342-
343-
// 1L << offset
344-
Vector256<long> inc = Avx2.ShiftLeftLogicalVariable(Vector256.Create(1L), offsetLong.AsUInt64());
313+
Vector256<long> masked = Avx2.CompareEqual(Avx2.And(Avx2.GatherVector256(tablePtr, blockOffset, 8), mask), mask);
345314

346315
// Mask to zero out non matches (add zero below) - first operand is NOT then AND result (order matters)
347-
inc = Avx2.AndNot(masked, inc);
316+
Vector256<long> inc = Avx2.AndNot(masked, Avx2.ShiftLeftLogicalVariable(Vector256.Create(1L), offsetLong));
348317

349-
Vector256<byte> result = Avx2.CompareEqual(masked.AsByte(), Vector256<byte>.Zero);
350-
bool wasInc = Avx2.MoveMask(result.AsByte()) == unchecked((int)(0b1111_1111_1111_1111_1111_1111_1111_1111));
318+
bool wasInc = Avx2.MoveMask(Avx2.CompareEqual(masked.AsByte(), Vector256<byte>.Zero).AsByte()) == unchecked((int)(0b1111_1111_1111_1111_1111_1111_1111_1111));
351319

352320
tablePtr[blockOffset.GetElement(0)] += inc.GetElement(0);
353321
tablePtr[blockOffset.GetElement(1)] += inc.GetElement(1);

0 commit comments

Comments
 (0)