diff --git a/colbert/indexing/faiss.py b/colbert/indexing/faiss.py index c0d55642..f52e118b 100644 --- a/colbert/indexing/faiss.py +++ b/colbert/indexing/faiss.py @@ -37,11 +37,11 @@ def load_sample(samples_paths, sample_fraction=None): return sample -def prepare_faiss_index(slice_samples_paths, partitions, sample_fraction=None): +def prepare_faiss_index(slice_samples_paths, partitions, m=16, sample_fraction=None): training_sample = load_sample(slice_samples_paths, sample_fraction=sample_fraction) dim = training_sample.shape[-1] - index = FaissIndex(dim, partitions) + index = FaissIndex(dim, partitions, m) print_message("#> Training with the vectors...") @@ -84,7 +84,7 @@ def index_faiss(args): assert not os.path.exists(output_path), output_path - index = prepare_faiss_index(slice_samples_paths, args.partitions, args.sample) + index = prepare_faiss_index(slice_samples_paths, args.partitions, args.m, args.sample) loaded_parts = queue.Queue(maxsize=1) diff --git a/colbert/indexing/faiss_index.py b/colbert/indexing/faiss_index.py index 7bb3a50d..4e3c5ab7 100644 --- a/colbert/indexing/faiss_index.py +++ b/colbert/indexing/faiss_index.py @@ -11,9 +11,10 @@ class FaissIndex(): - def __init__(self, dim, partitions): + def __init__(self, dim, partitions, m=16): self.dim = dim self.partitions = partitions + self.m = m self.gpu = FaissIndexGPU() self.quantizer, self.index = self._create_index() @@ -21,7 +22,7 @@ def __init__(self, dim, partitions): def _create_index(self): quantizer = faiss.IndexFlatL2(self.dim) # faiss.IndexHNSWFlat(dim, 32) - index = faiss.IndexIVFPQ(quantizer, self.dim, self.partitions, 16, 8) + index = faiss.IndexIVFPQ(quantizer, self.dim, self.partitions, self.m, 8) return quantizer, index diff --git a/colbert/utils/parser.py b/colbert/utils/parser.py index 7acda986..5d95e12e 100644 --- a/colbert/utils/parser.py +++ b/colbert/utils/parser.py @@ -78,6 +78,7 @@ def add_index_use_input(self): self.add_argument('--index_root', dest='index_root', required=True) self.add_argument('--index_name', dest='index_name', required=True) self.add_argument('--partitions', dest='partitions', default=None, type=int) + self.add_argument('--m', dest='m', default=16, type=int) def add_retrieval_input(self): self.add_index_use_input()