From d9f3e88dfed0e8c9bc601118d891a83a33ab2ba4 Mon Sep 17 00:00:00 2001 From: Hubert <42952108+yingfhu@users.noreply.github.com> Date: Wed, 27 Sep 2023 16:32:57 +0800 Subject: [PATCH] [Fix] fix clp potential error and support bs>1 (#439) * [Fix] fix clp potential error and support bs>1 * [Fix] fix clp potential error and support bs>1 * minor fix * minor fix --- .../icl_inferencer/icl_clp_inferencer.py | 83 +++++++++++++------ 1 file changed, 59 insertions(+), 24 deletions(-) diff --git a/opencompass/openicl/icl_inferencer/icl_clp_inferencer.py b/opencompass/openicl/icl_inferencer/icl_clp_inferencer.py index 727506484..3134371a6 100644 --- a/opencompass/openicl/icl_inferencer/icl_clp_inferencer.py +++ b/opencompass/openicl/icl_inferencer/icl_clp_inferencer.py @@ -119,7 +119,7 @@ def inference(self, if self.single_token: index = 0 prompt_list = [] - choice_target_ids = [] + target_pos = [] # TODO: Hard code temperaily, need to modified here choices = retriever.test_ds[0]['choices'] try: @@ -142,6 +142,13 @@ def inference(self, get_token_len = self.model.get_token_len + if hasattr(self.model.tokenizer, 'padding_side'): + # get padding_side for huggingface model + padding_side = self.model.tokenizer.padding_side + else: + # defaults to left for internal model + padding_side = 'left' + # prepare in context for each example and control the length for idx in range(len(ice_idx_list)): prompt = retriever.generate_prompt_for_generate_task( @@ -149,7 +156,7 @@ def inference(self, ice[idx], ice_template=ice_template, prompt_template=prompt_template) - prompt = self.model.parse_template(prompt, mode='ppl') + prompt = self.model.parse_template(prompt, mode='gen') if self.max_seq_len is not None: prompt_token_num = get_token_len(prompt) # add one because additional token will be added in the end @@ -165,15 +172,19 @@ def inference(self, ice_template=ice_template, prompt_template=prompt_template) prompt_token_num = get_token_len(prompt) - # Add single token for prompt, this token can be any token - prompt += 'yes' prompt_list.append(prompt) - # in case prompt token num reaches + # in case prompt token num reaches max if self.max_seq_len is not None and \ prompt_token_num + 1 > self.max_seq_len: prompt_token_num = self.max_seq_len - 1 - # minus the bos token - choice_target_ids.append(prompt_token_num - 1) + + # get the target position index + if padding_side == 'left': + # always the last position + target_pos.append(-1) + else: + # the last position of the original prompt + target_pos.append(prompt_token_num - 1) # 4.1 Fetch and zip prompt & gold answer if output column exists ds_reader = retriever.dataset_reader @@ -182,19 +193,36 @@ def inference(self, else: gold_ans = [None] * len(prompt_list) + if hasattr(self.model, 'batch_padding'): + # get batch padding for huggingface model + batch_padding = self.model.batch_padding + else: + # defaults to False for internal model + batch_padding = False + logger.info('Calculating conditional log probability for prompts.') for idx in trange(0, len(prompt_list), self.batch_size, disable=not self.is_main_process): + # get batch data sub_prompt_list = prompt_list[idx:idx + self.batch_size] sub_golds = gold_ans[idx:idx + self.batch_size] - sub_choice_target_ids = choice_target_ids[idx:idx + - self.batch_size] - sub_res = self.__get_cond_prob(sub_prompt_list, - sub_choice_target_ids, - choice_ids) + sub_target_pos = target_pos[idx:idx + self.batch_size] + + # get probability result + if batch_padding and self.batch_size > 1: + sub_res = self._get_cond_prob(sub_prompt_list, + sub_target_pos, choice_ids) + else: + sub_res = [] + for prompt, position in zip(sub_prompt_list, + sub_target_pos): + sub_res.extend( + self._get_cond_prob([prompt], [position], + choice_ids)) + # save all the result for res, prompt, gold in zip(sub_res, sub_prompt_list, sub_golds): example_input = prompt.replace(ice[idx], '') @@ -217,22 +245,29 @@ def inference(self, for sample in output_handler.results_dict.values() ] - def __get_cond_prob(self, - input_texts: List[str], - sub_choice_target_ids, - choice_ids, - mask_length=None): - # TODO: support multiple tokens + def _get_cond_prob(self, input_texts: List[str], target_pos: List[int], + choice_ids: List[int]): + """Get the condition probability of next token. + + Args: + input_texts (List[str]): All the input prompt to be tested. + target_pos (List[int]): Target position of next token. + choice_ids (List[int]): Choice ids of target tokens. + """ if hasattr(self.model, 'generator'): - outputs, _ = self.model.generator.get_logits(input_texts) + get_logits = self.model.generator.get_logits else: - outputs, _ = self.model.get_logits(input_texts) + get_logits = self.model.get_logits + + outputs, _ = get_logits(input_texts) - shift_logits = outputs[..., :-1, :].contiguous().float() + # we want get the next token probability + # therefore no shift here + logits = outputs.contiguous().float() - shift_logits = F.log_softmax(shift_logits, dim=-1) + logits = F.log_softmax(logits, dim=-1) log_probs = [] - for logits, target_ids in zip(shift_logits, sub_choice_target_ids): + for logit, target_ids in zip(logits, target_pos): log_probs.append( - F.softmax(logits[target_ids, choice_ids], dim=-1).tolist()) + F.softmax(logit[target_ids, choice_ids], dim=-1).tolist()) return log_probs