diff --git a/llm_eval/handler.py b/llm_eval/handler.py index 288f936..a822264 100644 --- a/llm_eval/handler.py +++ b/llm_eval/handler.py @@ -67,19 +67,19 @@ def post_process_output(self, prompt, output): def prepare_output(self): rows = [] - column_names = ['id', 'model'] # start with 'id' and 'model' columns + column_names = ['id', 'model'] for prompt in self.prompts: column_names.append(prompt["name"] + '.input') column_names.append(prompt["name"] + '.output') df = pd.DataFrame(columns=column_names) - for model_name in self.models.items(): - for data_id, data_value in self.dataset.items(): + for model_name in self.models: + for data_id, data_value in self.dataset.items(): row = {'id': data_id, 'model': model_name} for prompt in self.prompts: input_column_name = prompt["name"] + '.input' output_column_name = prompt["name"] + '.output' - row[input_column_name] = f"### Conversation ###\n{data_value}\n\n### Instruction ###\n{prompt['prompt']}\n\n### Output ###\n" + row[input_column_name] = f"### Conversation ###\n{data_value}\n\n### Instruction ###\n{prompt['prompt']}\n\n### Output ###\n" row[output_column_name] = "" rows.append(row) df = pd.concat([df, pd.DataFrame(rows)], ignore_index=True) @@ -90,8 +90,11 @@ def process_dataset(self): for model_name, model_handler in self.models.items(): self.current_model = model_name print(f"Loading {model_name}...") - handler = model_handler() - self.tokenizer, self.model = handler.load_model_and_tokenizer(self.device, model_name) + if model_handler == ModelHandler: + self.tokenizer, self.model = self.load_model_and_tokenizer(model_name) + else: + handler = model_handler() + self.tokenizer, self.model = handler.load_model_and_tokenizer(self.device, model_name) for index, row in df[df['model'] == model_name].iterrows(): print(f"Generating outputs for {row['id']}") for col in df.columns: diff --git a/pyproject.toml b/pyproject.toml index b2d4721..ece0d8f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "setuptools.build_meta" [project] name = "llm-eval" -version = "0.6.8" +version = "0.6.9" authors = [ {name = "Jonathan Eisenzopf", email = "jonathan.eisenzopf@talkmap.com"}, ]