diff --git a/pkg/hnsw/heap.go b/pkg/hnsw/heap.go index 499ab83..5dcff52 100644 --- a/pkg/hnsw/heap.go +++ b/pkg/hnsw/heap.go @@ -2,6 +2,7 @@ package hnsw import ( "fmt" + "maps" "math/bits" ) @@ -127,6 +128,18 @@ func NewDistHeap() *DistHeap { return d } +func (d *DistHeap) Clone() *DistHeap { + n := &DistHeap{ + items: make([]*Item, len(d.items)), + visited: make(map[Id]int, len(d.visited)), + } + + copy(n.items, d.items) + maps.Copy(n.visited, d.visited) + + return n +} + func (d *DistHeap) PeekMinItem() (*Item, error) { if d.IsEmpty() { return nil, EmptyHeapError diff --git a/pkg/hnsw/heap_test.go b/pkg/hnsw/heap_test.go index ec131c7..f7997cc 100644 --- a/pkg/hnsw/heap_test.go +++ b/pkg/hnsw/heap_test.go @@ -1,6 +1,9 @@ package hnsw -import "testing" +import ( + "reflect" + "testing" +) func TestHeap(t *testing.T) { @@ -138,6 +141,34 @@ func TestHeap(t *testing.T) { } } }) + + t.Run("copy", func(t *testing.T) { + m := NewDistHeap() + + for i := 0; i <= 10; i++ { + m.Insert(Id(i), float32(10-i)) + } + + n := m.Clone() + + reflect.DeepEqual(m.items, n.items) + reflect.DeepEqual(m.visited, n.visited) + + expectedId := Id(10) + + for !n.IsEmpty() { + item, err := n.PopMinItem() + if err != nil { + return + } + + if item.id != expectedId { + t.Fatalf("expected id to be %v, got %v", expectedId, item.id) + } + + expectedId -= 1 + } + }) } func furthestBuildings(heights []int, bricks, ladders int) (int, error) {