From c67684f4bcea631c0e01d6429226d6bc5616e480 Mon Sep 17 00:00:00 2001 From: Shahules786 Date: Tue, 8 Aug 2023 19:10:39 +0530 Subject: [PATCH 1/2] added RAG datasets --- .../custom_datasets/__init__.py | 9 +++++- .../custom_datasets/instruction.py | 29 +++++++++++++++++++ 2 files changed, 37 insertions(+), 1 deletion(-) diff --git a/model/model_training/custom_datasets/__init__.py b/model/model_training/custom_datasets/__init__.py index 29eb24f3b8..50f0a1bf6e 100644 --- a/model/model_training/custom_datasets/__init__.py +++ b/model/model_training/custom_datasets/__init__.py @@ -5,7 +5,12 @@ import numpy as np from model_training.custom_datasets.extra_rm_datasets import load_anthropic_rlhf, load_hellaswag, load_shp -from model_training.custom_datasets.instruction import INSTRUCTION_DATASETS, InstructionDataset +from model_training.custom_datasets.instruction import ( + INSTRUCTION_DATASETS, + RAG_DATASETS, + InstructionDataset, + RAGDataset, +) from model_training.custom_datasets.oasst_dataset import load_oasst_export from model_training.custom_datasets.pretrain_datasets import FanFics, RedPajama from model_training.custom_datasets.prompt_dialogue import DolphinMix, Gpt4All, OrcaChat, load_oig_file @@ -181,6 +186,8 @@ def get_one_dataset( dataset = OrcaChat(cache_dir=data_path, **kwargs) elif dataset_name == "dolphin-mix": dataset = DolphinMix(cache_dir=data_path, **kwargs) + elif dataset_name in RAG_DATASETS.keys(): + dataset = RAGDataset(dataset_name, cache_dir=data_path, **kwargs) else: raise ValueError(f"Unknown dataset {dataset_name}") diff --git a/model/model_training/custom_datasets/instruction.py b/model/model_training/custom_datasets/instruction.py index 7b6ad39787..33be76c382 100644 --- a/model/model_training/custom_datasets/instruction.py +++ b/model/model_training/custom_datasets/instruction.py @@ -124,3 +124,32 @@ def __getitem__(self, idx) -> DatasetEntry: answers=answers, lang=lang, ) + + +RAG_DATASETS = { + "multi-chapter-summaries": "shahules786/Multi-chapter-summaries", +} + + +class RAGDataset(Dataset): + def __init__( + self, + dataset, + split: str = "train", + cache_dir: str = ".cache/", + ): + if dataset not in RAG_DATASETS.keys(): + raise ValueError(f"Invalid dataset {dataset}") + + if dataset == "multi-chapter-summaries": + self.prompt, self.context, self.response = "prompt", "context", "response" + + self.dataset = load_dataset(RAG_DATASETS[dataset], cache_dir=cache_dir)[split] + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, idx): + prompt, context, response = [self.dataset[idx][key] for key in [self.prompt, self.context, self.response]] + + return create_dataset_entry_qa(mode="sft", questions=[prompt + context], answers=[response]) From 0ceb0f3d7039aeee6ac5b423de4441ded3318a7e Mon Sep 17 00:00:00 2001 From: Shahules786 Date: Wed, 9 Aug 2023 12:07:39 +0530 Subject: [PATCH 2/2] fix column name --- model/model_training/custom_datasets/instruction.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/model/model_training/custom_datasets/instruction.py b/model/model_training/custom_datasets/instruction.py index 33be76c382..e932d26259 100644 --- a/model/model_training/custom_datasets/instruction.py +++ b/model/model_training/custom_datasets/instruction.py @@ -142,7 +142,7 @@ def __init__( raise ValueError(f"Invalid dataset {dataset}") if dataset == "multi-chapter-summaries": - self.prompt, self.context, self.response = "prompt", "context", "response" + self.prompt, self.context, self.response = "prompt", "context", "summary" self.dataset = load_dataset(RAG_DATASETS[dataset], cache_dir=cache_dir)[split]