diff --git a/biotransformers/wrappers/transformers_wrappers.py b/biotransformers/wrappers/transformers_wrappers.py index ebeb0f0..ef30739 100755 --- a/biotransformers/wrappers/transformers_wrappers.py +++ b/biotransformers/wrappers/transformers_wrappers.py @@ -306,11 +306,11 @@ def compute_logits( # Remove padded logits logits = [ - torch.from_numpy(logit.numpy().transpose()[:, :length].transpose()) + torch.from_numpy(logit.numpy().transpose()[:, 1:length].transpose()) for logit, length in zip(list(logits), lengths) ] labels = [ - torch.from_numpy(label.numpy().transpose()[:, :length].transpose()) + torch.from_numpy(label.numpy().transpose()[:, 1:length].transpose()) for label, length in zip(list(labels), lengths) ] @@ -377,7 +377,7 @@ def compute_probabilities( # Remove padded logits # Use transpose so that function works for MSA and sequence logits = [ - torch.from_numpy(logit.numpy().transpose()[:, :length].transpose()) + torch.from_numpy(logit.numpy().transpose()[:, 1:length].transpose()) for logit, length in zip(list(logits), lengths) ] # Set to -inf logits that correspond to tokens that are not in tokens list @@ -745,7 +745,7 @@ def finetune( random_token_prob, toks_per_batch, extra_toks_per_seq, - validation=False + validation=False, ) if self._num_gpus == 0: @@ -787,7 +787,7 @@ def finetune( save_name = self._save_model(save_path, lightning_model) else: save_name = self._save_model(save_path, lightning_model) - + # Load new model self._language_model._load_model(save_name) log.info("Training completed.")