Skip to content

Commit

Permalink
Allow non-pointer types
Browse files Browse the repository at this point in the history
  • Loading branch information
rgngl committed Dec 26, 2023
1 parent b1e1334 commit c6a8b4a
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 40 deletions.
29 changes: 13 additions & 16 deletions queue.go
Original file line number Diff line number Diff line change
@@ -1,30 +1,26 @@
// Package lfg implements a lock-free, multiple-producer, multiple-consumer queue.
// Queue[T any] is a generic type, where the items added to the queue are of type *T.
package lfg

import (
"sync/atomic"
"unsafe"

"golang.org/x/sys/cpu"
)

const cacheLinesize = unsafe.Sizeof(cpu.CacheLinePad{})

// Queue[T any] is a lock-free, multiple-producer, multiple-consumer queue.
type Queue[T any] struct {
buf []*T
buf []T

mask int64

_ [cacheLinesize]byte
_ cpu.CacheLinePad
consumerBarrier atomic.Int64
_ [cacheLinesize]byte
_ cpu.CacheLinePad
consumerCursor atomic.Int64

_ [cacheLinesize]byte
_ cpu.CacheLinePad
producerBarrier atomic.Int64
_ [cacheLinesize]byte
_ cpu.CacheLinePad
producerCursor atomic.Int64
}

Expand All @@ -35,13 +31,13 @@ func NewQueue[T any](size uint) *Queue[T] {
}

return &Queue[T]{
buf: make([]*T, size),
buf: make([]T, size),
mask: int64(size - 1),
}
}

// Enqueue adds an item to the queue. It returns false if the buffer is full.
func (b *Queue[T]) Enqueue(v *T) bool {
func (b *Queue[T]) Enqueue(v T) bool {
var pc, cb int64

for {
Expand All @@ -57,7 +53,7 @@ func (b *Queue[T]) Enqueue(v *T) bool {
}
}

atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&b.buf[(pc+1)&b.mask])), unsafe.Pointer(v))
b.buf[(pc+1)&b.mask] = v

for {
if b.producerBarrier.CompareAndSwap(pc, pc+1) {
Expand All @@ -69,31 +65,32 @@ func (b *Queue[T]) Enqueue(v *T) bool {
}

// Dequeue removes an item from the queue. It returns false if the buffer is empty.
func (b *Queue[T]) Dequeue() (*T, bool) {
func (b *Queue[T]) Dequeue() (T, bool) {
var cc, pb int64

for {
cc = b.consumerCursor.Load()
pb = b.producerBarrier.Load()

if pb == cc {
return nil, false
var zero T
return zero, false
}

if b.consumerCursor.CompareAndSwap(cc, cc+1) {
break
}
}

v := atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&b.buf[(cc+1)&b.mask])))
v := b.buf[(cc+1)&b.mask]

for {
if b.consumerBarrier.CompareAndSwap(cc, cc+1) {
break
}
}

return (*T)(v), true
return v, true
}

func isPot(n uint) bool {
Expand Down
44 changes: 20 additions & 24 deletions queue_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,54 +15,54 @@ func TestRingBufferSingleThread(t *testing.T) {
t.Run("empty queue must return false", func(t *testing.T) {
v, ok := b.Dequeue()
assert.False(t, ok)
assert.Nil(t, v)
assert.Zero(t, v)
})

t.Run("enqueue one and dequeue one", func(t *testing.T) {
ok := b.Enqueue(intPtr(0))
ok := b.Enqueue(0)
assert.True(t, ok)

v, ok := b.Dequeue()
assert.True(t, ok)
assert.Equal(t, 0, *v)
assert.Equal(t, 0, v)
})

t.Run("enqueue until buffer is full", func(t *testing.T) {
ok := b.Enqueue(intPtr(1))
ok := b.Enqueue(1)
assert.True(t, ok)
ok = b.Enqueue(intPtr(2))
ok = b.Enqueue(2)
assert.True(t, ok)
ok = b.Enqueue(intPtr(3))
ok = b.Enqueue(3)
assert.True(t, ok)
ok = b.Enqueue(intPtr(4))
ok = b.Enqueue(4)
assert.False(t, ok)
})

t.Run("dequeue one and check value", func(t *testing.T) {
v, ok := b.Dequeue()
assert.True(t, ok)
assert.Equal(t, 1, *v)
assert.Equal(t, 1, v)
})

t.Run("enqueing one after dequeueing one must succeed", func(t *testing.T) {
ok := b.Enqueue(intPtr(5))
t.Run("enqueuing one after dequeueing one must succeed", func(t *testing.T) {
ok := b.Enqueue(5)
assert.True(t, ok)
})

t.Run("dequeue until buffer is empty", func(t *testing.T) {
v, ok := b.Dequeue()
assert.True(t, ok)
assert.Equal(t, 2, *v)
assert.Equal(t, 2, v)
v, ok = b.Dequeue()
assert.True(t, ok)
assert.Equal(t, 3, *v)
assert.Equal(t, 3, v)
v, ok = b.Dequeue()
assert.True(t, ok)
assert.Equal(t, 5, *v)
assert.Equal(t, 5, v)

v, ok = b.Dequeue()
assert.False(t, ok)
assert.Nil(t, v)
assert.Zero(t, v)
})

t.Run("creating queues with invalid sizes must panic", func(t *testing.T) {
Expand All @@ -77,18 +77,18 @@ func TestRingBufferSingleThread(t *testing.T) {
}

func TestRingBufferSPSC(t *testing.T) {
b := NewQueue[int](4)
b := NewQueue[int](128)

wg := sync.WaitGroup{}
wg.Add(2)

count := 1_000
count := 1_000_000

go func() {
defer wg.Done()

for i := 0; i < count; {
ok := b.Enqueue(intPtr(i))
ok := b.Enqueue(i)
if ok {
i++
}
Expand All @@ -104,7 +104,7 @@ func TestRingBufferSPSC(t *testing.T) {
v, ok := b.Dequeue()
if ok {
i++
if expected != *v {
if expected != v {
panic("unexpected value")
}
expected++
Expand All @@ -116,7 +116,7 @@ func TestRingBufferSPSC(t *testing.T) {
}

func BenchmarkRingBufferMPSC(b *testing.B) {
buf := NewQueue[testMsg](1024)
buf := NewQueue[*testMsg](1024)

const producerCount = 4
countPerProducer := b.N
Expand Down Expand Up @@ -155,7 +155,7 @@ func BenchmarkRingBufferMPSC(b *testing.B) {
}

func BenchmarkRingBufferSPSC(b *testing.B) {
buf := NewQueue[testMsg](1024)
buf := NewQueue[*testMsg](1024)

wg := sync.WaitGroup{}
wg.Add(2)
Expand Down Expand Up @@ -188,7 +188,3 @@ func BenchmarkRingBufferSPSC(b *testing.B) {

wg.Wait()
}

func intPtr(i int) *int {
return &i
}

0 comments on commit c6a8b4a

Please sign in to comment.