Skip to content
This repository has been archived by the owner on Mar 15, 2022. It is now read-only.

Add context to Check #27

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions async.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ func AsyncWithContext(ctx context.Context, check Check, interval time.Duration)
// make a wrapper that runs the check, and swaps out the current head of
// the channel with the latest result
update := func() {
err := check()
err := check(ctx)
<-result
result <- err
}
Expand All @@ -76,7 +76,7 @@ func AsyncWithContext(ctx context.Context, check Check, interval time.Duration)
}()

// return a Check function that closes over our result and mutex
return func() error {
return func(_ context.Context) error {
// peek at the head of the channel, then put it back
err := <-result
result <- err
Expand Down
8 changes: 4 additions & 4 deletions async_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,20 @@ import (
)

func TestAsync(t *testing.T) {
async := Async(func() error {
async := Async(func(ctx context.Context) error {
time.Sleep(50 * time.Millisecond)
return nil
}, 1*time.Millisecond)

// expect the first call to return ErrNoData since it takes 50ms to return the first time
assert.EqualError(t, async(), "no data yet")
assert.EqualError(t, async(context.Background()), "no data yet")

// wait for the first run to finish
time.Sleep(100 * time.Millisecond)

// make sure the next call returns nil ~immediately
start := time.Now()
assert.NoError(t, async())
assert.NoError(t, async(context.Background()))
assert.WithinDuration(t, time.Now(), start, 1*time.Millisecond,
"expected async() to return almost immediately")
}
Expand All @@ -46,7 +46,7 @@ func TestAsyncWithContext(t *testing.T) {

// start an async check that counts how many times it was called
calls := 0
AsyncWithContext(ctx, func() error {
AsyncWithContext(ctx, func(ctx context.Context) error {
calls++
time.Sleep(1 * time.Millisecond)
return nil
Expand Down
21 changes: 13 additions & 8 deletions checks.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import (
// TCPDialCheck returns a Check that checks TCP connectivity to the provided
// endpoint.
func TCPDialCheck(addr string, timeout time.Duration) Check {
return func() error {
return func(ctx context.Context) error {
conn, err := net.DialTimeout("tcp", addr, timeout)
if err != nil {
return err
Expand All @@ -47,8 +47,13 @@ func HTTPGetCheck(url string, timeout time.Duration) Check {
return http.ErrUseLastResponse
},
}
return func() error {
resp, err := client.Get(url)
return func(ctx context.Context) error {
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return err
}
req = req.WithContext(ctx)
resp, err := client.Do(req)
if err != nil {
return err
}
Expand All @@ -63,8 +68,8 @@ func HTTPGetCheck(url string, timeout time.Duration) Check {
// DatabasePingCheck returns a Check that validates connectivity to a
// database/sql.DB using Ping().
func DatabasePingCheck(database *sql.DB, timeout time.Duration) Check {
return func() error {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
return func(ctx context.Context) error {
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
if database == nil {
return fmt.Errorf("database is nil")
Expand All @@ -77,8 +82,8 @@ func DatabasePingCheck(database *sql.DB, timeout time.Duration) Check {
// to at least one IP address within the specified timeout.
func DNSResolveCheck(host string, timeout time.Duration) Check {
resolver := net.Resolver{}
return func() error {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
return func(ctx context.Context) error {
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
addrs, err := resolver.LookupHost(ctx, host)
if err != nil {
Expand All @@ -94,7 +99,7 @@ func DNSResolveCheck(host string, timeout time.Duration) Check {
// GoroutineCountCheck returns a Check that fails if too many goroutines are
// running (which could indicate a resource leak).
func GoroutineCountCheck(threshold int) Check {
return func() error {
return func(ctx context.Context) error {
count := runtime.NumGoroutine()
if count > threshold {
return fmt.Errorf("too many goroutines (%d > %d)", count, threshold)
Expand Down
40 changes: 29 additions & 11 deletions checks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package healthcheck

import (
"context"
"testing"
"time"

Expand All @@ -23,30 +24,47 @@ import (
)

func TestTCPDialCheck(t *testing.T) {
assert.NoError(t, TCPDialCheck("heptio.com:80", 5*time.Second)())
assert.Error(t, TCPDialCheck("heptio.com:25327", 5*time.Second)())
ctx := context.Background()

assert.NoError(t, TCPDialCheck("heptio.com:80", 5*time.Second)(ctx))
assert.Error(t, TCPDialCheck("heptio.com:25327", 5*time.Second)(ctx))
}

func TestHTTPGetCheck(t *testing.T) {
assert.NoError(t, HTTPGetCheck("https://heptio.com", 5*time.Second)())
assert.Error(t, HTTPGetCheck("http://heptio.com", 5*time.Second)(), "redirect should fail")
assert.Error(t, HTTPGetCheck("https://heptio.com/nonexistent", 5*time.Second)(), "404 should fail")
ctx := context.Background()
canceledCtx, cancel := context.WithCancel(ctx)
cancel()

assert.NoError(t, HTTPGetCheck("https://heptio.com", 5*time.Second)(ctx))
assert.Error(t, HTTPGetCheck("https://heptio.com", 5*time.Second)(canceledCtx))
assert.Error(t, HTTPGetCheck("http://heptio.com", 5*time.Second)(ctx), "redirect should fail")
assert.Error(t, HTTPGetCheck("https://heptio.com/nonexistent", 5*time.Second)(ctx), "404 should fail")
}

func TestDatabasePingCheck(t *testing.T) {
assert.Error(t, DatabasePingCheck(nil, 1*time.Second)(), "nil DB should fail")
ctx := context.Background()
canceledCtx, cancel := context.WithCancel(ctx)
cancel()

assert.Error(t, DatabasePingCheck(nil, 1*time.Second)(ctx), "nil DB should fail")

db, _, err := sqlmock.New()
assert.NoError(t, err)
assert.NoError(t, DatabasePingCheck(db, 1*time.Second)(), "ping should succeed")
assert.NoError(t, DatabasePingCheck(db, 1*time.Second)(ctx), "ping should succeed")
assert.Error(t, DatabasePingCheck(db, 1*time.Second)(canceledCtx), "ping should fail")
}

func TestDNSResolveCheck(t *testing.T) {
assert.NoError(t, DNSResolveCheck("heptio.com", 5*time.Second)())
assert.Error(t, DNSResolveCheck("nonexistent.heptio.com", 5*time.Second)())
ctx := context.Background()
canceledCtx, cancel := context.WithCancel(ctx)
cancel()

assert.NoError(t, DNSResolveCheck("heptio.com", 5*time.Second)(ctx))
assert.Error(t, DNSResolveCheck("nonexistent.heptio.com", 5*time.Second)(ctx))
assert.Error(t, DNSResolveCheck("heptio.com", 5*time.Second)(canceledCtx))
}

func TestGoroutineCountCheck(t *testing.T) {
assert.NoError(t, GoroutineCountCheck(1000)())
assert.Error(t, GoroutineCountCheck(0)())
assert.NoError(t, GoroutineCountCheck(1000)(context.Background()))
assert.Error(t, GoroutineCountCheck(0)(context.Background()))
}
10 changes: 5 additions & 5 deletions example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package healthcheck

import (
"context"
"database/sql"
"fmt"
"net/http"
Expand All @@ -23,10 +24,9 @@ import (
"strings"
"time"

sqlmock "gopkg.in/DATA-DOG/go-sqlmock.v1"

"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
sqlmock "gopkg.in/DATA-DOG/go-sqlmock.v1"
)

func Example() {
Expand Down Expand Up @@ -109,7 +109,7 @@ func Example_advanced() {
HTTPGetCheck(upstreamURL, 500*time.Millisecond))

// Implement a custom check with a 50 millisecond timeout.
health.AddLivenessCheck("custom-check-with-timeout", Timeout(func() error {
health.AddLivenessCheck("custom-check-with-timeout", Timeout(func(ctx context.Context) error {
// Simulate some work that could take a long time
time.Sleep(time.Millisecond * 100)
return nil
Expand Down Expand Up @@ -146,12 +146,12 @@ func Example_metrics() {
health := NewMetricsHandler(registry, "example")

// Add a simple readiness check that always fails.
health.AddReadinessCheck("failing-check", func() error {
health.AddReadinessCheck("failing-check", func(ctx context.Context) error {
return fmt.Errorf("example failure")
})

// Add a liveness check that always succeeds
health.AddLivenessCheck("successful-check", func() error {
health.AddLivenessCheck("successful-check", func(ctx context.Context) error {
return nil
})

Expand Down
7 changes: 4 additions & 3 deletions handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package healthcheck

import (
"context"
"encoding/json"
"net/http"
"sync"
Expand Down Expand Up @@ -59,11 +60,11 @@ func (s *basicHandler) AddReadinessCheck(name string, check Check) {
s.readinessChecks[name] = check
}

func (s *basicHandler) collectChecks(checks map[string]Check, resultsOut map[string]string, statusOut *int) {
func (s *basicHandler) collectChecks(ctx context.Context, checks map[string]Check, resultsOut map[string]string, statusOut *int) {
s.checksMutex.RLock()
defer s.checksMutex.RUnlock()
for name, check := range checks {
if err := check(); err != nil {
if err := check(ctx); err != nil {
*statusOut = http.StatusServiceUnavailable
resultsOut[name] = err.Error()
} else {
Expand All @@ -81,7 +82,7 @@ func (s *basicHandler) handle(w http.ResponseWriter, r *http.Request, checks ...
checkResults := make(map[string]string)
status := http.StatusOK
for _, checks := range checks {
s.collectChecks(checks, checkResults, &status)
s.collectChecks(r.Context(), checks, checkResults, &status)
}

// write out the response code and content type header
Expand Down
Loading