diff --git a/.gitignore b/.gitignore index 3d48313..c14509c 100644 --- a/.gitignore +++ b/.gitignore @@ -20,3 +20,4 @@ vendor/ +todo.txt diff --git a/dataloader.go b/dataloader.go index 0869084..2b6f97c 100644 --- a/dataloader.go +++ b/dataloader.go @@ -32,7 +32,7 @@ type BatchFunction func(context.Context, Keys) *ResultMap // Thunk returns a result for the key that it was generated for. // Calling the Thunk function will block until the result is returned from the batch function. -type Thunk func() Result +type Thunk func() (Result, bool) // ThunkMany returns a result map for the keys that it was generated for. // Calling ThunkMany will block until the result is returned from the batch function. @@ -128,21 +128,21 @@ func (d *dataloader) Load(ogCtx context.Context, key Key) Thunk { if r, ok := d.cache.GetResult(ctx, key); ok { d.logger.Logf("cache hit for: %d", key) d.strategy.LoadNoOp(ctx) - return func() Result { + return func() (Result, bool) { finish(r) - return r + return r, ok } } thunk := d.strategy.Load(ctx, key) - return func() Result { - result := thunk() + return func() (Result, bool) { + result, ok := thunk() d.cache.SetResult(ctx, key, result) finish(result) - return result + return result, ok } } @@ -173,16 +173,3 @@ func (d *dataloader) LoadMany(ogCtx context.Context, keyArr ...Key) ThunkMany { return result } } - -/* - Should we be handling context cancellation?? - - - is the current implementation of context canceler correct? That is, will the go routines be canceled - appropriately (we don't want them to block and leak) - specifically around the once strategy. The current implementation will execute the batch function no matter - what. It should therefore also be up to the user to handle canceling in the batch function in case it is - getting called or gets called. - - The other worker go routines will cancel and stop waiting for new keys. -*/ diff --git a/dataloader_test.go b/dataloader_test.go index f6ceef1..6ced464 100644 --- a/dataloader_test.go +++ b/dataloader_test.go @@ -101,7 +101,7 @@ func newMockStrategy() func(int, dataloader.BatchFunction) dataloader.Strategy { } func (s *mockStrategy) Load(ctx context.Context, key dataloader.Key) dataloader.Thunk { - return func() dataloader.Result { + return func() (dataloader.Result, bool) { keys := dataloader.NewKeys(1) keys.Append(key) r := s.batchFunc(ctx, keys) @@ -147,7 +147,8 @@ func TestLoadCacheHit(t *testing.T) { // invoke / assert thunk := loader.Load(context.Background(), key) - r := thunk() + r, ok := thunk() + assert.True(t, ok, "Expected result to have been found") assert.Equal(t, expectedResult.Result.(string), r.Result.(string), "Expected result from thunk") assert.Equal(t, 0, callCount, "Expected batch function to not be called") } @@ -174,12 +175,14 @@ func TestLoadManyCacheHit(t *testing.T) { thunk := loader.LoadMany(context.Background(), key, key2) r := thunk() + returned, ok := r.GetValue(key) + assert.True(t, ok, "Expected result to have been found") assert.Equal(t, expectedResult.Result.(string), - r.(dataloader.ResultMap).GetValue(key).Result.(string), + returned.Result.(string), "Expected result from thunk", ) - assert.Equal(t, 2, r.(dataloader.ResultMap).Length(), "Expected 2 result from cache") + assert.Equal(t, 2, r.Length(), "Expected 2 result from cache") assert.Equal(t, 0, callCount, "Expected batch function to not be called") } @@ -201,7 +204,8 @@ func TestLoadCacheMiss(t *testing.T) { // invoke / assert thunk := loader.Load(context.Background(), key) - r := thunk() + r, ok := thunk() + assert.True(t, ok, "Expected result to have been found") assert.Equal(t, result.Result.(string), r.Result.(string), "Expected result from thunk") assert.Equal(t, 1, callCount, "Expected batch function to be called") } @@ -223,9 +227,11 @@ func TestLoadManyCacheMiss(t *testing.T) { thunk := loader.LoadMany(context.Background(), key) r := thunk() + returned, ok := r.GetValue(key) + assert.True(t, ok, "Expected result to have been found") assert.Equal(t, result.Result.(string), - r.(dataloader.ResultMap).GetValue(key).Result.(string), + returned.Result.(string), "Expected result from thunk", ) assert.Equal(t, 1, callCount, "Expected batch function to be called") diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..3b1fd01 --- /dev/null +++ b/go.mod @@ -0,0 +1,10 @@ +module github.com/andy9775/dataloader + +require ( + github.com/bouk/monkey v1.0.0 + github.com/davecgh/go-spew v1.1.0 + github.com/opentracing/opentracing-go v1.0.2 + github.com/pmezard/go-difflib v1.0.0 + github.com/stretchr/testify v1.2.2 + golang.org/x/net v0.0.0-20180808004115-f9ce57c11b24 +) diff --git a/result.go b/result.go index a4507f9..252bd85 100644 --- a/result.go +++ b/result.go @@ -9,7 +9,7 @@ type Result struct { // ResultMap maps each loaded elements Result against the elements unique identifier (Key) type ResultMap interface { Set(string, Result) - GetValue(Key) Result + GetValue(Key) (Result, bool) Length() int // Keys returns a slice of all unique identifiers used in the containing map (keys) Keys() []string @@ -35,15 +35,15 @@ func (r *resultMap) Set(identifier string, value Result) { r.r[identifier] = value } -// GetValue returns the value from the results for the provided key. -// If no value exists, returns nil -func (r *resultMap) GetValue(key Key) Result { +// GetValue returns the value from the results for the provided key and true +// if the value was found, otherwise false. +func (r *resultMap) GetValue(key Key) (Result, bool) { if key == nil { - return Result{} + return Result{}, false } - // No need to check ok, missing value from map[Any]interface{} is nil by default. - return r.r[key.String()] + result, ok := r.r[key.String()] + return result, ok } func (r *resultMap) GetValueForString(key string) Result { diff --git a/result_test.go b/result_test.go new file mode 100644 index 0000000..ccd92c1 --- /dev/null +++ b/result_test.go @@ -0,0 +1,36 @@ +package dataloader_test + +import ( + "testing" + + "github.com/andy9775/dataloader" + "github.com/stretchr/testify/assert" +) + +// ================================================== tests ================================================== +// TestEnsureOKForResult tests getting the result value with a valid key expecting a valid value +func TestEnsureOKForResult(t *testing.T) { + // setup + rmap := dataloader.NewResultMap(2) + key := PrimaryKey(1) + value := dataloader.Result{Result: 1, Err: nil} + rmap.Set(key.String(), value) + + // invoke/assert + result, ok := rmap.GetValue(key) + assert.True(t, ok, "Expected valid result to have been found") + assert.Equal(t, result, value, "Expected result") +} +func TestEnsureNotOKForResult(t *testing.T) { + // setup + rmap := dataloader.NewResultMap(2) + key := PrimaryKey(1) + key2 := PrimaryKey(2) + value := dataloader.Result{Result: 1, Err: nil} + rmap.Set(key.String(), value) + + // invoke/assert + result, ok := rmap.GetValue(key2) + assert.False(t, ok, "Expected valid result to have been found") + assert.Nil(t, result.Result, "Expected nil result") +} diff --git a/strategies/once/once.go b/strategies/once/once.go index c54dbae..cc317c1 100644 --- a/strategies/once/once.go +++ b/strategies/once/once.go @@ -74,35 +74,41 @@ func WithLogger(l log.Logger) Option { // background go routine (blocking if no data is available). Note that if the strategy is configured to // run in the background, calling Load again will spin up another background go routine. func (s *onceStrategy) Load(ctx context.Context, key dataloader.Key) dataloader.Thunk { - var result dataloader.Result + + type data struct { + r dataloader.Result + ok bool + } + var result data if s.options.inBackground { - resultChan := make(chan dataloader.Result) + resultChan := make(chan data) // don't check if result is nil before starting in case a new key is passed in go func() { - resultChan <- (*s.batchFunc(ctx, dataloader.NewKeysWith(key))).GetValue(key) + r, ok := (*s.batchFunc(ctx, dataloader.NewKeysWith(key))).GetValue(key) + resultChan <- data{r, ok} }() // call batch in background and block util it returns - return func() dataloader.Result { - if result.Result != nil || result.Err != nil { - return result + return func() (dataloader.Result, bool) { + if result.r.Result != nil || result.r.Err != nil { + return result.r, result.ok } result = <-resultChan - return result + return result.r, result.ok } } // call batch when thunk is called - return func() dataloader.Result { - if result.Result != nil || result.Err != nil { - return result + return func() (dataloader.Result, bool) { + if result.ok { + return result.r, result.ok } - result = (*s.batchFunc(ctx, dataloader.NewKeysWith(key))).GetValue(key) - return result + result.r, result.ok = (*s.batchFunc(ctx, dataloader.NewKeysWith(key))).GetValue(key) + return result.r, result.ok } } diff --git a/strategies/once/once_test.go b/strategies/once/once_test.go index c72dd30..230a354 100644 --- a/strategies/once/once_test.go +++ b/strategies/once/once_test.go @@ -113,11 +113,13 @@ func TestBatchLoadInForegroundCalled(t *testing.T) { thunk := strategy.Load(context.Background(), key) assert.Equal(t, 0, callCount, "Batch function not expected to be called on Load()") - r := thunk() + r, ok := thunk() + assert.True(t, ok, "Expected result to have been found") assert.Equal(t, 1, callCount, "Batch function expected to be called on thunk()") assert.Equal(t, expectedResult, r.Result.(string), "Expected result from batch function") - r = thunk() + r, ok = thunk() + assert.True(t, ok, "Expected result to have been found") assert.Equal(t, 1, callCount, "Batch function expected to be called on thunk()") assert.Equal(t, expectedResult, r.Result.(string), "Expected result from batch function") } @@ -143,12 +145,16 @@ func TestBatchLoadManyInForegroundCalled(t *testing.T) { assert.Equal(t, 0, callCount, "Batch function not expected to be called on LoadMany()") r := thunkMany() + returned, ok := r.GetValue(key) + assert.True(t, ok, "Expected result to have been found") assert.Equal(t, 1, callCount, "Batch function expected to be called on thunkMany()") - assert.Equal(t, expectedResult, r.GetValue(key).Result.(string), "Expected result from batch function") + assert.Equal(t, expectedResult, returned.Result.(string), "Expected result from batch function") r = thunkMany() + returned, ok = r.GetValue(key) + assert.True(t, ok, "Expected result to have been found") assert.Equal(t, 1, callCount, "Batch function expected to be called on thunkMany()") - assert.Equal(t, expectedResult, r.GetValue(key).Result.(string), "Expected result from batch function") + assert.Equal(t, expectedResult, returned.Result.(string), "Expected result from batch function") } // ========================= background calls ========================= @@ -189,9 +195,12 @@ func TestBatchLoadInBackgroundCalled(t *testing.T) { assert.Equal(t, 1, callCount, "Batch function expected to be called on Load() in background") - r := thunk() + r, ok := thunk() + assert.True(t, ok, "Expected result to have been found") assert.Equal(t, expectedResult, r.Result.(string), "Expected value from batch function") - r = thunk() + + r, ok = thunk() + assert.True(t, ok, "Expected result to have been found") assert.Equal(t, expectedResult, r.Result.(string), "Expected value from batch function") assert.Equal(t, 1, callCount, "Batch function expected to be called on Load() in background") } @@ -228,9 +237,14 @@ func TestBatchLoadManyInBackgroundCalled(t *testing.T) { assert.Equal(t, 1, callCount, "Batch function expected to be called on LoadMany()") r := thunkMany() - assert.Equal(t, expectedResult, r.GetValue(key).Result.(string), "Expected result from batch function") + returned, ok := r.GetValue(key) + assert.True(t, ok, "Expected result to have been found") + assert.Equal(t, expectedResult, returned.Result.(string), "Expected result from batch function") + r = thunkMany() - assert.Equal(t, expectedResult, r.GetValue(key).Result.(string), "Expected result from batch function") + returned, ok = r.GetValue(key) + assert.True(t, ok, "Expected result to have been found") + assert.Equal(t, expectedResult, returned.Result.(string), "Expected result from batch function") assert.Equal(t, 1, callCount, "Batch function expected to be called on LoadMany()") } @@ -270,3 +284,79 @@ func TestCancellableContextLoadMany(t *testing.T) { m := log.Messages() assert.Equal(t, "worker cancelled", m[len(m)-1], "Expected worker to cancel and log exit") } + +// =============================================== result keys =============================================== +// TestKeyHandling ensures that processed and unprocessed keys by the batch function are handled correctly +// This test accomplishes this by skipping processing a single key and then asserts that they skipped key +// returns the correct ok value when loading the data +func TestKeyHandling(t *testing.T) { + // setup + expectedResult := map[PrimaryKey]interface{}{ + PrimaryKey(1): "valid_result", + PrimaryKey(2): nil, + PrimaryKey(3): "__skip__", // this key should be skipped by the batch function + } + + batch := func(ctx context.Context, keys dataloader.Keys) *dataloader.ResultMap { + m := dataloader.NewResultMap(2) + for i := 0; i < keys.Length(); i++ { + key := keys.Keys()[i].(PrimaryKey) + if expectedResult[key] != "__skip__" { + m.Set(key.String(), dataloader.Result{Result: expectedResult[key], Err: nil}) + } + } + return &m + } + + // invoke/assert + + // iterate through multiple strategies table test style + strategies := []dataloader.Strategy{ + once.NewOnceStrategy()(3, batch), + once.NewOnceStrategy(once.WithInBackground())(3, batch), + } + for _, strategy := range strategies { + + // Load + for key, expected := range expectedResult { + thunk := strategy.Load(context.Background(), key) + r, ok := thunk() + + switch expected.(type) { + case string: + if expected == "__skip__" { + assert.False(t, ok, "Expected skipped result to not be found") + assert.Nil(t, r.Result, "Expected skipped result to be nil") + } else { + assert.True(t, ok, "Expected processed result to be found") + assert.Equal(t, r.Result, expected, "Expected result") + } + case nil: + assert.True(t, ok, "Expected processed result to be found") + assert.Nil(t, r.Result, "Expected result to be nil") + } + } + + // LoadMany + thunkMany := strategy.LoadMany(context.Background(), PrimaryKey(1), PrimaryKey(2), PrimaryKey(3)) + for key, expected := range expectedResult { + result := thunkMany() + r, ok := result.GetValue(key) + + switch expected.(type) { + case string: + if expected == "__skip__" { + assert.False(t, ok, "Expected skipped result to not be found") + assert.Nil(t, r.Result, "Expected skipped result to be nil") + } else { + assert.True(t, ok, "Expected processed result to be found") + assert.Equal(t, r.Result, expected, "Expected result") + } + case nil: + assert.True(t, ok, "Expected processed result to be found") + assert.Nil(t, r.Result, "Expected result to be nil") + } + + } + } +} diff --git a/strategies/sozu/sozu.go b/strategies/sozu/sozu.go index d205018..361d918 100644 --- a/strategies/sozu/sozu.go +++ b/strategies/sozu/sozu.go @@ -125,6 +125,7 @@ func (s *sozuStrategy) Load(ctx context.Context, key dataloader.Key) dataloader. s.keyChan <- message // pass key to the worker go routine var result dataloader.Result + var ok bool /* TODO: clean up If a worker go routine is in the process of calling the batch function and another @@ -140,9 +141,9 @@ func (s *sozuStrategy) Load(ctx context.Context, key dataloader.Key) dataloader. This solution isn't clean, or totally efficient but ensures that a worker will pick up the key and process it. */ - return func() dataloader.Result { + return func() (dataloader.Result, bool) { if result.Result != nil || result.Err != nil { - return result + return result, ok } for { @@ -159,17 +160,17 @@ func (s *sozuStrategy) Load(ctx context.Context, key dataloader.Key) dataloader. */ select { case r := <-resultChan: - result = r.GetValue(key) - return result + result, ok = r.GetValue(key) + return result, ok default: } select { case <-ctx.Done(): - return dataloader.Result{Result: nil, Err: nil} + return dataloader.Result{Result: nil, Err: nil}, false case r := <-resultChan: - result = r.GetValue(key) - return result + result, ok = r.GetValue(key) + return result, ok case <-s.closeChan: /* Current worker closed, therefore no readers reading off of the key chan to get @@ -211,7 +212,9 @@ func (s *sozuStrategy) LoadMany(ctx context.Context, keyArr ...dataloader.Key) d result := dataloader.NewResultMap(len(keyArr)) for _, k := range keyArr { - result.Set(k.String(), r.GetValue(k)) + if val, ok := r.GetValue(k); ok { + result.Set(k.String(), val) + } } resultMap = result @@ -226,7 +229,9 @@ func (s *sozuStrategy) LoadMany(ctx context.Context, keyArr ...dataloader.Key) d result := dataloader.NewResultMap(len(keyArr)) for _, k := range keyArr { - result.Set(k.String(), r.GetValue(k)) + if val, ok := r.GetValue(k); ok { + result.Set(k.String(), val) + } } resultMap = result diff --git a/strategies/sozu/sozu_test.go b/strategies/sozu/sozu_test.go index 581a6cc..183973c 100644 --- a/strategies/sozu/sozu_test.go +++ b/strategies/sozu/sozu_test.go @@ -167,7 +167,8 @@ func TestLoadTimeoutTriggered(t *testing.T) { toWG.Wait() assert.True(t, timedOut, "Expected function to timeout") - r := thunk() + r, ok := thunk() + assert.True(t, ok, "Expected result to have been found") assert.Equal( t, fmt.Sprintf("2_%s", expectedResult), @@ -176,7 +177,8 @@ func TestLoadTimeoutTriggered(t *testing.T) { ) // test double call to thunk - r = thunk() + r, ok = thunk() + assert.True(t, ok, "Expected result to have been found") assert.Equal( t, fmt.Sprintf("2_%s", expectedResult), @@ -251,9 +253,11 @@ func TestLoadManyTimeoutTriggered(t *testing.T) { assert.True(t, timedOut, "Expected function to have timed out") r1 := thunkMany2() + returned1, ok := r1.GetValue(key2) + assert.True(t, ok, "Expected result to have been found") assert.Equal(t, fmt.Sprintf("2_%s", expectedResult), - r1.(dataloader.ResultMap).GetValue(key2).Result.(string), + returned1.Result.(string), "Expected batch function to return on thunkMany()", ) @@ -261,7 +265,7 @@ func TestLoadManyTimeoutTriggered(t *testing.T) { assert.Equal( t, 3, - r1.(dataloader.ResultMap).Length()+r2.(dataloader.ResultMap).Length(), + r1.Length()+r2.Length(), "Expected 3 total results from both thunkMany function", ) @@ -270,7 +274,7 @@ func TestLoadManyTimeoutTriggered(t *testing.T) { assert.Equal( t, 3, - r1.(dataloader.ResultMap).Length()+r2.(dataloader.ResultMap).Length(), + r1.Length()+r2.Length(), "Expected 3 total results from both thunkMany function", ) assert.Equal(t, 1, callCount, "Batch function expected to be called once") @@ -336,7 +340,8 @@ func TestLoadTriggered(t *testing.T) { assert.Equal(t, 1, callCount, "Batch function expected to be called once") // capacity is 2, called with 1 key and 1 cache hit assert.Equal(t, 1, len(k), "Expected to be called with 1 keys") - r1 := thunk() + r1, ok := thunk() + assert.True(t, ok, "Expected result to have been found") assert.Equal( t, fmt.Sprintf("1_%s", expectedResult), @@ -347,7 +352,8 @@ func TestLoadTriggered(t *testing.T) { assert.False(t, timedOut, "Expected function to not timeout") // test double call to thunk - r1 = thunk() + r1, ok = thunk() + assert.True(t, ok, "Expected result to have been found") assert.Equal( t, fmt.Sprintf("1_%s", expectedResult), @@ -417,10 +423,12 @@ func TestLoadManyTriggered(t *testing.T) { // capacity is 2, called with 2 keys and one cache hit assert.Equal(t, 2, len(k), "Expected to be called with 2 keys") r1 := thunk() + returned1, ok := r1.GetValue(key) + assert.True(t, ok, "Expected result to have been found") assert.Equal( t, fmt.Sprintf("1_%s", expectedResult), - r1.(dataloader.ResultMap).GetValue(key).Result, + returned1.Result, "Expected batch function to return on thunk()", ) @@ -428,10 +436,12 @@ func TestLoadManyTriggered(t *testing.T) { // test double call to thunk r1 = thunk() // don't block on second call + returned1, ok = r1.GetValue(key) + assert.True(t, ok, "Expected result to have been found") assert.Equal( t, fmt.Sprintf("1_%s", expectedResult), - r1.(dataloader.ResultMap).GetValue(key).Result, + returned1.Result, "Expected batch function to return on thunk()", ) assert.Equal(t, 1, callCount, "Batch function expected to be called once ") @@ -481,16 +491,18 @@ func TestLoadBlocked(t *testing.T) { thunk := strategy.Load(context.Background(), key) // --------- Load - call 1 strategy.LoadNoOp(context.Background()) // --------- LoadNoOp - call 2 - r := thunk() // block until batch function executes + r, ok := thunk() // block until batch function executes + assert.True(t, ok, "Expected result to have been found") assert.Equal(t, 1, callCount, "Batch function should have been called once") assert.False(t, timedOut, "Batch function should not have timed out") assert.Equal(t, 1, len(k), "Should have been called with one key") assert.Equal(t, fmt.Sprintf("1_%s", expectedResult), r.Result.(string), "Expected result from thunk()") // test double call to thunk - r = thunk() // don't block on second call + r, ok = thunk() // don't block on second call + assert.True(t, ok, "Expected result to have been found") assert.Equal(t, 1, callCount, "Batch function should have been called once") assert.False(t, timedOut, "Batch function should not have timed out") assert.Equal(t, 1, len(k), "Should have been called with one key") @@ -548,10 +560,12 @@ func TestLoadManyBlocked(t *testing.T) { assert.Equal(t, 1, callCount, "Batch function should have been called once") assert.False(t, timedOut, "Batch function should not have timed out") assert.Equal(t, 2, len(k), "Should have been called with two keys") + returned, ok := r.GetValue(key2) + assert.True(t, ok, "Expected result to have been found") assert.Equal( t, fmt.Sprintf("2_%s", expectedResult), - r.(dataloader.ResultMap).GetValue(key2).Result.(string), + returned.Result.(string), "Expected result from thunkMany()", ) @@ -561,10 +575,12 @@ func TestLoadManyBlocked(t *testing.T) { assert.Equal(t, 1, callCount, "Batch function should have been called once") assert.False(t, timedOut, "Batch function should not have timed out") assert.Equal(t, 2, len(k), "Should have been called with two keys") + returned, ok = r.GetValue(key2) + assert.True(t, ok, "Expected result to have been found") assert.Equal( t, fmt.Sprintf("2_%s", expectedResult), - r.(dataloader.ResultMap).GetValue(key2).Result.(string), + returned.Result.(string), "Expected result from thunkMany()", ) } @@ -636,3 +652,70 @@ func TestCancellableContextLoadMany(t *testing.T) { m := log.Messages() assert.Equal(t, "worker cancelled", m[len(m)-1], "Expected worker to cancel and log exit") } + +// =============================================== result keys =============================================== +// TestKeyHandling ensure that the strategy properly handles unprocessed and nil keys +func TestKeyHandling(t *testing.T) { + // setup + expectedResult := map[PrimaryKey]interface{}{ + PrimaryKey(1): "valid_result", + PrimaryKey(2): nil, + PrimaryKey(3): "__skip__", // this key should be skipped by the batch function + } + + batch := func(ctx context.Context, keys dataloader.Keys) *dataloader.ResultMap { + m := dataloader.NewResultMap(2) + for i := 0; i < keys.Length(); i++ { + key := keys.Keys()[i].(PrimaryKey) + if expectedResult[key] != "__skip__" { + m.Set(key.String(), dataloader.Result{Result: expectedResult[key], Err: nil}) + } + } + return &m + } + + // invoke/assert + strategy := sozu.NewSozuStrategy()(3, batch) + + // Load + for key, expected := range expectedResult { + thunk := strategy.Load(context.Background(), key) + r, ok := thunk() + + switch expected.(type) { + case string: + if expected == "__skip__" { + assert.False(t, ok, "Expected skipped result to not be found") + assert.Nil(t, r.Result, "Expected skipped result to be nil") + } else { + assert.True(t, ok, "Expected processed result to be found") + assert.Equal(t, r.Result, expected, "Expected result") + } + case nil: + assert.True(t, ok, "Expected processed result to be found") + assert.Nil(t, r.Result, "Expected result to be nil") + } + } + + // LoadMany + thunkMany := strategy.LoadMany(context.Background(), PrimaryKey(1), PrimaryKey(2), PrimaryKey(3)) + for key, expected := range expectedResult { + result := thunkMany() + r, ok := result.GetValue(key) + + switch expected.(type) { + case string: + if expected == "__skip__" { + assert.False(t, ok, "Expected skipped result to not be found") + assert.Nil(t, r.Result, "Expected skipped result to be nil") + } else { + assert.True(t, ok, "Expected processed result to be found") + assert.Equal(t, r.Result, expected, "Expected result") + } + case nil: + assert.True(t, ok, "Expected processed result to be found") + assert.Nil(t, r.Result, "Expected result to be nil") + } + + } +} diff --git a/strategies/standard/standard.go b/strategies/standard/standard.go index 1856d26..a13673f 100644 --- a/strategies/standard/standard.go +++ b/strategies/standard/standard.go @@ -110,10 +110,11 @@ func (s *standardStrategy) Load(ctx context.Context, key dataloader.Key) dataloa s.keyChan <- message // pass key to the worker go routine (buffered channel) var result dataloader.Result + var ok bool - return func() dataloader.Result { + return func() (dataloader.Result, bool) { if result.Result != nil || result.Err != nil { - return result + return result, ok } /* @@ -126,20 +127,20 @@ func (s *standardStrategy) Load(ctx context.Context, key dataloader.Key) dataloa */ select { case r := <-resultChan: - result = r.GetValue(key) - return result + result, ok = r.GetValue(key) + return result, ok default: } select { case <-ctx.Done(): - return dataloader.Result{Result: nil, Err: nil} + return dataloader.Result{Result: nil, Err: nil}, false case r := <-resultChan: - result = r.GetValue(key) - return result + result, ok = r.GetValue(key) + return result, ok case <-s.closeChan: - result = (*s.batchFunc(ctx, dataloader.NewKeysWith(key))).GetValue(key) - return result + result, ok = (*s.batchFunc(ctx, dataloader.NewKeysWith(key))).GetValue(key) + return result, ok } } } @@ -268,7 +269,9 @@ func buildResultMap(keyArr []dataloader.Key, r dataloader.ResultMap) dataloader. results := dataloader.NewResultMap(len(keyArr)) for _, k := range keyArr { - results.Set(k.String(), r.GetValue(k)) + if val, ok := r.GetValue(k); ok { + results.Set(k.String(), val) + } } return results diff --git a/strategies/standard/standard_test.go b/strategies/standard/standard_test.go index 5a08a59..10b6971 100644 --- a/strategies/standard/standard_test.go +++ b/strategies/standard/standard_test.go @@ -163,7 +163,8 @@ func TestLoadNoTimeout(t *testing.T) { assert.Equal(t, 1, callCount, "Batch function expected to be called once") assert.Equal(t, 2, len(k), "Expected to be called with 2 keys") - r := thunk() + r, ok := thunk() + assert.True(t, ok, "Expected result to have been found") assert.Equal( t, fmt.Sprintf("2_%s", expectedResult), @@ -174,7 +175,8 @@ func TestLoadNoTimeout(t *testing.T) { assert.False(t, timedOut, "Expected loader not to timeout") // test double call to thunk - r = thunk() + r, ok = thunk() + assert.True(t, ok, "Expected result to have been found") assert.Equal( t, fmt.Sprintf("2_%s", expectedResult), @@ -244,10 +246,12 @@ func TestLoadManyNoTimeout(t *testing.T) { assert.Equal(t, 3, len(k), "Expected to be called with 2 keys") r := thunk() + returned, ok := r.GetValue(key2) + assert.True(t, ok, "Expected result to have been found") assert.Equal( t, fmt.Sprintf("2_%s", expectedResult), - r.(dataloader.ResultMap).GetValue(key2).Result.(string), + returned.Result.(string), "Expected result from thunk()", ) @@ -255,10 +259,12 @@ func TestLoadManyNoTimeout(t *testing.T) { // test double call to thunk r = thunk() + returned, ok = r.GetValue(key2) + assert.True(t, ok, "Expected result to have been found") assert.Equal( t, fmt.Sprintf("2_%s", expectedResult), - r.(dataloader.ResultMap).GetValue(key2).Result.(string), + returned.Result.(string), "Expected result from thunk()", ) assert.Equal(t, 1, callCount, "Batch function expected to be called once") @@ -328,7 +334,8 @@ func TestLoadTimeout(t *testing.T) { assert.Equal(t, 1, len(k), "Expected to be called with 1 key") assert.True(t, timedOut, "Expected loader to timeout") - r := thunk2() + r, ok := thunk2() + assert.True(t, ok, "Expected result to have been found") assert.Equal( t, fmt.Sprintf("2_%s", expectedResult), @@ -339,7 +346,8 @@ func TestLoadTimeout(t *testing.T) { // don't wait below - callback is not executing in background go routine - ensure wg doesn't go negative wg.Add(1) thunk := strategy.Load(context.Background(), key) // --------- Load - call 3 - r = thunk() + r, ok = thunk() + assert.True(t, ok, "Expected result to have been found") // called once in go routine after timeout, once in thunk assert.Equal(t, 2, callCount, "Batch function expected to be called twice") @@ -350,7 +358,8 @@ func TestLoadTimeout(t *testing.T) { ) // test double call to thunk - r = thunk() + r, ok = thunk() + assert.True(t, ok, "Expected result to have been found") // called once in go routine after timeout, once in thunk assert.Equal(t, 2, callCount, "Batch function expected to be called twice") @@ -425,10 +434,12 @@ func TestLoadManyTimeout(t *testing.T) { assert.True(t, timedOut, "Expected loader to timeout") r := thunkMany2() + returned, ok := r.GetValue(key2) + assert.True(t, ok, "Expected result to have been found") assert.Equal( t, fmt.Sprintf("2_%s", expectedResult), - r.(dataloader.ResultMap).GetValue(key2).Result.(string), + returned.Result.(string), "Expected result from thunkMany()", ) @@ -439,9 +450,11 @@ func TestLoadManyTimeout(t *testing.T) { // called once in go routine after timeout, once in thunkMany assert.Equal(t, 2, callCount, "Batch function expected to be called twice") + returned, ok = r.GetValue(key3) + assert.True(t, ok, "Expected result to have been found") assert.Equal(t, fmt.Sprintf("3_%s", expectedResult), - r.(dataloader.ResultMap).GetValue(key3).Result.(string), + returned.Result.(string), "Expected result from thunkMany", ) assert.Equal(t, 2, len(k), "Expected to be called with 2 keys") // second function call @@ -451,9 +464,11 @@ func TestLoadManyTimeout(t *testing.T) { // called once in go routine after timeout, once in thunkMany assert.Equal(t, 2, callCount, "Batch function expected to be called twice") + returned, ok = r.GetValue(key3) + assert.True(t, ok, "Expected result to have been found") assert.Equal(t, fmt.Sprintf("3_%s", expectedResult), - r.(dataloader.ResultMap).GetValue(key3).Result.(string), + returned.Result.(string), "Expected result from thunkMany", ) } @@ -525,3 +540,70 @@ func TestCancellableContextLoadMany(t *testing.T) { m := log.Messages() assert.Equal(t, "worker cancelled", m[len(m)-1], "Expected worker to cancel and log exit") } + +// =============================================== result keys =============================================== +// TestKeyHandling ensure that the strategy properly handles unprocessed and nil keys +func TestKeyHandling(t *testing.T) { + // setup + expectedResult := map[PrimaryKey]interface{}{ + PrimaryKey(1): "valid_result", + PrimaryKey(2): nil, + PrimaryKey(3): "__skip__", // this key should be skipped by the batch function + } + + batch := func(ctx context.Context, keys dataloader.Keys) *dataloader.ResultMap { + m := dataloader.NewResultMap(2) + for i := 0; i < keys.Length(); i++ { + key := keys.Keys()[i].(PrimaryKey) + if expectedResult[key] != "__skip__" { + m.Set(key.String(), dataloader.Result{Result: expectedResult[key], Err: nil}) + } + } + return &m + } + + // invoke/assert + strategy := standard.NewStandardStrategy()(3, batch) + + // Load + for key, expected := range expectedResult { + thunk := strategy.Load(context.Background(), key) + r, ok := thunk() + + switch expected.(type) { + case string: + if expected == "__skip__" { + assert.False(t, ok, "Expected skipped result to not be found") + assert.Nil(t, r.Result, "Expected skipped result to be nil") + } else { + assert.True(t, ok, "Expected processed result to be found") + assert.Equal(t, r.Result, expected, "Expected result") + } + case nil: + assert.True(t, ok, "Expected processed result to be found") + assert.Nil(t, r.Result, "Expected result to be nil") + } + } + + // LoadMany + thunkMany := strategy.LoadMany(context.Background(), PrimaryKey(1), PrimaryKey(2), PrimaryKey(3)) + for key, expected := range expectedResult { + result := thunkMany() + r, ok := result.GetValue(key) + + switch expected.(type) { + case string: + if expected == "__skip__" { + assert.False(t, ok, "Expected skipped result to not be found") + assert.Nil(t, r.Result, "Expected skipped result to be nil") + } else { + assert.True(t, ok, "Expected processed result to be found") + assert.Equal(t, r.Result, expected, "Expected result") + } + case nil: + assert.True(t, ok, "Expected processed result to be found") + assert.Nil(t, r.Result, "Expected result to be nil") + } + + } +}