diff --git a/pkg/hnsw/hnsw.go b/pkg/hnsw/hnsw.go index 788c651..ff9cf6b 100644 --- a/pkg/hnsw/hnsw.go +++ b/pkg/hnsw/hnsw.go @@ -27,6 +27,15 @@ type Hnsw struct { M, Mmax0 int } +func (h *Hnsw) Neighborhood(id Id) (*Friends, error) { + friends, ok := h.friends[id] + if !ok { + return nil, ErrNodeNotFound + } + + return friends, nil +} + func NewHnsw(d int, efConstruction int, M int, entryPoint Point) *Hnsw { if d <= 0 || len(entryPoint) != d { panic("invalid vector dimensionality") diff --git a/pkg/vectorpage/manager.go b/pkg/vectorpage/manager.go index 2b434e3..5e2541c 100644 --- a/pkg/vectorpage/manager.go +++ b/pkg/vectorpage/manager.go @@ -11,36 +11,44 @@ import ( type HNSWAdjacencyPage [16][8]uint32 type VectorPageManager struct { - btree *btree.BTree - vectors []*hnsw.Point + btree *btree.BTree + // vectors []*hnsw.Point - bptree *bptree.BPTree - neighborhood map[hnsw.Id]*hnsw.Friends + bptree *bptree.BPTree + // neighborhood map[hnsw.Id]*hnsw.Friends + + hnsw *hnsw.Hnsw } -func NewVectorPageManager(btree *btree.BTree, bptree *bptree.BPTree, vectors []*hnsw.Point, neighborhood map[hnsw.Id]*hnsw.Friends) *VectorPageManager { +func NewVectorPageManager(btree *btree.BTree, bptree *bptree.BPTree, hnsw *hnsw.Hnsw) *VectorPageManager { if btree == nil || bptree == nil { panic("btree and bptree must not be nil") } - return &VectorPageManager{btree, vectors, bptree, neighborhood} + return &VectorPageManager{ + btree: btree, + bptree: bptree, + hnsw: hnsw, + } } -func (vp *VectorPageManager) AddNode(x hnsw.Id) error { - // we'll assume that this node id is the freshly inserted vector - xvector := *vp.vectors[x] - - if err := vp.btree.Insert(pointer.ReferencedId{Value: x}, xvector); err != nil { +func (vp *VectorPageManager) AddNode(x hnsw.Point) error { + xId, err := vp.hnsw.InsertVector(x) + if err != nil { return err } - xfriends, ok := vp.neighborhood[x] + // write point to btree + if err := vp.btree.Insert(pointer.ReferencedId{Value: xId}, x); err != nil { + return err + } - if !ok { + // write friends to bptree + xFriends, err := vp.hnsw.Neighborhood(xId) + if err != nil { return fmt.Errorf("vector id %v not found in hnsw neighborhood", x) } - - xfriendsBuf, err := xfriends.Flush(8) + xfriendsBuf, err := xFriends.Flush(8) if err != nil { return err }