From e6bc4f88d8b34e72e5eec84c9234d491e5fa4899 Mon Sep 17 00:00:00 2001 From: Matthew Kim <38759997+friendlymatthew@users.noreply.github.com> Date: Wed, 5 Jun 2024 13:27:29 -0400 Subject: [PATCH 1/5] feat: select neighbor --- pkg/hnsw/hnsw.go | 14 ++++++++ pkg/hnsw/hnsw_test.go | 77 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 91 insertions(+) diff --git a/pkg/hnsw/hnsw.go b/pkg/hnsw/hnsw.go index 8e604bc..fc43ce4 100644 --- a/pkg/hnsw/hnsw.go +++ b/pkg/hnsw/hnsw.go @@ -148,6 +148,20 @@ func (h *Hnsw) findCloserEntryPoint(q *Point, qFriends *Friends) *Item { return epItem } +func (h *Hnsw) selectNeighbors(nearestNeighbors *BaseQueue) (*BaseQueue, error) { + maxNearestNeighbors := FromBaseQueue(nearestNeighbors, MaxComparator{}) + + for maxNearestNeighbors.Len() > h.M { + _, err := maxNearestNeighbors.PopItem() + + if err != nil { + return nil, err + } + } + + return FromBaseQueue(maxNearestNeighbors, MinComparator{}), nil +} + func (h *Hnsw) InsertVector(q Point) error { if !h.validatePoint(q) { return fmt.Errorf("invalid vector dimensionality") diff --git a/pkg/hnsw/hnsw_test.go b/pkg/hnsw/hnsw_test.go index 9228564..e903bcd 100644 --- a/pkg/hnsw/hnsw_test.go +++ b/pkg/hnsw/hnsw_test.go @@ -328,3 +328,80 @@ func TestHnsw_FindCloserEntryPoint(t *testing.T) { } }) } + +func TestHnsw_SelectNeighbors(t *testing.T) { + + t.Run("selects neighbors given overflow", func(t *testing.T) { + nearestNeighbors := NewBaseQueue(MinComparator{}) + + M := 4 + + h := NewHnsw(2, 4, M, Point{0, 0}) + + // since M is 4 + for i := 5; i >= 0; i-- { + nearestNeighbors.Insert(Id(i), float32(i)) + } + + neighbors, err := h.selectNeighbors(nearestNeighbors) + + if err != nil { + t.Fatal(err) + } + + if neighbors.Len() != M { + t.Fatalf("select neighbors should have at most M friends") + } + + expectedId := Id(0) + for !neighbors.IsEmpty() { + nn, err := neighbors.PopItem() + + if err != nil { + t.Fatal(err) + } + + if nn.id != expectedId { + t.Fatalf("expected item to be %v, got %v", expectedId, nn.id) + } + + expectedId += 1 + } + }) + + t.Run("selects neighbors given lower bound", func(t *testing.T) { + M := 10 + h := NewHnsw(2, 10, M, Point{0, 0}) + + nnQueue := NewBaseQueue(MinComparator{}) + + for i := 0; i < 3; i++ { + nnQueue.Insert(Id(i), float32(i)) + } + + neighbors, err := h.selectNeighbors(nnQueue) + + if err != nil { + t.Fatal(err) + } + + if neighbors.Len() != 3 { + t.Fatalf("select neighbors should have at least 3 neighbors, got: %v", neighbors.Len()) + } + + expectedId := Id(0) + for !neighbors.IsEmpty() { + nn, err := neighbors.PopItem() + + if err != nil { + t.Fatal(err) + } + + if nn.id != expectedId { + t.Fatalf("expected item to be %v, got %v", expectedId, nn.id) + } + + expectedId += 1 + } + }) +} From 2b30fbfda78dc82d88aa987b42f21c177b55e2eb Mon Sep 17 00:00:00 2001 From: Matthew Kim <38759997+friendlymatthew@users.noreply.github.com> Date: Wed, 5 Jun 2024 13:43:00 -0400 Subject: [PATCH 2/5] feat: insertion --- pkg/hnsw/hnsw.go | 77 +++++++++++++++++++++++++++++++++++++++---- pkg/hnsw/hnsw_test.go | 19 ++++++++++- 2 files changed, 88 insertions(+), 8 deletions(-) diff --git a/pkg/hnsw/hnsw.go b/pkg/hnsw/hnsw.go index fc43ce4..be4877e 100644 --- a/pkg/hnsw/hnsw.go +++ b/pkg/hnsw/hnsw.go @@ -20,14 +20,14 @@ type Hnsw struct { levelMultiplier float64 - // efConstruction is the size of the dynamic candIdate list - efConstruction uint + // efConstruction is the size of the dynamic candidate list + efConstruction int // default number of connections M, Mmax0 int } -func NewHnsw(d int, efConstruction uint, M int, entryPoint Point) *Hnsw { +func NewHnsw(d int, efConstruction int, M int, entryPoint Point) *Hnsw { if d <= 0 || len(entryPoint) != d { panic("invalid vector dimensionality") } @@ -56,6 +56,10 @@ func (h *Hnsw) SpawnLevel() int { return int(math.Floor(-math.Log(rand.Float64() * h.levelMultiplier))) } +func (h *Hnsw) GenerateId() Id { + return Id(len(h.points)) +} + func (h *Hnsw) searchLevel(q *Point, entryItem *Item, numNearestToQToReturn, level int) (*BaseQueue, error) { visited := make([]bool, len(h.friends)+1) @@ -163,17 +167,76 @@ func (h *Hnsw) selectNeighbors(nearestNeighbors *BaseQueue) (*BaseQueue, error) } func (h *Hnsw) InsertVector(q Point) error { - if !h.validatePoint(q) { + if !h.isValidPoint(q) { return fmt.Errorf("invalid vector dimensionality") } + topLevel := h.friends[h.entryPointId].TopLevel() + + qId := h.GenerateId() qTopLevel := h.SpawnLevel() qFriends := NewFriends(qTopLevel) + h.friends[qId] = qFriends + h.points[qId] = &q + + entryItem := h.findCloserEntryPoint(&q, qFriends) + + for level := min(topLevel, qTopLevel); level >= 0; level-- { + nnToQAtLevel, err := h.searchLevel(&q, entryItem, h.efConstruction, level) + + if err != nil { + return fmt.Errorf("failed to search for nearest neighbors to Q at level %v: %w", level, err) + } + + neighbors, err := h.selectNeighbors(nnToQAtLevel) + if err != nil { + return fmt.Errorf("failed to select for nearest neighbors to Q at level %v: %w", level, err) + } + + // add bidirectional connections from neighbors to q at layer c + for _, neighbor := range neighbors.items { + neighborPoint := h.points[neighbor.id] + + distNeighToQ := EuclidDistance(*neighborPoint, q) + + h.friends[neighbor.id].InsertFriendsAtLevel(level, qId, distNeighToQ) + h.friends[qId].InsertFriendsAtLevel(level, neighbor.id, distNeighToQ) + } + + for _, neighbor := range neighbors.items { + neighborFriendsAtLevel, err := h.friends[neighbor.id].GetFriendsAtLevel(level) + + if err != nil { + return fmt.Errorf("failed to find nearest neighbor to Q at level %v: %w", level, err) + } + + maxNeighborsFriendsAtLevel := FromBaseQueue(neighborFriendsAtLevel, MaxComparator{}) + + for maxNeighborsFriendsAtLevel.Len() > h.M { + _, err = maxNeighborsFriendsAtLevel.PopItem() + if err != nil { + return fmt.Errorf("failed to find nearest neighbor to Q at level %v: %w", level, err) + } + } + + h.friends[neighbor.id].friends[level] = FromBaseQueue(maxNeighborsFriendsAtLevel, MinComparator{}) + } + + newEntryItem, err := nnToQAtLevel.PopItem() + if err != nil { + return fmt.Errorf("failed to find nearest neighbor to Q at level %v: %w", level, err) + } + + entryItem = newEntryItem + } + + if qTopLevel > topLevel { + h.entryPointId = qId + } - _ = h.findCloserEntryPoint(&q, qFriends) return nil } -func (h *Hnsw) validatePoint(point Point) bool { - return len(point) != h.vectorDimensionality +func (h *Hnsw) isValidPoint(point Point) bool { + return len(point) == h.vectorDimensionality } diff --git a/pkg/hnsw/hnsw_test.go b/pkg/hnsw/hnsw_test.go index e903bcd..c7eda22 100644 --- a/pkg/hnsw/hnsw_test.go +++ b/pkg/hnsw/hnsw_test.go @@ -36,7 +36,7 @@ var clusterB = []Point{ } func SetupClusterHnsw(cluster []Point) (*Hnsw, error) { - efc := uint(4) + efc := 4 entryPoint := Point{0, 0} g := NewHnsw(2, efc, 4, entryPoint) @@ -405,3 +405,20 @@ func TestHnsw_SelectNeighbors(t *testing.T) { } }) } + +func TestHnsw_InsertVector(t *testing.T) { + t.Run("basic insert", func(t *testing.T) { + h := NewHnsw(2, 3, 4, Point{0, 0}) + q := Point{3, 3} + + if len(q) != 2 { + t.Fatal("insert vector should have 2 elements") + } + + err := h.InsertVector(q) + + if err != nil { + t.Fatal(err) + } + }) +} From a10d1aa61362d1c27a87a91dc306e1399e061ef4 Mon Sep 17 00:00:00 2001 From: Matthew Kim <38759997+friendlymatthew@users.noreply.github.com> Date: Wed, 5 Jun 2024 14:03:17 -0400 Subject: [PATCH 3/5] feat: knnsearch --- pkg/hnsw/hnsw.go | 47 +++++++++++++++++++++++++++++++++++++++++ pkg/hnsw/hnsw_test.go | 49 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 96 insertions(+) diff --git a/pkg/hnsw/hnsw.go b/pkg/hnsw/hnsw.go index be4877e..21ead39 100644 --- a/pkg/hnsw/hnsw.go +++ b/pkg/hnsw/hnsw.go @@ -240,3 +240,50 @@ func (h *Hnsw) InsertVector(q Point) error { func (h *Hnsw) isValidPoint(point Point) bool { return len(point) == h.vectorDimensionality } + +func (h *Hnsw) KnnSearch(q Point, nnToReturn int) (*BaseQueue, error) { + + entryPoint, ok := h.points[h.entryPointId] + if !ok { + return nil, fmt.Errorf("no point found for entryPointId %v", h.entryPointId) + } + + entryPointFriends, ok := h.friends[h.entryPointId] + if !ok { + return nil, fmt.Errorf("no friends found for entryPointId %v", h.entryPointId) + } + + topLevel := entryPointFriends.TopLevel() + + entryItem := &Item{ + id: h.entryPointId, + dist: EuclidDistance(*entryPoint, q), + } + + for level := topLevel; level >= 1; level-- { + nnToQAtLevel, err := h.searchLevel(&q, entryItem, 1, level) + + if err != nil { + return nil, fmt.Errorf("failed to search for nearest neighbors to Q at level %v: %w", level, err) + } + + entryItem = nnToQAtLevel.Top() + } + + nnToQAtLevel0, err := h.searchLevel(&q, entryItem, h.efConstruction, 0) + + if err != nil { + return nil, fmt.Errorf("failed to search at level %v: %w", h.entryPointId, err) + } + + maxNNToQAtLevel0 := FromBaseQueue(nnToQAtLevel0, MaxComparator{}) + + for maxNNToQAtLevel0.Len() > nnToReturn { + _, err = maxNNToQAtLevel0.PopItem() + if err != nil { + return nil, fmt.Errorf("failed to find nearest neighbor to Q at level %v: %w", h.entryPointId, err) + } + } + + return FromBaseQueue(maxNNToQAtLevel0, MinComparator{}), nil +} diff --git a/pkg/hnsw/hnsw_test.go b/pkg/hnsw/hnsw_test.go index c7eda22..6fd6ffa 100644 --- a/pkg/hnsw/hnsw_test.go +++ b/pkg/hnsw/hnsw_test.go @@ -422,3 +422,52 @@ func TestHnsw_InsertVector(t *testing.T) { } }) } + +func TestHnsw_KnnSearch(t *testing.T) { + t.Run("basic search", func(t *testing.T) { + h := NewHnsw(2, 2, 2, Point{0, 0}) + + err := h.InsertVector(Point{3, 3}) + if err != nil { + t.Fatal(err) + } + + err = h.InsertVector(Point{4, 4}) + if err != nil { + t.Fatal(err) + } + + err = h.InsertVector(Point{5, 5}) + if err != nil { + t.Fatal(err) + } + + kNearestToQ, err := h.KnnSearch(Point{5, 5}, 2) + if err != nil { + t.Fatal(err) + } + + if kNearestToQ.IsEmpty() { + t.Fatal("kNearestToQ should not be empty") + } + + if kNearestToQ.Len() != 2 { + t.Fatalf("since K is 2, we expect two closest neighbors, got: %v", kNearestToQ.Len()) + } + + expectedId := Id(3) + for !kNearestToQ.IsEmpty() { + nnItem, err := kNearestToQ.PopItem() + if err != nil { + t.Fatal(err) + } + + if nnItem.id != expectedId { + t.Fatalf("expected item to be %v, got id: %v, point: %v", expectedId, nnItem.id, h.points[nnItem.id]) + } + + expectedId -= 1 + } + + }) +} From 219f3667ecbba3d21d545a99144c7aa17e8daddc Mon Sep 17 00:00:00 2001 From: Matthew Kim <38759997+friendlymatthew@users.noreply.github.com> Date: Wed, 5 Jun 2024 18:37:46 -0400 Subject: [PATCH 4/5] removed extra two FromBaseQueue conversions, create bq.FromItems --- pkg/hnsw/hnsw.go | 19 +++++++++++++++---- pkg/hnsw/pq.go | 39 ++++++++++++--------------------------- pkg/hnsw/pq_test.go | 36 ++++++++++++++++++++++++++++++++++++ 3 files changed, 63 insertions(+), 31 deletions(-) diff --git a/pkg/hnsw/hnsw.go b/pkg/hnsw/hnsw.go index 21ead39..d979dc2 100644 --- a/pkg/hnsw/hnsw.go +++ b/pkg/hnsw/hnsw.go @@ -210,16 +210,27 @@ func (h *Hnsw) InsertVector(q Point) error { return fmt.Errorf("failed to find nearest neighbor to Q at level %v: %w", level, err) } - maxNeighborsFriendsAtLevel := FromBaseQueue(neighborFriendsAtLevel, MaxComparator{}) + if neighborFriendsAtLevel.Len() <= h.M { + continue + } + + var items []*Item + + for !neighborFriendsAtLevel.IsEmpty() { + if len(items) != h.M { + continue + } + + neighborFriend, err := neighborFriendsAtLevel.PopItem() - for maxNeighborsFriendsAtLevel.Len() > h.M { - _, err = maxNeighborsFriendsAtLevel.PopItem() if err != nil { return fmt.Errorf("failed to find nearest neighbor to Q at level %v: %w", level, err) } + + items = append(items, neighborFriend) } - h.friends[neighbor.id].friends[level] = FromBaseQueue(maxNeighborsFriendsAtLevel, MinComparator{}) + h.friends[neighbor.id].friends[level] = FromItems(items, neighborFriendsAtLevel.comparator) } newEntryItem, err := nnToQAtLevel.PopItem() diff --git a/pkg/hnsw/pq.go b/pkg/hnsw/pq.go index cc62372..2a8b198 100644 --- a/pkg/hnsw/pq.go +++ b/pkg/hnsw/pq.go @@ -36,7 +36,6 @@ type Heapy interface { Len() int PopItem() *Item Top() *Item - Take(count int) (*BaseQueue, error) update(item *Item, id Id, dist float32) } @@ -48,32 +47,6 @@ type BaseQueue struct { comparator Comparator } -func (bq *BaseQueue) Take(count int, comparator Comparator) (*BaseQueue, error) { - if len(bq.items) < count { - return nil, fmt.Errorf("queue only has %v items, but want to take %v", len(bq.items), count) - } - - pq := NewBaseQueue(comparator) - - ct := 0 - for { - if ct == count { - break - } - - peeled, err := bq.PopItem() - if err != nil { - return nil, err - } - - pq.Insert(peeled.id, peeled.dist) - - ct++ - } - - return pq, nil -} - func (bq BaseQueue) Len() int { return len(bq.items) } func (bq BaseQueue) Swap(i, j int) { pq := bq.items @@ -150,6 +123,18 @@ func (bq *BaseQueue) update(item *Item, id Id, dist float32) { heap.Fix(bq, item.index) } +func FromItems(items []*Item, comparator Comparator) *BaseQueue { + bq := &BaseQueue{ + visitedIds: map[Id]*Item{}, + items: items, + comparator: comparator, + } + + heap.Init(bq) + + return bq +} + func FromBaseQueue(bq *BaseQueue, comparator Comparator) *BaseQueue { newBq := NewBaseQueue(comparator) diff --git a/pkg/hnsw/pq_test.go b/pkg/hnsw/pq_test.go index 5bc1539..0912d78 100644 --- a/pkg/hnsw/pq_test.go +++ b/pkg/hnsw/pq_test.go @@ -69,6 +69,42 @@ func TestPQ(t *testing.T) { i -= 1 } }) + + t.Run("from items", func(t *testing.T) { + items := []*Item{ + {id: 0, dist: 30}, + {id: 1, dist: 29}, + {id: 2, dist: 28}, + {id: 3, dist: 27}, + {id: 4, dist: 26}, + {id: 5, dist: 25}, + } + + bq := FromItems(items, MinComparator{}) + + if bq.IsEmpty() { + t.Fatal("empty queue") + } + + if bq.Len() != len(items) { + t.Fatalf("got %d, want %d", bq.Len(), len(items)) + } + + expectedId := Id(5) + for !bq.IsEmpty() { + bqItem, err := bq.PopItem() + if err != nil { + t.Fatal(err) + } + + if bqItem.id != expectedId { + t.Fatalf("got %d, want %d", bqItem.id, expectedId) + } + + expectedId -= 1 + } + + }) } func furthestBuildings(heights []int, bricks, ladders int) (int, error) { From 77eac5ea262fab0a18110c020d906e4dbe8b7a3f Mon Sep 17 00:00:00 2001 From: Matthew Kim <38759997+friendlymatthew@users.noreply.github.com> Date: Thu, 6 Jun 2024 11:35:06 -0400 Subject: [PATCH 5/5] fmt --- pkg/hnsw/hnsw.go | 2 -- pkg/hnsw/hnsw_test.go | 8 ++++---- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/pkg/hnsw/hnsw.go b/pkg/hnsw/hnsw.go index cf05a32..a085a20 100644 --- a/pkg/hnsw/hnsw.go +++ b/pkg/hnsw/hnsw.go @@ -152,7 +152,6 @@ func (h *Hnsw) findCloserEntryPoint(q *Point, qFriends *Friends) *Item { return epItem } - func (h *Hnsw) selectNeighbors(nearestNeighbors *BaseQueue) ([]*Item, error) { if nearestNeighbors.Len() <= h.M { return nearestNeighbors.items, nil @@ -217,7 +216,6 @@ func (h *Hnsw) InsertVector(q Point) error { return fmt.Errorf("failed to find nearest neighbor to Q at level %v: %w", level, err) } - if neighborFriendsAtLevel.Len() <= h.M { continue } diff --git a/pkg/hnsw/hnsw_test.go b/pkg/hnsw/hnsw_test.go index ea5437b..d1d819b 100644 --- a/pkg/hnsw/hnsw_test.go +++ b/pkg/hnsw/hnsw_test.go @@ -392,7 +392,6 @@ func TestHnsw_SelectNeighbors(t *testing.T) { t.Fatal(err) } - if len(neighbors) != 3 { t.Fatalf("select neighbors should have at least 3 neighbors, got: %v", len(neighbors)) } @@ -435,8 +434,8 @@ func TestHnsw_InsertVector(t *testing.T) { t.Fatal(err) } }) - - t.Run("bulk insert", func(t *testing.T) { + + t.Run("bulk insert", func(t *testing.T) { items := 1 h := NewHnsw(3, 4, 4, Point{0, 0, 0}) @@ -515,5 +514,6 @@ func TestHnsw_KnnSearch(t *testing.T) { } expectedId -= 1 - }) + } + }) }