Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
ljcui committed Sep 19, 2024
1 parent 0a42fd2 commit e0f4f46
Show file tree
Hide file tree
Showing 10 changed files with 96 additions and 39 deletions.
24 changes: 3 additions & 21 deletions src/core/field_extractor.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class FieldExtractor {
// fulltext index
bool fulltext_indexed_ = false;
// vector index
std::unique_ptr<VectorIndex> vector_index_;
std::shared_ptr<VectorIndex> vector_index_;

public:
FieldExtractor() : null_bit_off_(0), vertex_index_(nullptr),
Expand All @@ -74,16 +74,7 @@ class FieldExtractor {
vertex_index_.reset(rhs.vertex_index_ ? new VertexIndex(*rhs.vertex_index_) : nullptr);
edge_index_.reset(rhs.edge_index_ ? new EdgeIndex(*rhs.edge_index_) : nullptr);
fulltext_indexed_ = rhs.fulltext_indexed_;
if (rhs.vector_index_ != nullptr) {
if (rhs.vector_index_->GetIndexType() == "HNSW") {
vector_index_.reset(new HNSW(
dynamic_cast<HNSW&>(*rhs.vector_index_)));
} else {
vector_index_.reset(nullptr);
}
} else {
vector_index_.reset(nullptr);
}
vector_index_ = rhs.vector_index_;
}

FieldExtractor& operator=(const FieldExtractor& rhs) {
Expand All @@ -97,16 +88,7 @@ class FieldExtractor {
vertex_index_.reset(rhs.vertex_index_ ? new VertexIndex(*rhs.vertex_index_) : nullptr);
edge_index_.reset(rhs.edge_index_ ? new EdgeIndex(*rhs.edge_index_) : nullptr);
fulltext_indexed_ = rhs.fulltext_indexed_;
if (rhs.vector_index_ != nullptr) {
if (rhs.vector_index_->GetIndexType() == "HNSW") {
vector_index_.reset(new HNSW(
dynamic_cast<HNSW&>(*rhs.vector_index_)));
} else {
vector_index_.reset(nullptr);
}
} else {
vector_index_.reset(nullptr);
}
vector_index_ = rhs.vector_index_;
return *this;
}

Expand Down
4 changes: 1 addition & 3 deletions src/core/index_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -194,10 +194,8 @@ bool IndexManager::AddVectorIndex(KvTransaction& txn, const std::string& label,
Value idxv;
StoreVectorIndex(idx, idxv);
it->AddKeyValue(Value::ConstRef(table_name), idxv);
if (index_type == "hnsw") {
vector_index = std::make_unique<HNSW>(label, field, distance_type,
vector_index = std::make_unique<HNSW>(label, field, distance_type,
index_type, vec_dimension, index_spec);
}
return true;
}

Expand Down
2 changes: 2 additions & 0 deletions src/core/lightning_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -832,6 +832,8 @@ bool LightningGraph::AlterLabelDelFields(const std::string& label,
// delete fulltext index
index_manager_->DeleteFullTextIndex(txn.GetTxn(), is_vertex, label,
extractor->Name());
} else if (extractor->GetVectorIndex()) {
index_manager_->DeleteVectorIndex(txn.GetTxn(), label, extractor->Name());
}
}
auto composite_index_key = curr_schema->GetRelationalCompositeIndexKey(fids);
Expand Down
3 changes: 1 addition & 2 deletions src/core/schema.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -314,10 +314,9 @@ void Schema::AddEdgeToIndex(KvTransaction& txn, const EdgeUid& euid, const Value
}

void Schema::AddVectorToVectorIndex(KvTransaction& txn, VertexId vid, const Value& record) {
LOG_INFO() << "Schema::AddVectorToVectorIndex " << vector_index_fields_.size();
for (auto& idx : vector_index_fields_) {
LOG_INFO() << "Schema::AddVectorToVectorIndex-1";
auto& fe = fields_[idx];
std::string a = fe.Name();
if (fe.GetIsNull(record)) continue;
VectorIndex* index = fe.GetVectorIndex();
auto dim = index->GetVecDimension();
Expand Down
7 changes: 0 additions & 7 deletions src/core/vsag_hnsw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,6 @@ HNSW::HNSW(const std::string& label, const std::string& name,
vec_dimension, std::move(index_spec)),
createindex_(nullptr), index_(createindex_.get()) {}

HNSW::HNSW(const HNSW& rhs)
: VectorIndex(rhs),
createindex_(rhs.createindex_),
index_(createindex_.get()) {}

// add vector to index
void HNSW::Add(const std::vector<std::vector<float>>& vectors,
const std::vector<int64_t>& vids, int64_t num_vectors) {
Expand Down Expand Up @@ -190,11 +185,9 @@ HNSW::Search(const std::vector<float>& query, int64_t num_results, int ef_search
nlohmann::json parameters{
{"hnsw", {{"ef_search", ef_search}}},
};
LOG_INFO() << "index_->GetNumElements(): " << index_->GetNumElements();
std::vector<std::pair<int64_t, float>> ret;
auto result = index_->KnnSearch(dataset, num_results, parameters.dump());
if (result.has_value()) {
LOG_INFO() << "result.value()->GetDim():" << result.value()->GetDim();
for (int64_t i = 0; i < result.value()->GetDim(); ++i) {
ret.emplace_back(result.value()->GetIds()[i], result.value()->GetDistances()[i]);
}
Expand Down
2 changes: 1 addition & 1 deletion src/core/vsag_hnsw.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class HNSW : public VectorIndex {
const std::string& distance_type, const std::string& index_type,
int vec_dimension, std::vector<int> index_spec);

HNSW(const HNSW& rhs);
HNSW(const HNSW& rhs) = delete;

HNSW(HNSW&& rhs) = delete;

Expand Down
40 changes: 40 additions & 0 deletions test/resource/unit_test/vector_index/cypher/vector_index.result
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
CALL db.createVertexLabelByJson('{"label":"person","primary":"id","type":"VERTEX","detach_property":true,"properties":[{"name":"id","type":"INT32","optional":false},{"name":"name","type":"STRING","optional":false,"index":false},{"name":"embedding1","type":"FLOAT_VECTOR","optional":false}, {"name":"embedding2","type":"FLOAT_VECTOR","optional":false}]}');
[]
CALL db.addVertexVectorIndex('person','embedding1', {dimension:4});
[]
CALL db.addVertexVectorIndex('person','embedding2', {dimension:4});
[]
CALL db.addVertexVectorIndex('person','name', {dimension:4});
[VectorIndexException] Only FLOAT_VECTOR type supports vector index
CALL db.showVertexVectorIndex();
[{"dimension":4,"distance_type":"l2","field_name":"embedding1","hnsm.ef_construction":100,"hnsm.m":16,"index_type":"hnsw","label_name":"person"},{"dimension":4,"distance_type":"l2","field_name":"embedding2","hnsm.ef_construction":100,"hnsm.m":16,"index_type":"hnsw","label_name":"person"}]
CREATE (n:person {id:1, name:'name1', embedding1: [1.0,1.0,1.0,1.0], embedding2: [11.0,11.0,11.0,11.0]});
[{"<SUMMARY>":"created 1 vertices, created 0 edges."}]
CREATE (n:person {id:2, name:'name2', embedding1: [2.0,2.0,2.0,2.0], embedding2: [12.0,12.0,12.0,12.0]});
[{"<SUMMARY>":"created 1 vertices, created 0 edges."}]
CALL db.upsertVertex('person', [{id:3, name:'name3', embedding1: [3.0,3.0,3.0,3.0], embedding2: [13.0,13.0,13.0,13.0]}, {id:4, name:'name4', embedding1: [4.0,4.0,4.0,4.0], embedding2: [14.0,14.0,14.0,14.0]}]);
[{"data_error":0,"index_conflict":0,"insert":2,"total":2,"update":0}]
CALL db.vertexVectorIndexQuery('person','embedding1',[1,2,3,4], {top_k:2, hnsw_ef_search:10});
[{"node":{"identity":1,"label":"person","properties":{"embedding1":[2.0,2.0,2.0,2.0],"embedding2":[12.0,12.0,12.0,12.0],"id":2,"name":"name2"}},"score":6.0},{"node":{"identity":2,"label":"person","properties":{"embedding1":[3.0,3.0,3.0,3.0],"embedding2":[13.0,13.0,13.0,13.0],"id":3,"name":"name3"}},"score":6.0}]
CALL db.vertexVectorIndexQuery('person','embedding1',[1,2,3,4], {top_k:2, hnsw_ef_search:10}) yield node return node.id;
[{"node.id":2},{"node.id":3}]
CALL db.vertexVectorIndexQuery('person','embedding1',[1,2,3,4], {top_k:20, hnsw_ef_search:100}) yield node return node.id;
[{"node.id":2},{"node.id":3},{"node.id":1},{"node.id":4}]
CALL db.vertexVectorIndexQuery('person','embedding2',[1,2,3,4], {top_k:2, hnsw_ef_search:10}) yield node return node.id;
[{"node.id":1},{"node.id":2}]
CALL db.upsertVertex('person', [{id:1, embedding1: [33.0,33.0,33.0,33.0]}]);
[{"data_error":0,"index_conflict":0,"insert":0,"total":1,"update":1}]
CALL db.vertexVectorIndexQuery('person','embedding1',[1,2,3,4], {top_k:2, hnsw_ef_search:10}) yield node return node.id;
[{"node.id":2},{"node.id":3}]
match(n:person {id:2}) delete n;
[{"<SUMMARY>":"deleted 1 vertices, deleted 0 edges."}]
CALL db.vertexVectorIndexQuery('person','embedding1',[1,2,3,4], {top_k:2, hnsw_ef_search:10}) yield node return node.id;
[{"node.id":3},{"node.id":4}]
CALL db.alterLabelDelFields('vertex', 'person', ['embedding1']);
[{"record_affected":3}]
CALL db.showVertexVectorIndex();
[{"dimension":4,"distance_type":"l2","field_name":"embedding2","hnsm.ef_construction":100,"hnsm.m":16,"index_type":"hnsw","label_name":"person"}]
CALL db.vertexVectorIndexQuery('person','embedding1',[1,2,3,4], {top_k:2, hnsw_ef_search:10}) yield node return node.id;
[FieldNotFound] Field [embedding1] does not exist.
CALL db.vertexVectorIndexQuery('person','embedding2',[1,2,3,4], {top_k:2, hnsw_ef_search:10}) yield node return node.id;
[{"node.id":1},{"node.id":3}]
20 changes: 20 additions & 0 deletions test/resource/unit_test/vector_index/cypher/vector_index.test
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
CALL db.createVertexLabelByJson('{"label":"person","primary":"id","type":"VERTEX","detach_property":true,"properties":[{"name":"id","type":"INT32","optional":false},{"name":"name","type":"STRING","optional":false,"index":false},{"name":"embedding1","type":"FLOAT_VECTOR","optional":false}, {"name":"embedding2","type":"FLOAT_VECTOR","optional":false}]}');
CALL db.addVertexVectorIndex('person','embedding1', {dimension:4});
CALL db.addVertexVectorIndex('person','embedding2', {dimension:4});
CALL db.addVertexVectorIndex('person','name', {dimension:4});
CALL db.showVertexVectorIndex();
CREATE (n:person {id:1, name:'name1', embedding1: [1.0,1.0,1.0,1.0], embedding2: [11.0,11.0,11.0,11.0]});
CREATE (n:person {id:2, name:'name2', embedding1: [2.0,2.0,2.0,2.0], embedding2: [12.0,12.0,12.0,12.0]});
CALL db.upsertVertex('person', [{id:3, name:'name3', embedding1: [3.0,3.0,3.0,3.0], embedding2: [13.0,13.0,13.0,13.0]}, {id:4, name:'name4', embedding1: [4.0,4.0,4.0,4.0], embedding2: [14.0,14.0,14.0,14.0]}]);
CALL db.vertexVectorIndexQuery('person','embedding1',[1,2,3,4], {top_k:2, hnsw_ef_search:10});
CALL db.vertexVectorIndexQuery('person','embedding1',[1,2,3,4], {top_k:2, hnsw_ef_search:10}) yield node return node.id;
CALL db.vertexVectorIndexQuery('person','embedding1',[1,2,3,4], {top_k:20, hnsw_ef_search:100}) yield node return node.id;
CALL db.vertexVectorIndexQuery('person','embedding2',[1,2,3,4], {top_k:2, hnsw_ef_search:10}) yield node return node.id;
CALL db.upsertVertex('person', [{id:1, embedding1: [33.0,33.0,33.0,33.0]}]);
CALL db.vertexVectorIndexQuery('person','embedding1',[1,2,3,4], {top_k:2, hnsw_ef_search:10}) yield node return node.id;
match(n:person {id:2}) delete n;
CALL db.vertexVectorIndexQuery('person','embedding1',[1,2,3,4], {top_k:2, hnsw_ef_search:10}) yield node return node.id;
CALL db.alterLabelDelFields('vertex', 'person', ['embedding1']);
CALL db.showVertexVectorIndex();
CALL db.vertexVectorIndexQuery('person','embedding1',[1,2,3,4], {top_k:2, hnsw_ef_search:10}) yield node return node.id;
CALL db.vertexVectorIndexQuery('person','embedding2',[1,2,3,4], {top_k:2, hnsw_ef_search:10}) yield node return node.id;
7 changes: 7 additions & 0 deletions test/test_cypher_v2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -596,3 +596,10 @@ TEST_F(TestCypherV2, TestEdgeIdQuery) {
std::string dir = test_suite_dir_ + "/edge_id_query/cypher";
test_files(dir);
}

TEST_F(TestCypherV2, TestVectorIndex) {
set_graph_type(GraphFactory::GRAPH_DATASET_TYPE::EMPTY);
set_query_type(lgraph::ut::QUERY_TYPE::NEWCYPHER);
std::string dir = test_suite_dir_ + "/vector_index/cypher";
test_files(dir);
}
26 changes: 21 additions & 5 deletions test/test_vsag_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class TestVsag : public TuGraphTest {
}
}
vector_index =
std::make_unique<lgraph::HNSW>("label", "name", "L2", "HNSW", dim, index_spec);
std::make_unique<lgraph::HNSW>("label", "name", "l2", "hnsw", dim, index_spec);
}
void TearDown() override {}
};
Expand All @@ -79,7 +79,7 @@ TEST_F(TestVsag, SaveAndLoadIndex) {
EXPECT_NO_THROW(vector_index->Add(vectors, vids, num_vectors));
std::vector<uint8_t> serialized_index = vector_index->Save();
ASSERT_FALSE(serialized_index.empty());
lgraph::HNSW vector_index_loaded("label", "name", "L2", "HNSW", dim, index_spec);
lgraph::HNSW vector_index_loaded("label", "name", "l2", "hnsw", dim, index_spec);
ASSERT_TRUE(vector_index_loaded.Build());
vector_index_loaded.Load(serialized_index);
std::vector<float> query(vectors[0].begin(), vectors[0].end());
Expand Down Expand Up @@ -148,7 +148,7 @@ TEST_F(TestVsag, restart) {
"'person','id','id','int64',false,'vector','float_vector',true)");
UT_EXPECT_TRUE(ret);
ret = client.CallCypher(
str, "CALL db.AddVertexVectorIndex('person','vector', {dimension:4})");
str, "CALL db.addVertexVectorIndex('person','vector', {dimension:4})");
UT_EXPECT_TRUE(ret);
ret = client.CallCypher(str, "CREATE (n:person {id:1, vector: [1.0,1.0,1.0,1.0]})");
UT_EXPECT_TRUE(ret);
Expand All @@ -157,8 +157,9 @@ TEST_F(TestVsag, restart) {
ret = client.CallCypher(str,
"CALL db.upsertVertex('person', [{id:3, vector: [3.0,3.0,3.0,3.0]},"
"{id:4, vector: [4.0,4.0,4.0,4.0]}])");
UT_EXPECT_TRUE(ret);
ret = client.CallCypher(str,"CALL db.vertexVectorIndexQuery" //NOLINT
"('person','vector',[1,2,3,4], 4, 10) YIELD node RETURN node.id");
"('person','vector',[1,2,3,4], {top_k:4, hnsw_ef_search:10}) YIELD node RETURN node.id");
UT_EXPECT_EQ(str, R"([{"node.id":2},{"node.id":3},{"node.id":1},{"node.id":4}])");
UT_EXPECT_TRUE(ret);
server->Kill();
Expand All @@ -171,10 +172,25 @@ TEST_F(TestVsag, restart) {
_detail::DEFAULT_ADMIN_NAME, _detail::DEFAULT_ADMIN_PASS);
std::string str;
auto ret = client.CallCypher(str, "CALL db.vertexVectorIndexQuery"
"('person','vector',[1,2,3,4], 4, 10) "
"('person','vector',[1,2,3,4], {top_k:4, hnsw_ef_search:10}) "
"YIELD node RETURN node.id");
UT_EXPECT_EQ(str, R"([{"node.id":2},{"node.id":3},{"node.id":1},{"node.id":4}])");
UT_EXPECT_TRUE(ret);
ret = client.CallCypher(str, "CALL db.alterLabelDelFields('vertex', 'person', ['vector'])");
UT_EXPECT_TRUE(ret);
server->Kill();
server->Wait();
}
{
auto server = StartLGraphServer(conf);
// create graphs
RpcClient client(UT_FMT("{}:{}", conf.bind_host, conf.rpc_port),
_detail::DEFAULT_ADMIN_NAME, _detail::DEFAULT_ADMIN_PASS);
std::string str;
auto ret = client.CallCypher(str, "CALL db.vertexVectorIndexQuery"
"('person','vector',[1,2,3,4], 4, 10) "
"YIELD node RETURN node.id");
UT_EXPECT_FALSE(ret);
server->Kill();
server->Wait();
}
Expand Down

0 comments on commit e0f4f46

Please sign in to comment.