diff --git a/pkg/hnsw/hnsw.go b/pkg/hnsw/hnsw.go index 9416df6..a96b5f8 100644 --- a/pkg/hnsw/hnsw.go +++ b/pkg/hnsw/hnsw.go @@ -213,16 +213,13 @@ func (h *Hnsw) InsertVector(q Point) error { // add bidirectional connections from neighbors to q at layer c for _, neighbor := range neighbors { neighborPoint := h.points[neighbor.id] - distNeighToQ := EuclidDistance(*neighborPoint, q) - h.friends[neighbor.id].InsertFriendsAtLevel(level, qId, distNeighToQ) h.friends[qId].InsertFriendsAtLevel(level, neighbor.id, distNeighToQ) } for _, neighbor := range neighbors { neighborFriendsAtLevel, err := h.friends[neighbor.id].GetFriendsAtLevel(level) - if err != nil { return fmt.Errorf("failed to find nearest neighbor to Q at level %v: %w", level, err) } diff --git a/pkg/hnsw/hnsw_test.go b/pkg/hnsw/hnsw_test.go index 905bb1c..2573c4a 100644 --- a/pkg/hnsw/hnsw_test.go +++ b/pkg/hnsw/hnsw_test.go @@ -436,7 +436,7 @@ func TestHnsw_InsertVector(t *testing.T) { t.Run("bulk insert", func(t *testing.T) { items := 1 - h := NewHnsw(3, 4, 4, Point{0, 0, 0}) + h := NewHnsw(3, 4, 10, Point{0, 0, 0}) for i := 100; i >= 1; i-- { j := float32(i) @@ -465,6 +465,34 @@ func TestHnsw_InsertVector(t *testing.T) { items += 1 } + + // ensure every friend pq is of max length 4 + var allNodeIds []Id + for id := range h.friends { + allNodeIds = append(allNodeIds, id) + } + + for _, nodeId := range allNodeIds { + nodeFriends, ok := h.friends[nodeId] + if !ok { + t.Fatalf("expected to find point for node %v", nodeId) + } + + for level, friendsAtLevel := range nodeFriends.friends { + if level == 0 { + if friendsAtLevel.Len() > h.Mmax0 { + t.Fatalf("node id %v, num friends at level 0 cannot be greater than max number of connections M = %v. Got %v", nodeId, h.M, friendsAtLevel.Len()) + } + + continue + } + + if friendsAtLevel.Len() > h.M { + t.Fatalf("num friends at level %v cannot be greater than max number of connections M: %v. Got: %v", level, h.M, friendsAtLevel.Len()) + } + } + + } }) t.Run("basic cluster insertion", func(t *testing.T) {