From 28cfcd02340acbea46adff247017300d96db3a39 Mon Sep 17 00:00:00 2001 From: Carlo Alberto Ferraris Date: Thu, 7 Sep 2023 01:02:31 +0000 Subject: [PATCH] handle multiple accept-encoding headers --- accepts.go | 14 ++++++++------ adapter.go | 2 +- adapter_test.go | 42 +++++++++++++++++++++++++++++++++--------- 3 files changed, 42 insertions(+), 16 deletions(-) diff --git a/accepts.go b/accepts.go index 147086e..ccb8d4e 100644 --- a/accepts.go +++ b/accepts.go @@ -39,14 +39,16 @@ func acceptedCompression(accept codings, comps comps) []string { // Errors encountered during parsing the codings are ignored. // // See: http://tools.ietf.org/html/rfc2616#section-14.3. -func parseEncodings(s string) codings { +func parseEncodings(vv []string) codings { c := make(codings) - for _, ss := range strings.Split(s, ",") { - coding, qvalue := parseCoding(ss) - if coding == "" { - continue + for _, v := range vv { + for _, sv := range strings.Split(v, ",") { + coding, qvalue := parseCoding(sv) + if coding == "" { + continue + } + c[coding] = qvalue } - c[coding] = qvalue } return c } diff --git a/adapter.go b/adapter.go index 3da818b..19c10c7 100644 --- a/adapter.go +++ b/adapter.go @@ -71,7 +71,7 @@ func Adapter(opts ...Option) (func(http.Handler) http.Handler, error) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { addVaryHeader(w.Header(), acceptEncoding) - accept := parseEncodings(r.Header.Get(acceptEncoding)) + accept := parseEncodings(r.Header.Values(acceptEncoding)) common := acceptedCompression(accept, c.compressor) if len(common) == 0 { h.ServeHTTP(w, r) diff --git a/adapter_test.go b/adapter_test.go index a133cb1..1c67710 100644 --- a/adapter_test.go +++ b/adapter_test.go @@ -6,11 +6,11 @@ import ( "context" "fmt" "io" - "io/ioutil" "net" "net/http" "net/http/httptest" "net/url" + "os" "strconv" "testing" @@ -59,8 +59,32 @@ func TestParseEncodings(t *testing.T) { } for eg, exp := range examples { - act := parseEncodings(eg) - assert.Equal(t, exp, act) + t.Run(eg, func(t *testing.T) { + act := parseEncodings([]string{eg}) + assert.Equal(t, exp, act) + }) + } +} + +func TestParseEncodings2(t *testing.T) { + t.Parallel() + + cases := []struct { + accept []string + parsed codings + }{ + {[]string{"gzip"}, codings{"gzip": 1}}, + {[]string{"gzip,gzip"}, codings{"gzip": 1}}, + {[]string{"gzip,gzip;q=0.8"}, codings{"gzip": 0.8}}, + {[]string{"gzip,gzip;q=0"}, codings{"gzip": 0}}, + {[]string{"gzip,br"}, codings{"gzip": 1, "br": 1}}, + {[]string{"gzip", "br"}, codings{"gzip": 1, "br": 1}}, + {[]string{"gzip", "gzip;q=0,br"}, codings{"gzip": 0, "br": 1}}, + } + for i, c := range cases { + t.Run(fmt.Sprint(i), func(t *testing.T) { + assert.Equal(t, c.parsed, parseEncodings(c.accept)) + }) } } @@ -443,7 +467,7 @@ func TestGzipHandlerNoBody(t *testing.T) { req.Header.Set("Accept-Encoding", "gzip") handler.ServeHTTP(rec, req) - body, err := ioutil.ReadAll(rec.Body) + body, err := io.ReadAll(rec.Body) if err != nil { t.Fatalf("Unexpected error reading response body: %v", err) } @@ -514,7 +538,7 @@ func TestGzipHandlerContentLength(t *testing.T) { } defer res.Body.Close() - body, err := ioutil.ReadAll(res.Body) + body, err := io.ReadAll(res.Body) if err != nil { t.Fatalf("Unexpected error reading response body in test iteration %d: %v", num, err) } @@ -624,7 +648,7 @@ func TestGzipHandlerDoubleWriteHeader(t *testing.T) { } req.Header.Set("Accept-Encoding", "gzip") wrapper.ServeHTTP(rec, req) - body, err := ioutil.ReadAll(rec.Body) + body, err := io.ReadAll(rec.Body) if err != nil { t.Fatalf("Unexpected error reading response body: %v", err) } @@ -652,7 +676,7 @@ func TestGzipHandlerDoubleVary(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set("Accept-Encoding", "gzip") wrapper.ServeHTTP(rec, req) - body, err := ioutil.ReadAll(rec.Body) + body, err := io.ReadAll(rec.Body) if err != nil { t.Fatalf("Unexpected error reading response body: %v", err) } @@ -1206,7 +1230,7 @@ func zstdStrLevel(s string, lvl kpzstd.EncoderLevel) []byte { } func benchmark(b *testing.B, parallel bool, size int, ae string, d int) { - bin, err := ioutil.ReadFile("testdata/benchmark.json") + bin, err := os.ReadFile("testdata/benchmark.json") if err != nil { b.Fatal(err) } @@ -1321,5 +1345,5 @@ func decodeGzip(i io.Reader) ([]byte, error) { if err != nil { return nil, err } - return ioutil.ReadAll(r) + return io.ReadAll(r) }