Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Knnsearch #330

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 131 additions & 7 deletions pkg/hnsw/hnsw.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@ type Hnsw struct {

levelMultiplier float64

// efConstruction is the size of the dynamic candIdate list
efConstruction uint
// efConstruction is the size of the dynamic candidate list
efConstruction int

// default number of connections
M, Mmax0 int
}

func NewHnsw(d int, efConstruction uint, M int, entryPoint Point) *Hnsw {
func NewHnsw(d int, efConstruction int, M int, entryPoint Point) *Hnsw {
if d <= 0 || len(entryPoint) != d {
panic("invalid vector dimensionality")
}
Expand Down Expand Up @@ -56,6 +56,10 @@ func (h *Hnsw) SpawnLevel() int {
return int(math.Floor(-math.Log(rand.Float64() * h.levelMultiplier)))
}

func (h *Hnsw) GenerateId() Id {
return Id(len(h.points))
}

func (h *Hnsw) searchLevel(q *Point, entryItem *Item, numNearestToQToReturn, level int) (*BaseQueue, error) {
visited := make([]bool, len(h.friends)+1)

Expand Down Expand Up @@ -148,18 +152,138 @@ func (h *Hnsw) findCloserEntryPoint(q *Point, qFriends *Friends) *Item {
return epItem
}

func (h *Hnsw) selectNeighbors(nearestNeighbors *BaseQueue) (*BaseQueue, error) {
maxNearestNeighbors := FromBaseQueue(nearestNeighbors, MaxComparator{})

for maxNearestNeighbors.Len() > h.M {
_, err := maxNearestNeighbors.PopItem()

if err != nil {
return nil, err
}
}

return FromBaseQueue(maxNearestNeighbors, MinComparator{}), nil
friendlymatthew marked this conversation as resolved.
Show resolved Hide resolved
}

func (h *Hnsw) InsertVector(q Point) error {
if !h.validatePoint(q) {
if !h.isValidPoint(q) {
return fmt.Errorf("invalid vector dimensionality")
}

topLevel := h.friends[h.entryPointId].TopLevel()

qId := h.GenerateId()
qTopLevel := h.SpawnLevel()
qFriends := NewFriends(qTopLevel)
h.friends[qId] = qFriends
h.points[qId] = &q

entryItem := h.findCloserEntryPoint(&q, qFriends)

for level := min(topLevel, qTopLevel); level >= 0; level-- {
nnToQAtLevel, err := h.searchLevel(&q, entryItem, h.efConstruction, level)

if err != nil {
return fmt.Errorf("failed to search for nearest neighbors to Q at level %v: %w", level, err)
}

neighbors, err := h.selectNeighbors(nnToQAtLevel)
if err != nil {
return fmt.Errorf("failed to select for nearest neighbors to Q at level %v: %w", level, err)
}

// add bidirectional connections from neighbors to q at layer c
for _, neighbor := range neighbors.items {
neighborPoint := h.points[neighbor.id]

distNeighToQ := EuclidDistance(*neighborPoint, q)

h.friends[neighbor.id].InsertFriendsAtLevel(level, qId, distNeighToQ)
h.friends[qId].InsertFriendsAtLevel(level, neighbor.id, distNeighToQ)
}

for _, neighbor := range neighbors.items {
neighborFriendsAtLevel, err := h.friends[neighbor.id].GetFriendsAtLevel(level)

if err != nil {
return fmt.Errorf("failed to find nearest neighbor to Q at level %v: %w", level, err)
}

maxNeighborsFriendsAtLevel := FromBaseQueue(neighborFriendsAtLevel, MaxComparator{})

for maxNeighborsFriendsAtLevel.Len() > h.M {
_, err = maxNeighborsFriendsAtLevel.PopItem()
if err != nil {
return fmt.Errorf("failed to find nearest neighbor to Q at level %v: %w", level, err)
}
}

h.friends[neighbor.id].friends[level] = FromBaseQueue(maxNeighborsFriendsAtLevel, MinComparator{})
}

newEntryItem, err := nnToQAtLevel.PopItem()
if err != nil {
return fmt.Errorf("failed to find nearest neighbor to Q at level %v: %w", level, err)
}

entryItem = newEntryItem
}

if qTopLevel > topLevel {
h.entryPointId = qId
}

_ = h.findCloserEntryPoint(&q, qFriends)
return nil
}

func (h *Hnsw) validatePoint(point Point) bool {
return len(point) != h.vectorDimensionality
func (h *Hnsw) isValidPoint(point Point) bool {
return len(point) == h.vectorDimensionality
}

func (h *Hnsw) KnnSearch(q Point, nnToReturn int) (*BaseQueue, error) {

entryPoint, ok := h.points[h.entryPointId]
if !ok {
return nil, fmt.Errorf("no point found for entryPointId %v", h.entryPointId)
}

entryPointFriends, ok := h.friends[h.entryPointId]
if !ok {
return nil, fmt.Errorf("no friends found for entryPointId %v", h.entryPointId)
}

topLevel := entryPointFriends.TopLevel()

entryItem := &Item{
id: h.entryPointId,
dist: EuclidDistance(*entryPoint, q),
}

for level := topLevel; level >= 1; level-- {
nnToQAtLevel, err := h.searchLevel(&q, entryItem, 1, level)

if err != nil {
return nil, fmt.Errorf("failed to search for nearest neighbors to Q at level %v: %w", level, err)
}

entryItem = nnToQAtLevel.Top()
}

nnToQAtLevel0, err := h.searchLevel(&q, entryItem, h.efConstruction, 0)

if err != nil {
return nil, fmt.Errorf("failed to search at level %v: %w", h.entryPointId, err)
}

maxNNToQAtLevel0 := FromBaseQueue(nnToQAtLevel0, MaxComparator{})

for maxNNToQAtLevel0.Len() > nnToReturn {
_, err = maxNNToQAtLevel0.PopItem()
if err != nil {
return nil, fmt.Errorf("failed to find nearest neighbor to Q at level %v: %w", h.entryPointId, err)
}
}

return FromBaseQueue(maxNNToQAtLevel0, MinComparator{}), nil
}
145 changes: 144 additions & 1 deletion pkg/hnsw/hnsw_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ var clusterB = []Point{
}

func SetupClusterHnsw(cluster []Point) (*Hnsw, error) {
efc := uint(4)
efc := 4

entryPoint := Point{0, 0}
g := NewHnsw(2, efc, 4, entryPoint)
Expand Down Expand Up @@ -328,3 +328,146 @@ func TestHnsw_FindCloserEntryPoint(t *testing.T) {
}
})
}

func TestHnsw_SelectNeighbors(t *testing.T) {

t.Run("selects neighbors given overflow", func(t *testing.T) {
nearestNeighbors := NewBaseQueue(MinComparator{})

M := 4

h := NewHnsw(2, 4, M, Point{0, 0})

// since M is 4
for i := 5; i >= 0; i-- {
nearestNeighbors.Insert(Id(i), float32(i))
}

neighbors, err := h.selectNeighbors(nearestNeighbors)

if err != nil {
t.Fatal(err)
}

if neighbors.Len() != M {
t.Fatalf("select neighbors should have at most M friends")
}

expectedId := Id(0)
for !neighbors.IsEmpty() {
nn, err := neighbors.PopItem()

if err != nil {
t.Fatal(err)
}

if nn.id != expectedId {
t.Fatalf("expected item to be %v, got %v", expectedId, nn.id)
}

expectedId += 1
}
})

t.Run("selects neighbors given lower bound", func(t *testing.T) {
M := 10
h := NewHnsw(2, 10, M, Point{0, 0})

nnQueue := NewBaseQueue(MinComparator{})

for i := 0; i < 3; i++ {
nnQueue.Insert(Id(i), float32(i))
}

neighbors, err := h.selectNeighbors(nnQueue)

if err != nil {
t.Fatal(err)
}

if neighbors.Len() != 3 {
t.Fatalf("select neighbors should have at least 3 neighbors, got: %v", neighbors.Len())
}

expectedId := Id(0)
for !neighbors.IsEmpty() {
nn, err := neighbors.PopItem()

if err != nil {
t.Fatal(err)
}

if nn.id != expectedId {
t.Fatalf("expected item to be %v, got %v", expectedId, nn.id)
}

expectedId += 1
}
})
}

func TestHnsw_InsertVector(t *testing.T) {
t.Run("basic insert", func(t *testing.T) {
h := NewHnsw(2, 3, 4, Point{0, 0})
q := Point{3, 3}

if len(q) != 2 {
t.Fatal("insert vector should have 2 elements")
}

err := h.InsertVector(q)

if err != nil {
t.Fatal(err)
}
})
}

func TestHnsw_KnnSearch(t *testing.T) {
t.Run("basic search", func(t *testing.T) {
h := NewHnsw(2, 2, 2, Point{0, 0})

err := h.InsertVector(Point{3, 3})
if err != nil {
t.Fatal(err)
}

err = h.InsertVector(Point{4, 4})
if err != nil {
t.Fatal(err)
}

err = h.InsertVector(Point{5, 5})
if err != nil {
t.Fatal(err)
}

kNearestToQ, err := h.KnnSearch(Point{5, 5}, 2)
if err != nil {
t.Fatal(err)
}

if kNearestToQ.IsEmpty() {
t.Fatal("kNearestToQ should not be empty")
}

if kNearestToQ.Len() != 2 {
t.Fatalf("since K is 2, we expect two closest neighbors, got: %v", kNearestToQ.Len())
}

expectedId := Id(3)
for !kNearestToQ.IsEmpty() {
nnItem, err := kNearestToQ.PopItem()
if err != nil {
t.Fatal(err)
}

if nnItem.id != expectedId {
t.Fatalf("expected item to be %v, got id: %v, point: %v", expectedId, nnItem.id, h.points[nnItem.id])
}

expectedId -= 1
}

})
}
Loading