From 6adbe4050b5ce68567a2386b06078feddaaf2ce3 Mon Sep 17 00:00:00 2001 From: Kevin Wang Date: Sat, 30 Dec 2023 18:19:58 -0500 Subject: [PATCH] feat: add seekable buffer (#7) --- pkg/btree/buffer.go | 53 ++++++++++++++++++++++ pkg/btree/buffer_test.go | 96 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 149 insertions(+) create mode 100644 pkg/btree/buffer.go create mode 100644 pkg/btree/buffer_test.go diff --git a/pkg/btree/buffer.go b/pkg/btree/buffer.go new file mode 100644 index 00000000..c468a7b1 --- /dev/null +++ b/pkg/btree/buffer.go @@ -0,0 +1,53 @@ +package btree + +import "io" + +// seekableBuffer is a buffer that can be seeked into. +// this replicates the behavior of a file on disk without having to write to disk +// which is useful for testing. +type seekableBuffer struct { + buf []byte + pos int +} + +func newSeekableBuffer() *seekableBuffer { + return &seekableBuffer{} +} + +func (b *seekableBuffer) Write(p []byte) (int, error) { + n := copy(b.buf[b.pos:], p) + if n < len(p) { + b.buf = append(b.buf, p[n:]...) + } + b.pos += len(p) + return len(p), nil +} + +func (b *seekableBuffer) Seek(offset int64, whence int) (int64, error) { + switch whence { + case io.SeekStart: + b.pos = int(offset) + case io.SeekCurrent: + b.pos += int(offset) + case io.SeekEnd: + b.pos = len(b.buf) + int(offset) + } + if b.pos < 0 { + b.pos = 0 + } + if b.pos > len(b.buf) { + b.pos = len(b.buf) + } + return int64(b.pos), nil +} + +func (b *seekableBuffer) Read(p []byte) (int, error) { + if b.pos >= len(b.buf) { + return 0, io.EOF + } + n := copy(p, b.buf[b.pos:]) + b.pos += n + return n, nil +} + +var _ io.ReadWriteSeeker = &seekableBuffer{} diff --git a/pkg/btree/buffer_test.go b/pkg/btree/buffer_test.go new file mode 100644 index 00000000..ca8f5d40 --- /dev/null +++ b/pkg/btree/buffer_test.go @@ -0,0 +1,96 @@ +package btree + +import ( + "io" + "testing" +) + +func TestSeekableBuffer(t *testing.T) { + t.Run("Write", func(t *testing.T) { + b := newSeekableBuffer() + n, err := b.Write([]byte("hello")) + if err != nil { + t.Fatal(err) + } + if n != 5 { + t.Fatalf("expected to write 5 bytes, wrote %d", n) + } + if string(b.buf) != "hello" { + t.Fatalf("expected to write 'hello', wrote %s", string(b.buf)) + } + }) + + t.Run("write to end", func(t *testing.T) { + b := newSeekableBuffer() + if _, err := b.Write([]byte("hello")); err != nil { + t.Fatal(err) + } + if _, err := b.Seek(-2, io.SeekEnd); err != nil { + t.Fatal(err) + } + if _, err := b.Write([]byte("world")); err != nil { + t.Fatal(err) + } + if string(b.buf) != "helworld" { + t.Fatalf("expected to write 'helworld', wrote %s", string(b.buf)) + } + }) + + t.Run("Seek", func(t *testing.T) { + b := newSeekableBuffer() + if _, err := b.Write([]byte("helloo")); err != nil { + t.Fatal(err) + } + if _, err := b.Seek(0, io.SeekStart); err != nil { + t.Fatal(err) + } + if _, err := b.Write([]byte("world")); err != nil { + t.Fatal(err) + } + if string(b.buf) != "worldo" { + t.Fatalf("expected to write 'worldo', wrote %s", string(b.buf)) + } + }) + + t.Run("Read", func(t *testing.T) { + b := newSeekableBuffer() + if _, err := b.Write([]byte("hello")); err != nil { + t.Fatal(err) + } + if _, err := b.Seek(0, io.SeekStart); err != nil { + t.Fatal(err) + } + buf := make([]byte, 5) + n, err := b.Read(buf) + if err != nil { + t.Fatal(err) + } + if n != 5 { + t.Fatalf("expected to read 5 bytes, read %d", n) + } + if string(buf) != "hello" { + t.Fatalf("expected to read 'hello', read %s", string(buf)) + } + }) + + t.Run("read from middle", func(t *testing.T) { + b := newSeekableBuffer() + if _, err := b.Write([]byte("hello")); err != nil { + t.Fatal(err) + } + if _, err := b.Seek(2, io.SeekStart); err != nil { + t.Fatal(err) + } + buf := make([]byte, 3) + n, err := b.Read(buf) + if err != nil { + t.Fatal(err) + } + if n != 3 { + t.Fatalf("expected to read 3 bytes, read %d", n) + } + if string(buf) != "llo" { + t.Fatalf("expected to read 'llo', read %s", string(buf)) + } + }) +}