Skip to content

Commit

Permalink
Improve long running streamed response handling
Browse files Browse the repository at this point in the history
Signed-off-by: Alex Ellis (OpenFaaS Ltd) <[email protected]>
  • Loading branch information
alexellis committed Jan 11, 2024
1 parent f91e762 commit b4254fe
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 36 deletions.
16 changes: 14 additions & 2 deletions httputil/write_interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,17 @@ import (
)

func NewHttpWriteInterceptor(w http.ResponseWriter) *HttpWriteInterceptor {
return &HttpWriteInterceptor{w, 0}
return &HttpWriteInterceptor{
ResponseWriter: w,
statusCode: 0,
bytesWritten: 0,
}
}

type HttpWriteInterceptor struct {
http.ResponseWriter
statusCode int
statusCode int
bytesWritten int64
}

func (c *HttpWriteInterceptor) Status() int {
Expand All @@ -22,6 +27,10 @@ func (c *HttpWriteInterceptor) Status() int {
return c.statusCode
}

func (c *HttpWriteInterceptor) BytesWritten() int64 {
return c.bytesWritten
}

func (c *HttpWriteInterceptor) Header() http.Header {
return c.ResponseWriter.Header()
}
Expand All @@ -30,6 +39,9 @@ func (c *HttpWriteInterceptor) Write(data []byte) (int, error) {
if c.statusCode == 0 {
c.WriteHeader(http.StatusOK)
}

c.bytesWritten += int64(len(data))

return c.ResponseWriter.Write(data)
}

Expand Down
48 changes: 29 additions & 19 deletions httputil/write_interceptor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,32 +6,42 @@ import (
"testing"
)

func Test_StatusIsRecorded(t *testing.T) {
wantCode := http.StatusAccepted
gotCode := 0
func Test_WriteCountsBytes(t *testing.T) {
w := httptest.NewRecorder()
wi := NewHttpWriteInterceptor(w)

next := func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(wantCode)
writeStr := "hello world"
wi.Write([]byte(writeStr))

want := int64(len(writeStr))
got := wi.BytesWritten()
if got != want {
t.Errorf("want bytes: %d, got %d", want, got)
}
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ww := NewHttpWriteInterceptor(w)
next(ww, r)
gotCode = ww.Status()
}))
}

defer func() {
s.Close()
}()
func Test_WriteGetsStatusCode(t *testing.T) {
w := httptest.NewRecorder()
wi := NewHttpWriteInterceptor(w)

req, _ := http.NewRequest("GET", s.URL, nil)
_, err := http.DefaultClient.Do(req)
wi.WriteHeader(http.StatusTeapot)

if err != nil {
t.Fatalf("Error doing request: %v", err)
want := http.StatusTeapot
got := wi.Status()
if got != want {
t.Errorf("want status code: %d, got %d", want, got)
}
}

func Test_WriteGetsStatusCode_WithoutWriteHeader(t *testing.T) {
w := httptest.NewRecorder()
wi := NewHttpWriteInterceptor(w)

wi.Write([]byte("hello world"))

if gotCode != wantCode {
t.Errorf("got code %d, want %d", gotCode, wantCode)
want := http.StatusOK
got := wi.Status()
if got != want {
t.Errorf("want default status code: %d, got %d", want, got)
}
}
4 changes: 2 additions & 2 deletions httputil/writer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package httputil

import (
"fmt"
"io/ioutil"
"io"
"log"
"net/http"
"net/http/httptest"
Expand All @@ -24,7 +24,7 @@ func TestFirstWriteSetsStatusCode(t *testing.T) {
t.Fatalf("incorrect status code in the original response object: %d", w.Result().StatusCode)
}

out, _ := ioutil.ReadAll(w.Result().Body)
out, _ := io.ReadAll(w.Result().Body)
if string(out) != `{"value": "ok"}` {
t.Fatalf("incorrect response content: %q", out)
}
Expand Down
42 changes: 29 additions & 13 deletions proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,12 @@ import (
"log"
"net"
"net/http"
"net/http/httputil"
"net/url"
"time"

"github.com/gorilla/mux"
"github.com/openfaas/faas-provider/httputil"
fhttputil "github.com/openfaas/faas-provider/httputil"
"github.com/openfaas/faas-provider/types"
)

Expand Down Expand Up @@ -68,6 +69,14 @@ func NewHandlerFunc(config types.FaaSConfig, resolver BaseURLResolver, verbose b

proxyClient := NewProxyClientFromConfig(config)

reverseProxy := httputil.ReverseProxy{}
reverseProxy.Director = func(req *http.Request) {
// At least an empty director is required to prevent runtime errors.
}

// Errors are common during disconnect of client, no need to log them.
reverseProxy.ErrorLog = log.New(io.Discard, "", 0)

return func(w http.ResponseWriter, r *http.Request) {
if r.Body != nil {
defer r.Body.Close()
Expand All @@ -81,7 +90,7 @@ func NewHandlerFunc(config types.FaaSConfig, resolver BaseURLResolver, verbose b
http.MethodGet,
http.MethodOptions,
http.MethodHead:
proxyRequest(w, r, proxyClient, resolver, verbose)
proxyRequest(w, r, proxyClient, resolver, &reverseProxy, verbose)

default:
w.WriteHeader(http.StatusMethodNotAllowed)
Expand Down Expand Up @@ -134,15 +143,15 @@ func NewProxyClient(timeout time.Duration, maxIdleConns int, maxIdleConnsPerHost
}

// proxyRequest handles the actual resolution of and then request to the function service.
func proxyRequest(w http.ResponseWriter, originalReq *http.Request, proxyClient *http.Client, resolver BaseURLResolver, verbose bool) {
func proxyRequest(w http.ResponseWriter, originalReq *http.Request, proxyClient *http.Client, resolver BaseURLResolver, reverseProxy *httputil.ReverseProxy, verbose bool) {
ctx := originalReq.Context()

pathVars := mux.Vars(originalReq)
functionName := pathVars["name"]
if functionName == "" {
w.Header().Add(openFaaSInternalHeader, "proxy")

httputil.Errorf(w, http.StatusBadRequest, "Provide function name in the request path")
fhttputil.Errorf(w, http.StatusBadRequest, "Provide function name in the request path")
return
}

Expand All @@ -152,7 +161,7 @@ func proxyRequest(w http.ResponseWriter, originalReq *http.Request, proxyClient

// TODO: Should record the 404/not found error in Prometheus.
log.Printf("resolver error: no endpoints for %s: %s\n", functionName, err.Error())
httputil.Errorf(w, http.StatusServiceUnavailable, "No endpoints available for: %s.", functionName)
fhttputil.Errorf(w, http.StatusServiceUnavailable, "No endpoints available for: %s.", functionName)
return
}

Expand All @@ -161,35 +170,42 @@ func proxyRequest(w http.ResponseWriter, originalReq *http.Request, proxyClient

w.Header().Add(openFaaSInternalHeader, "proxy")

httputil.Errorf(w, http.StatusInternalServerError, "Failed to resolve service: %s.", functionName)
fhttputil.Errorf(w, http.StatusInternalServerError, "Failed to resolve service: %s.", functionName)
return
}

if proxyReq.Body != nil {
defer proxyReq.Body.Close()
}

start := time.Now()
if verbose {
start := time.Now()
defer func() {
seconds := time.Since(start)
log.Printf("%s took %f seconds\n", functionName, seconds.Seconds())
}()
}

if v := originalReq.Header.Get("Accept"); v == "text/event-stream" {
reverseProxy.ServeHTTP(w, proxyReq)
return
}

response, err := proxyClient.Do(proxyReq.WithContext(ctx))
seconds := time.Since(start)

if err != nil {
log.Printf("error with proxy request to: %s, %s\n", proxyReq.URL.String(), err.Error())

w.Header().Add(openFaaSInternalHeader, "proxy")

httputil.Errorf(w, http.StatusInternalServerError, "Can't reach service for: %s.", functionName)
fhttputil.Errorf(w, http.StatusInternalServerError, "Can't reach service for: %s.", functionName)
return
}

if response.Body != nil {
defer response.Body.Close()
}

if verbose {
log.Printf("%s took %f seconds\n", functionName, seconds.Seconds())
}

clientHeader := w.Header()
copyHeaders(clientHeader, &response.Header)
w.Header().Set("Content-Type", getContentType(originalReq.Header, response.Header))
Expand Down

0 comments on commit b4254fe

Please sign in to comment.