Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
friendlymatthew committed Jun 28, 2024
1 parent 0063103 commit d9e740c
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 39 deletions.
35 changes: 18 additions & 17 deletions pkg/btree/btree.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,10 @@ func (t *BTree) Insert(key pointer.ReferencedId, vector hnsw.Point) error {
}

if root == nil {
node := &BTreeNode{Width: t.Width}
node.Keys = []pointer.ReferencedId{key}
node := &BTreeNode{Width: t.Width, VectorDim: t.VectorDim}
node.Ids = []pointer.ReferencedId{key}
node.Vectors = []hnsw.Point{vector}
node.Offsets = make([]uint64, 0)

buf, err := node.MarshalBinary()
if err != nil {
Expand All @@ -76,7 +77,7 @@ func (t *BTree) Insert(key pointer.ReferencedId, vector hnsw.Point) error {

parent, parentOffset := root, rootOffset.Offset
for !parent.Leaf() {
index, found := slices.BinarySearchFunc(parent.Keys, key, pointer.CompareReferencedIds)
index, found := slices.BinarySearchFunc(parent.Ids, key, pointer.CompareReferencedIds)

if found {
panic("cannot insert duplicate key")
Expand All @@ -90,16 +91,16 @@ func (t *BTree) Insert(key pointer.ReferencedId, vector hnsw.Point) error {

if int(child.Size()) > t.PageFile.PageSize() {
// split node here
mid := len(child.Keys) / 2
midKey := child.Keys[mid]
mid := len(child.Ids) / 2
midKey := child.Ids[mid]

rightChild := &BTreeNode{Width: t.Width}
rightChild := &BTreeNode{Width: t.Width, VectorDim: t.VectorDim}
if !child.Leaf() {
rightChild.Offsets = child.Offsets[mid+1:]
child.Offsets = child.Offsets[:mid]
}
rightChild.Vectors = child.Vectors[mid+1:]
rightChild.Keys = child.Keys[mid+1:]
rightChild.Ids = child.Ids[mid+1:]

rbuf, err := rightChild.MarshalBinary()
if err != nil {
Expand All @@ -111,7 +112,7 @@ func (t *BTree) Insert(key pointer.ReferencedId, vector hnsw.Point) error {
}

// shrink left child (child)
child.Keys = child.Keys[:mid]
child.Ids = child.Ids[:mid]
child.Vectors = child.Vectors[:mid]
if _, err := t.PageFile.Seek(int64(loffset), io.SeekStart); err != nil {
return err
Expand All @@ -122,11 +123,11 @@ func (t *BTree) Insert(key pointer.ReferencedId, vector hnsw.Point) error {
}

// update parent to include new key and store left right offsets
if index == len(parent.Keys) {
parent.Keys = append(parent.Keys, midKey)
if index == len(parent.Ids) {
parent.Ids = append(parent.Ids, midKey)
} else {
parent.Keys = append(parent.Keys[:index+1], parent.Keys[index:]...)
parent.Keys[index] = midKey
parent.Ids = append(parent.Ids[:index+1], parent.Ids[index:]...)
parent.Ids[index] = midKey
}

parent.Offsets = append(parent.Offsets[:index+2], parent.Offsets[:index+1]...)
Expand Down Expand Up @@ -154,13 +155,13 @@ func (t *BTree) Insert(key pointer.ReferencedId, vector hnsw.Point) error {
}
}

index, found := slices.BinarySearchFunc(parent.Keys, key, pointer.CompareReferencedIds)
index, found := slices.BinarySearchFunc(parent.Ids, key, pointer.CompareReferencedIds)
if found {
panic("cannot insert duplicate key")
}

parent.Keys = append(parent.Keys[:index+1], parent.Keys[index:]...)
parent.Keys[index] = key
parent.Ids = append(parent.Ids[:index+1], parent.Ids[index:]...)
parent.Ids[index] = key

parent.Vectors = append(parent.Vectors[:index+1], parent.Vectors[index:]...)
parent.Vectors[index] = vector
Expand All @@ -186,10 +187,10 @@ func (t *BTree) Find(key pointer.ReferencedId) (pointer.ReferencedId, pointer.Me
return pointer.ReferencedId{}, pointer.MemoryPointer{}, nil
}

index, found := slices.BinarySearchFunc(node.Keys, key, pointer.CompareReferencedIds)
index, found := slices.BinarySearchFunc(node.Ids, key, pointer.CompareReferencedIds)

if found {
return node.Keys[index-1], pointer.MemoryPointer{Offset: node.Offsets[index]}, nil
return node.Ids[index], pointer.MemoryPointer{Offset: node.Ids[index].DataPointer.Offset}, nil
}

// no key found
Expand Down
16 changes: 6 additions & 10 deletions pkg/btree/btree_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,12 @@ func TestBTree(t *testing.T) {
}
tree := &BTree{PageFile: p, MetaPage: newTestMetaPage(t, p)}
// find a key that doesn't exist
k, _, err := tree.Find(pointer.ReferencedId{Id: hnsw.Id(0)})
k, _, err := tree.Find(pointer.ReferencedId{Value: hnsw.Id(0)})
if err != nil {
t.Fatal(err)
}

if k.Id != hnsw.Id(0) {
if k.Value != hnsw.Id(0) {
t.Fatalf("expected id 0, got %d", k)
}

Expand All @@ -75,23 +75,19 @@ func TestBTree(t *testing.T) {
if err != nil {
t.Fatal(err)
}
tree := &BTree{PageFile: p, MetaPage: newTestMetaPage(t, p), Width: uint16(0)}
if err := tree.Insert(pointer.ReferencedId{Id: 1}, hnsw.Point{1}); err != nil {
tree := &BTree{PageFile: p, MetaPage: newTestMetaPage(t, p), Width: uint16(0), VectorDim: 1}
if err := tree.Insert(pointer.ReferencedId{Value: 1}, hnsw.Point{1}); err != nil {
t.Fatal(err)
}
k, v, err := tree.Find(pointer.ReferencedId{Id: 1})
k, _, err := tree.Find(pointer.ReferencedId{Value: 1})

if err != nil {
t.Fatal(err)
}

if k.Id != hnsw.Id(1) {
if k.Value != hnsw.Id(1) {
t.Fatalf("expected id 1, got %d", k)
}

if v.Offset != 1 {
t.Fatalf("expected value 1, got %d", v)
}
})

}
6 changes: 1 addition & 5 deletions pkg/btree/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ func (n *BTreeNode) Size() int64 {
}

func (n *BTreeNode) Leaf() bool {
return n.Offsets == nil || len(n.Offsets) == 0
return len(n.Offsets) == 0
}

func (n *BTreeNode) MarshalBinary() ([]byte, error) {
Expand Down Expand Up @@ -155,7 +155,3 @@ func (n *BTreeNode) WriteTo(w io.Writer) (int64, error) {
m, err := w.Write(buf)
return int64(m), err
}

func (n *BTreeNode) Leaf() bool {
return len(n.Offsets) == 0
}
14 changes: 7 additions & 7 deletions pkg/pointer/referenced_value_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,34 +7,34 @@ import (

func TestReferencedValue(t *testing.T) {
t.Run("compare referenced value", func(t *testing.T) {
keys := []ReferencedValue{
keys := []ReferencedId{
{
Value: []byte{1},
Value: 1,
DataPointer: MemoryPointer{
Offset: 100,
Length: 0,
},
},
{
Value: []byte{2},
Value: 2,
DataPointer: MemoryPointer{
Offset: 200,
Length: 0,
},
},
{
Value: []byte{3},
Value: 3,
DataPointer: MemoryPointer{
Offset: 300,
Length: 0,
},
},
}

index, found := slices.BinarySearchFunc(keys, ReferencedValue{
index, found := slices.BinarySearchFunc(keys, ReferencedId{
DataPointer: MemoryPointer{},
Value: []byte{1},
}, CompareUniqueReferencedValues)
Value: 1,
}, CompareReferencedIds)

if !found {
t.Fatal("expected to find key 1")
Expand Down

0 comments on commit d9e740c

Please sign in to comment.