Skip to content

Commit

Permalink
fix: fix length loglikelihood
Browse files Browse the repository at this point in the history
  • Loading branch information
delfosseaurelien committed Jul 8, 2021
1 parent d7f86a2 commit 29ef0b8
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions biotransformers/wrappers/transformers_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,11 +306,11 @@ def compute_logits(

# Remove padded logits
logits = [
torch.from_numpy(logit.numpy().transpose()[:, :length].transpose())
torch.from_numpy(logit.numpy().transpose()[:, 1:length].transpose())
for logit, length in zip(list(logits), lengths)
]
labels = [
torch.from_numpy(label.numpy().transpose()[:, :length].transpose())
torch.from_numpy(label.numpy().transpose()[:, 1:length].transpose())
for label, length in zip(list(labels), lengths)
]

Expand Down Expand Up @@ -377,7 +377,7 @@ def compute_probabilities(
# Remove padded logits
# Use transpose so that function works for MSA and sequence
logits = [
torch.from_numpy(logit.numpy().transpose()[:, :length].transpose())
torch.from_numpy(logit.numpy().transpose()[:, 1:length].transpose())
for logit, length in zip(list(logits), lengths)
]
# Set to -inf logits that correspond to tokens that are not in tokens list
Expand Down Expand Up @@ -745,7 +745,7 @@ def finetune(
random_token_prob,
toks_per_batch,
extra_toks_per_seq,
validation=False
validation=False,
)

if self._num_gpus == 0:
Expand Down Expand Up @@ -787,7 +787,7 @@ def finetune(
save_name = self._save_model(save_path, lightning_model)
else:
save_name = self._save_model(save_path, lightning_model)

# Load new model
self._language_model._load_model(save_name)
log.info("Training completed.")

0 comments on commit 29ef0b8

Please sign in to comment.