Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Documentation dataset format #2020

Merged
merged 74 commits into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from 66 commits
Commits
Show all changes
74 commits
Select commit Hold shift + click to select a range
f8dffef
first piece of doc
qgallouedec Sep 5, 2024
6fe9e98
improve readibility
qgallouedec Sep 5, 2024
6f3d846
some data utils and doc
qgallouedec Sep 5, 2024
d83ead9
simplify prompt-only
qgallouedec Sep 5, 2024
e2709e5
format
qgallouedec Sep 5, 2024
5508d37
fix path data utils
qgallouedec Sep 5, 2024
6d18f9d
fix example format
qgallouedec Sep 5, 2024
cb6eaa3
simplify
qgallouedec Sep 5, 2024
1a7d77d
tests
qgallouedec Sep 5, 2024
929d43e
prompt-completion
qgallouedec Sep 5, 2024
3ccabea
update antropic hh
qgallouedec Sep 5, 2024
548222c
update dataset script
qgallouedec Sep 5, 2024
d820ef5
implicit prompt
qgallouedec Sep 7, 2024
2bbf766
additional content
qgallouedec Sep 7, 2024
4d894e0
`maybe_reformat_dpo_to_kto` -> `unpair_preference_dataset`
qgallouedec Sep 7, 2024
b16337e
Preference dataset with implicit prompt
qgallouedec Sep 7, 2024
8b32301
unpair preference dataset tests
qgallouedec Sep 7, 2024
27360ac
documentation
qgallouedec Sep 7, 2024
60fa768
...
qgallouedec Sep 7, 2024
97b7cde
doc
qgallouedec Sep 7, 2024
eb00261
changes applied to dpo example
qgallouedec Sep 7, 2024
3a9b7ab
better doc and better log error
qgallouedec Sep 7, 2024
74ed8e7
a bit more doc
qgallouedec Sep 7, 2024
444083d
improve doc
qgallouedec Sep 7, 2024
fe996c3
converting
qgallouedec Sep 8, 2024
cf5ef88
some subsections
qgallouedec Sep 8, 2024
6a891fe
Merge branch 'main' into dataset_format
qgallouedec Sep 8, 2024
7a3242c
converting section
qgallouedec Sep 8, 2024
ea7ddf6
further refinements
qgallouedec Sep 8, 2024
cb9a344
tldr
qgallouedec Sep 9, 2024
a0bf787
tldr preference
qgallouedec Sep 9, 2024
cb1083e
rename
qgallouedec Sep 9, 2024
da81e60
lm-human-preferences-sentiment
qgallouedec Sep 9, 2024
28312d7
`imdb` to `stanfordnlp/imdb`
qgallouedec Sep 9, 2024
8cf6347
Add script for LM human preferences descriptiveness
qgallouedec Sep 9, 2024
fa38e0b
Remove sentiment_descriptiveness.py script
qgallouedec Sep 9, 2024
ea06ea3
style
qgallouedec Sep 9, 2024
a5193be
example judge tlrd with new dataset
qgallouedec Sep 9, 2024
d8ff5e0
Syle
qgallouedec Sep 9, 2024
2b9aa71
Dataset conversion for TRL compatibility
qgallouedec Sep 9, 2024
ac0b8c8
further refinements
qgallouedec Sep 9, 2024
9bc6cf9
trainers in doc
qgallouedec Sep 9, 2024
6f5c249
top level for functions
qgallouedec Sep 9, 2024
d8dd465
stanfordnlp/imdb
qgallouedec Sep 9, 2024
81c60b5
Merge branch 'main' into dataset_format
qgallouedec Sep 10, 2024
157edf9
downgrade transformers
qgallouedec Sep 10, 2024
102dc9c
Merge branch 'dataset_format' of https://github.com/huggingface/trl i…
qgallouedec Sep 10, 2024
22249da
temp reduction of tests
qgallouedec Sep 10, 2024
bd0eb0e
next commit
qgallouedec Sep 10, 2024
7bde159
next commit
qgallouedec Sep 10, 2024
7197222
Merge branch 'main' into dataset_format
qgallouedec Sep 10, 2024
7b09202
additional content
qgallouedec Sep 10, 2024
504d64f
proper tick format
qgallouedec Sep 10, 2024
c4b8e46
precise the assistant start token
qgallouedec Sep 10, 2024
f366398
improve
qgallouedec Sep 10, 2024
fb9df62
lower case
qgallouedec Sep 10, 2024
bd8c95d
Update titles in _toctree.yml and data_utils.mdx
qgallouedec Sep 10, 2024
4c199e0
revert make change
qgallouedec Sep 10, 2024
69ffd4e
correct dataset ids
qgallouedec Sep 10, 2024
235c7fe
expand a bit dataset formats
qgallouedec Sep 10, 2024
1cd829c
skip gated repo tests
qgallouedec Sep 10, 2024
5d85b91
data utilities in API
qgallouedec Sep 10, 2024
d0932cb
Update docs/source/dataset_formats.mdx
qgallouedec Sep 11, 2024
fab82e1
Update docs/source/dataset_formats.mdx
qgallouedec Sep 11, 2024
b434444
Update docs/source/dataset_formats.mdx
qgallouedec Sep 11, 2024
a7ef8c3
Update docs/source/dataset_formats.mdx
qgallouedec Sep 11, 2024
040d4cf
tiny internal testing for chat template testing
qgallouedec Sep 11, 2024
af0ad76
Merge branch 'dataset_format' of https://github.com/huggingface/trl i…
qgallouedec Sep 11, 2024
c2b6574
precise type/format
qgallouedec Sep 11, 2024
b35ad53
exlude sft trainer in doc
qgallouedec Sep 11, 2024
16aa25e
Update trl/trainer/utils.py
qgallouedec Sep 11, 2024
2a7ef20
Merge branch 'main' into dataset_format
qgallouedec Sep 11, 2024
24a75c2
Merge branch 'dataset_format' of https://github.com/huggingface/trl i…
qgallouedec Sep 11, 2024
814cd38
XPO in the doc
qgallouedec Sep 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions docs/source/_toctree.yml
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
- sections:
- local: index
title: TRL
- local: quickstart
title: Quickstart
- local: installation
title: Installation
- local: quickstart
title: Quickstart
- local: clis
title: Get started with Command Line Interfaces (CLIs)
- local: dataset_formats
title: Dataset Formats
- local: how_to_train
title: PPO Training FAQ
- local: use_model
Expand Down Expand Up @@ -59,6 +61,8 @@
title: Judges
- local: callbacks
title: Callbacks
- local: data_utils
title: Data Utilities
- local: text_environments
title: Text Environments
title: API
Expand Down
15 changes: 15 additions & 0 deletions docs/source/data_utils.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
## Data Utilities

[[autodoc]] is_conversational

[[autodoc]] apply_chat_template

[[autodoc]] maybe_apply_chat_template

[[autodoc]] extract_prompt

[[autodoc]] maybe_extract_prompt

[[autodoc]] unpair_preference_dataset

[[autodoc]] maybe_unpair_preference_dataset
709 changes: 709 additions & 0 deletions docs/source/dataset_formats.mdx

Large diffs are not rendered by default.

122 changes: 0 additions & 122 deletions examples/datasets/anthropic_hh.py

This file was deleted.

82 changes: 82 additions & 0 deletions examples/datasets/hh-rlhf-helpful-base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import re
from dataclasses import dataclass
from typing import Dict, List, Optional

from datasets import load_dataset
from transformers import HfArgumentParser


@dataclass
class ScriptArguments:
r"""
Arguments for the script.

Args:
push_to_hub (`bool`, *optional*, defaults to `False`):
Whether to push the dataset to the Hugging Face Hub.
repo_id (`str`, *optional*, defaults to `"trl-lib/hh-rlhf-helpful-base"`):
Hugging Face repository ID to push the dataset to.
dataset_num_proc (`Optional[int]`, *optional*, defaults to `None`):
Number of workers to use for dataset processing.
"""

push_to_hub: bool = False
repo_id: str = "trl-lib/hh-rlhf-helpful-base"
dataset_num_proc: Optional[int] = None


def common_start(str1: str, str2: str) -> str:
# Zip the two strings and iterate over them together
common_chars = []
for c1, c2 in zip(str1, str2):
if c1 == c2:
common_chars.append(c1)
else:
break
# Join the common characters and return as a string
return "".join(common_chars)


def extract_dialogue(example: str) -> List[Dict[str, str]]:
# Extract the prompt, which corresponds to the common start of the chosen and rejected dialogues
prompt_text = common_start(example["chosen"], example["rejected"])

# The chosen and rejected may share a common start, so we need to remove the common part
if not prompt_text.endswith("\n\nAssistant: "):
prompt_text = prompt_text[: prompt_text.rfind("\n\nAssistant: ")] + "\n\nAssistant: "

# Extract the chosen and rejected lines
chosen_line = example["chosen"][len(prompt_text) :]
rejected_line = example["rejected"][len(prompt_text) :]

# Remove the generation prompt ("\n\nAssistant: ") from the prompt
prompt_text = prompt_text[: -len("\n\nAssistant: ")]

# Split the string at every occurrence of "Human: " or "Assistant: "
prompt_lines = re.split(r"(\n\nAssistant: |\n\nHuman: )", prompt_text)

# Remove the first element as it's empty
prompt_lines = prompt_lines[1:]

prompt = []
for idx in range(0, len(prompt_lines), 2):
role = "user" if prompt_lines[idx] == "\n\nHuman: " else "assistant"
content = prompt_lines[idx + 1]
prompt.append({"role": role, "content": content})

# Remove the prompt from the chosen and rejected dialogues
chosen = [{"role": "assitant", "content": chosen_line}]
rejected = [{"role": "assistant", "content": rejected_line}]

return {"prompt": prompt, "chosen": chosen, "rejected": rejected}


if __name__ == "__main__":
parser = HfArgumentParser(ScriptArguments)
args = parser.parse_args_into_dataclasses()[0]

dataset = load_dataset("Anthropic/hh-rlhf", data_dir="helpful-base")
dataset = dataset.map(extract_dialogue, num_proc=args.dataset_num_proc)

if args.push_to_hub:
dataset.push_to_hub(args.repo_id)
67 changes: 67 additions & 0 deletions examples/datasets/lm-human-preferences-descriptiveness.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from dataclasses import dataclass
from typing import Optional

from datasets import load_dataset
from transformers import AutoTokenizer, HfArgumentParser


@dataclass
class ScriptArguments:
r"""
Arguments for the script.

Args:
push_to_hub (`bool`, *optional*, defaults to `False`):
Whether to push the dataset to the Hugging Face Hub.
repo_id (`str`, *optional*, defaults to `"trl-lib/lm-human-preferences-descriptiveness"`):
Hugging Face repository ID to push the dataset to.
dataset_num_proc (`Optional[int]`, *optional*, defaults to `None`):
Number of workers to use for dataset processing.
"""

push_to_hub: bool = False
repo_id: str = "trl-lib/lm-human-preferences-descriptiveness"
dataset_num_proc: Optional[int] = None


# Edge cases handling: remove the cases where all samples are the same
def samples_not_all_same(example):
return not all(example["sample0"] == example[f"sample{j}"] for j in range(1, 4))


def to_prompt_completion(example, tokenizer):
prompt = tokenizer.decode(example["query"]).strip()
best_idx = example["best"]
chosen = tokenizer.decode(example[f"sample{best_idx}"])
for rejected_idx in range(4): # take the first rejected sample that is different from the chosen one
rejected = tokenizer.decode(example[f"sample{rejected_idx}"])
if chosen != rejected:
break
assert chosen != rejected
return {"prompt": prompt, "chosen": chosen, "rejected": rejected}


if __name__ == "__main__":
parser = HfArgumentParser(ScriptArguments)
args = parser.parse_args_into_dataclasses()[0]

dataset = load_dataset(
"json",
data_files="https://openaipublic.blob.core.windows.net/lm-human-preferences/labels/descriptiveness/offline_5k.json",
split="train",
)

dataset = dataset.filter(samples_not_all_same, num_proc=args.dataset_num_proc)

dataset = dataset.map(
to_prompt_completion,
num_proc=args.dataset_num_proc,
remove_columns=["query", "sample0", "sample1", "sample2", "sample3", "best"],
fn_kwargs={"tokenizer": AutoTokenizer.from_pretrained("gpt2")},
)

# train_size taken from https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/launch.py#L79)
dataset = dataset.train_test_split(train_size=4992)

if args.push_to_hub:
dataset.push_to_hub(args.repo_id)
60 changes: 60 additions & 0 deletions examples/datasets/lm-human-preferences-sentiment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from dataclasses import dataclass
from typing import Optional

from datasets import load_dataset
from transformers import AutoTokenizer, HfArgumentParser


@dataclass
class ScriptArguments:
r"""
Arguments for the script.

Args:
push_to_hub (`bool`, *optional*, defaults to `False`):
Whether to push the dataset to the Hugging Face Hub.
repo_id (`str`, *optional*, defaults to `"trl-lib/lm-human-preferences-sentiment"`):
Hugging Face repository ID to push the dataset to.
dataset_num_proc (`Optional[int]`, *optional*, defaults to `None`):
Number of workers to use for dataset processing.
"""

push_to_hub: bool = False
repo_id: str = "trl-lib/lm-human-preferences-sentiment"
dataset_num_proc: Optional[int] = None


def to_prompt_completion(example, tokenizer):
prompt = tokenizer.decode(example["query"]).strip()
best_idx = example["best"]
chosen = tokenizer.decode(example[f"sample{best_idx}"])
for rejected_idx in range(4): # take the first rejected sample that is different from the chosen one
rejected = tokenizer.decode(example[f"sample{rejected_idx}"])
if chosen != rejected:
break
assert chosen != rejected
return {"prompt": prompt, "chosen": chosen, "rejected": rejected}


if __name__ == "__main__":
parser = HfArgumentParser(ScriptArguments)
args = parser.parse_args_into_dataclasses()[0]

dataset = load_dataset(
"json",
data_files="https://openaipublic.blob.core.windows.net/lm-human-preferences/labels/sentiment/offline_5k.json",
split="train",
)

dataset = dataset.map(
to_prompt_completion,
num_proc=args.dataset_num_proc,
remove_columns=["query", "sample0", "sample1", "sample2", "sample3", "best"],
fn_kwargs={"tokenizer": AutoTokenizer.from_pretrained("gpt2")},
)

# train_size taken from https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/launch.py#L70)
dataset = dataset.train_test_split(train_size=4992)

if args.push_to_hub:
dataset.push_to_hub(args.repo_id)
Loading
Loading