diff --git a/src/hnsw/utils.c b/src/hnsw/utils.c index 82616a24..77dd47a8 100644 --- a/src/hnsw/utils.c +++ b/src/hnsw/utils.c @@ -78,8 +78,13 @@ void CheckMem(int limit, Relation index, usearch_index_t uidx, uint32 n_nodes, c double M = ldb_HnswGetM(index); double mL = 1 / log(M); metadata_t meta = usearch_index_metadata(uidx, &error); - // todo:: update sizeof(float) to correct vector size once #19 is merged - node_size = UsearchNodeBytes(&meta, meta.dimensions * sizeof(float), (int)round(mL + 1)); + int vector_bytes_num = divide_round_up(meta.dimensions * GetUsearchBitsPerScalar(GetUsearchScalarKindFromIndexMeta(meta)), 8); + + // use the node size at level `(int)rount(mL + 1)` as the average node size, + // the sizes of nodes in different levels is actually not linearly related, but + // since nodes are exponentially distributed between levels, dominated by bottom level, + // this is a reasonably good approximation. + node_size = UsearchNodeBytes(&meta, vector_bytes_num, (int)round(mL + 1)); } // todo:: there's figure out a way to check this in pg <= 12 #if PG_VERSION_NUM >= 130000 diff --git a/src/hnsw/utils.h b/src/hnsw/utils.h index 99f6ec3b..02f0303c 100644 --- a/src/hnsw/utils.h +++ b/src/hnsw/utils.h @@ -17,6 +17,25 @@ uint32 EstimateRowCount(Relation heap); int32 GetColumnAttributeNumber(Relation rel, const char *columnName); usearch_metric_kind_t GetMetricKindFromStr(char *metric_kind_str); +inline size_t divide_round_up(size_t num, size_t denominator) { + return (num + denominator - 1) / denominator; +} + +inline size_t GetUsearchBitsPerScalar(usearch_scalar_kind_t scalar_kind) { + switch (scalar_kind) { + case usearch_scalar_f64_k: return 64; + case usearch_scalar_f32_k: return 32; + case usearch_scalar_f16_k: return 16; + case usearch_scalar_i8_k: return 8; + case usearch_scalar_b1_k: return 1; + default: return 0; + } +} + +inline usearch_scalar_kind_t GetUsearchScalarKindFromIndexMeta(metadata_t meta) { + return meta.init_options.quantization; +} + // hoping to throw the error via an assertion, if those are on, before elog(ERROR)-ing as a last resort // We prefer Assert() because this function is used in contexts where the stack contains non-POD types // in which case elog-s long jumps cause undefined behaviour.