diff --git a/api/api.go b/api/api.go index 0c9fcba..b6bc3b3 100644 --- a/api/api.go +++ b/api/api.go @@ -25,11 +25,12 @@ import ( ) type Config struct { - Logger *logrus.Entry - HealthHandler echo.HandlerFunc - CorsOrigins []string - HealthResponse map[string]interface{} - StatusResponse map[string]interface{} + Logger *logrus.Entry + LoggingMiddlwareConfig LoggingMiddlwareConfig + HealthHandler echo.HandlerFunc + CorsOrigins []string + HealthResponse map[string]interface{} + StatusResponse map[string]interface{} } func New(cfg Config) *echo.Echo { @@ -54,7 +55,7 @@ func New(cfg Config) *echo.Echo { e.Logger.SetOutput(os.Stdout) e.HideBanner = true e.HTTPErrorHandler = NewHTTPErrorHandler(e) - e.Use(LoggingMiddleware(cfg.Logger)) + e.Use(LoggingMiddlewareWithConfig(cfg.Logger, cfg.LoggingMiddlwareConfig)) if cfg.CorsOrigins != nil { e.Use(middleware.CORSWithConfig(middleware.CORSConfig{ AllowOrigins: cfg.CorsOrigins, diff --git a/api/api_test.go b/api/api_test.go index 3f03242..468ff00 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -152,6 +152,52 @@ var _ = Describe("API", func() { Expect(logHook.Entries).To(HaveLen(1)) Expect(logHook.Entries[0].Level).To(Equal(logrus.DebugLevel)) }) + It("can log request and response headers", func() { + e = api.New(api.Config{ + Logger: logEntry, + LoggingMiddlwareConfig: api.LoggingMiddlwareConfig{ + RequestHeaders: true, + ResponseHeaders: true, + }, + }) + e.GET("/", func(c echo.Context) error { + c.Response().Header().Set("ResHead", "ResHeadVal") + return c.String(200, "ok") + }) + Expect(Serve(e, GetRequest("/", SetReqHeader("ReqHead", "ReqHeadVal")))).To(HaveResponseCode(200)) + Expect(logHook.Entries).To(HaveLen(1)) + Expect(logHook.Entries[0].Data).To(And( + HaveKeyWithValue("request_header.Reqhead", "ReqHeadVal"), + HaveKeyWithValue("response_header.Reshead", "ResHeadVal"), + )) + }) + It("can use custom DoLog, BeforeRequest, and AfterRequest hooks", func() { + doLogCalled := false + e = api.New(api.Config{ + Logger: logEntry, + LoggingMiddlwareConfig: api.LoggingMiddlwareConfig{ + BeforeRequest: func(_ echo.Context, e *logrus.Entry) *logrus.Entry { + return e.WithField("before", 1) + }, + AfterRequest: func(_ echo.Context, e *logrus.Entry) *logrus.Entry { + return e.WithField("after", 2) + }, + DoLog: func(c echo.Context, e *logrus.Entry) { + doLogCalled = true + api.LoggingMiddlewareDefaultDoLog(c, e) + }, + }, + }) + e.GET("/", func(c echo.Context) error { + return c.String(400, "") + }) + Expect(Serve(e, GetRequest("/"))).To(HaveResponseCode(400)) + Expect(doLogCalled).To(BeTrue()) + Expect(logHook.Entries[len(logHook.Entries)-1].Data).To(And( + HaveKeyWithValue("before", 1), + HaveKeyWithValue("after", 2), + )) + }) }) Describe("error handling", func() { @@ -300,5 +346,20 @@ var _ = Describe("API", func() { HaveKeyWithValue("debug_response_body", ContainSubstring("ok")), )) }) + It("can print memory stats every n requests", func() { + e.Use(api.DebugMiddleware(api.DebugMiddlewareConfig{Enabled: true, DumpMemoryEvery: 2})) + e.GET("/endpoint", func(c echo.Context) error { + return c.String(200, "ok") + }) + Serve(e, NewRequest("GET", "/endpoint", nil, SetReqHeader("Foo", "x"))) + Serve(e, NewRequest("GET", "/endpoint", nil, SetReqHeader("Foo", "x"))) + Expect(logHook.Entries).To(HaveLen(4)) + Expect(logHook.Entries[0].Message).To(Equal("request_debug")) + Expect(logHook.Entries[0].Data).ToNot(HaveKey("memory_sys")) + Expect(logHook.Entries[1].Message).To(Equal("request_finished")) + Expect(logHook.Entries[2].Message).To(Equal("request_debug")) + Expect(logHook.Entries[2].Data).To(HaveKey("memory_sys")) + Expect(logHook.Entries[3].Message).To(Equal("request_finished")) + }) }) }) diff --git a/api/debug.go b/api/debug.go index 29c469b..312517e 100644 --- a/api/debug.go +++ b/api/debug.go @@ -5,6 +5,8 @@ import ( "github.com/labstack/echo/middleware" "github.com/lithictech/go-aperitif/logctx" "net/http" + "runtime" + "sync/atomic" ) type DebugMiddlewareConfig struct { @@ -14,6 +16,9 @@ type DebugMiddlewareConfig struct { DumpRequestHeaders bool DumpResponseHeaders bool DumpAll bool + // Log out memory stats every 'n' requests. + // If <= 0, do not log them. + DumpMemoryEvery int } func DebugMiddleware(cfg DebugMiddlewareConfig) echo.MiddlewareFunc { @@ -30,7 +35,10 @@ func DebugMiddleware(cfg DebugMiddlewareConfig) echo.MiddlewareFunc { cfg.DumpResponseHeaders = true cfg.DumpResponseBody = true } + var requestCounter uint64 + dumpEveryUint := uint64(cfg.DumpMemoryEvery) bd := middleware.BodyDump(func(c echo.Context, reqBody []byte, resBody []byte) { + atomic.AddUint64(&requestCounter, 1) log := logctx.Logger(StdContext(c)) if cfg.DumpRequestBody { log = log.WithField("debug_request_body", string(reqBody)) @@ -44,6 +52,30 @@ func DebugMiddleware(cfg DebugMiddlewareConfig) echo.MiddlewareFunc { if cfg.DumpResponseHeaders { log = log.WithField("debug_response_headers", headerToMap(c.Response().Header())) } + if cfg.DumpMemoryEvery > 0 && (requestCounter%dumpEveryUint) == 0 { + var ms runtime.MemStats + runtime.ReadMemStats(&ms) + log = log.WithFields(map[string]interface{}{ + "memory_alloc": ms.Alloc, + "memory_total_alloc": ms.TotalAlloc, + "memory_sys": ms.Sys, + "memory_mallocs": ms.Mallocs, + "memory_frees": ms.Frees, + "memory_heap_alloc": ms.HeapAlloc, + "memory_heap_sys": ms.HeapSys, + "memory_heap_idle": ms.HeapIdle, + "memory_heap_inuse": ms.HeapInuse, + "memory_heap_released": ms.HeapReleased, + "memory_heap_objects": ms.HeapObjects, + "memory_stack_inuse": ms.StackInuse, + "memory_stack_sys": ms.StackSys, + "memory_other_sys": ms.OtherSys, + "memory_next_gc": ms.NextGC, + "memory_last_gc": ms.LastGC, + "memory_pause_total_ns": ms.PauseTotalNs, + "memory_num_gc": ms.NumGC, + }) + } log.Debug("request_debug") }) return bd diff --git a/api/logging.go b/api/logging.go index f6fdd2c..65ac53b 100644 --- a/api/logging.go +++ b/api/logging.go @@ -29,7 +29,31 @@ func SetLogger(c echo.Context, logger *logrus.Entry) { c.Set(logctx.LoggerKey, logger) } +type LoggingMiddlwareConfig struct { + // If true, log request headers. + RequestHeaders bool + // If true, log response headers. + ResponseHeaders bool + // If provided, the returned logger is stored in the context + // which is eventually passed to the handler. + // Use to add additional fields to the logger based on the request. + BeforeRequest func(echo.Context, *logrus.Entry) *logrus.Entry + // If provided, the returned logger is used for response logging. + // Use to add additional fields to the logger based on the request or response. + AfterRequest func(echo.Context, *logrus.Entry) *logrus.Entry + // The function that does the actual logging. + // By default, it will log at a certain level based on the status code of the response. + DoLog func(echo.Context, *logrus.Entry) +} + func LoggingMiddleware(outerLogger *logrus.Entry) echo.MiddlewareFunc { + return LoggingMiddlewareWithConfig(outerLogger, LoggingMiddlwareConfig{}) +} + +func LoggingMiddlewareWithConfig(outerLogger *logrus.Entry, cfg LoggingMiddlwareConfig) echo.MiddlewareFunc { + if cfg.DoLog == nil { + cfg.DoLog = LoggingMiddlewareDefaultDoLog + } return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { start := time.Now() @@ -57,11 +81,16 @@ func LoggingMiddleware(outerLogger *logrus.Entry) echo.MiddlewareFunc { "request_bytes_in": bytesIn, string(logctx.RequestTraceIdKey): TraceId(c), }) - //for k, v := range req.Header { - // if len(v) > 0 && k != "Authorization" && k != "Cookie" { - // logger = logger.WithField("header."+k, v[0]) - // } - //} + if cfg.RequestHeaders { + for k, v := range req.Header { + if len(v) > 0 && k != "Authorization" && k != "Cookie" { + logger = logger.WithField("request_header."+k, v[0]) + } + } + } + if cfg.BeforeRequest != nil { + logger = cfg.BeforeRequest(c, logger) + } SetLogger(c, logger) @@ -80,28 +109,42 @@ func LoggingMiddleware(outerLogger *logrus.Entry) echo.MiddlewareFunc { "request_latency_ms": int(stop.Sub(start)) / 1000 / 1000, "request_bytes_out": strconv.FormatInt(res.Size, 10), }) + if cfg.ResponseHeaders { + for k, v := range res.Header() { + if len(v) > 0 && k != "Set-Cookie" { + logger = logger.WithField("response_header."+k, v[0]) + } + } + } if err != nil { logger = logger.WithField("request_error", err) } - - logMethod := logger.Info - if req.Method == http.MethodOptions { - logMethod = logger.Debug - } else if res.Status >= 500 { - logMethod = logger.Error - } else if res.Status >= 400 { - logMethod = logger.Warn - } else if req.URL.Path == HealthPath || req.URL.Path == StatusPath { - logMethod = logger.Debug + if cfg.BeforeRequest != nil { + logger = cfg.AfterRequest(c, logger) } - logMethod("request_finished") - + cfg.DoLog(c, logger) // c.Error is already called return nil } } } +func LoggingMiddlewareDefaultDoLog(c echo.Context, logger *logrus.Entry) { + req := c.Request() + res := c.Response() + logMethod := logger.Info + if req.Method == http.MethodOptions { + logMethod = logger.Debug + } else if res.Status >= 500 { + logMethod = logger.Error + } else if res.Status >= 400 { + logMethod = logger.Warn + } else if req.URL.Path == HealthPath || req.URL.Path == StatusPath { + logMethod = logger.Debug + } + logMethod("request_finished") +} + // Invoke next(c) within a function wrapped with defer, // so that if it panics, we can recover from it and pass on a 500. // Use the "named return parameter can be set in defer" trick so we can diff --git a/logctx/logctx.go b/logctx/logctx.go index 8b7890f..545ac42 100644 --- a/logctx/logctx.go +++ b/logctx/logctx.go @@ -67,6 +67,8 @@ func Logger(c context.Context) *logrus.Entry { return logger } +// ActiveTraceId returns the first valid trace value and type from the given context, +// or MissingTraceIdKey if there is none. func ActiveTraceId(c context.Context) (TraceIdKey, string) { if trace, ok := c.Value(RequestTraceIdKey).(string); ok { return RequestTraceIdKey, trace @@ -80,6 +82,12 @@ func ActiveTraceId(c context.Context) (TraceIdKey, string) { return MissingTraceIdKey, "no-trace-id-in-context" } +// ActiveTraceIdValue returns the value part of ActiveTraceId (does not return the TradeIdKey type part). +func ActiveTraceIdValue(c context.Context) string { + _, v := ActiveTraceId(c) + return v +} + func AddFieldsAndGet(c context.Context, fields map[string]interface{}) (context.Context, *logrus.Entry) { logger := Logger(c) logger = logger.WithFields(fields) diff --git a/logctx/logctx_test.go b/logctx/logctx_test.go index 29ab9e8..402a32c 100644 --- a/logctx/logctx_test.go +++ b/logctx/logctx_test.go @@ -30,6 +30,7 @@ var _ = Describe("logtools", func() { key, val := logctx.ActiveTraceId(c) Expect(key).To(Equal(logctx.RequestTraceIdKey)) Expect(val).To(Equal("abc")) + Expect(logctx.ActiveTraceIdValue(c)).To(Equal("abc")) }) It("returns a process trace id", func() { c := context.WithValue(bg, logctx.ProcessTraceIdKey, "abc") diff --git a/parallel/parallel.go b/parallel/parallel.go index 71ed4cf..47a48c0 100644 --- a/parallel/parallel.go +++ b/parallel/parallel.go @@ -1,11 +1,14 @@ package parallel import ( + "errors" "github.com/hashicorp/go-multierror" "github.com/lithictech/go-aperitif/mariobros" "sync" ) +var ErrInvalidParallelism = errors.New("degree of parallelism must be > 0") + type empty struct{} type Processor func(idx int) error @@ -22,8 +25,12 @@ type Processor func(idx int) error // and assign to the slice index while processing. // See ParallelForFiles for an example usage. func ForEach(total int, n int, process Processor) error { + if n <= 0 { + return ErrInvalidParallelism + } + semaphore := make(chan empty, n) - errors := make([]error, total) + errs := make([]error, total) wg := sync.WaitGroup{} wg.Add(total) @@ -32,11 +39,11 @@ func ForEach(total int, n int, process Processor) error { mario := mariobros.Yo("parallel.foreach") defer mario() semaphore <- empty{} - errors[i] = process(i) + errs[i] = process(i) <-semaphore wg.Done() }(i) } wg.Wait() - return multierror.Append(nil, errors...).ErrorOrNil() + return multierror.Append(nil, errs...).ErrorOrNil() } diff --git a/parallel/parallel_test.go b/parallel/parallel_test.go index 2ae38df..115e611 100644 --- a/parallel/parallel_test.go +++ b/parallel/parallel_test.go @@ -37,4 +37,8 @@ var _ = Describe("ParallelFor", func() { Expect(called).To(Equal(1000)) Expect(active).To(Equal(0)) }) + It("errors for 0 or negative n", func() { + err := parallel.ForEach(1, 0, nil) + Expect(err).To(BeIdenticalTo(parallel.ErrInvalidParallelism)) + }) })