Skip to content

Commit 600ae64

Browse files
committed
Advanced params support
1 parent bca4f04 commit 600ae64

10 files changed

+279
-40
lines changed

ZstdNet.Benchmarks/CompressionBenchmarks.cs

+3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using System;
2+
using System.Collections.Generic;
23
using System.IO;
34
using System.Threading.Tasks;
45
using BenchmarkDotNet.Attributes;
@@ -18,6 +19,7 @@ public class CompressionOverheadBenchmarks
1819
private byte[] Buffer;
1920

2021
private readonly Compressor Compressor = new Compressor(CompressionOptions.Default);
22+
private readonly Compressor CompressorAdvanced = new Compressor(new CompressionOptions(null, new Dictionary<ZSTD_cParameter, int>()));
2123
private readonly Decompressor Decompressor = new Decompressor();
2224

2325
[GlobalSetup]
@@ -38,6 +40,7 @@ public void GlobalSetup()
3840
}
3941

4042
[Benchmark] public void Compress() => Compressor.Wrap(Data, Buffer, 0);
43+
[Benchmark] public void CompressAdvanced() => CompressorAdvanced.Wrap(Data, Buffer, 0);
4144
[Benchmark] public void Decompress() => Decompressor.Unwrap(CompressedData, Buffer, 0);
4245

4346
[Benchmark]

ZstdNet.Tests/Binding_Tests.cs

+27
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using System;
2+
using System.Collections.Generic;
23
using System.Linq;
34
using System.Text;
45
using System.Threading;
@@ -42,6 +43,32 @@ public void CompressAndDecompress_workCorrectly([Values(false, true)] bool useDi
4243
CollectionAssert.AreEqual(data, decompressed);
4344
}
4445

46+
[Test]
47+
public void CompressAndDecompress_worksCorrectly_advanced([Values(false, true)] bool useDictionary)
48+
{
49+
var data = GenerateSample();
50+
var dict = useDictionary ? BuildDictionary() : null;
51+
52+
byte[] compressed1, compressed2;
53+
54+
using(var options = new CompressionOptions(dict, new Dictionary<ZSTD_cParameter, int> {{ZSTD_cParameter.ZSTD_c_checksumFlag, 0}}))
55+
using(var compressor = new Compressor(options))
56+
compressed1 = compressor.Wrap(data);
57+
58+
using(var options = new CompressionOptions(dict, new Dictionary<ZSTD_cParameter, int> {{ZSTD_cParameter.ZSTD_c_checksumFlag, 1}}))
59+
using(var compressor = new Compressor(options))
60+
compressed2 = compressor.Wrap(data);
61+
62+
Assert.AreEqual(compressed1.Length + 4, compressed2.Length);
63+
64+
using(var options = new DecompressionOptions(dict, new Dictionary<ZSTD_dParameter, int>()))
65+
using(var decompressor = new Decompressor(options))
66+
{
67+
CollectionAssert.AreEqual(data, decompressor.Unwrap(compressed1));
68+
CollectionAssert.AreEqual(data, decompressor.Unwrap(compressed2));
69+
}
70+
}
71+
4572
[Test]
4673
public void DecompressWithDictionary_worksCorrectly_onDataCompressedWithoutIt()
4774
{

ZstdNet.Tests/SteamingCompressionTests.cs

+74-29
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using System;
2+
using System.Collections.Generic;
23
using System.IO;
34
using System.Linq;
45
using System.Threading;
@@ -17,8 +18,8 @@ internal static class DataGenerator
1718
{
1819
private static readonly Random Random = new Random(1234);
1920

20-
public const int LargeBufferSize = 1 * 1024 * 1024;
21-
public const int SmallBufferSize = 1 * 1024;
21+
public const int LargeBufferSize = 1024 * 1024;
22+
public const int SmallBufferSize = 1024;
2223

2324
public static MemoryStream GetSmallStream(DataFill dataFill) => GetStream(SmallBufferSize, dataFill);
2425
public static MemoryStream GetLargeStream(DataFill dataFill) => GetStream(LargeBufferSize, dataFill);
@@ -168,24 +169,30 @@ public void StreamingCompressionFlushDataFromInternalBuffers()
168169
[Test]
169170
public void CompressionImprovesWithDictionary()
170171
{
171-
var trainingData = new byte[100][];
172-
for(int i = 0; i < trainingData.Length; i++)
173-
trainingData[i] = DataGenerator.GetSmallBuffer(DataFill.Random);
174-
175-
var dict = DictBuilder.TrainFromBuffer(trainingData);
172+
var dict = TrainDict();
176173
var compressionOptions = new CompressionOptions(dict);
177174

178-
var dataStream = DataGenerator.GetSmallStream(DataFill.Random);
175+
var dataStream = DataGenerator.GetSmallStream(DataFill.Sequential);
179176

180177
var normalResultStream = new MemoryStream();
181178
using(var compressionStream = new CompressionStream(normalResultStream))
182179
dataStream.CopyTo(compressionStream);
183180

181+
dataStream.Seek(0, SeekOrigin.Begin);
182+
184183
var dictResultStream = new MemoryStream();
185184
using(var compressionStream = new CompressionStream(dictResultStream, compressionOptions))
186185
dataStream.CopyTo(compressionStream);
187186

188187
Assert.Greater(normalResultStream.Length, dictResultStream.Length);
188+
189+
dictResultStream.Seek(0, SeekOrigin.Begin);
190+
191+
var resultStream = new MemoryStream();
192+
using(var decompressionStream = new DecompressionStream(dictResultStream, new DecompressionOptions(dict)))
193+
decompressionStream.CopyTo(resultStream);
194+
195+
Assert.AreEqual(dataStream.ToArray(), resultStream.ToArray());
189196
}
190197

191198
[Test]
@@ -234,67 +241,97 @@ public void RoundTrip_StreamingToBatch()
234241

235242
[Test, Combinatorial, Parallelizable(ParallelScope.Children)]
236243
public void RoundTrip_StreamingToStreaming(
244+
[Values(false, true)] bool useDict, [Values(false, true)] bool advanced,
237245
[Values(1, 2, 7, 101, 1024, 65535, DataGenerator.LargeBufferSize, DataGenerator.LargeBufferSize + 1)] int zstdBufferSize,
238246
[Values(1, 2, 7, 101, 1024, 65535, DataGenerator.LargeBufferSize, DataGenerator.LargeBufferSize + 1)] int copyBufferSize)
239247
{
248+
var dict = useDict ? TrainDict() : null;
240249
var testStream = DataGenerator.GetLargeStream(DataFill.Sequential);
241250

251+
const int offset = 1;
252+
var buffer = new byte[copyBufferSize + offset + 1];
253+
242254
var tempStream = new MemoryStream();
243-
using(var compressionStream = new CompressionStream(tempStream, zstdBufferSize))
244-
testStream.CopyTo(compressionStream, copyBufferSize);
255+
using(var compressionStream = new CompressionStream(tempStream, new CompressionOptions(dict, advanced ? new Dictionary<ZSTD_cParameter, int> {{ZSTD_cParameter.ZSTD_c_windowLog, 11}, {ZSTD_cParameter.ZSTD_c_checksumFlag, 1}, {ZSTD_cParameter.ZSTD_c_nbWorkers, 4}} : null), zstdBufferSize))
256+
{
257+
int bytesRead;
258+
while((bytesRead = testStream.Read(buffer, offset, copyBufferSize)) > 0)
259+
compressionStream.Write(buffer, offset, bytesRead);
260+
}
245261

246262
tempStream.Seek(0, SeekOrigin.Begin);
247263

248264
var resultStream = new MemoryStream();
249-
using(var decompressionStream = new DecompressionStream(tempStream, zstdBufferSize))
250-
decompressionStream.CopyTo(resultStream, copyBufferSize);
265+
using(var decompressionStream = new DecompressionStream(tempStream, new DecompressionOptions(dict, advanced ? new Dictionary<ZSTD_dParameter, int> {{ZSTD_dParameter.ZSTD_d_windowLogMax, 11}} : null), zstdBufferSize))
266+
{
267+
int bytesRead;
268+
while((bytesRead = decompressionStream.Read(buffer, offset, copyBufferSize)) > 0)
269+
resultStream.Write(buffer, offset, bytesRead);
270+
}
251271

252272
Assert.AreEqual(testStream.ToArray(), resultStream.ToArray());
253273
}
254274

255275
[Test, Combinatorial, Parallelizable(ParallelScope.Children)]
256276
public async Task RoundTrip_StreamingToStreamingAsync(
277+
[Values(false, true)] bool useDict, [Values(false, true)] bool advanced,
257278
[Values(1, 2, 7, 101, 1024, 65535, DataGenerator.LargeBufferSize, DataGenerator.LargeBufferSize + 1)] int zstdBufferSize,
258279
[Values(1, 2, 7, 101, 1024, 65535, DataGenerator.LargeBufferSize, DataGenerator.LargeBufferSize + 1)] int copyBufferSize)
259280
{
281+
var dict = useDict ? TrainDict() : null;
260282
var testStream = DataGenerator.GetLargeStream(DataFill.Sequential);
261283

284+
const int offset = 1;
285+
var buffer = new byte[copyBufferSize + offset + 1];
286+
262287
var tempStream = new MemoryStream();
263-
await using(var compressionStream = new CompressionStream(tempStream, zstdBufferSize))
264-
await testStream.CopyToAsync(compressionStream, copyBufferSize);
288+
await using(var compressionStream = new CompressionStream(tempStream, new CompressionOptions(dict, advanced ? new Dictionary<ZSTD_cParameter, int> {{ZSTD_cParameter.ZSTD_c_windowLog, 11}, {ZSTD_cParameter.ZSTD_c_checksumFlag, 1}, {ZSTD_cParameter.ZSTD_c_nbWorkers, 4}} : null), zstdBufferSize))
289+
{
290+
int bytesRead;
291+
while((bytesRead = await testStream.ReadAsync(buffer, offset, copyBufferSize)) > 0)
292+
await compressionStream.WriteAsync(buffer, offset, bytesRead);
293+
}
265294

266295
tempStream.Seek(0, SeekOrigin.Begin);
267296

268297
var resultStream = new MemoryStream();
269-
await using(var decompressionStream = new DecompressionStream(tempStream, zstdBufferSize))
270-
await decompressionStream.CopyToAsync(resultStream, copyBufferSize);
298+
await using(var decompressionStream = new DecompressionStream(tempStream, new DecompressionOptions(dict, advanced ? new Dictionary<ZSTD_dParameter, int> {{ZSTD_dParameter.ZSTD_d_windowLogMax, 11}} : null), zstdBufferSize))
299+
{
300+
int bytesRead;
301+
while((bytesRead = await decompressionStream.ReadAsync(buffer, offset, copyBufferSize)) > 0)
302+
await resultStream.WriteAsync(buffer, offset, bytesRead);
303+
}
271304

272305
Assert.AreEqual(testStream.ToArray(), resultStream.ToArray());
273306
}
274307

275308
[Test, Explicit("stress")]
276-
public void RoundTrip_StreamingToStreaming_Stress([Values(true, false)] bool async)
309+
public void RoundTrip_StreamingToStreaming_Stress([Values(true, false)] bool useDict, [Values(true, false)] bool async)
277310
{
278311
long i = 0;
312+
var dict = useDict ? TrainDict() : null;
313+
var compressionOptions = new CompressionOptions(dict);
314+
var decompressionOptions = new DecompressionOptions(dict);
279315
Enumerable.Range(0, 10000)
280316
.AsParallel()
281317
.WithDegreeOfParallelism(Environment.ProcessorCount * 4)
282-
.ForAll(_ =>
318+
.ForAll(n =>
283319
{
284-
var buffer = new byte[13];
285-
var testStream = DataGenerator.GetSmallStream(DataFill.Random);
320+
var testStream = DataGenerator.GetSmallStream(DataFill.Sequential);
321+
var cBuffer = new byte[1 + (int)(n % (testStream.Length * 11))];
322+
var dBuffer = new byte[1 + (int)(n % (testStream.Length * 13))];
286323

287324
var tempStream = new MemoryStream();
288-
using(var compressionStream = new CompressionStream(tempStream, 511))
325+
using(var compressionStream = new CompressionStream(tempStream, compressionOptions, 1 + (int)(n % (testStream.Length * 17))))
289326
{
290327
int bytesRead;
291-
int offset = (int)(Interlocked.Read(ref i) % buffer.Length);
292-
while((bytesRead = testStream.Read(buffer, offset, buffer.Length - offset)) > 0)
328+
int offset = n % cBuffer.Length;
329+
while((bytesRead = testStream.Read(cBuffer, offset, cBuffer.Length - offset)) > 0)
293330
{
294331
if(async)
295-
compressionStream.WriteAsync(buffer, offset, bytesRead).GetAwaiter().GetResult();
332+
compressionStream.WriteAsync(cBuffer, offset, bytesRead).GetAwaiter().GetResult();
296333
else
297-
compressionStream.Write(buffer, offset, bytesRead);
334+
compressionStream.Write(cBuffer, offset, bytesRead);
298335
if(Interlocked.Increment(ref i) % 100 == 0)
299336
GC.Collect(GC.MaxGeneration, GCCollectionMode.Forced, true, true);
300337
}
@@ -303,13 +340,13 @@ public void RoundTrip_StreamingToStreaming_Stress([Values(true, false)] bool asy
303340
tempStream.Seek(0, SeekOrigin.Begin);
304341

305342
var resultStream = new MemoryStream();
306-
using(var decompressionStream = new DecompressionStream(tempStream, 511))
343+
using(var decompressionStream = new DecompressionStream(tempStream, decompressionOptions, 1 + (int)(n % (testStream.Length * 19))))
307344
{
308345
int bytesRead;
309-
int offset = (int)(Interlocked.Read(ref i) % buffer.Length);
310-
while((bytesRead = async ? decompressionStream.ReadAsync(buffer, offset, buffer.Length - offset).GetAwaiter().GetResult() : decompressionStream.Read(buffer, offset, buffer.Length - offset)) > 0)
346+
int offset = n % dBuffer.Length;
347+
while((bytesRead = async ? decompressionStream.ReadAsync(dBuffer, offset, dBuffer.Length - offset).GetAwaiter().GetResult() : decompressionStream.Read(dBuffer, offset, dBuffer.Length - offset)) > 0)
311348
{
312-
resultStream.Write(buffer, offset, bytesRead);
349+
resultStream.Write(dBuffer, offset, bytesRead);
313350
if(Interlocked.Increment(ref i) % 100 == 0)
314351
GC.Collect(GC.MaxGeneration, GCCollectionMode.Forced, true, true);
315352
}
@@ -318,5 +355,13 @@ public void RoundTrip_StreamingToStreaming_Stress([Values(true, false)] bool asy
318355
Assert.AreEqual(testStream.ToArray(), resultStream.ToArray());
319356
});
320357
}
358+
359+
private static byte[] TrainDict()
360+
{
361+
var trainingData = new byte[100][];
362+
for(int i = 0; i < trainingData.Length; i++)
363+
trainingData[i] = DataGenerator.GetSmallBuffer(DataFill.Sequential);
364+
return DictBuilder.TrainFromBuffer(trainingData);
365+
}
321366
}
322367
}

ZstdNet/CompressionOptions.cs

+32
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using System;
2+
using System.Collections.Generic;
23
using size_t = System.UIntPtr;
34

45
namespace ZstdNet
@@ -24,6 +25,36 @@ public CompressionOptions(byte[] dict, int compressionLevel = DefaultCompression
2425
GC.SuppressFinalize(this); // No unmanaged resources
2526
}
2627

28+
public CompressionOptions(byte[] dict, IReadOnlyDictionary<ZSTD_cParameter, int> advancedParams, int compressionLevel = DefaultCompressionLevel)
29+
: this(dict, compressionLevel)
30+
{
31+
if(advancedParams == null)
32+
return;
33+
34+
foreach(var param in advancedParams)
35+
{
36+
var bounds = ExternMethods.ZSTD_cParam_getBounds(param.Key);
37+
bounds.error.EnsureZstdSuccess();
38+
39+
if(param.Value < bounds.lowerBound || param.Value > bounds.upperBound)
40+
throw new ArgumentOutOfRangeException(nameof(advancedParams), $"Advanced parameter '{param.Key}' is out of range [{bounds.lowerBound}, {bounds.upperBound}]");
41+
}
42+
43+
this.AdvancedParams = advancedParams;
44+
}
45+
46+
internal void ApplyCompressionParams(IntPtr cctx)
47+
{
48+
if(AdvancedParams == null || !AdvancedParams.ContainsKey(ZSTD_cParameter.ZSTD_c_compressionLevel))
49+
ExternMethods.ZSTD_CCtx_setParameter(cctx, ZSTD_cParameter.ZSTD_c_compressionLevel, CompressionLevel).EnsureZstdSuccess();
50+
51+
if(AdvancedParams == null)
52+
return;
53+
54+
foreach(var param in AdvancedParams)
55+
ExternMethods.ZSTD_CCtx_setParameter(cctx, param.Key, param.Value).EnsureZstdSuccess();
56+
}
57+
2758
~CompressionOptions() => Dispose(false);
2859

2960
public void Dispose()
@@ -51,6 +82,7 @@ private void Dispose(bool disposing)
5182

5283
public readonly int CompressionLevel;
5384
public readonly byte[] Dictionary;
85+
public readonly IReadOnlyDictionary<ZSTD_cParameter, int> AdvancedParams;
5486

5587
internal IntPtr Cdict;
5688
}

ZstdNet/CompressionStream.cs

+9-4
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,15 @@ public CompressionStream(Stream stream, CompressionOptions options, int bufferSi
4040
innerStream = stream;
4141

4242
cStream = ZSTD_createCStream().EnsureZstdSuccess();
43-
if(options.Cdict == IntPtr.Zero)
44-
ZSTD_initCStream(cStream, options.CompressionLevel).EnsureZstdSuccess();
45-
else
46-
ZSTD_initCStream_usingCDict(cStream, options.Cdict).EnsureZstdSuccess();
43+
ZSTD_CCtx_reset(cStream, ZSTD_ResetDirective.ZSTD_reset_session_only).EnsureZstdSuccess();
44+
45+
if(options != null)
46+
{
47+
options.ApplyCompressionParams(cStream);
48+
49+
if(options.Cdict != IntPtr.Zero)
50+
ZSTD_CCtx_refCDict(cStream, options.Cdict).EnsureZstdSuccess();
51+
}
4752

4853
this.bufferSize = bufferSize > 0 ? bufferSize : (int)ZSTD_CStreamOutSize().EnsureZstdSuccess();
4954
outputBuffer = ArrayPool<byte>.Shared.Rent(this.bufferSize);

ZstdNet/Compressor.cs

+10-3
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,11 @@ public Compressor(CompressionOptions options)
1313
{
1414
Options = options;
1515
cctx = ExternMethods.ZSTD_createCCtx().EnsureZstdSuccess();
16+
17+
options.ApplyCompressionParams(cctx);
18+
19+
if(options.Cdict != IntPtr.Zero)
20+
ExternMethods.ZSTD_CCtx_refCDict(cctx, options.Cdict).EnsureZstdSuccess();
1621
}
1722

1823
~Compressor() => Dispose(false);
@@ -72,9 +77,11 @@ public int Wrap(ReadOnlySpan<byte> src, byte[] dst, int offset)
7277

7378
public int Wrap(ReadOnlySpan<byte> src, Span<byte> dst)
7479
{
75-
var dstSize = Options.Cdict == IntPtr.Zero
76-
? ExternMethods.ZSTD_compressCCtx(cctx, dst, (size_t)dst.Length, src, (size_t)src.Length, Options.CompressionLevel)
77-
: ExternMethods.ZSTD_compress_usingCDict(cctx, dst, (size_t)dst.Length, src, (size_t)src.Length, Options.Cdict);
80+
var dstSize = Options.AdvancedParams != null
81+
? ExternMethods.ZSTD_compress2(cctx, dst, (size_t)dst.Length, src, (size_t)src.Length)
82+
: Options.Cdict == IntPtr.Zero
83+
? ExternMethods.ZSTD_compressCCtx(cctx, dst, (size_t)dst.Length, src, (size_t)src.Length, Options.CompressionLevel)
84+
: ExternMethods.ZSTD_compress_usingCDict(cctx, dst, (size_t)dst.Length, src, (size_t)src.Length, Options.Cdict);
7885

7986
return (int)dstSize.EnsureZstdSuccess();
8087
}

0 commit comments

Comments
 (0)