Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversational dataset support for DPOTrainer #2131

Draft
wants to merge 21 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions docs/source/dataset_formats.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,8 @@ preference_example = {"prompt": "The sky is", "chosen": " blue.", "rejected": "
preference_example = {"chosen": "The sky is blue.", "rejected": "The sky is green."}
```

Some preference datasets can be found with [the tag `dpo` on Hugging Face Hub](https://huggingface.co/datasets?other=dpo). You can also explore the [librarian-bots' DPO Collections](https://huggingface.co/collections/librarian-bots/direct-preference-optimization-datasets-66964b12835f46289b6ef2fc) to identify preference datasets.

### Unpaired preference

An unpaired preference dataset is similar to a preference dataset but instead of having `"chosen"` and `"rejected"` completions for the same prompt, it includes a single `"completion"` and a `"label"` indicating whether the completion is preferred or not.
Expand Down Expand Up @@ -710,3 +712,35 @@ dataset = dataset.remove_columns(["completion", "label"])
>>> dataset[0]
{'prompt': 'The sky is'}
```

## Vision datasets

Some trainers also support fine-tuning vision-language models (VLMs) using image-text pairs. In this scenario, it's recommended to use a conversational format, as each model handles image placeholders in text differently.

A conversational vision dataset differs from a standard conversational dataset in two key ways:

1. The dataset must contain the key `images` with the image data.
2. The `"content"` field in messages must be a list of dictionaries, where each dictionary specifies the type of data: `"image"` or `"text"`.

Example:

```python
# Textual dataset format:
"content": "What color is the sky?"

# Vision dataset format:
"content": [
{"type": "image"},
{"type": "text", "text": "What color is the sky in the image?"}
]
```

An example of a conversational vision dataset is the [openbmb/RLAIF-V-Dataset](https://huggingface.co/datasets/openbmb/RLAIF-V-Dataset). Below is an embedded view of the dataset's training data, allowing you to explore it directly:

<iframe
src="https://huggingface.co/datasets/trl-lib/rlaif-v/embed/viewer/default/train"
frameborder="0"
width="100%"
height="560px"
></iframe>

227 changes: 110 additions & 117 deletions docs/source/dpo_trainer.mdx

Large diffs are not rendered by default.

6 changes: 1 addition & 5 deletions docs/source/online_dpo_trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,7 @@ train_dataset = load_dataset("trl-lib/ultrafeedback-prompt", split="train")

training_args = OnlineDPOConfig(output_dir="online-dpo-qwen2", logging_steps=10)
trainer = OnlineDPOTrainer(
model=model,
reward_model=reward_model,
args=training_args,
tokenizer=tokenizer,
train_dataset=train_dataset,
model=model, reward_model=reward_model, args=training_args, tokenizer=tokenizer, train_dataset=train_dataset
)
trainer.train()
```
Expand Down
73 changes: 73 additions & 0 deletions examples/datasets/rlaif-v.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass
from typing import Optional

from datasets import features, load_dataset
from transformers import HfArgumentParser


@dataclass
class ScriptArguments:
r"""
Arguments for the script.

Args:
push_to_hub (`bool`, *optional*, defaults to `False`):
Whether to push the dataset to the Hugging Face Hub.
repo_id (`str`, *optional*, defaults to `"trl-lib/rlaif-v"`):
Hugging Face repository ID to push the dataset to.
dataset_num_proc (`Optional[int]`, *optional*, defaults to `None`):
Number of workers to use for dataset processing.
"""

push_to_hub: bool = False
repo_id: str = "trl-lib/rlaif-v"
dataset_num_proc: Optional[int] = None


def to_conversational(example):
"""
Convert prompt from "xxx" to [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "xxx"}]}]
and chosen and rejected from "xxx" to [{"role": "assistant", "content": [{"type": "text", "text": "xxx"}]}].
Images are wrapped into a list.
"""
prompt = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": example["question"]}]}]
chosen = [{"role": "assistant", "content": [{"type": "text", "text": example["chosen"]}]}]
rejected = [{"role": "assistant", "content": [{"type": "text", "text": example["rejected"]}]}]
return {"prompt": prompt, "images": [example["image"]], "chosen": chosen, "rejected": rejected}


if __name__ == "__main__":
parser = HfArgumentParser(ScriptArguments)
script_args = parser.parse_args_into_dataclasses()[0]

dataset = load_dataset("openbmb/RLAIF-V-Dataset", split="train")
dataset = dataset.map(
to_conversational,
num_proc=script_args.dataset_num_proc,
remove_columns=dataset.column_names,
writer_batch_size=128,
)

# Cast the images to Sequence[Image] to avoid bytes format
f = dataset.features
f["images"] = features.Sequence(features.Image(decode=True))
dataset = dataset.cast(f)

dataset = dataset.train_test_split(test_size=0.01, writer_batch_size=128)

if script_args.push_to_hub:
dataset.push_to_hub(script_args.repo_id)
9 changes: 0 additions & 9 deletions examples/scripts/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
"""

import torch
from accelerate import PartialState
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer

Expand All @@ -60,8 +59,6 @@
get_kbit_device_map,
get_peft_config,
get_quantization_config,
maybe_apply_chat_template,
maybe_extract_prompt,
)
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE

Expand Down Expand Up @@ -115,12 +112,6 @@
################
dataset = load_dataset(script_args.dataset_name)

with PartialState().local_main_process_first():
dataset = dataset.map(maybe_extract_prompt, num_proc=training_args.dataset_num_proc)
dataset = dataset.map(
maybe_apply_chat_template, num_proc=training_args.dataset_num_proc, fn_kwargs={"tokenizer": tokenizer}
)

##########
# Training
################
Expand Down
60 changes: 49 additions & 11 deletions tests/test_data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def test_maybe_unpair_preference_dataset_dict_already_paired(self):


class ExtractPromptTester(unittest.TestCase):
example_implicit_prompt = {
example_implicit_prompt_conversational = {
"chosen": [
{"role": "user", "content": "What color is the sky?"},
{"role": "assistant", "content": "It is blue."},
Expand All @@ -279,7 +279,7 @@ class ExtractPromptTester(unittest.TestCase):
],
}

example_explicit_prompt = {
example_explicit_prompt_conversational = {
"prompt": [
{"role": "user", "content": "What color is the sky?"},
],
Expand All @@ -291,30 +291,68 @@ class ExtractPromptTester(unittest.TestCase):
],
}

def test_extract_prompt(self):
example_implicit_prompt_standard = {
"chosen": "The sky is blue.",
"rejected": "The sky is green.",
}

example_explicit_prompt_standard = {
"prompt": "The sky is",
"chosen": " blue.",
"rejected": " green.",
}

def test_extract_prompt_conversational(self):
# Test that the prompt is correctly extracted from the dataset
example_extracted_prompt = extract_prompt(self.example_implicit_prompt_conversational)
self.assertEqual(
example_extracted_prompt,
self.example_explicit_prompt_conversational,
"The prompt is not correctly extracted from the dataset.",
)

def test_maybe_extract_prompt_conversational(self):
# Test that the prompt is correctly extracted from the dataset with maybe_extract_prompt
example_extracted_prompt = maybe_extract_prompt(self.example_implicit_prompt_conversational)
self.assertEqual(
example_extracted_prompt,
self.example_explicit_prompt_conversational,
"The prompt is not correctly extracted from the dataset.",
)

def test_maybe_extract_prompt_conversational_already_explicit(self):
# Test that the prompt remains unchanged with maybe_extract_prompt
example_extracted_prompt = maybe_extract_prompt(self.example_explicit_prompt_conversational)
self.assertEqual(
example_extracted_prompt,
self.example_explicit_prompt_conversational,
"The prompt should remain unchanged.",
)

def test_extract_prompt_standard(self):
# Test that the prompt is correctly extracted from the dataset
example_extracted_prompt = extract_prompt(self.example_implicit_prompt)
example_extracted_prompt = extract_prompt(self.example_implicit_prompt_standard)
self.assertEqual(
example_extracted_prompt,
self.example_explicit_prompt,
self.example_explicit_prompt_standard,
"The prompt is not correctly extracted from the dataset.",
)

def test_maybe_extract_prompt(self):
def test_maybe_extract_prompt_standard(self):
# Test that the prompt is correctly extracted from the dataset with maybe_extract_prompt
example_extracted_prompt = maybe_extract_prompt(self.example_implicit_prompt)
example_extracted_prompt = maybe_extract_prompt(self.example_implicit_prompt_standard)
self.assertEqual(
example_extracted_prompt,
self.example_explicit_prompt,
self.example_explicit_prompt_standard,
"The prompt is not correctly extracted from the dataset.",
)

def test_maybe_extract_prompt_already_explicit(self):
def test_maybe_extract_prompt_standard_already_explicit(self):
# Test that the prompt remains unchanged with maybe_extract_prompt
example_extracted_prompt = maybe_extract_prompt(self.example_explicit_prompt)
example_extracted_prompt = maybe_extract_prompt(self.example_explicit_prompt_standard)
self.assertEqual(
example_extracted_prompt,
self.example_explicit_prompt,
self.example_explicit_prompt_standard,
"The prompt should remain unchanged.",
)

Expand Down
22 changes: 13 additions & 9 deletions trl/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List, Optional, TypeVar
from typing import Any, Dict, List, Optional, Sequence, TypeVar

from datasets import Dataset, DatasetDict
from transformers import PreTrainedTokenizer
Expand Down Expand Up @@ -280,15 +280,17 @@ def maybe_unpair_preference_dataset(dataset: DatasetType, num_proc: Optional[int
return dataset


def extract_prompt(example: Dict[str, List]) -> Dict[str, List]:
def extract_prompt(example: Dict[str, Sequence]) -> Dict[str, Sequence]:
r"""
Extracts the shared prompt from a preference data example, where the prompt is implicit within both
the chosen and rejected completions.

For more details, see [`maybe_extract_prompt`].
"""
for idx in range(min(len(example["chosen"]), len(example["rejected"]))):
if example["chosen"][idx]["content"] != example["rejected"][idx]["content"]:
if example["chosen"][idx] != example["rejected"][idx]:
if example["chosen"][idx - 1] == " ": # remove space before the prompt
idx -= 1
break
return {
"prompt": example["chosen"][:idx],
Expand All @@ -303,15 +305,14 @@ def maybe_extract_prompt(example: Dict[str, List]) -> Dict[str, List]:
the chosen and rejected completions.

If the example already contains a `"prompt"` key, the function returns the example as is. Else, the function

identifies the longest common sequence (prefix) of conversation turns between the "chosen" and "rejected"
completions and extracts this as the prompt. It then removes this prompt from the respective "chosen" and
"rejected" completions.

Args:
example (`Dict[str, List]`):
A dictionary representing a single data entry in the preference dataset. It must contain the keys
`"chosen"` and `"rejected"`, where each value is a list.
`"chosen"` and `"rejected"`, where each value is either conversational or standard (`str`).

Returns:
`Dict[str, List]`: A dictionary containing:
Expand Down Expand Up @@ -379,7 +380,10 @@ def maybe_extract_prompt(example: Dict[str, List]) -> Dict[str, List]:
# "chosen": [{"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is blue."}],
# "rejected": [{"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is green."}]}
# That's why we check if the prompt is also conversational before deciding not to extract it.
if "prompt" in example and is_conversational({"prompt": example["prompt"]}):
return example
else:
return extract_prompt({"chosen": example["chosen"], "rejected": example["rejected"]})
if "prompt" in example:
# Both conversational or both non-conversational
chosen_conv = is_conversational({"chosen": example["chosen"]})
prompt_conv = is_conversational({"prompt": example["prompt"]})
if (chosen_conv and prompt_conv) or (not chosen_conv and not prompt_conv):
return example
return extract_prompt({"chosen": example["chosen"], "rejected": example["rejected"]})
4 changes: 4 additions & 0 deletions trl/trainer/dpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ class DPOConfig(TrainingArguments):
command line.

Parameters:
learning_rate (`float`, *optional*, defaults to `1e-6`):
Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
[`~transformers.TrainingArguments`].
beta (`float`, *optional*, defaults to `0.1`):
Parameter controlling the deviation from the reference model. Higher β means less deviation from the
reference model. For the IPO loss (`loss_type="ipo"`), β is the regularization parameter denoted by τ in
Expand Down Expand Up @@ -130,6 +133,7 @@ class DPOConfig(TrainingArguments):
DPO loss. The paper recommends `rpo_alpha=1.0`.
"""

learning_rate: float = 1e-6
beta: float = 0.1
label_smoothing: float = 0.0
loss_type: Literal[
Expand Down
12 changes: 12 additions & 0 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from transformers.trainer_utils import EvalLoopOutput
from transformers.utils import is_peft_available

from ..data_utils import maybe_apply_chat_template, maybe_extract_prompt
from ..models import PreTrainedModelWrapper, create_reference_model
from .callbacks import SyncRefModelCallback
from .dpo_config import DPOConfig, FDivergenceConstants, FDivergenceType
Expand Down Expand Up @@ -815,6 +816,17 @@ def make_inputs_require_grad(module, input, output):
# Compute that only on the main process for faster data processing.
# see: https://github.com/huggingface/trl/pull/1255
with PartialState().local_main_process_first():
# Extract the prompt if needed, and apply the chat template if needed
train_dataset = train_dataset.map(maybe_extract_prompt, num_proc=args.dataset_num_proc)
train_dataset = train_dataset.map(
maybe_apply_chat_template, fn_kwargs={"tokenizer": tokenizer}, num_proc=args.dataset_num_proc
)
if eval_dataset is not None:
eval_dataset = eval_dataset.map(maybe_extract_prompt, num_proc=args.dataset_num_proc)
eval_dataset = eval_dataset.map(
maybe_apply_chat_template, fn_kwargs={"tokenizer": tokenizer}, num_proc=args.dataset_num_proc
)

# tokenize the dataset, lower writer batch size to avoid OOM (frequent in vision models)
fn_kwargs = {
"tokenizer": self.tokenizer,
Expand Down
3 changes: 2 additions & 1 deletion trl/trainer/kto_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ class KTOConfig(TrainingArguments):

Parameters:
learning_rate (`float`, *optional*, defaults to `5e-7`):
Initial learning rate for [`AdamW`] optimizer. The default value replaces that of [`~transformers.TrainingArguments`].
Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
[`~transformers.TrainingArguments`].
max_length (`Optional[int]`, *optional*, defaults to `None`):
Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want
to use the default data collator.
Expand Down
Loading