Skip to content

Commit 5669d38

Browse files
authored
Implement LFU sketch using arm64 intrinsics (redux) (#648)
* basic impl * run tests * fix * table lookup * opt * opt * temp * cleanup * endif * fix return * cleanup ---------
1 parent d8ac6f4 commit 5669d38

File tree

8 files changed

+289
-13
lines changed

8 files changed

+289
-13
lines changed

BitFaster.Caching.Benchmarks/BitFaster.Caching.Benchmarks.csproj

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@
33
<PropertyGroup>
44
<OutputType>Exe</OutputType>
55
<LangVersion>latest</LangVersion>
6-
<TargetFrameworks>net48;net6.0;net8.0</TargetFrameworks>
6+
<TargetFrameworks>net6.0;net8.0</TargetFrameworks>
77
<AllowUnsafeBlocks>True</AllowUnsafeBlocks>
88
<!-- https://stackoverflow.com/a/59916801/131345 -->
99
<IsWindows Condition="'$([System.Runtime.InteropServices.RuntimeInformation]::IsOSPlatform($([System.Runtime.InteropServices.OSPlatform]::Windows)))' == 'true'">true</IsWindows>
1010
<IsLinux Condition="'$([System.Runtime.InteropServices.RuntimeInformation]::IsOSPlatform($([System.Runtime.InteropServices.OSPlatform]::Linux)))' == 'true'">true</IsLinux>
1111
<IsMacOS Condition="'$([System.Runtime.InteropServices.RuntimeInformation]::IsOSPlatform($([System.Runtime.InteropServices.OSPlatform]::OSX)))' == 'true'">true</IsMacOS>
12+
<IsArm64 Condition="$([System.Runtime.InteropServices.RuntimeInformation]::OSArchitecture) == Arm64">true</IsArm64>
13+
<IsX64 Condition="$([System.Runtime.InteropServices.RuntimeInformation]::OSArchitecture) == X64">true</IsX64>
1214
</PropertyGroup>
1315

1416
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|AnyCPU'">
@@ -41,5 +43,11 @@
4143
<PropertyGroup Condition="'$(IsMacOS)'=='true'">
4244
<DefineConstants>MacOS</DefineConstants>
4345
</PropertyGroup>
46+
<PropertyGroup Condition="'$(IsArm64)'=='true'">
47+
<DefineConstants>Arm64</DefineConstants>
48+
</PropertyGroup>
49+
<PropertyGroup Condition="'$(IsX64)'=='true'">
50+
<DefineConstants>X64</DefineConstants>
51+
</PropertyGroup>
4452

45-
</Project>
53+
</Project>

BitFaster.Caching.Benchmarks/Lfu/CmSketchNoPin.cs

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
using System;
22
using System.Collections.Generic;
33
using System.Diagnostics.CodeAnalysis;
4+
using System.Runtime.CompilerServices;
5+
46

57
#if NET6_0_OR_GREATER
68
using System.Runtime.Intrinsics;
9+
using System.Runtime.Intrinsics.Arm;
710
using System.Runtime.Intrinsics.X86;
811
#endif
912

@@ -61,6 +64,12 @@ public int EstimateFrequency(T value)
6164
{
6265
return EstimateFrequencyAvx(value);
6366
}
67+
#if NET6_0_OR_GREATER
68+
else if (isa.IsArm64Supported)
69+
{
70+
return EstimateFrequencyArm(value);
71+
}
72+
#endif
6473
else
6574
{
6675
return EstimateFrequencyStd(value);
@@ -84,6 +93,12 @@ public void Increment(T value)
8493
{
8594
IncrementAvx(value);
8695
}
96+
#if NET6_0_OR_GREATER
97+
else if (isa.IsArm64Supported)
98+
{
99+
IncrementArm(value);
100+
}
101+
#endif
87102
else
88103
{
89104
IncrementStd(value);
@@ -314,5 +329,94 @@ private unsafe void IncrementAvx(T value)
314329
}
315330
}
316331
#endif
332+
333+
#if NET6_0_OR_GREATER
334+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
335+
private unsafe void IncrementArm(T value)
336+
{
337+
int blockHash = Spread(comparer.GetHashCode(value));
338+
int counterHash = Rehash(blockHash);
339+
int block = (blockHash & blockMask) << 3;
340+
341+
Vector128<int> h = AdvSimd.ShiftArithmetic(Vector128.Create(counterHash), Vector128.Create(0, -8, -16, -24));
342+
Vector128<int> index = AdvSimd.And(AdvSimd.ShiftRightLogical(h, 1), Vector128.Create(0xf));
343+
Vector128<int> blockOffset = AdvSimd.Add(AdvSimd.Add(Vector128.Create(block), AdvSimd.And(h, Vector128.Create(1))), Vector128.Create(0, 2, 4, 6));
344+
345+
fixed (long* tablePtr = table)
346+
{
347+
int t0 = AdvSimd.Extract(blockOffset, 0);
348+
int t1 = AdvSimd.Extract(blockOffset, 1);
349+
int t2 = AdvSimd.Extract(blockOffset, 2);
350+
int t3 = AdvSimd.Extract(blockOffset, 3);
351+
352+
Vector128<long> tableVectorA = Vector128.Create(AdvSimd.LoadVector64(tablePtr + t0), AdvSimd.LoadVector64(tablePtr + t1));
353+
Vector128<long> tableVectorB = Vector128.Create(AdvSimd.LoadVector64(tablePtr + t2), AdvSimd.LoadVector64(tablePtr + t3));
354+
355+
index = AdvSimd.ShiftLeftLogicalSaturate(index, 2);
356+
357+
Vector128<int> longOffA = AdvSimd.Arm64.InsertSelectedScalar(AdvSimd.Arm64.InsertSelectedScalar(Vector128<int>.Zero, 0, index, 0), 2, index, 1);
358+
Vector128<int> longOffB = AdvSimd.Arm64.InsertSelectedScalar(AdvSimd.Arm64.InsertSelectedScalar(Vector128<int>.Zero, 0, index, 2), 2, index, 3);
359+
360+
Vector128<long> fifteen = Vector128.Create(0xfL);
361+
Vector128<long> maskA = AdvSimd.ShiftArithmetic(fifteen, longOffA.AsInt64());
362+
Vector128<long> maskB = AdvSimd.ShiftArithmetic(fifteen, longOffB.AsInt64());
363+
364+
Vector128<long> maskedA = AdvSimd.Not(AdvSimd.Arm64.CompareEqual(AdvSimd.And(tableVectorA, maskA), maskA));
365+
Vector128<long> maskedB = AdvSimd.Not(AdvSimd.Arm64.CompareEqual(AdvSimd.And(tableVectorB, maskB), maskB));
366+
367+
var one = Vector128.Create(1L);
368+
Vector128<long> incA = AdvSimd.And(maskedA, AdvSimd.ShiftArithmetic(one, longOffA.AsInt64()));
369+
Vector128<long> incB = AdvSimd.And(maskedB, AdvSimd.ShiftArithmetic(one, longOffB.AsInt64()));
370+
371+
tablePtr[t0] += AdvSimd.Extract(incA, 0);
372+
tablePtr[t1] += AdvSimd.Extract(incA, 1);
373+
tablePtr[t2] += AdvSimd.Extract(incB, 0);
374+
tablePtr[t3] += AdvSimd.Extract(incB, 1);
375+
376+
var max = AdvSimd.Arm64.MaxAcross(AdvSimd.Arm64.InsertSelectedScalar(AdvSimd.Arm64.MaxAcross(incA.AsInt32()), 1, AdvSimd.Arm64.MaxAcross(incB.AsInt32()), 0).AsInt16());
377+
378+
if (max.ToScalar() != 0 && (++size == sampleSize))
379+
{
380+
Reset();
381+
}
382+
}
383+
}
384+
385+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
386+
private unsafe int EstimateFrequencyArm(T value)
387+
{
388+
int blockHash = Spread(comparer.GetHashCode(value));
389+
int counterHash = Rehash(blockHash);
390+
int block = (blockHash & blockMask) << 3;
391+
392+
Vector128<int> h = AdvSimd.ShiftArithmetic(Vector128.Create(counterHash), Vector128.Create(0, -8, -16, -24));
393+
Vector128<int> index = AdvSimd.And(AdvSimd.ShiftRightLogical(h, 1), Vector128.Create(0xf));
394+
Vector128<int> blockOffset = AdvSimd.Add(AdvSimd.Add(Vector128.Create(block), AdvSimd.And(h, Vector128.Create(1))), Vector128.Create(0, 2, 4, 6));
395+
396+
fixed (long* tablePtr = table)
397+
{
398+
Vector128<long> tableVectorA = Vector128.Create(AdvSimd.LoadVector64(tablePtr + AdvSimd.Extract(blockOffset, 0)), AdvSimd.LoadVector64(tablePtr + AdvSimd.Extract(blockOffset, 1)));
399+
Vector128<long> tableVectorB = Vector128.Create(AdvSimd.LoadVector64(tablePtr + AdvSimd.Extract(blockOffset, 2)), AdvSimd.LoadVector64(tablePtr + AdvSimd.Extract(blockOffset, 3)));
400+
401+
index = AdvSimd.ShiftLeftLogicalSaturate(index, 2);
402+
403+
Vector128<int> indexA = AdvSimd.Negate(AdvSimd.Arm64.InsertSelectedScalar(AdvSimd.Arm64.InsertSelectedScalar(Vector128<int>.Zero, 0, index, 0), 2, index, 1));
404+
Vector128<int> indexB = AdvSimd.Negate(AdvSimd.Arm64.InsertSelectedScalar(AdvSimd.Arm64.InsertSelectedScalar(Vector128<int>.Zero, 0, index, 2), 2, index, 3));
405+
406+
var fifteen = Vector128.Create(0xfL);
407+
Vector128<long> a = AdvSimd.And(AdvSimd.ShiftArithmetic(tableVectorA, indexA.AsInt64()), fifteen);
408+
Vector128<long> b = AdvSimd.And(AdvSimd.ShiftArithmetic(tableVectorB, indexB.AsInt64()), fifteen);
409+
410+
// Before: < 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, A, B, C, D, E, F >
411+
// After: < 0, 1, 2, 3, 8, 9, A, B, 4, 5, 6, 7, C, D, E, F >
412+
var min = AdvSimd.Arm64.VectorTableLookup(a.AsByte(), Vector128.Create(0x0B0A090803020100, 0xFFFFFFFFFFFFFFFF).AsByte());
413+
min = AdvSimd.Arm64.VectorTableLookupExtension(min, b.AsByte(), Vector128.Create(0xFFFFFFFFFFFFFFFF, 0x0B0A090803020100).AsByte());
414+
415+
var min32 = AdvSimd.Arm64.MinAcross(min.AsInt32());
416+
417+
return min32.ToScalar();
418+
}
419+
}
420+
#endif
317421
}
318422
}

BitFaster.Caching.Benchmarks/Lfu/SketchFrequency.cs

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ public int FrequencyFlat()
5151

5252
return count;
5353
}
54-
54+
#if X64
5555
[Benchmark(OperationsPerInvoke = iterations)]
5656
public int FrequencyFlatAvx()
5757
{
@@ -61,7 +61,7 @@ public int FrequencyFlatAvx()
6161

6262
return count;
6363
}
64-
64+
#endif
6565
[Benchmark(OperationsPerInvoke = iterations)]
6666
public int FrequencyBlock()
6767
{
@@ -73,7 +73,11 @@ public int FrequencyBlock()
7373
}
7474

7575
[Benchmark(OperationsPerInvoke = iterations)]
76+
#if Arm64
77+
public int FrequencyBlockNeonNotPinned()
78+
#else
7679
public int FrequencyBlockAvxNotPinned()
80+
#endif
7781
{
7882
int count = 0;
7983
for (int i = 0; i < iterations; i++)
@@ -83,7 +87,12 @@ public int FrequencyBlockAvxNotPinned()
8387
}
8488

8589
[Benchmark(OperationsPerInvoke = iterations)]
90+
91+
#if Arm64
92+
public int FrequencyBlockNeonPinned()
93+
#else
8694
public int FrequencyBlockAvxPinned()
95+
#endif
8796
{
8897
int count = 0;
8998
for (int i = 0; i < iterations; i++)

BitFaster.Caching.Benchmarks/Lfu/SketchIncrement.cs

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ public class SketchIncrement
2727
private CmSketchNoPin<int, DetectIsa> blockAvxNoPin;
2828
private CmSketchCore<int, DetectIsa> blockAvx;
2929

30+
3031
[Params(32_768, 524_288, 8_388_608, 134_217_728)]
3132
public int Size { get; set; }
3233

@@ -49,7 +50,7 @@ public void IncFlat()
4950
flatStd.Increment(i);
5051
}
5152
}
52-
53+
#if X64
5354
[Benchmark(OperationsPerInvoke = iterations)]
5455
public void IncFlatAvx()
5556
{
@@ -58,7 +59,7 @@ public void IncFlatAvx()
5859
flatAvx.Increment(i);
5960
}
6061
}
61-
62+
#endif
6263
[Benchmark(OperationsPerInvoke = iterations)]
6364
public void IncBlock()
6465
{
@@ -69,7 +70,11 @@ public void IncBlock()
6970
}
7071

7172
[Benchmark(OperationsPerInvoke = iterations)]
73+
#if Arm64
74+
public void IncBlockNeonNotPinned()
75+
#else
7276
public void IncBlockAvxNotPinned()
77+
#endif
7378
{
7479
for (int i = 0; i < iterations; i++)
7580
{
@@ -78,7 +83,11 @@ public void IncBlockAvxNotPinned()
7883
}
7984

8085
[Benchmark(OperationsPerInvoke = iterations)]
86+
#if Arm64
87+
public void IncBlockNeonPinned()
88+
#else
8189
public void IncBlockAvxPinned()
90+
#endif
8291
{
8392
for (int i = 0; i < iterations; i++)
8493
{

BitFaster.Caching.UnitTests/Intrinsics.cs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
#if NETCOREAPP3_1_OR_GREATER
22
using System.Runtime.Intrinsics.X86;
33
#endif
4+
#if NET6_0_OR_GREATER
5+
using System.Runtime.Intrinsics.Arm;
6+
#endif
7+
48
using Xunit;
59

610
namespace BitFaster.Caching.UnitTests
@@ -10,8 +14,14 @@ public static class Intrinsics
1014
public static void SkipAvxIfNotSupported<I>()
1115
{
1216
#if NETCOREAPP3_1_OR_GREATER
17+
#if NET6_0_OR_GREATER
18+
// when we are trying to test Avx2/Arm64, skip the test if it's not supported
19+
Skip.If(typeof(I) == typeof(DetectIsa) && !(Avx2.IsSupported || AdvSimd.Arm64.IsSupported));
20+
#else
1321
// when we are trying to test Avx2, skip the test if it's not supported
1422
Skip.If(typeof(I) == typeof(DetectIsa) && !Avx2.IsSupported);
23+
#endif
24+
1525
#else
1626
Skip.If(true);
1727
#endif

BitFaster.Caching.UnitTests/Lfu/CmSketchTests.cs

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@
66

77
namespace BitFaster.Caching.UnitTests.Lfu
88
{
9-
// Test with AVX2 if it is supported
10-
public class CMSketchAvx2Tests : CmSketchTestBase<DetectIsa>
9+
// Test with AVX2/ARM64 if it is supported
10+
public class CMSketchIntrinsicsTests : CmSketchTestBase<DetectIsa>
1111
{
1212
}
1313

14-
// Test with AVX2 disabled
14+
// Test with AVX2/ARM64 disabled
1515
public class CmSketchTests : CmSketchTestBase<DisableHardwareIntrinsics>
1616
{
1717
}
@@ -29,14 +29,23 @@ public CmSketchTestBase()
2929
public void Repro()
3030
{
3131
sketch = new CmSketchCore<int, I>(1_048_576, EqualityComparer<int>.Default);
32+
var baseline = new CmSketchCore<int, DisableHardwareIntrinsics>(1_048_576, EqualityComparer<int>.Default);
3233

3334
for (int i = 0; i < 1_048_576; i++)
3435
{
3536
if (i % 3 == 0)
3637
{
3738
sketch.Increment(i);
39+
baseline.Increment(i);
3840
}
3941
}
42+
43+
baseline.Size.Should().Be(sketch.Size);
44+
45+
for (int i = 0; i < 1_048_576; i++)
46+
{
47+
sketch.EstimateFrequency(i).Should().Be(baseline.EstimateFrequency(i));
48+
}
4049
}
4150

4251

BitFaster.Caching/Intrinsics.cs

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22
using System.Runtime.Intrinsics.X86;
33
#endif
44

5+
#if NET6_0
6+
using System.Runtime.Intrinsics.Arm;
7+
#endif
8+
59
namespace BitFaster.Caching
610
{
711
/// <summary>
@@ -12,7 +16,14 @@ public interface IsaProbe
1216
/// <summary>
1317
/// Gets a value indicating whether AVX2 is supported.
1418
/// </summary>
15-
bool IsAvx2Supported { get; }
19+
bool IsAvx2Supported { get; }
20+
21+
#if NET6_0_OR_GREATER
22+
/// <summary>
23+
/// Gets a value indicating whether Arm64 is supported.
24+
/// </summary>
25+
bool IsArm64Supported { get => false; }
26+
#endif
1627
}
1728

1829
/// <summary>
@@ -25,7 +36,15 @@ public interface IsaProbe
2536
public bool IsAvx2Supported => false;
2637
#else
2738
/// <inheritdoc/>
28-
public bool IsAvx2Supported => Avx2.IsSupported;
39+
public bool IsAvx2Supported => Avx2.IsSupported;
40+
#endif
41+
42+
#if NET6_0_OR_GREATER
43+
/// <inheritdoc/>
44+
public bool IsArm64Supported => AdvSimd.Arm64.IsSupported;
45+
#else
46+
/// <inheritdoc/>
47+
public bool IsArm64Supported => false;
2948
#endif
3049
}
3150

@@ -35,6 +54,9 @@ public interface IsaProbe
3554
public readonly struct DisableHardwareIntrinsics : IsaProbe
3655
{
3756
/// <inheritdoc/>
38-
public bool IsAvx2Supported => false;
57+
public bool IsAvx2Supported => false;
58+
59+
/// <inheritdoc/>
60+
public bool IsArm64Supported => false;
3961
}
4062
}

0 commit comments

Comments
 (0)