Skip to content

Commit

Permalink
[FIX] fix interntrain get_loglikelihood (#1584)
Browse files Browse the repository at this point in the history
  • Loading branch information
x54-729 authored Oct 8, 2024
1 parent 89abcba commit 4d6349d
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions opencompass/models/interntrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ def _convert_dtype(self, default_dtype, model_dtype=None):
else:
raise NotImplementedError(f'Unknown model dtype {model_dtype}')

def get_token_len(self, prompt: str) -> int:
def get_token_len(self, prompt: str, use_bos=None, use_eos=None) -> int:
"""Get lengths of the tokenized strings.
Args:
Expand All @@ -297,7 +297,7 @@ def get_token_len(self, prompt: str) -> int:
Returns:
int: Length of the input tokens
"""
tokens = self.tokenizer(prompt, use_bos=True, use_eos=True)
tokens = self.tokenizer(prompt, use_bos=use_bos, use_eos=use_eos)
return len(tokens)

def generate(self,
Expand Down Expand Up @@ -391,7 +391,7 @@ def get_loglikelihood(self, input_texts: List[str],
for input_text, cont in zip(input_texts, conts)
]
replaced_lens = [
len(self.encode(input_text)[0]) for input_text in replaced_texts
self.get_token_len(input_text) for input_text in replaced_texts
]
loglikelihoods = []
for nloss, nlen, rlen in zip(loss, lens, replaced_lens):
Expand Down

0 comments on commit 4d6349d

Please sign in to comment.