Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 90 additions & 0 deletions random_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
package lz4

import (
"bytes"
"crypto/md5"
"crypto/rand"
"fmt"
"io"
"testing"
)

var testDataSizes = []int{
0, 1, 16 << 10, 32 << 10, 64 << 10, 128 << 10, 256 << 10,
512 << 10, 1 << 20, 2 << 20, 4 << 20, 8 << 20, 16 << 20,
20 << 20, 32 << 20, 100 << 20,
}

// generateCompressedData generates random data of given size, compresses it using lz4,
// and returns the compressed data along with the MD5 checksum of original data.
func generateCompressedData(size int) ([]byte, []byte, error) {
reader := io.LimitReader(rand.Reader, int64(size))
hasher := md5.New()
teeReader := io.TeeReader(reader, hasher)

var compressed bytes.Buffer
lz4Writer := NewWriter(&compressed)
if _, err := io.Copy(lz4Writer, teeReader); err != nil {
return nil, nil, fmt.Errorf("error writing compressed data: %w", err)
}
if err := lz4Writer.Close(); err != nil {
return nil, nil, fmt.Errorf("error closing lz4 writer: %w", err)
}
checksum := hasher.Sum(nil)
return compressed.Bytes(), checksum, nil
}

func TestReaderWithRandomData(t *testing.T) {
for _, size := range testDataSizes {
size := size // capture range variable
t.Run(fmt.Sprintf("Size_%d_bytes", size), func(t *testing.T) {
compressedData, originalChecksum, err := generateCompressedData(size)
if err != nil {
t.Fatalf("failed to generate compressed data: %v", err)
}

decompressReader := NewReader(bytes.NewReader(compressedData))
decompressedData, err := io.ReadAll(decompressReader)
if err != nil {
t.Fatalf("decompression failed: %v", err)
}

if len(decompressedData) != size {
t.Errorf("expected decompressed data size %d, got %d", size, len(decompressedData))
}

newChecksum := md5.Sum(decompressedData)
if !bytes.Equal(originalChecksum, newChecksum[:]) {
t.Errorf("checksum mismatch for size %d", size)
}
})
}
}

func TestWriterToWithRandomData(t *testing.T) {
for _, size := range testDataSizes {
size := size // capture range variable
t.Run(fmt.Sprintf("Size_%d_bytes", size), func(t *testing.T) {
compressedData, originalChecksum, err := generateCompressedData(size)
if err != nil {
t.Fatalf("failed to generate compressed data: %v", err)
}

decompressReader := NewReader(bytes.NewReader(compressedData))
var decompressed bytes.Buffer
n, err := io.Copy(&decompressed, decompressReader)
if err != nil {
t.Fatalf("decompression failed: %v", err)
}

if n != int64(size) {
t.Errorf("expected decompressed data size %d, got %d", size, n)
}

newChecksum := md5.Sum(decompressed.Bytes())
if !bytes.Equal(originalChecksum, newChecksum[:]) {
t.Errorf("checksum mismatch for size %d", size)
}
})
}
}
12 changes: 7 additions & 5 deletions reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ type Reader struct {
src io.Reader // source reader
num int // concurrency level
frame *lz4stream.Frame // frame being read
data []byte // block buffer allocated in non concurrent mode
data []byte // block buffer allocated in non-concurrent mode
reads chan []byte // pending data
idx int // size of pending data
handler func(int)
Expand Down Expand Up @@ -157,9 +157,9 @@ func (r *Reader) Read(buf []byte) (n int, err error) {
}

// read uncompresses the next block as follow:
// - if buf has enough room, the block is uncompressed into it directly
// and the lenght of used space is returned
// - else, the uncompress data is stored in r.data and 0 is returned
// - if buf has enough room, the block is uncompressed into it directly
// and the lenght of used space is returned
// - else, the uncompress data is stored in r.data and 0 is returned
func (r *Reader) read(buf []byte) (int, error) {
block := r.frame.Blocks.Block
_, err := block.Read(r.frame, r.src, r.cum)
Expand All @@ -169,7 +169,9 @@ func (r *Reader) read(buf []byte) (int, error) {
var direct bool
dst := r.data[:cap(r.data)]
if len(buf) >= len(dst) {
// Uncompress directly into buf.
// Decompress directly into buf.
// trim r.data as it is not needed now
r.data = r.data[:0]
direct = true
dst = buf
}
Expand Down