Skip to content

Commit 5ef26a3

Browse files
authored
Allocate pinned buffer for vectorized code (#601)
* align * unsafe as ptr * always pin * try without pad * direct * cleanup * align 64 * freq bench * fix colors * rem comments ---------
1 parent 25ea2bd commit 5ef26a3

File tree

5 files changed

+390
-10
lines changed

5 files changed

+390
-10
lines changed

BitFaster.Caching.Benchmarks/BitFaster.Caching.Benchmarks.csproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
<PropertyGroup>
44
<OutputType>Exe</OutputType>
5+
<LangVersion>latest</LangVersion>
56
<TargetFrameworks>net48;net6.0;net8.0</TargetFrameworks>
67
<AllowUnsafeBlocks>True</AllowUnsafeBlocks>
78
<!-- https://stackoverflow.com/a/59916801/131345 -->
Lines changed: 318 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,318 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Diagnostics.CodeAnalysis;
4+
5+
#if NET6_0_OR_GREATER
6+
using System.Runtime.Intrinsics;
7+
using System.Runtime.Intrinsics.X86;
8+
#endif
9+
10+
namespace BitFaster.Caching.Benchmarks.Lfu
11+
{
12+
internal class CmSketchNoPin<T, I>
13+
where T : notnull
14+
where I : struct, IsaProbe
15+
{
16+
private const long ResetMask = 0x7777777777777777L;
17+
private const long OneMask = 0x1111111111111111L;
18+
19+
private long[] table;
20+
private int sampleSize;
21+
private int blockMask;
22+
private int size;
23+
24+
private readonly IEqualityComparer<T> comparer;
25+
26+
/// <summary>
27+
/// Initializes a new instance of the CmSketch class with the specified maximum size and equality comparer.
28+
/// </summary>
29+
/// <param name="maximumSize">The maximum size.</param>
30+
/// <param name="comparer">The equality comparer.</param>
31+
public CmSketchNoPin(long maximumSize, IEqualityComparer<T> comparer)
32+
{
33+
EnsureCapacity(maximumSize);
34+
this.comparer = comparer;
35+
}
36+
37+
/// <summary>
38+
/// Gets the reset sample size.
39+
/// </summary>
40+
public int ResetSampleSize => this.sampleSize;
41+
42+
/// <summary>
43+
/// Gets the size.
44+
/// </summary>
45+
public int Size => this.size;
46+
47+
/// <summary>
48+
/// Estimate the frequency of the specified value, up to the maximum of 15.
49+
/// </summary>
50+
/// <param name="value">The value.</param>
51+
/// <returns>The estimated frequency of the value.</returns>
52+
public int EstimateFrequency(T value)
53+
{
54+
#if NET48
55+
return EstimateFrequencyStd(value);
56+
#else
57+
58+
I isa = default;
59+
60+
if (isa.IsAvx2Supported)
61+
{
62+
return EstimateFrequencyAvx(value);
63+
}
64+
else
65+
{
66+
return EstimateFrequencyStd(value);
67+
}
68+
#endif
69+
}
70+
71+
/// <summary>
72+
/// Increment the count of the specified value.
73+
/// </summary>
74+
/// <param name="value">The value.</param>
75+
public void Increment(T value)
76+
{
77+
#if NET48
78+
IncrementStd(value);
79+
#else
80+
81+
I isa = default;
82+
83+
if (isa.IsAvx2Supported)
84+
{
85+
IncrementAvx(value);
86+
}
87+
else
88+
{
89+
IncrementStd(value);
90+
}
91+
#endif
92+
}
93+
94+
/// <summary>
95+
/// Clears the count for all items.
96+
/// </summary>
97+
public void Clear()
98+
{
99+
table = new long[table.Length];
100+
size = 0;
101+
}
102+
103+
// [MemberNotNull(nameof(table))]
104+
private void EnsureCapacity(long maximumSize)
105+
{
106+
int maximum = (int)Math.Min(maximumSize, int.MaxValue >> 1);
107+
108+
table = new long[Math.Max(BitOps.CeilingPowerOfTwo(maximum), 8)];
109+
blockMask = (int)((uint)table.Length >> 3) - 1;
110+
sampleSize = (maximumSize == 0) ? 10 : (10 * maximum);
111+
112+
size = 0;
113+
}
114+
115+
private unsafe int EstimateFrequencyStd(T value)
116+
{
117+
var count = stackalloc int[4];
118+
int blockHash = Spread(comparer.GetHashCode(value));
119+
int counterHash = Rehash(blockHash);
120+
int block = (blockHash & blockMask) << 3;
121+
122+
for (int i = 0; i < 4; i++)
123+
{
124+
int h = (int)((uint)counterHash >> (i << 3));
125+
int index = (h >> 1) & 15;
126+
int offset = h & 1;
127+
count[i] = (int)(((ulong)table[block + offset + (i << 1)] >> (index << 2)) & 0xfL);
128+
}
129+
return Math.Min(Math.Min(count[0], count[1]), Math.Min(count[2], count[3]));
130+
}
131+
132+
private unsafe void IncrementStd(T value)
133+
{
134+
var index = stackalloc int[8];
135+
int blockHash = Spread(comparer.GetHashCode(value));
136+
int counterHash = Rehash(blockHash);
137+
int block = (blockHash & blockMask) << 3;
138+
139+
for (int i = 0; i < 4; i++)
140+
{
141+
int h = (int)((uint)counterHash >> (i << 3));
142+
index[i] = (h >> 1) & 15;
143+
int offset = h & 1;
144+
index[i + 4] = block + offset + (i << 1);
145+
}
146+
147+
bool added =
148+
IncrementAt(index[4], index[0])
149+
| IncrementAt(index[5], index[1])
150+
| IncrementAt(index[6], index[2])
151+
| IncrementAt(index[7], index[3]);
152+
153+
if (added && (++size == sampleSize))
154+
{
155+
Reset();
156+
}
157+
}
158+
159+
// Applies another round of hashing for additional randomization
160+
private static int Rehash(int x)
161+
{
162+
x = (int)(x * 0x31848bab);
163+
x ^= (int)((uint)x >> 14);
164+
return x;
165+
}
166+
167+
// Applies a supplemental hash functions to defends against poor quality hash.
168+
private static int Spread(int x)
169+
{
170+
x ^= (int)((uint)x >> 17);
171+
x = (int)(x * 0xed5ad4bb);
172+
x ^= (int)((uint)x >> 11);
173+
x = (int)(x * 0xac4c1b51);
174+
x ^= (int)((uint)x >> 15);
175+
return x;
176+
}
177+
178+
private bool IncrementAt(int i, int j)
179+
{
180+
int offset = j << 2;
181+
long mask = (0xfL << offset);
182+
183+
if ((table[i] & mask) != mask)
184+
{
185+
table[i] += (1L << offset);
186+
return true;
187+
}
188+
189+
return false;
190+
}
191+
192+
private void Reset()
193+
{
194+
// unroll, almost 2x faster
195+
int count0 = 0;
196+
int count1 = 0;
197+
int count2 = 0;
198+
int count3 = 0;
199+
200+
for (int i = 0; i < table.Length; i += 4)
201+
{
202+
count0 += BitOps.BitCount(table[i] & OneMask);
203+
count1 += BitOps.BitCount(table[i + 1] & OneMask);
204+
count2 += BitOps.BitCount(table[i + 2] & OneMask);
205+
count3 += BitOps.BitCount(table[i + 3] & OneMask);
206+
207+
table[i] = (long)((ulong)table[i] >> 1) & ResetMask;
208+
table[i + 1] = (long)((ulong)table[i + 1] >> 1) & ResetMask;
209+
table[i + 2] = (long)((ulong)table[i + 2] >> 1) & ResetMask;
210+
table[i + 3] = (long)((ulong)table[i + 3] >> 1) & ResetMask;
211+
}
212+
213+
count0 = (count0 + count1) + (count2 + count3);
214+
215+
size = (size - (count0 >> 2)) >> 1;
216+
}
217+
218+
#if NET6_0_OR_GREATER
219+
private unsafe int EstimateFrequencyAvx(T value)
220+
{
221+
int blockHash = Spread(comparer.GetHashCode(value));
222+
int counterHash = Rehash(blockHash);
223+
int block = (blockHash & blockMask) << 3;
224+
225+
Vector128<int> h = Vector128.Create(counterHash);
226+
h = Avx2.ShiftRightLogicalVariable(h.AsUInt32(), Vector128.Create(0U, 8U, 16U, 24U)).AsInt32();
227+
228+
var index = Avx2.ShiftRightLogical(h, 1);
229+
index = Avx2.And(index, Vector128.Create(15)); // j - counter index
230+
Vector128<int> offset = Avx2.And(h, Vector128.Create(1));
231+
Vector128<int> blockOffset = Avx2.Add(Vector128.Create(block), offset); // i - table index
232+
blockOffset = Avx2.Add(blockOffset, Vector128.Create(0, 2, 4, 6)); // + (i << 1)
233+
234+
fixed (long* tablePtr = table)
235+
{
236+
Vector256<long> tableVector = Avx2.GatherVector256(tablePtr, blockOffset, 8);
237+
index = Avx2.ShiftLeftLogical(index, 2);
238+
239+
// convert index from int to long via permute
240+
Vector256<long> indexLong = Vector256.Create(index, Vector128<int>.Zero).AsInt64();
241+
Vector256<int> permuteMask2 = Vector256.Create(0, 4, 1, 5, 2, 5, 3, 7);
242+
indexLong = Avx2.PermuteVar8x32(indexLong.AsInt32(), permuteMask2).AsInt64();
243+
tableVector = Avx2.ShiftRightLogicalVariable(tableVector, indexLong.AsUInt64());
244+
tableVector = Avx2.And(tableVector, Vector256.Create(0xfL));
245+
246+
Vector256<int> permuteMask = Vector256.Create(0, 2, 4, 6, 1, 3, 5, 7);
247+
Vector128<ushort> count = Avx2.PermuteVar8x32(tableVector.AsInt32(), permuteMask)
248+
.GetLower()
249+
.AsUInt16();
250+
251+
// set the zeroed high parts of the long value to ushort.Max
252+
#if NET6_0
253+
count = Avx2.Blend(count, Vector128<ushort>.AllBitsSet, 0b10101010);
254+
#else
255+
count = Avx2.Blend(count, Vector128.Create(ushort.MaxValue), 0b10101010);
256+
#endif
257+
258+
return Avx2.MinHorizontal(count).GetElement(0);
259+
}
260+
}
261+
262+
private unsafe void IncrementAvx(T value)
263+
{
264+
int blockHash = Spread(comparer.GetHashCode(value));
265+
int counterHash = Rehash(blockHash);
266+
int block = (blockHash & blockMask) << 3;
267+
268+
Vector128<int> h = Vector128.Create(counterHash);
269+
h = Avx2.ShiftRightLogicalVariable(h.AsUInt32(), Vector128.Create(0U, 8U, 16U, 24U)).AsInt32();
270+
271+
Vector128<int> index = Avx2.ShiftRightLogical(h, 1);
272+
index = Avx2.And(index, Vector128.Create(15)); // j - counter index
273+
Vector128<int> offset = Avx2.And(h, Vector128.Create(1));
274+
Vector128<int> blockOffset = Avx2.Add(Vector128.Create(block), offset); // i - table index
275+
blockOffset = Avx2.Add(blockOffset, Vector128.Create(0, 2, 4, 6)); // + (i << 1)
276+
277+
fixed (long* tablePtr = table)
278+
{
279+
Vector256<long> tableVector = Avx2.GatherVector256(tablePtr, blockOffset, 8);
280+
281+
// j == index
282+
index = Avx2.ShiftLeftLogical(index, 2);
283+
Vector256<long> offsetLong = Vector256.Create(index, Vector128<int>.Zero).AsInt64();
284+
285+
Vector256<int> permuteMask = Vector256.Create(0, 4, 1, 5, 2, 5, 3, 7);
286+
offsetLong = Avx2.PermuteVar8x32(offsetLong.AsInt32(), permuteMask).AsInt64();
287+
288+
// mask = (0xfL << offset)
289+
Vector256<long> fifteen = Vector256.Create(0xfL);
290+
Vector256<long> mask = Avx2.ShiftLeftLogicalVariable(fifteen, offsetLong.AsUInt64());
291+
292+
// (table[i] & mask) != mask)
293+
// Note masked is 'equal' - therefore use AndNot below
294+
Vector256<long> masked = Avx2.CompareEqual(Avx2.And(tableVector, mask), mask);
295+
296+
// 1L << offset
297+
Vector256<long> inc = Avx2.ShiftLeftLogicalVariable(Vector256.Create(1L), offsetLong.AsUInt64());
298+
299+
// Mask to zero out non matches (add zero below) - first operand is NOT then AND result (order matters)
300+
inc = Avx2.AndNot(masked, inc);
301+
302+
Vector256<byte> result = Avx2.CompareEqual(masked.AsByte(), Vector256<byte>.Zero);
303+
bool wasInc = Avx2.MoveMask(result.AsByte()) == unchecked((int)(0b1111_1111_1111_1111_1111_1111_1111_1111));
304+
305+
tablePtr[blockOffset.GetElement(0)] += inc.GetElement(0);
306+
tablePtr[blockOffset.GetElement(1)] += inc.GetElement(1);
307+
tablePtr[blockOffset.GetElement(2)] += inc.GetElement(2);
308+
tablePtr[blockOffset.GetElement(3)] += inc.GetElement(3);
309+
310+
if (wasInc && (++size == sampleSize))
311+
{
312+
Reset();
313+
}
314+
}
315+
}
316+
#endif
317+
}
318+
}

BitFaster.Caching.Benchmarks/Lfu/SketchFrequency.cs

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,11 @@
88
namespace BitFaster.Caching.Benchmarks.Lfu
99
{
1010
[SimpleJob(RuntimeMoniker.Net60)]
11+
[SimpleJob(RuntimeMoniker.Net80)]
12+
[SimpleJob(RuntimeMoniker.Net90)]
1113
[MemoryDiagnoser(displayGenColumns: false)]
1214
[HideColumns("Job", "Median", "RatioSD", "Alloc Ratio")]
13-
[ColumnChart(Title ="Sketch Frequency ({JOB})")]
15+
[ColumnChart(Title = "Sketch Frequency ({JOB})", Colors = "#cd5c5c,#fa8072,#ffa07a")]
1416
public class SketchFrequency
1517
{
1618
const int sketchSize = 1_048_576;
@@ -20,6 +22,7 @@ public class SketchFrequency
2022
private CmSketchFlat<int, DetectIsa> flatAvx;
2123

2224
private CmSketchCore<int, DisableHardwareIntrinsics> blockStd;
25+
private CmSketchNoPin<int, DetectIsa> blockAvxNoPin;
2326
private CmSketchCore<int, DetectIsa> blockAvx;
2427

2528
[Params(32_768, 524_288, 8_388_608, 134_217_728)]
@@ -32,6 +35,7 @@ public void Setup()
3235
flatAvx = new CmSketchFlat<int, DetectIsa>(Size, EqualityComparer<int>.Default);
3336

3437
blockStd = new CmSketchCore<int, DisableHardwareIntrinsics>(Size, EqualityComparer<int>.Default);
38+
blockAvxNoPin = new CmSketchNoPin<int, DetectIsa>(Size, EqualityComparer<int>.Default);
3539
blockAvx = new CmSketchCore<int, DetectIsa>(Size, EqualityComparer<int>.Default);
3640
}
3741

@@ -66,7 +70,17 @@ public int FrequencyBlock()
6670
}
6771

6872
[Benchmark(OperationsPerInvoke = iterations)]
69-
public int FrequencyBlockAvx()
73+
public int FrequencyBlockAvxNotPinned()
74+
{
75+
int count = 0;
76+
for (int i = 0; i < iterations; i++)
77+
count += blockAvxNoPin.EstimateFrequency(i) > blockAvx.EstimateFrequency(i + 1) ? 1 : 0;
78+
79+
return count;
80+
}
81+
82+
[Benchmark(OperationsPerInvoke = iterations)]
83+
public int FrequencyBlockAvxPinned()
7084
{
7185
int count = 0;
7286
for (int i = 0; i < iterations; i++)

0 commit comments

Comments
 (0)