diff --git a/pkg/hnsw/hnsw_test.go b/pkg/hnsw/hnsw_test.go index 6e55f87..9896acb 100644 --- a/pkg/hnsw/hnsw_test.go +++ b/pkg/hnsw/hnsw_test.go @@ -2,6 +2,7 @@ package hnsw import ( "fmt" + "reflect" "testing" ) @@ -492,3 +493,80 @@ func TestSpawnLevelDistribution(t *testing.T) { fmt.Printf("levels distribution: %v\n", levels) }) } + +func TestHnsw_KnnCluster(t *testing.T) { + + var clusterC = []Vector{ + {0.2, 0.5}, + {0.2, 0.7}, + {0.3, 0.8}, + {0.5, 0.5}, + {0.4, 0.1}, + } + + var clusterCNodes = map[NodeId][]NodeId{ + 1: {2, 4, 3, 5}, + 2: {3, 1, 4, 5}, + 3: {2, 1, 4, 5}, + 4: {1, 2, 3, 5}, + 5: {4, 1, 2, 3}, + } + + var clusterCVisited = map[NodeId][]bool{ + 1: {false, true, true, true, true, true}, + 2: {false, false, true, true, true, true}, + 3: {false, false, false, true, true, true}, + 4: {false, true, true, true, false, true}, + 5: {false, true, true, true, true, false}, + } + + t.Run("cluster c insert", func(t *testing.T) { + h := NewHNSW(2, 4, 4, []float32{0, 0}) + + for i, q := range clusterC { + if err := h.Insert(q); err != nil { + t.Fatalf("failed to insert item %d: %v", i, err) + } + } + + fmt.Printf("%v", h.Nodes) + + if reflect.DeepEqual(h.Nodes, clusterCNodes) { + t.Fatalf("expected all node keys to be the same as clusterC") + } + + if len(h.Nodes) != 6 { + t.Fatalf("expected 6 nodes, got %d", len(h.Nodes)) + } + + for i := 1; i <= 5; i++ { + nodeId := NodeId(i) + node := h.Nodes[nodeId] + + var nodeNN []NodeId + visitedNN := make([]bool, 6) // counting entry + + for level := node.level; level >= 0; level-- { + friendsAtLevel := node.friends[level] + + for !friendsAtLevel.IsEmpty() { + peeled, err := friendsAtLevel.Peel() + if err != nil { + t.Fatal(err) + } + + if !visitedNN[peeled.id] { + nodeNN = append(nodeNN, peeled.id) + visitedNN[peeled.id] = true + } + } + } + + if reflect.DeepEqual(clusterCVisited[nodeId], visitedNN) { + t.Fatalf("expected all node keys to be the same as clusterC") + } + } + + }) + +}