diff --git a/decoder.go b/decoder.go index 5df4e2a..28ff803 100644 --- a/decoder.go +++ b/decoder.go @@ -152,11 +152,15 @@ func (d *Decoder) DecodeBytes() ([]byte, error) { if n == 0 { return nil, ErrInvalidVarintData } - if d.offset+n+int(l) > len(d.p) { + nb := int(l) + if nb < 0 { + return nil, fmt.Errorf("csproto: bad byte length %d", nb) + } + if d.offset+n+nb > len(d.p) { return nil, io.ErrUnexpectedEOF } - b := d.p[d.offset+n : d.offset+n+int(l)] - d.offset += n + int(l) + b := d.p[d.offset+n : d.offset+n+nb] + d.offset += n + nb return b, nil } @@ -900,7 +904,7 @@ func (d *Decoder) Skip(tag int, wt WireType) ([]byte, error) { default: return nil, fmt.Errorf("unsupported wire type value: %v", wt) } - if d.offset+skipped >= len(d.p) { + if d.offset+skipped > len(d.p) { return nil, io.ErrUnexpectedEOF } d.offset += skipped diff --git a/decoder_test.go b/decoder_test.go index 8aeceb2..6fb1ee9 100644 --- a/decoder_test.go +++ b/decoder_test.go @@ -110,14 +110,6 @@ func TestDecodeBytes(t *testing.T) { wt: csproto.WireTypeLengthDelimited, expected: []byte{0x42, 0x11, 0x38}, }, - { - name: "invalid data", - fieldNum: 2, - v: []byte{0x12, 0x3, 0x42, 0x11}, // field data is truncated - wt: csproto.WireTypeLengthDelimited, - expected: []byte{0x42, 0x11, 0x38}, - expectedErr: io.ErrUnexpectedEOF, - }, } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { @@ -136,6 +128,28 @@ func TestDecodeBytes(t *testing.T) { } }) } + + // separate t.Run() because we're not checking for sentinel errors + t.Run("corrupt messages", func(t *testing.T) { + t.Run("truncated data", func(t *testing.T) { + // length-delimited field value with a length of 3 but only 2 bytes + data := []byte{0x12, 0x3, 0x42, 0x11} + dec := csproto.NewDecoder(data) + + got, err := dec.DecodeBytes() + assert.Error(t, err) + assert.Nil(t, got) + }) + t.Run("negative length", func(t *testing.T) { + // length-delimited field value with a length of -50 + data := []byte{0xCE, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x01, 0x3, 0x42, 0x11, 0x38} + dec := csproto.NewDecoder(data) + + got, err := dec.DecodeBytes() + assert.Error(t, err) + assert.Nil(t, got) + }) + }) } func TestDecodeUInt32(t *testing.T) {