From 7310165a18546263767e3436cbbae946d42392fa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Benjamin=20Clavi=C3=A9?= Date: Sat, 13 Jul 2024 13:10:54 +0200 Subject: [PATCH 1/3] schedule-free --- colbert/infra/config/settings.py | 10 ++++ colbert/training/training.py | 97 ++++++++++++++++++++++++-------- colbert/training/utils.py | 24 ++++++-- 3 files changed, 103 insertions(+), 28 deletions(-) diff --git a/colbert/infra/config/settings.py b/colbert/infra/config/settings.py index 3e1f805a..78cc5f0f 100644 --- a/colbert/infra/config/settings.py +++ b/colbert/infra/config/settings.py @@ -156,6 +156,16 @@ class TrainingSettings: model_name: str = DefaultVal(None) # DefaultVal('bert-base-uncased') + # V2.5 + + schedule_free: bool = DefaultVal(False) + + quant_aware: bool = DefaultVal(False) + + highest_quant_level: int = DefaultVal(8) + + lowest_quant_level: int = DefaultVal(2) + @dataclass class IndexingSettings: diff --git a/colbert/training/training.py b/colbert/training/training.py index b409b2c5..701b0db9 100644 --- a/colbert/training/training.py +++ b/colbert/training/training.py @@ -19,9 +19,8 @@ from colbert.training.utils import print_progress, manage_checkpoints - def train(config: ColBERTConfig, triples, queries=None, collection=None): - config.checkpoint = config.checkpoint or 'bert-base-uncased' + config.checkpoint = config.checkpoint or "bert-base-uncased" if config.rank < 1: config.help() @@ -34,13 +33,32 @@ def train(config: ColBERTConfig, triples, queries=None, collection=None): assert config.bsize % config.nranks == 0, (config.bsize, config.nranks) config.bsize = config.bsize // config.nranks - print("Using config.bsize =", config.bsize, "(per process) and config.accumsteps =", config.accumsteps) + print( + "Using config.bsize =", + config.bsize, + "(per process) and config.accumsteps =", + config.accumsteps, + ) if collection is not None: if config.reranker: - reader = RerankBatcher(config, triples, queries, collection, (0 if config.rank == -1 else config.rank), config.nranks) + reader = RerankBatcher( + config, + triples, + queries, + collection, + (0 if config.rank == -1 else config.rank), + config.nranks, + ) else: - reader = LazyBatcher(config, triples, queries, collection, (0 if config.rank == -1 else config.rank), config.nranks) + reader = LazyBatcher( + config, + triples, + queries, + collection, + (0 if config.rank == -1 else config.rank), + config.nranks, + ) else: raise NotImplementedError() @@ -52,18 +70,38 @@ def train(config: ColBERTConfig, triples, queries=None, collection=None): colbert = colbert.to(DEVICE) colbert.train() - colbert = torch.nn.parallel.DistributedDataParallel(colbert, device_ids=[config.rank], - output_device=config.rank, - find_unused_parameters=True) - - optimizer = AdamW(filter(lambda p: p.requires_grad, colbert.parameters()), lr=config.lr, eps=1e-8) + colbert = torch.nn.parallel.DistributedDataParallel( + colbert, + device_ids=[config.rank], + output_device=config.rank, + find_unused_parameters=True, + ) + + if not config.schedule_free: + optimizer = AdamW( + filter(lambda p: p.requires_grad, colbert.parameters()), + lr=config.lr, + eps=1e-8, + ) + else: + optimizer = AdamWScheduleFree( + filter(lambda p: p.requires_grad, colbert.parameters()), + lr=config.lr, + warmup_steps=config.warmup, + ) + optimizer.train() optimizer.zero_grad() scheduler = None if config.warmup is not None: - print(f"#> LR will use {config.warmup} warmup steps and linear decay over {config.maxsteps} steps.") - scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=config.warmup, - num_training_steps=config.maxsteps) + print( + f"#> LR will use {config.warmup} warmup steps and linear decay over {config.maxsteps} steps." + ) + scheduler = get_linear_schedule_with_warmup( + optimizer, + num_warmup_steps=config.warmup, + num_training_steps=config.maxsteps, + ) warmup_bert = config.warmup_bert if warmup_bert is not None: @@ -100,7 +138,10 @@ def train(config: ColBERTConfig, triples, queries=None, collection=None): encoding, target_scores = batch encoding = [encoding.to(DEVICE)] - scores = colbert(*encoding) + if not config.quant_aware: + scores = colbert(*encoding) + else: + raise NotImplementedError if config.use_ib_negatives: scores, ib_loss = scores @@ -108,18 +149,24 @@ def train(config: ColBERTConfig, triples, queries=None, collection=None): scores = scores.view(-1, config.nway) if len(target_scores) and not config.ignore_scores: - target_scores = torch.tensor(target_scores).view(-1, config.nway).to(DEVICE) + target_scores = ( + torch.tensor(target_scores).view(-1, config.nway).to(DEVICE) + ) target_scores = target_scores * config.distillation_alpha - target_scores = torch.nn.functional.log_softmax(target_scores, dim=-1) + target_scores = torch.nn.functional.log_softmax( + target_scores, dim=-1 + ) log_scores = torch.nn.functional.log_softmax(scores, dim=-1) - loss = torch.nn.KLDivLoss(reduction='batchmean', log_target=True)(log_scores, target_scores) + loss = torch.nn.KLDivLoss(reduction="batchmean", log_target=True)( + log_scores, target_scores + ) else: - loss = nn.CrossEntropyLoss()(scores, labels[:scores.size(0)]) + loss = nn.CrossEntropyLoss()(scores, labels[: scores.size(0)]) if config.use_ib_negatives: if config.rank < 1: - print('\t\t\t\t', loss.item(), ib_loss.item()) + print("\t\t\t\t", loss.item(), ib_loss.item()) loss += ib_loss @@ -139,16 +186,22 @@ def train(config: ColBERTConfig, triples, queries=None, collection=None): if config.rank < 1: print_message(batch_idx, train_loss) - manage_checkpoints(config, colbert, optimizer, batch_idx+1, savepath=None) + manage_checkpoints(config, colbert, optimizer, batch_idx + 1, savepath=None) if config.rank < 1: print_message("#> Done with all triples!") - ckpt_path = manage_checkpoints(config, colbert, optimizer, batch_idx+1, savepath=None, consumed_all_triples=True) + ckpt_path = manage_checkpoints( + config, + colbert, + optimizer, + batch_idx + 1, + savepath=None, + consumed_all_triples=True, + ) return ckpt_path # TODO: This should validate and return the best checkpoint, not just the last one. - def set_bert_grad(colbert, value): try: for p in colbert.bert.parameters(): diff --git a/colbert/training/utils.py b/colbert/training/utils.py index 1d3c7fe6..5a14d197 100644 --- a/colbert/training/utils.py +++ b/colbert/training/utils.py @@ -8,16 +8,23 @@ def print_progress(scores): - positive_avg, negative_avg = round(scores[:, 0].mean().item(), 2), round(scores[:, 1].mean().item(), 2) - print("#>>> ", positive_avg, negative_avg, '\t\t|\t\t', positive_avg - negative_avg) + positive_avg, negative_avg = ( + round(scores[:, 0].mean().item(), 2), + round(scores[:, 1].mean().item(), 2), + ) + print( + "#>>> ", positive_avg, negative_avg, "\t\t|\t\t", positive_avg - negative_avg + ) -def manage_checkpoints(args, colbert, optimizer, batch_idx, savepath=None, consumed_all_triples=False): +def manage_checkpoints( + args, colbert, optimizer, batch_idx, savepath=None, consumed_all_triples=False +): # arguments = dict(args) # TODO: Call provenance() on the values that support it?? - checkpoints_path = savepath or os.path.join(Run().path_, 'checkpoints') + checkpoints_path = savepath or os.path.join(Run().path_, "checkpoints") name = None try: @@ -27,24 +34,26 @@ def manage_checkpoints(args, colbert, optimizer, batch_idx, savepath=None, consu if not os.path.exists(checkpoints_path): os.makedirs(checkpoints_path) - + path_save = None if consumed_all_triples or (batch_idx % 2000 == 0): # name = os.path.join(path, "colbert.dnn") # save_checkpoint(name, 0, batch_idx, colbert, optimizer, arguments) path_save = os.path.join(checkpoints_path, "colbert") + optimizer.eval() if batch_idx in SAVED_CHECKPOINTS: # name = os.path.join(path, "colbert-{}.dnn".format(batch_idx)) # save_checkpoint(name, 0, batch_idx, colbert, optimizer, arguments) path_save = os.path.join(checkpoints_path, f"colbert-{batch_idx}") + optimizer.eval() if path_save: print(f"#> Saving a checkpoint to {path_save} ..") checkpoint = {} - checkpoint['batch'] = batch_idx + checkpoint["batch"] = batch_idx # checkpoint['epoch'] = 0 # checkpoint['model_state_dict'] = model.state_dict() # checkpoint['optimizer_state_dict'] = optimizer.state_dict() @@ -52,4 +61,7 @@ def manage_checkpoints(args, colbert, optimizer, batch_idx, savepath=None, consu save(path_save) + if not consumed_all_triples: + optimizer.train() + return path_save From fe9aaec6571240e7e293ef3739206450f553f67c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Benjamin=20Clavi=C3=A9?= Date: Sat, 13 Jul 2024 20:49:42 +0200 Subject: [PATCH 2/3] config update --- colbert/infra/config/settings.py | 17 ++++++++++++ colbert/training/training.py | 47 +++++++++++++++++++++++++++----- 2 files changed, 57 insertions(+), 7 deletions(-) diff --git a/colbert/infra/config/settings.py b/colbert/infra/config/settings.py index 78cc5f0f..14fd5c04 100644 --- a/colbert/infra/config/settings.py +++ b/colbert/infra/config/settings.py @@ -160,6 +160,23 @@ class TrainingSettings: schedule_free: bool = DefaultVal(False) + kldiv_loss: bool = DefaultVal(True) + + marginse_loss: bool = DefaultVal(False) + + kldiv_weight: float = DefaultVal(1.0) + + marginse_weight: float = DefaultVal(0.05) + + ib_loss_weight: float = DefaultVal(1.0) + + normalise_training_scores: bool = DefaultVal(False) + + # Can be 'minmax', 'querylen' + normalization_method: str = DefaultVal("minmax") + + # TODO + quant_aware: bool = DefaultVal(False) highest_quant_level: int = DefaultVal(8) diff --git a/colbert/training/training.py b/colbert/training/training.py index 701b0db9..d856aaff 100644 --- a/colbert/training/training.py +++ b/colbert/training/training.py @@ -145,22 +145,55 @@ def train(config: ColBERTConfig, triples, queries=None, collection=None): if config.use_ib_negatives: scores, ib_loss = scores + ib_loss = ib_loss * config.ib_loss_weight scores = scores.view(-1, config.nway) + if config.normalise_training_scores: + if config.normalization_method == "minmax": + scores = (scores - scores.min(dim=-1, keepdim=True)[0]) / ( + scores.max(dim=-1, keepdim=True)[0] + - scores.min(dim=-1, keepdim=True)[0] + + 1e-8 + ) + elif config.normalization_method == "querylen": + scores = scores / ( + queries.shape[1] + 1e-8 + ) # Divide by the number of tokens in the queries if len(target_scores) and not config.ignore_scores: target_scores = ( torch.tensor(target_scores).view(-1, config.nway).to(DEVICE) ) target_scores = target_scores * config.distillation_alpha - target_scores = torch.nn.functional.log_softmax( - target_scores, dim=-1 - ) - log_scores = torch.nn.functional.log_softmax(scores, dim=-1) - loss = torch.nn.KLDivLoss(reduction="batchmean", log_target=True)( - log_scores, target_scores - ) + if config.kldiv_loss: + target_scores = torch.nn.functional.log_softmax( + target_scores, dim=-1 + ) + + log_scores = torch.nn.functional.log_softmax(scores, dim=-1) + kldivloss = torch.nn.KLDivLoss( + reduction="batchmean", log_target=True + )(log_scores, target_scores) + + if config.marginmse_loss: + margin = scores[:, 0] - scores[:, 1:] + target_margin = target_scores[:, 0] - target_scores[:, 1:] + marginmse_loss = torch.nn.MSELoss()(margin, target_margin) + + if config.kldiv_loss and config.marginmse_loss: + loss = ( + kldivloss * config.kldiv_weight + + marginmse_loss * config.marginmse_weight + ) + elif config.kldiv_loss: + loss = kldivloss + elif config.marginmse_loss: + loss = marginmse_loss + else: + raise ValueError( + "One or both of config.kldiv_loss and config.marginmse_loss must be True if distillation is enabled!" + ) else: loss = nn.CrossEntropyLoss()(scores, labels[: scores.size(0)]) From dd0b1e2ee94dcbb6c8705fa272bbd813f1c6e5a6 Mon Sep 17 00:00:00 2001 From: bclavie Date: Fri, 2 Aug 2024 12:57:00 +0000 Subject: [PATCH 3/3] v2.5 training --- colbert/infra/config/settings.py | 10 +- .../tokenization/query_tokenization.py | 191 ++++++++++++++++++ colbert/parameters.py | 1 + colbert/searcher.py | 2 +- colbert/training/training.py | 44 +++- colbert/training/utils.py | 10 +- colbert/utils/amp.py | 6 +- 7 files changed, 247 insertions(+), 17 deletions(-) diff --git a/colbert/infra/config/settings.py b/colbert/infra/config/settings.py index 14fd5c04..56d1c508 100644 --- a/colbert/infra/config/settings.py +++ b/colbert/infra/config/settings.py @@ -119,6 +119,10 @@ class QuerySettings: query_maxlen: int = DefaultVal(32) attend_to_mask_tokens: bool = DefaultVal(False) interaction: str = DefaultVal("colbert") + # V2.5 + cap_padding: int = DefaultVal(0) + dynamic_query_maxlen: bool = DefaultVal(False) + dynamic_querylen_multiples: int = DefaultVal(32) @dataclass @@ -160,13 +164,15 @@ class TrainingSettings: schedule_free: bool = DefaultVal(False) + schedule_free_wd: float = DefaultVal(0.0) + kldiv_loss: bool = DefaultVal(True) - marginse_loss: bool = DefaultVal(False) + marginmse_loss: bool = DefaultVal(False) kldiv_weight: float = DefaultVal(1.0) - marginse_weight: float = DefaultVal(0.05) + marginmse_weight: float = DefaultVal(0.05) ib_loss_weight: float = DefaultVal(1.0) diff --git a/colbert/modeling/tokenization/query_tokenization.py b/colbert/modeling/tokenization/query_tokenization.py index 78668fce..a602e18e 100644 --- a/colbert/modeling/tokenization/query_tokenization.py +++ b/colbert/modeling/tokenization/query_tokenization.py @@ -77,9 +77,30 @@ def tensorize(self, batch_text, bsize=None, context=None, full_length_search=Fal ids, mask = obj['input_ids'], obj['attention_mask'] # postprocess for the [Q] marker and the [MASK] augmentation + # Log original size ids[:, 1] = self.Q_marker_token_id + unpadded_sizes = (ids != self.pad_token_id).sum(dim=1) + # Log original sizes + original_sizes = unpadded_sizes.clone() ids[ids == self.pad_token_id] = self.mask_token_id + # Shorten ids and mask if necessary + if self.config.cap_padding > 0: + for i in range(ids.size(0)): + unpadded_size = unpadded_sizes[i].item() + # Add 8 to the query size itself, per query + max_allowed_length = unpadded_size + self.config.cap_padding + if ids.size(1) > max_allowed_length: + ids[i, max_allowed_length:] = self.pad_token_id + mask[i, max_allowed_length:] = 0 + # Trim the batch to the maximum allowed length across all queries + max_length = max(unpadded_size + self.config.cap_padding for unpadded_size in unpadded_sizes) + max_length = min(max_length, ids.size(1)) + ids = ids[:, :max_length] + mask = mask[:, :max_length] + # Note: This implementation already adds 8 (or the value of cap_padding) to each query individually + + if context is not None: assert len(context) == len(batch_text), (len(context), len(batch_text)) @@ -116,3 +137,173 @@ def tensorize(self, batch_text, bsize=None, context=None, full_length_search=Fal # Ensure that query_maxlen <= length <= 500 tokens def max_len(self, length): return min(500, max(self.query_maxlen, length)) + + +import torch +# import math + +# from colbert.modeling.hf_colbert import class_factory +# from colbert.infra import ColBERTConfig +# from colbert.modeling.tokenization.utils import _split_into_batches +# from colbert.utils.utils import batch +# from colbert.parameters import DEVICE + + +# class QueryTokenizer(): +# def __init__(self, config: ColBERTConfig, verbose: int = 3): +# HF_ColBERT = class_factory(config.checkpoint) +# self.tok = HF_ColBERT.raw_tokenizer_from_pretrained(config.checkpoint) +# self.verbose = verbose + +# self.config = config +# self.query_maxlen = config.query_maxlen +# self.background_maxlen = 512 - self.query_maxlen + 1 # FIXME: Make this configurable + +# self.Q_marker_token, self.Q_marker_token_id = config.query_token, self.tok.convert_tokens_to_ids(config.query_token_id) +# self.cls_token, self.cls_token_id = self.tok.cls_token, self.tok.cls_token_id +# self.sep_token, self.sep_token_id = self.tok.sep_token, self.tok.sep_token_id +# self.mask_token, self.mask_token_id = self.tok.mask_token, self.tok.mask_token_id +# self.pad_token,self.pad_token_id = self.tok.pad_token,self.tok.pad_token_id +# self.used = False + +# def tokenize(self, batch_text, add_special_tokens=False): +# assert type(batch_text) in [list, tuple], (type(batch_text)) + +# tokens = [self.tok.tokenize(x, add_special_tokens=False) for x in batch_text] + +# if not add_special_tokens: +# return tokens + +# prefix, suffix = [self.cls_token, self.Q_marker_token], [self.sep_token] +# tokens = [prefix + lst + suffix + [self.mask_token] * (self.query_maxlen - (len(lst)+3)) for lst in tokens] + +# return tokens + +# def encode(self, batch_text, add_special_tokens=False): +# assert type(batch_text) in [list, tuple], (type(batch_text)) + +# ids = self.tok(batch_text, add_special_tokens=False).to(DEVICE)['input_ids'] + +# if not add_special_tokens: +# return ids + +# prefix, suffix = [self.cls_token_id, self.Q_marker_token_id], [self.sep_token_id] +# ids = [prefix + lst + suffix + [self.mask_token_id] * (self.query_maxlen - (len(lst)+3)) for lst in ids] + +# return ids + +# def tensorize(self, batch_text, bsize=None, context=None, full_length_search=False): +# assert type(batch_text) in [list, tuple], (type(batch_text)) + +# # add placehold for the [Q] marker +# batch_text = ['. ' + x for x in batch_text] + +# # Full length search is only available for single inference (for now) +# # Batched full length search requires far deeper changes to the code base +# assert(full_length_search == False or (type(batch_text) == list and len(batch_text) == 1)) + +# if full_length_search: +# # Tokenize each string in the batch +# un_truncated_ids = self.tok(batch_text, add_special_tokens=False).to(DEVICE)['input_ids'] +# # Get the longest length in the batch +# max_length_in_batch = max(len(x) for x in un_truncated_ids) +# # Set the max length +# max_length = self.max_len(max_length_in_batch) +# else: +# # Max length is the default max length from the config +# max_length = self.query_maxlen + +# if self.config.dynamic_query_maxlen: +# max_length = self.config.doc_maxlen +# obj = self.tok(batch_text, padding=False, truncation=True, +# return_tensors='pt', max_length=max_length).to(DEVICE) + +# ids, mask = obj['input_ids'], obj['attention_mask'] + +# # postprocess for the [Q] marker and the [MASK] augmentation +# # Log original size +# ids[:, 1] = self.Q_marker_token_id +# unpadded_sizes = (ids != self.pad_token_id).sum(dim=1) +# # Log original sizes +# original_sizes = unpadded_sizes.clone() +# ids[ids == self.pad_token_id] = self.mask_token_id + +# # Shorten ids and mask if necessary +# if self.config.cap_padding > 0: +# for i in range(ids.size(0)): +# unpadded_size = unpadded_sizes[i].item() +# # Add 8 to the query size itself, per query +# max_allowed_length = unpadded_size + self.config.cap_padding +# if ids.size(1) > max_allowed_length: +# ids[i, max_allowed_length:] = self.pad_token_id +# mask[i, max_allowed_length:] = 0 +# # Trim the batch to the maximum allowed length across all queries +# max_length = max(unpadded_size + self.config.cap_padding for unpadded_size in unpadded_sizes) +# max_length = min(max_length, ids.size(1)) +# ids = ids[:, :max_length] +# mask = mask[:, :max_length] +# # Note: This implementation already adds 8 (or the value of cap_padding) to each query individually + +# if self.config.dynamic_query_maxlen: +# new_ids = [] +# new_mask = [] +# for i in range(ids.size(0)): +# original_length = original_sizes[i].item() +# if original_length % self.config.dynamic_querylen_multiples <= 8: +# QLEN = original_length + 8 +# else: +# QLEN = math.ceil(original_length / self.config.dynamic_querylen_multiples) * self.config.dynamic_querylen_multiples + +# if original_length < QLEN: +# print("Entering padding") +# print("Original length: ", original_length) +# print("QLEN: ", QLEN) +# pad_length = QLEN - original_length +# padded_ids = ids[i, :original_length].tolist() + [self.mask_token_id] * pad_length +# padded_mask = mask[i, :original_length].tolist() + [0] * pad_length +# else: +# padded_ids = ids.tolist() +# padded_mask = mask.tolist() + +# new_ids.append(padded_ids) +# new_mask.append(padded_mask) + +# ids = torch.tensor(new_ids, device=DEVICE) +# mask = torch.tensor(new_mask, device=DEVICE) + +# if context is not None: +# assert len(context) == len(batch_text), (len(context), len(batch_text)) + +# obj_2 = self.tok(context, padding='longest', truncation=True, +# return_tensors='pt', max_length=self.background_maxlen).to(DEVICE) + +# ids_2, mask_2 = obj_2['input_ids'][:, 1:], obj_2['attention_mask'][:, 1:] # Skip the first [SEP] + +# ids = torch.cat((ids, ids_2), dim=-1) +# mask = torch.cat((mask, mask_2), dim=-1) + +# if self.config.attend_to_mask_tokens: +# mask[ids == self.mask_token_id] = 1 +# assert mask.sum().item() == mask.size(0) * mask.size(1), mask + +# if bsize: +# batches = _split_into_batches(ids, mask, bsize) +# return batches + +# if self.used is False: +# self.used = True + +# firstbg = (context is None) or context[0] +# if self.verbose > 1: +# print() +# print("#> QueryTokenizer.tensorize(batch_text[0], batch_background[0], bsize) ==") +# print(f"#> Input: {batch_text[0]}, \t\t {firstbg}, \t\t {bsize}") +# print(f"#> Output IDs: {ids[0].size()}, {ids[0]}") +# print(f"#> Output Mask: {mask[0].size()}, {mask[0]}") +# print() + +# return ids, mask + +# # Ensure that query_maxlen <= length <= 500 tokens +# def max_len(self, length): +# return min(500, max(self.query_maxlen, length)) diff --git a/colbert/parameters.py b/colbert/parameters.py index beaafd0e..4f64f3c4 100644 --- a/colbert/parameters.py +++ b/colbert/parameters.py @@ -5,6 +5,7 @@ SAVED_CHECKPOINTS = [32*1000, 100*1000, 150*1000, 200*1000, 250*1000, 300*1000, 400*1000] SAVED_CHECKPOINTS += [10*1000, 20*1000, 30*1000, 40*1000, 50*1000, 60*1000, 70*1000, 80*1000, 90*1000] SAVED_CHECKPOINTS += [25*1000, 50*1000, 75*1000] +SAVED_CHECKPOINTS += [2000, 5000, 10000, 15000, 20000, 25000, 30000, 35000, 40000, 45000] SAVED_CHECKPOINTS = set(SAVED_CHECKPOINTS) diff --git a/colbert/searcher.py b/colbert/searcher.py index 8bc07c50..1bed271b 100644 --- a/colbert/searcher.py +++ b/colbert/searcher.py @@ -55,7 +55,7 @@ def configure(self, **kw_args): def encode(self, text: TextQueries, full_length_search=False): queries = text if type(text) is list else [text] - bsize = 128 if len(queries) > 128 else None + bsize = 512 if len(queries) > 512 else None self.checkpoint.query_tokenizer.query_maxlen = self.config.query_maxlen Q = self.checkpoint.queryFromText(queries, bsize=bsize, to_cpu=True, full_length_search=full_length_search) diff --git a/colbert/training/training.py b/colbert/training/training.py index d856aaff..9d3766b6 100644 --- a/colbert/training/training.py +++ b/colbert/training/training.py @@ -18,6 +18,8 @@ from colbert.utils.utils import print_message from colbert.training.utils import print_progress, manage_checkpoints +from schedulefree import AdamWScheduleFree + def train(config: ColBERTConfig, triples, queries=None, collection=None): config.checkpoint = config.checkpoint or "bert-base-uncased" @@ -77,23 +79,30 @@ def train(config: ColBERTConfig, triples, queries=None, collection=None): find_unused_parameters=True, ) - if not config.schedule_free: + if config.schedule_free is False: optimizer = AdamW( filter(lambda p: p.requires_grad, colbert.parameters()), lr=config.lr, eps=1e-8, ) else: + print("WARNING, USING SCHEDULE FREE") + print("WARNING, USING SCHEDULE FREE") + print("WARNING, USING SCHEDULE FREE") + print("WARNING, USING SCHEDULE FREE") + print("WARNING, USING SCHEDULE FREE") optimizer = AdamWScheduleFree( filter(lambda p: p.requires_grad, colbert.parameters()), lr=config.lr, warmup_steps=config.warmup, + weight_decay=config.schedule_free_wd, ) - optimizer.train() + if config.schedule_free: + optimizer.train() optimizer.zero_grad() scheduler = None - if config.warmup is not None: + if config.warmup is not None and config.schedule_free is False: print( f"#> LR will use {config.warmup} warmup steps and linear decay over {config.maxsteps} steps." ) @@ -150,6 +159,7 @@ def train(config: ColBERTConfig, triples, queries=None, collection=None): scores = scores.view(-1, config.nway) if config.normalise_training_scores: if config.normalization_method == "minmax": + print('norm') scores = (scores - scores.min(dim=-1, keepdim=True)[0]) / ( scores.max(dim=-1, keepdim=True)[0] - scores.min(dim=-1, keepdim=True)[0] @@ -157,7 +167,7 @@ def train(config: ColBERTConfig, triples, queries=None, collection=None): ) elif config.normalization_method == "querylen": scores = scores / ( - queries.shape[1] + 1e-8 + config.query_maxlen + 1e-8 ) # Divide by the number of tokens in the queries if len(target_scores) and not config.ignore_scores: @@ -177,14 +187,16 @@ def train(config: ColBERTConfig, triples, queries=None, collection=None): )(log_scores, target_scores) if config.marginmse_loss: - margin = scores[:, 0] - scores[:, 1:] - target_margin = target_scores[:, 0] - target_scores[:, 1:] + margin = scores[:, 0].unsqueeze(1) - scores[:, 1:] + target_margin = target_scores[:, 0].unsqueeze(1) - target_scores[:, 1:] marginmse_loss = torch.nn.MSELoss()(margin, target_margin) if config.kldiv_loss and config.marginmse_loss: + weighted_kldiv = kldivloss * config.kldiv_weight + weighted_marginmse = marginmse_loss * config.marginmse_weight loss = ( - kldivloss * config.kldiv_weight - + marginmse_loss * config.marginmse_weight + weighted_kldiv + + weighted_marginmse ) elif config.kldiv_loss: loss = kldivloss @@ -195,12 +207,14 @@ def train(config: ColBERTConfig, triples, queries=None, collection=None): "One or both of config.kldiv_loss and config.marginmse_loss must be True if distillation is enabled!" ) else: + raise ValueError("crossentropy loss shouldn't be used here") loss = nn.CrossEntropyLoss()(scores, labels[: scores.size(0)]) if config.use_ib_negatives: if config.rank < 1: print("\t\t\t\t", loss.item(), ib_loss.item()) + og_loss = loss loss += ib_loss loss = loss / config.accumsteps @@ -215,9 +229,22 @@ def train(config: ColBERTConfig, triples, queries=None, collection=None): train_loss = this_batch_loss if train_loss is None else train_loss train_loss = train_loss_mu * train_loss + (1 - train_loss_mu) * this_batch_loss + if config.schedule_free: + assert scheduler is None + amp.step(colbert, optimizer, scheduler) if config.rank < 1: + if config.use_ib_negatives: + print_message(f"IB Loss: {ib_loss}") + print_message(f"KL-D loss: {og_loss}") + if config.kldiv_loss and config.marginmse_loss: + TOTAL = weighted_kldiv + weighted_marginmse + kldiv_proportion = weighted_kldiv / TOTAL + marginmse_proportion = weighted_marginmse / TOTAL + print_message(f"Weighted KL-D loss: {weighted_kldiv:.4f}") + print_message(f"Weighted MarginMSE loss: {weighted_marginmse:.4f}") + print_message(f"Respective proportions: KL-D {kldiv_proportion:.2%}, MarginMSE {marginmse_proportion:.2%}") print_message(batch_idx, train_loss) manage_checkpoints(config, colbert, optimizer, batch_idx + 1, savepath=None) @@ -230,6 +257,7 @@ def train(config: ColBERTConfig, triples, queries=None, collection=None): batch_idx + 1, savepath=None, consumed_all_triples=True, + is_schedule_free=config.schedule_free, ) return ckpt_path # TODO: This should validate and return the best checkpoint, not just the last one. diff --git a/colbert/training/utils.py b/colbert/training/utils.py index 5a14d197..b3926187 100644 --- a/colbert/training/utils.py +++ b/colbert/training/utils.py @@ -18,7 +18,7 @@ def print_progress(scores): def manage_checkpoints( - args, colbert, optimizer, batch_idx, savepath=None, consumed_all_triples=False + args, colbert, optimizer, batch_idx, savepath=None, consumed_all_triples=False, is_schedule_free=False ): # arguments = dict(args) @@ -41,13 +41,15 @@ def manage_checkpoints( # name = os.path.join(path, "colbert.dnn") # save_checkpoint(name, 0, batch_idx, colbert, optimizer, arguments) path_save = os.path.join(checkpoints_path, "colbert") - optimizer.eval() + if is_schedule_free: + optimizer.eval() if batch_idx in SAVED_CHECKPOINTS: # name = os.path.join(path, "colbert-{}.dnn".format(batch_idx)) # save_checkpoint(name, 0, batch_idx, colbert, optimizer, arguments) path_save = os.path.join(checkpoints_path, f"colbert-{batch_idx}") - optimizer.eval() + if is_schedule_free: + optimizer.eval() if path_save: print(f"#> Saving a checkpoint to {path_save} ..") @@ -61,7 +63,7 @@ def manage_checkpoints( save(path_save) - if not consumed_all_triples: + if not consumed_all_triples and is_schedule_free: optimizer.train() return path_save diff --git a/colbert/utils/amp.py b/colbert/utils/amp.py index 2a0dadc1..29d7a6a1 100644 --- a/colbert/utils/amp.py +++ b/colbert/utils/amp.py @@ -23,12 +23,14 @@ def backward(self, loss): def step(self, colbert, optimizer, scheduler=None): if self.activated: self.scaler.unscale_(optimizer) - torch.nn.utils.clip_grad_norm_(colbert.parameters(), 2.0, error_if_nonfinite=False) + if scheduler is not None: + torch.nn.utils.clip_grad_norm_(colbert.parameters(), 2.0, error_if_nonfinite=False) self.scaler.step(optimizer) self.scaler.update() else: - torch.nn.utils.clip_grad_norm_(colbert.parameters(), 2.0) + if scheduler is not None: + torch.nn.utils.clip_grad_norm_(colbert.parameters(), 2.0) optimizer.step() if scheduler is not None: