Skip to content

Commit

Permalink
handle multiple accept-encoding headers
Browse files Browse the repository at this point in the history
  • Loading branch information
CAFxX committed Sep 7, 2023
1 parent e701083 commit 28cfcd0
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 16 deletions.
14 changes: 8 additions & 6 deletions accepts.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,16 @@ func acceptedCompression(accept codings, comps comps) []string {
// Errors encountered during parsing the codings are ignored.
//
// See: http://tools.ietf.org/html/rfc2616#section-14.3.
func parseEncodings(s string) codings {
func parseEncodings(vv []string) codings {
c := make(codings)
for _, ss := range strings.Split(s, ",") {
coding, qvalue := parseCoding(ss)
if coding == "" {
continue
for _, v := range vv {
for _, sv := range strings.Split(v, ",") {
coding, qvalue := parseCoding(sv)
if coding == "" {
continue
}
c[coding] = qvalue
}
c[coding] = qvalue
}
return c
}
Expand Down
2 changes: 1 addition & 1 deletion adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ func Adapter(opts ...Option) (func(http.Handler) http.Handler, error) {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
addVaryHeader(w.Header(), acceptEncoding)

accept := parseEncodings(r.Header.Get(acceptEncoding))
accept := parseEncodings(r.Header.Values(acceptEncoding))
common := acceptedCompression(accept, c.compressor)
if len(common) == 0 {
h.ServeHTTP(w, r)
Expand Down
42 changes: 33 additions & 9 deletions adapter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@ import (
"context"
"fmt"
"io"
"io/ioutil"
"net"
"net/http"
"net/http/httptest"
"net/url"
"os"
"strconv"
"testing"

Expand Down Expand Up @@ -59,8 +59,32 @@ func TestParseEncodings(t *testing.T) {
}

for eg, exp := range examples {
act := parseEncodings(eg)
assert.Equal(t, exp, act)
t.Run(eg, func(t *testing.T) {
act := parseEncodings([]string{eg})
assert.Equal(t, exp, act)
})
}
}

func TestParseEncodings2(t *testing.T) {
t.Parallel()

cases := []struct {
accept []string
parsed codings
}{
{[]string{"gzip"}, codings{"gzip": 1}},
{[]string{"gzip,gzip"}, codings{"gzip": 1}},
{[]string{"gzip,gzip;q=0.8"}, codings{"gzip": 0.8}},
{[]string{"gzip,gzip;q=0"}, codings{"gzip": 0}},
{[]string{"gzip,br"}, codings{"gzip": 1, "br": 1}},
{[]string{"gzip", "br"}, codings{"gzip": 1, "br": 1}},
{[]string{"gzip", "gzip;q=0,br"}, codings{"gzip": 0, "br": 1}},
}
for i, c := range cases {
t.Run(fmt.Sprint(i), func(t *testing.T) {
assert.Equal(t, c.parsed, parseEncodings(c.accept))
})
}
}

Expand Down Expand Up @@ -443,7 +467,7 @@ func TestGzipHandlerNoBody(t *testing.T) {
req.Header.Set("Accept-Encoding", "gzip")
handler.ServeHTTP(rec, req)

body, err := ioutil.ReadAll(rec.Body)
body, err := io.ReadAll(rec.Body)
if err != nil {
t.Fatalf("Unexpected error reading response body: %v", err)
}
Expand Down Expand Up @@ -514,7 +538,7 @@ func TestGzipHandlerContentLength(t *testing.T) {
}
defer res.Body.Close()

body, err := ioutil.ReadAll(res.Body)
body, err := io.ReadAll(res.Body)
if err != nil {
t.Fatalf("Unexpected error reading response body in test iteration %d: %v", num, err)
}
Expand Down Expand Up @@ -624,7 +648,7 @@ func TestGzipHandlerDoubleWriteHeader(t *testing.T) {
}
req.Header.Set("Accept-Encoding", "gzip")
wrapper.ServeHTTP(rec, req)
body, err := ioutil.ReadAll(rec.Body)
body, err := io.ReadAll(rec.Body)
if err != nil {
t.Fatalf("Unexpected error reading response body: %v", err)
}
Expand Down Expand Up @@ -652,7 +676,7 @@ func TestGzipHandlerDoubleVary(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("Accept-Encoding", "gzip")
wrapper.ServeHTTP(rec, req)
body, err := ioutil.ReadAll(rec.Body)
body, err := io.ReadAll(rec.Body)
if err != nil {
t.Fatalf("Unexpected error reading response body: %v", err)
}
Expand Down Expand Up @@ -1206,7 +1230,7 @@ func zstdStrLevel(s string, lvl kpzstd.EncoderLevel) []byte {
}

func benchmark(b *testing.B, parallel bool, size int, ae string, d int) {
bin, err := ioutil.ReadFile("testdata/benchmark.json")
bin, err := os.ReadFile("testdata/benchmark.json")
if err != nil {
b.Fatal(err)
}
Expand Down Expand Up @@ -1321,5 +1345,5 @@ func decodeGzip(i io.Reader) ([]byte, error) {
if err != nil {
return nil, err
}
return ioutil.ReadAll(r)
return io.ReadAll(r)
}

0 comments on commit 28cfcd0

Please sign in to comment.