Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support cache skipping for Load() calls that throw SkipCacheError #111

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion dataloader.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,24 @@ func (p *PanicErrorWrapper) Error() string {
return p.panicError.Error()
}

// SkipCacheError wraps the error interface.
// The cache should not store SkipCacheErrors.
type SkipCacheError struct {
err error
}

func (s *SkipCacheError) Error() string {
return s.err.Error()
}

func (s *SkipCacheError) Unwrap() error {
return s.err
}

func NewSkipCacheError(err error) *SkipCacheError {
return &SkipCacheError{err: err}
}

// Loader implements the dataloader.Interface.
type Loader[K comparable, V any] struct {
// the batch function to be used by this loader
Expand Down Expand Up @@ -232,7 +250,8 @@ func (l *Loader[K, V]) Load(originalContext context.Context, key K) Thunk[V] {
result.mu.RLock()
defer result.mu.RUnlock()
var ev *PanicErrorWrapper
if result.value.Error != nil && errors.As(result.value.Error, &ev) {
var es *SkipCacheError
if result.value.Error != nil && (errors.As(result.value.Error, &ev) || errors.As(result.value.Error, &es)){
l.Clear(ctx, key)
}
return result.value.Data, result.value.Error
Expand Down
63 changes: 63 additions & 0 deletions dataloader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,45 @@ func TestLoader(t *testing.T) {
}
})

t.Run("test Load Method not caching results with errors of type SkipCacheError", func(t *testing.T) {
t.Parallel()
skipCacheLoader, loadCalls := SkipCacheErrorLoader(3, "1")
ctx := context.Background()
futures1 := skipCacheLoader.LoadMany(ctx, []string{"1", "2", "3"})
_, errs1 := futures1()
var errCount int = 0
var nilCount int = 0
for _, err := range errs1 {
if err == nil {
nilCount++
} else {
errCount++
}
}
if errCount != 1 {
t.Error("Expected an error on only key \"1\"")
}

if nilCount != 2 {
t.Error("Expected the other errors to be nil")
}

futures2 := skipCacheLoader.LoadMany(ctx, []string{"2", "3", "1"})
_, errs2 := futures2()
// There should be no errors in the second batch, as the only key that was not cached
// this time around will not throw an error
if errs2 != nil {
t.Error("Expected LoadMany() to return nil error slice when no errors occurred")
}

calls := (*loadCalls)[1]
expected := []string{"1"}

if !reflect.DeepEqual(calls, expected) {
t.Errorf("Expected load calls %#v, got %#v", expected, calls)
}
})

t.Run("test Load Method Panic Safety in multiple keys", func(t *testing.T) {
t.Parallel()
defer func() {
Expand Down Expand Up @@ -622,6 +661,30 @@ func ErrorCacheLoader[K comparable](max int) (*Loader[K, K], *[][]K) {
return errorCacheLoader, &loadCalls
}

func SkipCacheErrorLoader[K comparable](max int, onceErrorKey K) (*Loader[K, K], *[][]K) {
var mu sync.Mutex
var loadCalls [][]K
errorThrown := false
skipCacheErrorLoader := NewBatchedLoader(func(_ context.Context, keys []K) []*Result[K] {
var results []*Result[K]
mu.Lock()
loadCalls = append(loadCalls, keys)
mu.Unlock()
// return a non cacheable error for the first occurence of onceErrorKey
for _, k := range keys {
if !errorThrown && k == onceErrorKey {
results = append(results, &Result[K]{k, NewSkipCacheError(fmt.Errorf("non cacheable error"))})
errorThrown = true
} else {
results = append(results, &Result[K]{k, nil})
}
}

return results
}, WithBatchCapacity[K, K](max))
return skipCacheErrorLoader, &loadCalls
}

func BadLoader[K comparable](max int) (*Loader[K, K], *[][]K) {
var mu sync.Mutex
var loadCalls [][]K
Expand Down