diff --git a/dataloader.go b/dataloader.go index 3719df7..8d79b42 100644 --- a/dataloader.go +++ b/dataloader.go @@ -414,7 +414,6 @@ func (b *batcher) batch(originalContext context.Context) { } ctx, finish := b.tracer.TraceBatch(originalContext, keys) - defer finish(items) func() { defer func() { @@ -432,6 +431,8 @@ func (b *batcher) batch(originalContext context.Context) { items = b.batchFn(ctx, keys) }() + defer finish(items) + if panicErr != nil { for _, req := range reqs { req.channel <- &Result{Error: fmt.Errorf("Panic received in batch function: %v", panicErr)} diff --git a/dataloader_test.go b/dataloader_test.go index d676731..71283f0 100644 --- a/dataloader_test.go +++ b/dataloader_test.go @@ -457,6 +457,26 @@ func TestLoader(t *testing.T) { } }) + t.Run("tracer's TraceBatch finish func is passed the Result slice", func(t *testing.T) { + t.Parallel() + identityLoader, _ := IDLoader(0) + tracer := new(RecordingTracer) + identityLoader.tracer = tracer + ctx := context.Background() + future := identityLoader.Load(ctx, StringKey("1")) + _, err := future() + if err != nil { + t.Error(err.Error()) + } + + calls := tracer.traceBatchFinishCalls + inner := []*Result{{Data: "1"}} + expected := [][]*Result{inner} + if !reflect.DeepEqual(calls, expected) { + t.Errorf("tracer did not receive expected results. Expected %#v, got %#v", expected, calls) + } + }) + } // test helpers @@ -586,6 +606,24 @@ func FaultyLoader() (*Loader, *[][]string) { return loader, &loadCalls } +type RecordingTracer struct { + traceBatchFinishCalls [][]*Result +} + +func (t *RecordingTracer) TraceLoad(ctx context.Context, key Key) (context.Context, TraceLoadFinishFunc) { + return ctx, func(Thunk) {} +} + +func (t *RecordingTracer) TraceLoadMany(ctx context.Context, keys Keys) (context.Context, TraceLoadManyFinishFunc) { + return ctx, func(ThunkMany) {} +} + +func (t *RecordingTracer) TraceBatch(ctx context.Context, keys Keys) (context.Context, TraceBatchFinishFunc) { + return ctx, func(results []*Result) { + t.traceBatchFinishCalls = append(t.traceBatchFinishCalls, results) + } +} + /////////////////////////////////////////////////// // Benchmarks ///////////////////////////////////////////////////