Skip to content

Commit

Permalink
Configure context cancel handling
Browse files Browse the repository at this point in the history
Kill worker go routine and stop blocking thunk function on cancellation
of context
  • Loading branch information
andy9775 committed Sep 9, 2018
1 parent e83bd43 commit 3d627b7
Show file tree
Hide file tree
Showing 7 changed files with 334 additions and 13 deletions.
13 changes: 13 additions & 0 deletions dataloader.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,3 +173,16 @@ func (d *dataloader) LoadMany(ogCtx context.Context, keyArr ...Key) ThunkMany {
return result
}
}

/*
Should we be handling context cancellation??
is the current implementation of context canceler correct? That is, will the go routines be canceled
appropriately (we don't want them to block and leak)
specifically around the once strategy. The current implementation will execute the batch function no matter
what. It should therefore also be up to the user to handle canceling in the batch function in case it is
getting called or gets called.
The other worker go routines will cancel and stop waiting for new keys.
*/
20 changes: 18 additions & 2 deletions strategies/once/once.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,15 @@ package once
import (
"context"

"github.com/go-log/log"

"github.com/andy9775/dataloader"
)

// Options contains the strategy configuration
type options struct {
inBackground bool
logger log.Logger
}

// Option accepts the dataloader and sets an option on it.
Expand Down Expand Up @@ -58,6 +61,13 @@ func WithInBackground() Option {
}
}

// WithLogger configures the logger for the strategy. Default is a no op logger.
func WithLogger(l log.Logger) Option {
return func(o *options) {
o.logger = l
}
}

// ===========================================================================================================

// Load returns a Thunk which either calls the batch function when invoked or waits for a result from a
Expand Down Expand Up @@ -116,8 +126,13 @@ func (s *onceStrategy) LoadMany(ctx context.Context, keyArr ...dataloader.Key) d
return result
}

result = <-resultChan
return result
select {
case <-ctx.Done():
s.options.logger.Log("worker cancelled")
return dataloader.NewResultMap(0)
case result = <-resultChan:
return result
}
}
}

Expand All @@ -142,4 +157,5 @@ func (*onceStrategy) LoadNoOp(context.Context) {}
// formatOptions configures the default values for the loader
func formatOptions(opts *options) {
opts.inBackground = false
opts.logger = log.DefaultLogger
}
72 changes: 72 additions & 0 deletions strategies/once/once_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,41 @@ func timeout(t *testing.T, timeoutChannel chan struct{}, after time.Duration) {
}()
}

type mockLogger struct {
logMsgs []string
m sync.Mutex
}

func (l *mockLogger) Log(v ...interface{}) {
l.m.Lock()
defer l.m.Unlock()

for _, value := range v {
switch i := value.(type) {
case string:
l.logMsgs = append(l.logMsgs, i)
default:
panic("mock logger only takes single log string")
}
}
}

func (l *mockLogger) Logf(format string, v ...interface{}) {
l.m.Lock()
defer l.m.Unlock()

l.logMsgs = append(l.logMsgs, fmt.Sprintf(format, v...))
}

func (l *mockLogger) Messages() []string {
l.m.Lock()
defer l.m.Unlock()

result := make([]string, len(l.logMsgs))
copy(result, l.logMsgs)
return result
}

// ================================================== tests ==================================================

// ========================= foreground calls =========================
Expand Down Expand Up @@ -198,3 +233,40 @@ func TestBatchLoadManyInBackgroundCalled(t *testing.T) {
assert.Equal(t, expectedResult, r.GetValue(key).Result.(string), "Expected result from batch function")
assert.Equal(t, 1, callCount, "Batch function expected to be called on LoadMany()")
}

// =========================================== cancellable context ===========================================

// TestCancellableContextLoadMany ensures that a call to cancel the context kills the background worker
func TestCancellableContextLoadMany(t *testing.T) {
// setup
closeChan := make(chan struct{})
timeout(t, closeChan, TEST_TIMEOUT*3)

expectedResult := "cancel_via_context"
cb := func() {
close(closeChan)
}

key := PrimaryKey(1)
result := dataloader.Result{Result: expectedResult, Err: nil}
/*
ensure the loader doesn't call batch after timeout. If it does, the test will timeout and panic
*/
log := mockLogger{logMsgs: make([]string, 2), m: sync.Mutex{}}
batch := getBatchFunction(cb, result)
strategy := once.NewOnceStrategy(
once.WithLogger(&log),
once.WithInBackground(),
)(2, batch) // expected 2 load calls
ctx, cancel := context.WithCancel(context.Background())

// invoke
go cancel()
thunk := strategy.LoadMany(ctx, key)
thunk()
time.Sleep(100 * time.Millisecond)

// assert
m := log.Messages()
assert.Equal(t, "worker cancelled", m[len(m)-1], "Expected worker to cancel and log exit")
}
17 changes: 12 additions & 5 deletions strategies/sozu/sozu.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,10 @@ func NewSozuStrategy(opts ...Option) dataloader.StrategyFunction {
counter: strategies.NewCounter(capacity),

workerMutex: &sync.Mutex{},
keyChan: make(chan workerMessage, capacity),
goroutineStatus: notRunning,
options: o,

keyChan: make(chan workerMessage, capacity),
options: o,

keys: dataloader.NewKeys(capacity),
}
Expand Down Expand Up @@ -93,13 +94,12 @@ type sozuStrategy struct {
keys dataloader.Keys
batchFunc dataloader.BatchFunction

workerMutex *sync.Mutex
workerMutex *sync.Mutex
goroutineStatus int

keyChan chan workerMessage
closeChan chan struct{}

goroutineStatus int

options options
}

Expand Down Expand Up @@ -165,6 +165,8 @@ func (s *sozuStrategy) Load(ctx context.Context, key dataloader.Key) dataloader.
}

select {
case <-ctx.Done():
return dataloader.Result{Result: nil, Err: nil}
case r := <-resultChan:
result = r.GetValue(key)
return result
Expand Down Expand Up @@ -218,6 +220,8 @@ func (s *sozuStrategy) LoadMany(ctx context.Context, keyArr ...dataloader.Key) d
}

select {
case <-ctx.Done():
return dataloader.NewResultMap(0)
case r := <-resultChan:
result := dataloader.NewResultMap(len(keyArr))

Expand Down Expand Up @@ -273,6 +277,9 @@ func (s *sozuStrategy) startWorker(ctx context.Context) {
var r *dataloader.ResultMap
for r == nil {
select {
case <-ctx.Done():
s.options.logger.Log("worker cancelled")
return
case key := <-s.keyChan:
// if LoadNoOp passes a value through the chan, ignore the data and increment the counter
if key.resultChan != nil {
Expand Down
103 changes: 103 additions & 0 deletions strategies/sozu/sozu_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,41 @@ func timeout(t *testing.T, timeoutChannel chan struct{}, after time.Duration) {
}()
}

type mockLogger struct {
logMsgs []string
m sync.Mutex
}

func (l *mockLogger) Log(v ...interface{}) {
l.m.Lock()
defer l.m.Unlock()

for _, value := range v {
switch i := value.(type) {
case string:
l.logMsgs = append(l.logMsgs, i)
default:
panic("mock logger only takes single log string")
}
}
}

func (l *mockLogger) Logf(format string, v ...interface{}) {
l.m.Lock()
defer l.m.Unlock()

l.logMsgs = append(l.logMsgs, fmt.Sprintf(format, v...))
}

func (l *mockLogger) Messages() []string {
l.m.Lock()
defer l.m.Unlock()

result := make([]string, len(l.logMsgs))
copy(result, l.logMsgs)
return result
}

// ================================================== tests ==================================================

// ========================= test timeout =========================
Expand Down Expand Up @@ -533,3 +568,71 @@ func TestLoadManyBlocked(t *testing.T) {
"Expected result from thunkMany()",
)
}

// =========================================== cancellable context ===========================================

// TestCancellableContextLoad ensures that a call to cancel the context kills the background worker
func TestCancellableContextLoad(t *testing.T) {
// setup
closeChan := make(chan struct{})
timeout(t, closeChan, TEST_TIMEOUT*3)

callCount := 0
expectedResult := "cancel_via_context"
cb := func(keys dataloader.Keys) {
callCount += 1
close(closeChan)
}

key := PrimaryKey(1)
log := mockLogger{logMsgs: make([]string, 2), m: sync.Mutex{}}
batch := getBatchFunction(cb, expectedResult)
strategy := sozu.NewSozuStrategy(
sozu.WithLogger(&log),
)(2, batch) // expected 2 load calls
ctx, cancel := context.WithCancel(context.Background())

// invoke
go cancel()
thunk := strategy.Load(ctx, key)
thunk()
time.Sleep(100 * time.Millisecond)

// assert
assert.Equal(t, 0, callCount, "Batch should not have been called")
m := log.Messages()
assert.Equal(t, "worker cancelled", m[len(m)-1], "Expected worker to cancel and log exit")
}

// TestCancellableContextLoadMany ensures that a call to cancel the context kills the background worker
func TestCancellableContextLoadMany(t *testing.T) {
// setup
closeChan := make(chan struct{})
timeout(t, closeChan, TEST_TIMEOUT*3)

callCount := 0
expectedResult := "cancel_via_context"
cb := func(keys dataloader.Keys) {
callCount += 1
close(closeChan)
}

key := PrimaryKey(1)
log := mockLogger{logMsgs: make([]string, 2), m: sync.Mutex{}}
batch := getBatchFunction(cb, expectedResult)
strategy := sozu.NewSozuStrategy(
sozu.WithLogger(&log),
)(2, batch) // expected 2 load calls
ctx, cancel := context.WithCancel(context.Background())

// invoke
go cancel()
thunk := strategy.LoadMany(ctx, key)
thunk()
time.Sleep(100 * time.Millisecond)

// assert
assert.Equal(t, 0, callCount, "Batch should not have been called")
m := log.Messages()
assert.Equal(t, "worker cancelled", m[len(m)-1], "Expected worker to cancel and log exit")
}
19 changes: 13 additions & 6 deletions strategies/standard/standard.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,11 @@ func NewStandardStrategy(opts ...Option) dataloader.StrategyFunction {
counter: strategies.NewCounter(capacity),

workerMutex: &sync.Mutex{},
keyChan: make(chan workerMessage, capacity),
closeChan: make(chan struct{}),
goroutineStatus: notRunning,
options: o,

keyChan: make(chan workerMessage, capacity),
closeChan: make(chan struct{}),
options: o,

keys: dataloader.NewKeys(capacity),
}
Expand Down Expand Up @@ -84,13 +85,12 @@ type standardStrategy struct {
keys dataloader.Keys
batchFunc dataloader.BatchFunction

workerMutex *sync.Mutex
workerMutex *sync.Mutex
goroutineStatus int

keyChan chan workerMessage
closeChan chan struct{}

goroutineStatus int

options options
}

Expand Down Expand Up @@ -132,6 +132,8 @@ func (s *standardStrategy) Load(ctx context.Context, key dataloader.Key) dataloa
}

select {
case <-ctx.Done():
return dataloader.Result{Result: nil, Err: nil}
case r := <-resultChan:
result = r.GetValue(key)
return result
Expand Down Expand Up @@ -170,6 +172,8 @@ func (s *standardStrategy) LoadMany(ctx context.Context, keyArr ...dataloader.Ke
}

select {
case <-ctx.Done():
return dataloader.NewResultMap(0)
case r := <-resultChan:
resultMap = buildResultMap(keyArr, r)
return resultMap
Expand Down Expand Up @@ -221,6 +225,9 @@ func (s *standardStrategy) startWorker(ctx context.Context) {
var r *dataloader.ResultMap
for r == nil {
select {
case <-ctx.Done():
s.options.logger.Logf("worker cancelled")
return
case key := <-s.keyChan:
// if LoadNoOp passes a value through the chan, ignore the data and increment the counter
if key.resultChan != nil {
Expand Down
Loading

0 comments on commit 3d627b7

Please sign in to comment.