Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prototype Dataset Processor #1646

Draft
wants to merge 17 commits into
base: main
Choose a base branch
from
Draft

Prototype Dataset Processor #1646

wants to merge 17 commits into from

Conversation

vwxyzjn
Copy link
Contributor

@vwxyzjn vwxyzjn commented May 16, 2024

This PR attempts to refactor and pull all tokenization logic out of the Trainer class. Having a separate tokenization process gives us higher visibility into what's being used in training, providing more clarified logic and reducing bugs. It attempts to do the following things.

# 1. PPO (prompt)
# 2. SFT (prompt + demonstration), there is also packing.
# 3. ✅ RM / DPO (chosen and rejected)
# 4. ✅ Visualization of length distributions?
# 5. ✅ Filter?
#   * Smart truncation?
# 6. ✅ dataset_num_proc
# 7. check EOS token
# 8. dataset mixer?
# 9. ✅ pretty print that show tokenization?
# 10. hashable tokneization?
# 11. inputs / labels / attention_mask
# 12. always set a `tokenizer.pad_token_id`?

why?

Currently, the Trainer is also responsible for tokenization. It causes several issues:

  1. duplicate tokenization steps. For example, alignment-handbook calls apply_chat_template(tokenize=False) for the dataset, followed by SFT/DPO trainer calling tokenized again. To remove duplication, we only needed to go through the dataset once by calling apply_chat_template(tokenize=True)

  2. truncation logic happens in various places and is hard to predict. SFTTrainer calls it the max_seq_length, RewardModeling calls it max_length, DPO/KTOTrainers call it max_length, max_prompt_length, max_target_length. There are also different truncation logics. E.g., [(truncate the prompt if prompt + chosen is too long)]
    (

    # if combined sequence is too long, truncate the prompt
    for answer_tokens in [chosen_tokens, rejected_tokens, prompt_tokens]:
    if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
    ). This causes issue like https://huggingface.slack.com/archives/C04EX6W3QSY/p1715255460198239 as raised by @abhishekkrthakur.

    • the hard truncation logic seems debatable: if the sequence length is too long, shouldn't we filter them out instead of giving a truncated response? The truncated response could be an incomplete code snippet / summaries (basically bad data). If truncation is really desired, we should do some kind of smart truncation like truncate at the last paragraph, so the sentences are still complete.
  3. learning to generate EOS tokens. Learning to generate EOS tokens  #1623 (comment) suggested that EOS tokens always 1) correspond to -100 in the labels and 2) if the dataset contains the EOS token before collating, then the attention mask of EOS token is also 1. It's possible that the model may never learn to generate EOS tokens.

    • what's a bit unclear to me is how zephyr learns to output EOS tokens, despite all the labels of EOS token are marked with -100 and are being masked out. My suspicion is that the attention_mask=1 plays some roles in it.
  4. dataset_num_proc is not uniformly applied, as a result [ORPO] Enable batched tokenization & multiprocessing to process large datasets #1624 is needed. There is also the question of hashable tokenization

  5. Dataset mixer (e.g., in our h4 codebase), that should be more widely available to use in TRL and can be combined with this class.

The current design

The current design roughly looks like this. Note that we can still put it in Trainer.__init__ so users don't have to configure it directly.

dataset_config = DatasetConfig(max_token_length=1024, max_prompt_token_lenth=128)
dataset_processor = PreferenceDatasetProcessor(tokenizer=tok, config=dataset_config)
train_dataset = dataset_processor.tokenize(preference_datasets["train"])
stats = dataset_processor.get_token_length_stats(train_dataset)
pprint.pp(stats)
train_dataset = dataset_processor.filter(train_dataset)
stats = dataset_processor.get_token_length_stats(train_dataset)
pprint.pp(stats)
dataset_processor.get_token_length_visualization(train_dataset)
print(tok.decode(train_dataset[0]["chosen"]))
visualize_token(train_dataset[0]["chosen"], tok)
image image

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@kashif
Copy link
Collaborator

kashif commented May 16, 2024

very cool! thanks! checking

@edbeeching
Copy link
Collaborator

what's a bit unclear to me is how zephyr learns to output EOS tokens, despite all the labels of EOS token are marked with -100 and are being masked out. My suspicion is that the attention_mask=1 plays some roles in it.

I think for zephyr we used packing and there is a concat token=eos that is not masked / ignored.

@vwxyzjn
Copy link
Contributor Author

vwxyzjn commented May 18, 2024

@edbeeching you are right, because the datacollator is not called when using the packed dataset! The output below confirms it.

image

@vwxyzjn vwxyzjn marked this pull request as ready for review June 21, 2024 15:21
@@ -872,7 +872,9 @@ def print_rich_table(df: pd.DataFrame) -> Table:
SIMPLE_SFT_CHAT_TEMPLATE = "{% for message in messages %}{{' ' + message['content']}}{% endfor %}{{ eos_token }}"
# SIMPLE_SFT_CHAT_TEMPLATE simply ends things with an EOS token, this helps the SFT model learn to end the completions with EOS tokens

SIMPLE_QUERY_CHAT_TEMPLATE = "{% for message in messages %}{{' ' + message['content']}}{% endfor %}"
SIMPLE_QUERY_CHAT_TEMPLATE = (
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will refactor this later

Comment on lines +322 to +343
accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml \
examples/scripts/rm/rm.py \
--dataset_name trl-internal-testing/tldr-preference-trl-style \
--dataset_train_split train \
--dataset_eval_split validation \
--model_name_or_path EleutherAI/pythia-1b-deduped \
--chat_template simple_concat_with_space \
--learning_rate 3e-6 \
--per_device_train_batch_size 8 \
--per_device_eval_batch_size 8 \
--gradient_accumulation_steps 4 \
--logging_steps 1 \
--eval_strategy steps \
--max_token_length 1280 \
--max_prompt_token_lenth 1024 \
--remove_unused_columns False \
--num_train_epochs 1 \
--eval_steps=300 \
--bf16 \
--output_dir models/rm/rm_tldr_1b \
--push_to_hub \
--hub_model_id trl-internal-testing/rm_tldr_1b
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link

github-actions bot commented Aug 8, 2024

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

@github-actions github-actions bot closed this Aug 16, 2024
@kashif kashif reopened this Aug 16, 2024
@github-actions github-actions bot closed this Aug 25, 2024
@lewtun
Copy link
Member

lewtun commented Aug 26, 2024

Bot begone!

@lewtun lewtun reopened this Aug 26, 2024
@lewtun lewtun mentioned this pull request Sep 12, 2024
5 tasks
@qgallouedec qgallouedec marked this pull request as draft September 23, 2024 21:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants