From 84d9d28cd99f59468677b5e41b02c9ccdbd7aa55 Mon Sep 17 00:00:00 2001 From: Siarhei Navatski Date: Sun, 15 Sep 2019 14:33:12 +0300 Subject: [PATCH] Add context to Check Signed-off-by: Siarhei Navatski --- async.go | 4 +- async_test.go | 8 ++-- checks.go | 21 +++++---- checks_test.go | 40 +++++++++++----- example_test.go | 10 ++-- handler.go | 7 +-- handler_test.go | 102 +++++++++++++++++++++++++++++++++++----- metrics_handler.go | 4 +- metrics_handler_test.go | 8 ++-- timeout.go | 8 ++-- timeout_test.go | 9 ++-- types.go | 3 +- 12 files changed, 166 insertions(+), 58 deletions(-) diff --git a/async.go b/async.go index 2c6bb9f..dabc044 100644 --- a/async.go +++ b/async.go @@ -52,7 +52,7 @@ func AsyncWithContext(ctx context.Context, check Check, interval time.Duration) // make a wrapper that runs the check, and swaps out the current head of // the channel with the latest result update := func() { - err := check() + err := check(ctx) <-result result <- err } @@ -76,7 +76,7 @@ func AsyncWithContext(ctx context.Context, check Check, interval time.Duration) }() // return a Check function that closes over our result and mutex - return func() error { + return func(_ context.Context) error { // peek at the head of the channel, then put it back err := <-result result <- err diff --git a/async_test.go b/async_test.go index afebcdd..1b4eda1 100644 --- a/async_test.go +++ b/async_test.go @@ -23,20 +23,20 @@ import ( ) func TestAsync(t *testing.T) { - async := Async(func() error { + async := Async(func(ctx context.Context) error { time.Sleep(50 * time.Millisecond) return nil }, 1*time.Millisecond) // expect the first call to return ErrNoData since it takes 50ms to return the first time - assert.EqualError(t, async(), "no data yet") + assert.EqualError(t, async(context.Background()), "no data yet") // wait for the first run to finish time.Sleep(100 * time.Millisecond) // make sure the next call returns nil ~immediately start := time.Now() - assert.NoError(t, async()) + assert.NoError(t, async(context.Background())) assert.WithinDuration(t, time.Now(), start, 1*time.Millisecond, "expected async() to return almost immediately") } @@ -46,7 +46,7 @@ func TestAsyncWithContext(t *testing.T) { // start an async check that counts how many times it was called calls := 0 - AsyncWithContext(ctx, func() error { + AsyncWithContext(ctx, func(ctx context.Context) error { calls++ time.Sleep(1 * time.Millisecond) return nil diff --git a/checks.go b/checks.go index a41eb32..c5dbd97 100644 --- a/checks.go +++ b/checks.go @@ -27,7 +27,7 @@ import ( // TCPDialCheck returns a Check that checks TCP connectivity to the provided // endpoint. func TCPDialCheck(addr string, timeout time.Duration) Check { - return func() error { + return func(ctx context.Context) error { conn, err := net.DialTimeout("tcp", addr, timeout) if err != nil { return err @@ -47,8 +47,13 @@ func HTTPGetCheck(url string, timeout time.Duration) Check { return http.ErrUseLastResponse }, } - return func() error { - resp, err := client.Get(url) + return func(ctx context.Context) error { + req, err := http.NewRequest("GET", url, nil) + if err != nil { + return err + } + req = req.WithContext(ctx) + resp, err := client.Do(req) if err != nil { return err } @@ -63,8 +68,8 @@ func HTTPGetCheck(url string, timeout time.Duration) Check { // DatabasePingCheck returns a Check that validates connectivity to a // database/sql.DB using Ping(). func DatabasePingCheck(database *sql.DB, timeout time.Duration) Check { - return func() error { - ctx, cancel := context.WithTimeout(context.Background(), timeout) + return func(ctx context.Context) error { + ctx, cancel := context.WithTimeout(ctx, timeout) defer cancel() if database == nil { return fmt.Errorf("database is nil") @@ -77,8 +82,8 @@ func DatabasePingCheck(database *sql.DB, timeout time.Duration) Check { // to at least one IP address within the specified timeout. func DNSResolveCheck(host string, timeout time.Duration) Check { resolver := net.Resolver{} - return func() error { - ctx, cancel := context.WithTimeout(context.Background(), timeout) + return func(ctx context.Context) error { + ctx, cancel := context.WithTimeout(ctx, timeout) defer cancel() addrs, err := resolver.LookupHost(ctx, host) if err != nil { @@ -94,7 +99,7 @@ func DNSResolveCheck(host string, timeout time.Duration) Check { // GoroutineCountCheck returns a Check that fails if too many goroutines are // running (which could indicate a resource leak). func GoroutineCountCheck(threshold int) Check { - return func() error { + return func(ctx context.Context) error { count := runtime.NumGoroutine() if count > threshold { return fmt.Errorf("too many goroutines (%d > %d)", count, threshold) diff --git a/checks_test.go b/checks_test.go index debe4ee..1137451 100644 --- a/checks_test.go +++ b/checks_test.go @@ -15,6 +15,7 @@ package healthcheck import ( + "context" "testing" "time" @@ -23,30 +24,47 @@ import ( ) func TestTCPDialCheck(t *testing.T) { - assert.NoError(t, TCPDialCheck("heptio.com:80", 5*time.Second)()) - assert.Error(t, TCPDialCheck("heptio.com:25327", 5*time.Second)()) + ctx := context.Background() + + assert.NoError(t, TCPDialCheck("heptio.com:80", 5*time.Second)(ctx)) + assert.Error(t, TCPDialCheck("heptio.com:25327", 5*time.Second)(ctx)) } func TestHTTPGetCheck(t *testing.T) { - assert.NoError(t, HTTPGetCheck("https://heptio.com", 5*time.Second)()) - assert.Error(t, HTTPGetCheck("http://heptio.com", 5*time.Second)(), "redirect should fail") - assert.Error(t, HTTPGetCheck("https://heptio.com/nonexistent", 5*time.Second)(), "404 should fail") + ctx := context.Background() + canceledCtx, cancel := context.WithCancel(ctx) + cancel() + + assert.NoError(t, HTTPGetCheck("https://heptio.com", 5*time.Second)(ctx)) + assert.Error(t, HTTPGetCheck("https://heptio.com", 5*time.Second)(canceledCtx)) + assert.Error(t, HTTPGetCheck("http://heptio.com", 5*time.Second)(ctx), "redirect should fail") + assert.Error(t, HTTPGetCheck("https://heptio.com/nonexistent", 5*time.Second)(ctx), "404 should fail") } func TestDatabasePingCheck(t *testing.T) { - assert.Error(t, DatabasePingCheck(nil, 1*time.Second)(), "nil DB should fail") + ctx := context.Background() + canceledCtx, cancel := context.WithCancel(ctx) + cancel() + + assert.Error(t, DatabasePingCheck(nil, 1*time.Second)(ctx), "nil DB should fail") db, _, err := sqlmock.New() assert.NoError(t, err) - assert.NoError(t, DatabasePingCheck(db, 1*time.Second)(), "ping should succeed") + assert.NoError(t, DatabasePingCheck(db, 1*time.Second)(ctx), "ping should succeed") + assert.Error(t, DatabasePingCheck(db, 1*time.Second)(canceledCtx), "ping should fail") } func TestDNSResolveCheck(t *testing.T) { - assert.NoError(t, DNSResolveCheck("heptio.com", 5*time.Second)()) - assert.Error(t, DNSResolveCheck("nonexistent.heptio.com", 5*time.Second)()) + ctx := context.Background() + canceledCtx, cancel := context.WithCancel(ctx) + cancel() + + assert.NoError(t, DNSResolveCheck("heptio.com", 5*time.Second)(ctx)) + assert.Error(t, DNSResolveCheck("nonexistent.heptio.com", 5*time.Second)(ctx)) + assert.Error(t, DNSResolveCheck("heptio.com", 5*time.Second)(canceledCtx)) } func TestGoroutineCountCheck(t *testing.T) { - assert.NoError(t, GoroutineCountCheck(1000)()) - assert.Error(t, GoroutineCountCheck(0)()) + assert.NoError(t, GoroutineCountCheck(1000)(context.Background())) + assert.Error(t, GoroutineCountCheck(0)(context.Background())) } diff --git a/example_test.go b/example_test.go index d7d7717..b28ee65 100644 --- a/example_test.go +++ b/example_test.go @@ -15,6 +15,7 @@ package healthcheck import ( + "context" "database/sql" "fmt" "net/http" @@ -23,10 +24,9 @@ import ( "strings" "time" - sqlmock "gopkg.in/DATA-DOG/go-sqlmock.v1" - "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promhttp" + sqlmock "gopkg.in/DATA-DOG/go-sqlmock.v1" ) func Example() { @@ -109,7 +109,7 @@ func Example_advanced() { HTTPGetCheck(upstreamURL, 500*time.Millisecond)) // Implement a custom check with a 50 millisecond timeout. - health.AddLivenessCheck("custom-check-with-timeout", Timeout(func() error { + health.AddLivenessCheck("custom-check-with-timeout", Timeout(func(ctx context.Context) error { // Simulate some work that could take a long time time.Sleep(time.Millisecond * 100) return nil @@ -146,12 +146,12 @@ func Example_metrics() { health := NewMetricsHandler(registry, "example") // Add a simple readiness check that always fails. - health.AddReadinessCheck("failing-check", func() error { + health.AddReadinessCheck("failing-check", func(ctx context.Context) error { return fmt.Errorf("example failure") }) // Add a liveness check that always succeeds - health.AddLivenessCheck("successful-check", func() error { + health.AddLivenessCheck("successful-check", func(ctx context.Context) error { return nil }) diff --git a/handler.go b/handler.go index 6ea9740..846cb0a 100644 --- a/handler.go +++ b/handler.go @@ -15,6 +15,7 @@ package healthcheck import ( + "context" "encoding/json" "net/http" "sync" @@ -59,11 +60,11 @@ func (s *basicHandler) AddReadinessCheck(name string, check Check) { s.readinessChecks[name] = check } -func (s *basicHandler) collectChecks(checks map[string]Check, resultsOut map[string]string, statusOut *int) { +func (s *basicHandler) collectChecks(ctx context.Context, checks map[string]Check, resultsOut map[string]string, statusOut *int) { s.checksMutex.RLock() defer s.checksMutex.RUnlock() for name, check := range checks { - if err := check(); err != nil { + if err := check(ctx); err != nil { *statusOut = http.StatusServiceUnavailable resultsOut[name] = err.Error() } else { @@ -81,7 +82,7 @@ func (s *basicHandler) handle(w http.ResponseWriter, r *http.Request, checks ... checkResults := make(map[string]string) status := http.StatusOK for _, checks := range checks { - s.collectChecks(checks, checkResults, &status) + s.collectChecks(r.Context(), checks, checkResults, &status) } // write out the response code and content type header diff --git a/handler_test.go b/handler_test.go index 2e894b4..cad0f89 100644 --- a/handler_test.go +++ b/handler_test.go @@ -15,6 +15,7 @@ package healthcheck import ( + "context" "errors" "net/http" "net/http/httptest" @@ -25,13 +26,14 @@ import ( func TestNewHandler(t *testing.T) { tests := []struct { - name string - method string - path string - live bool - ready bool - expect int - expectBody string + name string + method string + path string + live bool + ready bool + canceledContext bool + expect int + expectBody string }{ { name: "GET /foo should generate a 404", @@ -66,6 +68,16 @@ func TestNewHandler(t *testing.T) { expect: http.StatusOK, expectBody: "{}\n", }, + { + name: "with no checks and canceled context, /live should succeed", + method: "GET", + path: "/live", + live: true, + ready: true, + canceledContext: true, + expect: http.StatusOK, + expectBody: "{}\n", + }, { name: "with no checks, /ready should succeed", method: "GET", @@ -75,6 +87,16 @@ func TestNewHandler(t *testing.T) { expect: http.StatusOK, expectBody: "{}\n", }, + { + name: "with no checks and canceled context, /ready should succeed", + method: "GET", + path: "/ready", + live: true, + ready: true, + canceledContext: true, + expect: http.StatusOK, + expectBody: "{}\n", + }, { name: "with a failing readiness check, /live should still succeed", method: "GET", @@ -84,6 +106,16 @@ func TestNewHandler(t *testing.T) { expect: http.StatusOK, expectBody: "{}\n", }, + { + name: "with a failing readiness check and canceled context, /live should still succeed", + method: "GET", + path: "/live?full=1", + live: true, + ready: false, + canceledContext: true, + expect: http.StatusOK, + expectBody: "{}\n", + }, { name: "with a failing readiness check, /ready should fail", method: "GET", @@ -93,6 +125,16 @@ func TestNewHandler(t *testing.T) { expect: http.StatusServiceUnavailable, expectBody: "{\n \"test-readiness-check\": \"failed readiness check\"\n}\n", }, + { + name: "with a failing readiness check and canceled context, /ready should fail", + method: "GET", + path: "/ready?full=1", + live: true, + ready: false, + canceledContext: true, + expect: http.StatusServiceUnavailable, + expectBody: "{\n \"test-readiness-check\": \"context canceled\"\n}\n", + }, { name: "with a failing liveness check, /live should fail", method: "GET", @@ -102,6 +144,16 @@ func TestNewHandler(t *testing.T) { expect: http.StatusServiceUnavailable, expectBody: "{\n \"test-liveness-check\": \"failed liveness check\"\n}\n", }, + { + name: "with a failing liveness check and canceled context, /live should fail", + method: "GET", + path: "/live?full=1", + live: false, + ready: true, + canceledContext: true, + expect: http.StatusServiceUnavailable, + expectBody: "{\n \"test-liveness-check\": \"context canceled\"\n}\n", + }, { name: "with a failing liveness check, /ready should fail", method: "GET", @@ -112,7 +164,17 @@ func TestNewHandler(t *testing.T) { expectBody: "{\n \"test-liveness-check\": \"failed liveness check\"\n}\n", }, { - name: "with a failing liveness check, /ready without full=1 should fail with an empty body", + name: "with a failing liveness check and canceled context, /ready should fail", + method: "GET", + path: "/ready?full=1", + live: false, + ready: true, + canceledContext: true, + expect: http.StatusServiceUnavailable, + expectBody: "{\n \"test-liveness-check\": \"context canceled\"\n}\n", + }, + { + name: "with a fsailing liveness check, /ready without full=1 should fail with an empty body", method: "GET", path: "/ready", live: false, @@ -126,20 +188,36 @@ func TestNewHandler(t *testing.T) { h := NewHandler() if !tt.live { - h.AddLivenessCheck("test-liveness-check", func() error { - return errors.New("failed liveness check") + h.AddLivenessCheck("test-liveness-check", func(ctx context.Context) error { + select { + case <-ctx.Done(): + return ctx.Err() + default: + return errors.New("failed liveness check") + } }) } if !tt.ready { - h.AddReadinessCheck("test-readiness-check", func() error { - return errors.New("failed readiness check") + h.AddReadinessCheck("test-readiness-check", func(ctx context.Context) error { + select { + case <-ctx.Done(): + return ctx.Err() + default: + return errors.New("failed readiness check") + } }) } req, err := http.NewRequest(tt.method, tt.path, nil) assert.NoError(t, err) + if tt.canceledContext { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + req = req.WithContext(ctx) + } + reqStr := tt.method + " " + tt.path rr := httptest.NewRecorder() h.ServeHTTP(rr, req) diff --git a/metrics_handler.go b/metrics_handler.go index e95be0f..8619728 100644 --- a/metrics_handler.go +++ b/metrics_handler.go @@ -15,6 +15,7 @@ package healthcheck import ( + "context" "net/http" "github.com/prometheus/client_golang/prometheus" @@ -66,7 +67,8 @@ func (h *metricsHandler) wrap(name string, check Check) Check { ConstLabels: prometheus.Labels{"check": name}, }, func() float64 { - if check() == nil { + ctx := context.Background() + if check(ctx) == nil { return 0 } return 1 diff --git a/metrics_handler_test.go b/metrics_handler_test.go index caea359..6489349 100644 --- a/metrics_handler_test.go +++ b/metrics_handler_test.go @@ -15,6 +15,7 @@ package healthcheck import ( + "context" "fmt" "net/http" "net/http/httptest" @@ -23,7 +24,6 @@ import ( "testing" "github.com/prometheus/client_golang/prometheus" - "github.com/stretchr/testify/assert" ) @@ -31,13 +31,13 @@ func TestNewMetricsHandler(t *testing.T) { handler := NewMetricsHandler(prometheus.DefaultRegisterer, "test") for _, check := range []string{"aaa", "bbb", "ccc"} { - handler.AddLivenessCheck(check, func() error { + handler.AddLivenessCheck(check, func(ctx context.Context) error { return nil }) } for _, check := range []string{"ddd", "eee", "fff"} { - handler.AddLivenessCheck(check, func() error { + handler.AddLivenessCheck(check, func(ctx context.Context) error { return fmt.Errorf("failing health check %q", check) }) } @@ -74,7 +74,7 @@ test_healthcheck_status{check="fff"} 1 func TestNewMetricsHandlerEndpoints(t *testing.T) { handler := NewMetricsHandler(prometheus.NewRegistry(), "test") - handler.AddReadinessCheck("fail", func() error { + handler.AddReadinessCheck("fail", func(ctx context.Context) error { return fmt.Errorf("failing readiness check") }) diff --git a/timeout.go b/timeout.go index 86bf21d..6199da0 100644 --- a/timeout.go +++ b/timeout.go @@ -15,6 +15,7 @@ package healthcheck import ( + "context" "fmt" "time" ) @@ -39,13 +40,14 @@ func (e timeoutError) Temporary() bool { // Timeout adds a timeout to a Check. If the underlying check takes longer than // the timeout, it returns an error. func Timeout(check Check, timeout time.Duration) Check { - return func() error { + return func(ctx context.Context) error { + ctx, _ = context.WithTimeout(ctx, timeout) c := make(chan error, 1) - go func() { c <- check() }() + go func() { c <- check(ctx) }() select { case err := <-c: return err - case <-time.After(timeout): + case <-ctx.Done(): return timeoutError(timeout) } } diff --git a/timeout_test.go b/timeout_test.go index fd4f515..5ccfaaa 100644 --- a/timeout_test.go +++ b/timeout_test.go @@ -15,17 +15,18 @@ package healthcheck import ( + "context" "net" "testing" "time" ) func TestTimeout(t *testing.T) { - tooSlow := Timeout(func() error { + tooSlow := Timeout(func(ctx context.Context) error { time.Sleep(10 * time.Millisecond) return nil }, 1*time.Millisecond) - err := tooSlow() + err := tooSlow(context.Background()) if _, isTimeoutError := err.(timeoutError); !isTimeoutError { t.Errorf("expected a TimeoutError, got %v", err) } @@ -38,11 +39,11 @@ func TestTimeout(t *testing.T) { t.Errorf("expected Temporary() to be true, got %v", err) } - notTooSlow := Timeout(func() error { + notTooSlow := Timeout(func(ctx context.Context) error { time.Sleep(1 * time.Millisecond) return nil }, 10*time.Millisecond) - if err := notTooSlow(); err != nil { + if err := notTooSlow(context.Background()); err != nil { t.Errorf("expected success, got %v", err) } } diff --git a/types.go b/types.go index 714294b..182d4df 100644 --- a/types.go +++ b/types.go @@ -15,11 +15,12 @@ package healthcheck import ( + "context" "net/http" ) // Check is a health/readiness check. -type Check func() error +type Check func(ctx context.Context) error // Handler is an http.Handler with additional methods that register health and // readiness checks. It handles handle "/live" and "/ready" HTTP