Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Knnsearch #330

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 62 additions & 4 deletions pkg/hnsw/hnsw.go
Original file line number Diff line number Diff line change
Expand Up @@ -216,16 +216,27 @@ func (h *Hnsw) InsertVector(q Point) error {
return fmt.Errorf("failed to find nearest neighbor to Q at level %v: %w", level, err)
}

maxNeighborsFriendsAtLevel := FromBaseQueue(neighborFriendsAtLevel, MaxComparator{})
if neighborFriendsAtLevel.Len() <= h.M {
continue
}

var items []*Item

for !neighborFriendsAtLevel.IsEmpty() {
if len(items) != h.M {
continue
}

neighborFriend, err := neighborFriendsAtLevel.PopItem()

for maxNeighborsFriendsAtLevel.Len() > h.M {
_, err = maxNeighborsFriendsAtLevel.PopItem()
if err != nil {
return fmt.Errorf("failed to find nearest neighbor to Q at level %v: %w", level, err)
}

items = append(items, neighborFriend)
}

h.friends[neighbor.id].friends[level] = FromBaseQueue(maxNeighborsFriendsAtLevel, MinComparator{})
h.friends[neighbor.id].friends[level] = FromItems(items, neighborFriendsAtLevel.comparator)
}

newEntryItem, err := nnToQAtLevel.PopItem()
Expand All @@ -246,3 +257,50 @@ func (h *Hnsw) InsertVector(q Point) error {
func (h *Hnsw) isValidPoint(point Point) bool {
return len(point) == h.vectorDimensionality
}

func (h *Hnsw) KnnSearch(q Point, nnToReturn int) (*BaseQueue, error) {

entryPoint, ok := h.points[h.entryPointId]
if !ok {
return nil, fmt.Errorf("no point found for entryPointId %v", h.entryPointId)
}

entryPointFriends, ok := h.friends[h.entryPointId]
if !ok {
return nil, fmt.Errorf("no friends found for entryPointId %v", h.entryPointId)
}

topLevel := entryPointFriends.TopLevel()

entryItem := &Item{
id: h.entryPointId,
dist: EuclidDistance(*entryPoint, q),
}

for level := topLevel; level >= 1; level-- {
nnToQAtLevel, err := h.searchLevel(&q, entryItem, 1, level)

if err != nil {
return nil, fmt.Errorf("failed to search for nearest neighbors to Q at level %v: %w", level, err)
}

entryItem = nnToQAtLevel.Top()
}

nnToQAtLevel0, err := h.searchLevel(&q, entryItem, h.efConstruction, 0)

if err != nil {
return nil, fmt.Errorf("failed to search at level %v: %w", h.entryPointId, err)
}

maxNNToQAtLevel0 := FromBaseQueue(nnToQAtLevel0, MaxComparator{})

for maxNNToQAtLevel0.Len() > nnToReturn {
_, err = maxNNToQAtLevel0.PopItem()
if err != nil {
return nil, fmt.Errorf("failed to find nearest neighbor to Q at level %v: %w", h.entryPointId, err)
}
}

return FromBaseQueue(maxNNToQAtLevel0, MinComparator{}), nil
}
47 changes: 47 additions & 0 deletions pkg/hnsw/hnsw_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,53 @@ func TestHnsw_InsertVector(t *testing.T) {

items += 1
}
})
}

func TestHnsw_KnnSearch(t *testing.T) {
t.Run("basic search", func(t *testing.T) {
h := NewHnsw(2, 2, 2, Point{0, 0})

err := h.InsertVector(Point{3, 3})
if err != nil {
t.Fatal(err)
}

err = h.InsertVector(Point{4, 4})
if err != nil {
t.Fatal(err)
}

err = h.InsertVector(Point{5, 5})
if err != nil {
t.Fatal(err)
}

kNearestToQ, err := h.KnnSearch(Point{5, 5}, 2)
if err != nil {
t.Fatal(err)
}

if kNearestToQ.IsEmpty() {
t.Fatal("kNearestToQ should not be empty")
}

if kNearestToQ.Len() != 2 {
t.Fatalf("since K is 2, we expect two closest neighbors, got: %v", kNearestToQ.Len())
}

expectedId := Id(3)
for !kNearestToQ.IsEmpty() {
nnItem, err := kNearestToQ.PopItem()
if err != nil {
t.Fatal(err)
}

if nnItem.id != expectedId {
t.Fatalf("expected item to be %v, got id: %v, point: %v", expectedId, nnItem.id, h.points[nnItem.id])
}

expectedId -= 1
}
})
}
39 changes: 12 additions & 27 deletions pkg/hnsw/pq.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ type Heapy interface {
Len() int
PopItem() *Item
Top() *Item
Take(count int) (*BaseQueue, error)
update(item *Item, id Id, dist float32)
}

Expand All @@ -48,32 +47,6 @@ type BaseQueue struct {
comparator Comparator
}

func (bq *BaseQueue) Take(count int, comparator Comparator) (*BaseQueue, error) {
if len(bq.items) < count {
return nil, fmt.Errorf("queue only has %v items, but want to take %v", len(bq.items), count)
}

pq := NewBaseQueue(comparator)

ct := 0
for {
if ct == count {
break
}

peeled, err := bq.PopItem()
if err != nil {
return nil, err
}

pq.Insert(peeled.id, peeled.dist)

ct++
}

return pq, nil
}

func (bq BaseQueue) Len() int { return len(bq.items) }
func (bq BaseQueue) Swap(i, j int) {
pq := bq.items
Expand Down Expand Up @@ -150,6 +123,18 @@ func (bq *BaseQueue) update(item *Item, id Id, dist float32) {
heap.Fix(bq, item.index)
}

func FromItems(items []*Item, comparator Comparator) *BaseQueue {
bq := &BaseQueue{
visitedIds: map[Id]*Item{},
items: items,
comparator: comparator,
}

heap.Init(bq)

return bq
}

func FromBaseQueue(bq *BaseQueue, comparator Comparator) *BaseQueue {
newBq := NewBaseQueue(comparator)

Expand Down
36 changes: 36 additions & 0 deletions pkg/hnsw/pq_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,42 @@ func TestPQ(t *testing.T) {
i -= 1
}
})

t.Run("from items", func(t *testing.T) {
items := []*Item{
{id: 0, dist: 30},
{id: 1, dist: 29},
{id: 2, dist: 28},
{id: 3, dist: 27},
{id: 4, dist: 26},
{id: 5, dist: 25},
}

bq := FromItems(items, MinComparator{})

if bq.IsEmpty() {
t.Fatal("empty queue")
}

if bq.Len() != len(items) {
t.Fatalf("got %d, want %d", bq.Len(), len(items))
}

expectedId := Id(5)
for !bq.IsEmpty() {
bqItem, err := bq.PopItem()
if err != nil {
t.Fatal(err)
}

if bqItem.id != expectedId {
t.Fatalf("got %d, want %d", bqItem.id, expectedId)
}

expectedId -= 1
}

})
}

func furthestBuildings(heights []int, bricks, ladders int) (int, error) {
Expand Down
Loading