Skip to content

Commit

Permalink
Merge pull request #7 from wandb/context-cancel
Browse files Browse the repository at this point in the history
context cancelation
  • Loading branch information
mumbleskates authored Sep 25, 2024
2 parents c70b8da + 80752bb commit 70e36f1
Show file tree
Hide file tree
Showing 6 changed files with 373 additions and 58 deletions.
78 changes: 71 additions & 7 deletions cleanupsdemo/cleanups_demo.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ func main() {
const cycles = 100
const batchSize = 100

ctx := context.Background()

// This demonstrates the library's cleanup functionality, where forgotten
// executors that own goroutines and have been discarded without being awaited
// will clean themselves up without permanently leaking goroutines. This is
Expand All @@ -22,14 +24,34 @@ func main() {
// of the thunks that are registered with `runtime.SetFinalizer()` in
// group.go and collect.go, which close channels and call cancel functions.

println("leaking ~", cycles*(batchSize+2), "goroutines from abandoned group executors...")
println("leaking goroutines from group executors...")

// Dependent contexts are always canceled, too.
//
// This binary generally tests that executors and collectors always clean up
// after themselves, that any extra goroutines and channels and contexts
// they open will be shut down. The easiest thing we can measure is running
// goroutines, so to measure things that aren't goroutines (like unclosed
// contexts) we can just start new goroutines that wait for contexts we
// expect to be canceled.
leakDependent := func(ctx context.Context) {
// All canceleable child contexts are also canceled
ctx, cancel := context.WithCancel(ctx)
_ = cancel
// Leak a goroutine whose halting depends on the given context being
// canceled.
go func() {
<-ctx.Done()
}()
}

// Leak just a crazy number of goroutines
for i := 0; i < cycles; i++ {
func() {
g := parallel.Collect[int](parallel.Limited(context.Background(), batchSize))
g := parallel.Collect[int](parallel.Limited(ctx, batchSize))
for j := 0; j < batchSize; j++ {
g.Go(func(context.Context) (int, error) {
g.Go(func(ctx context.Context) (int, error) {
leakDependent(ctx)
return 1, nil
})
}
Expand All @@ -38,23 +60,65 @@ func main() {

func() {
defer func() { _ = recover() }()
g := parallel.Feed(parallel.Unlimited(context.Background()), func(context.Context, int) error {
g := parallel.Feed(parallel.Unlimited(ctx), func(context.Context, int) error {
panic("feed function panics")
})
g.Go(func(context.Context) (int, error) {
g.Go(func(ctx context.Context) (int, error) {
leakDependent(ctx)
return 1, nil
})
// Leak the group without awaiting it
}()

func() {
defer func() { _ = recover() }()
g := parallel.Collect[int](parallel.Unlimited(context.Background()))
g.Go(func(context.Context) (int, error) {
g := parallel.Collect[int](parallel.Unlimited(ctx))
g.Go(func(ctx context.Context) (int, error) {
leakDependent(ctx)
panic("op panics")
})
// Leak the group without awaiting it
}()

// Start some executors that complete normally without error
{
g := parallel.Unlimited(ctx)
g.Go(func(ctx context.Context) {
leakDependent(ctx)
})
g.Wait()
}
{
g := parallel.Limited(ctx, 0)
g.Go(func(ctx context.Context) {
leakDependent(ctx)
})
g.Wait()
}
{
g := parallel.Collect[int](parallel.Limited(ctx, batchSize))
g.Go(func(ctx context.Context) (int, error) {
leakDependent(ctx)
return 1, nil
})
_, err := g.Wait()
if err != nil {
panic(err)
}
}
{
g := parallel.Feed(parallel.Unlimited(ctx), func(context.Context, int) error {
return nil
})
g.Go(func(ctx context.Context) (int, error) {
leakDependent(ctx)
return 1, nil
})
err := g.Wait()
if err != nil {
panic(err)
}
}
}

println("monitoring and running GC...")
Expand Down
41 changes: 32 additions & 9 deletions collect.go
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,20 @@ func FeedWithErrs[T any](executor Executor, receiver func(context.Context, T) er
return making
}

// groupError returns the error associated with a group's context; if the error
// was errGroupDone, that doesn't count as an error and nil is returned instead.
func groupError(ctx context.Context) error {
err := context.Cause(ctx)
// We are explicitly using == here to check for the exact value of our
// sentinel error, not using errors.Is(), because we don't actually want to
// find it if it's in wrapped errors. We *only* want to know whether the
// cancelation error is *exactly* errGroupDone.
if err == errGroupDone {
return nil
}
return err
}

var _ ErrGroupExecutor = &errGroup{}

type errGroup struct {
Expand All @@ -242,7 +256,7 @@ func (eg *errGroup) Go(op func(context.Context) error) {
func (eg *errGroup) Wait() error {
eg.g.Wait()
ctx, _ := eg.g.getContext()
return context.Cause(ctx)
return groupError(ctx)
}

func makePipeGroup[T any, R any](executor Executor) *pipeGroup[T, R] {
Expand Down Expand Up @@ -290,6 +304,14 @@ func (pg *pipeGroup[T, R]) doWait() {
// even in case of a panic by deferring it, and that always only happens at
// the end of the function... so, we just put an inner function here to make
// it happen "early."

// Runs last: We must make completely certain that we cancel the context
// owned by the pipeGroup. This context is shared between the executor and
// the pipeWorkers; we take charge of making sure this cancelation happens
// as soon as possible here, and we want it to happen at the very end after
// everything else in case something else wanted to set the cancel cause of
// the context to an actual error instead of our "no error" sentinel value.
defer pg.pipeWorkers.cancel(errGroupDone)
func() {
// Runs second: Close the results chan and unblock the pipe worker.
// Because we're deferring this, it will happen even if there is a panic
Expand All @@ -300,11 +322,12 @@ func (pg *pipeGroup[T, R]) doWait() {
runtime.SetFinalizer(pg, nil)
}
}()
// Runs first: Wait for inputs
pg.g.Wait()
// Runs first: Wait for inputs. Wait "quietly", not canceling the
// context yet so if there is an error later we can still see it
pg.g.quietWait()
}()
// Runs third: Wait for outputs to be done
pg.pipeWorkers.Wait()
pg.pipeWorkers.quietWait()
}

var _ CollectingExecutor[int] = collectingGroup[int]{}
Expand All @@ -329,7 +352,7 @@ func (cg collectingGroup[T]) Go(op func(context.Context) (T, error)) {
func (cg collectingGroup[T]) Wait() ([]T, error) {
cg.doWait()
ctx, _ := cg.g.getContext()
if err := context.Cause(ctx); err != nil {
if err := groupError(ctx); err != nil {
// We have an error; return it
return nil, err
}
Expand Down Expand Up @@ -358,7 +381,7 @@ func (fg feedingGroup[T]) Go(op func(context.Context) (T, error)) {
func (fg feedingGroup[T]) Wait() error {
fg.doWait()
ctx, _ := fg.g.getContext()
return context.Cause(ctx)
return groupError(ctx)
}

var _ AllErrsExecutor = multiErrGroup{}
Expand All @@ -381,7 +404,7 @@ func (meg multiErrGroup) Wait() MultiError {
meg.doWait()
err := CombineErrors(*meg.res...)
ctx, _ := meg.g.getContext()
if cause := context.Cause(ctx); cause != nil {
if cause := groupError(ctx); cause != nil {
return CombineErrors(cause, err)
}
return err
Expand Down Expand Up @@ -415,7 +438,7 @@ func (ceg collectingMultiErrGroup[T]) Wait() ([]T, MultiError) {
ceg.doWait()
res, err := ceg.res.values, CombineErrors(ceg.res.errs...)
ctx, _ := ceg.g.getContext()
if cause := context.Cause(ctx); cause != nil {
if cause := groupError(ctx); cause != nil {
return res, CombineErrors(cause, err)
}
return res, err
Expand All @@ -439,7 +462,7 @@ func (feg feedingMultiErrGroup[T]) Wait() MultiError {
feg.doWait()
err := CombineErrors(*feg.res...)
ctx, _ := feg.g.getContext()
if cause := context.Cause(ctx); cause != nil {
if cause := groupError(ctx); cause != nil {
return CombineErrors(cause, err)
}
return err
Expand Down
21 changes: 21 additions & 0 deletions collect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestErrGroup(t *testing.T) {
Expand Down Expand Up @@ -149,6 +150,26 @@ func TestFeedErroring(t *testing.T) {
assert.Subset(t, []int{1, 2, 3}, res)
}

func TestFeedLastReceiverErrs(t *testing.T) {
t.Parallel()
// Even when the very very last item through the pipe group causes an error,
// the group's context shouldn't be canceled yet and it should still be able
// to set the error.
g := Feed(Limited(context.Background(), 0), func(ctx context.Context, val int) error {
if val == 10 {
return errors.New("boom")
} else {
return nil
}
})
for i := 1; i <= 10; i++ {
g.Go(func(ctx context.Context) (int, error) {
return i, nil
})
}
require.Error(t, g.Wait())
}

func TestFeedErroringInReceiver(t *testing.T) {
t.Parallel()
g := Feed(Unlimited(context.Background()), func(ctx context.Context, val int) error {
Expand Down
72 changes: 65 additions & 7 deletions group.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,16 @@ const bufferSize = 8

const misuseMessage = "parallel executor misuse: don't reuse executors"

var errPanicked = errors.New("panicked")
var (
errPanicked = errors.New("panicked")
// errGroupDone is a sentinel error value used to cancel an execution
// context when it has completed without error.
errGroupDone = errors.New("executor done")
errGroupAbandoned = errors.New("executor abandoned")

// Contexts are canceled with this error when executors are awaited.
GroupDoneError = errGroupDone
)

// WorkerPanic represents a panic value propagated from a task within a parallel
// executor, and is the main type of panic that you might expect to receive.
Expand Down Expand Up @@ -54,6 +63,9 @@ type Executor interface {

// internal
getContext() (context.Context, context.CancelCauseFunc)
// Waits without canceling the context with errGroupDone. The caller of this
// function promises that they will be responsible for canceling the context
quietWait()
}

// Creates a basic executor which runs all the functions given in one goroutine
Expand All @@ -73,8 +85,16 @@ func Unlimited(ctx context.Context) Executor {
// should still be cleaned up.
func Limited(ctx context.Context, maxGoroutines int) Executor {
if maxGoroutines < 1 {
// When maxGoroutines is non-positive, we return the trivial executor
// type directly.
gctx, cancel := context.WithCancelCause(ctx)
return &runner{ctx: gctx, cancel: cancel}
g := &runner{ctx: gctx, cancel: cancel}
// This executor still needs to make certain that its context always
// gets canceled!
runtime.SetFinalizer(g, func(doomed *runner) {
doomed.cancel(errGroupAbandoned)
})
return g
}
making := &limitedGroup{
g: makeGroup(context.WithCancelCause(ctx)),
Expand All @@ -83,15 +103,29 @@ func Limited(ctx context.Context, maxGoroutines int) Executor {
}
runtime.SetFinalizer(making, func(doomed *limitedGroup) {
close(doomed.ops)
doomed.g.cancel(nil)
})
return making
}

// Base executor with an interface that runs everything serially.
// Base executor with an interface that runs everything serially. This can be
// returned directly from Limited in a special case, and otherwise it is just
// composed as inner struct fields for the base concurrent group struct.
//
// The lifecycle of the context is important: When the executor is set up we
// create a cancelable context, and we need to guarantee that it is eventually
// canceled or it can stay resident indefinitely in the known children of a
// parent context, effectively leaking memory. To do this, we guarantee that the
// context is canceled in one of a couple ways:
// 1. if the executor is abandoned without awaiting, a runtime finalizer that
// is registered immediately after we create the executor will cancel it
// 2. if the executor is awaited and completes normally, after everything else
// has completed the context will be canceled with the errGroupDone sentinel
// 3. if there is a panic or another kind of error that causes the executor to
// terminate early (such as with ErrGroup), the context is canceled with
// error normally in this way.
type runner struct {
ctx context.Context // Closed when we panic or get garbage collected
cancel context.CancelCauseFunc // Only close the dying channel one time
ctx context.Context // Execution context
cancel context.CancelCauseFunc // Cancel for the ctx; must always be called
awaited atomic.Bool // Set when Wait() is called
}

Expand All @@ -108,15 +142,25 @@ func (n *runner) Go(op func(context.Context)) {
}

func (n *runner) Wait() {
n.quietWait()
n.cancel(errGroupDone)
}

func (n *runner) quietWait() {
n.awaited.Store(true)
runtime.SetFinalizer(n, nil)
}

func (n *runner) getContext() (context.Context, context.CancelCauseFunc) {
return n.ctx, n.cancel
}

func makeGroup(ctx context.Context, cancel context.CancelCauseFunc) *group {
return &group{runner: runner{ctx: ctx, cancel: cancel}}
g := &group{runner: runner{ctx: ctx, cancel: cancel}}
runtime.SetFinalizer(g, func(doomed *group) {
doomed.cancel(errGroupAbandoned)
})
return g
}

// Base concurrent executor
Expand Down Expand Up @@ -181,7 +225,13 @@ func (g *group) Go(op func(context.Context)) {
}

func (g *group) Wait() {
defer g.cancel(errGroupDone)
g.quietWait()
}

func (g *group) quietWait() {
g.awaited.Store(true)
runtime.SetFinalizer(g, nil)
g.wg.Wait()
g.checkPanic()
}
Expand Down Expand Up @@ -260,6 +310,14 @@ func (lg *limitedGroup) Wait() {
lg.g.Wait()
}

func (lg *limitedGroup) quietWait() {
if !lg.awaited.Swap(true) {
close(lg.ops)
runtime.SetFinalizer(lg, nil) // Don't try to close this chan again :)
}
lg.g.quietWait()
}

func (lg *limitedGroup) getContext() (context.Context, context.CancelCauseFunc) {
return lg.g.getContext()
}
Loading

0 comments on commit 70e36f1

Please sign in to comment.