-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Prototype Dataset Processor #1646
base: main
Are you sure you want to change the base?
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
very cool! thanks! checking |
I think for zephyr we used packing and there is a concat token=eos that is not masked / ignored. |
@edbeeching you are right, because the datacollator is not called when using the packed dataset! The output below confirms it. |
@@ -872,7 +872,9 @@ def print_rich_table(df: pd.DataFrame) -> Table: | |||
SIMPLE_SFT_CHAT_TEMPLATE = "{% for message in messages %}{{' ' + message['content']}}{% endfor %}{{ eos_token }}" | |||
# SIMPLE_SFT_CHAT_TEMPLATE simply ends things with an EOS token, this helps the SFT model learn to end the completions with EOS tokens | |||
|
|||
SIMPLE_QUERY_CHAT_TEMPLATE = "{% for message in messages %}{{' ' + message['content']}}{% endfor %}" | |||
SIMPLE_QUERY_CHAT_TEMPLATE = ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will refactor this later
accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml \ | ||
examples/scripts/rm/rm.py \ | ||
--dataset_name trl-internal-testing/tldr-preference-trl-style \ | ||
--dataset_train_split train \ | ||
--dataset_eval_split validation \ | ||
--model_name_or_path EleutherAI/pythia-1b-deduped \ | ||
--chat_template simple_concat_with_space \ | ||
--learning_rate 3e-6 \ | ||
--per_device_train_batch_size 8 \ | ||
--per_device_eval_batch_size 8 \ | ||
--gradient_accumulation_steps 4 \ | ||
--logging_steps 1 \ | ||
--eval_strategy steps \ | ||
--max_token_length 1280 \ | ||
--max_prompt_token_lenth 1024 \ | ||
--remove_unused_columns False \ | ||
--num_train_epochs 1 \ | ||
--eval_steps=300 \ | ||
--bf16 \ | ||
--output_dir models/rm/rm_tldr_1b \ | ||
--push_to_hub \ | ||
--hub_model_id trl-internal-testing/rm_tldr_1b |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @qgallouedec
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. |
Bot begone! |
This PR attempts to refactor and pull all tokenization logic out of the Trainer class. Having a separate tokenization process gives us higher visibility into what's being used in training, providing more clarified logic and reducing bugs. It attempts to do the following things.
why?
Currently, the Trainer is also responsible for tokenization. It causes several issues:
duplicate tokenization steps. For example, alignment-handbook calls apply_chat_template(tokenize=False) for the dataset, followed by SFT/DPO trainer calling tokenized again. To remove duplication, we only needed to go through the dataset once by calling
apply_chat_template(tokenize=True)
truncation logic happens in various places and is hard to predict. SFTTrainer calls it the
max_seq_length
, RewardModeling calls itmax_length
, DPO/KTOTrainers call itmax_length
,max_prompt_length
,max_target_length
. There are also different truncation logics. E.g., [(truncate the prompt if prompt + chosen is too long)](
trl/trl/trainer/dpo_trainer.py
Lines 797 to 799 in 99f2c94
learning to generate EOS tokens. Learning to generate EOS tokens #1623 (comment) suggested that EOS tokens always 1) correspond to -100 in the labels and 2) if the dataset contains the EOS token before collating, then the attention mask of EOS token is also 1. It's possible that the model may never learn to generate EOS tokens.
dataset_num_proc
is not uniformly applied, as a result [ORPO] Enable batched tokenization & multiprocessing to process large datasets #1624 is needed. There is also the question of hashable tokenizationDataset mixer (e.g., in our h4 codebase), that should be more widely available to use in TRL and can be combined with this class.
The current design
The current design roughly looks like this. Note that we can still put it in
Trainer.__init__
so users don't have to configure it directly.