Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SCoRE] initial score stage 1 #2115

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions trl/trainer/score_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass, field

from trl.trainer.online_dpo_config import OnlineDPOConfig


@dataclass
class SCoREConfig(OnlineDPOConfig):
r"""
Configuration class for the [`SCoRETrainer`].

Subclass of [`OnlineDPOConfig`] we can use all its arguments and add the following:

"""

# Prompts
correction_instruction: str = field(
default="The previous response may contain errors. Please review and correct any mistakes: ",
metadata={"help": "Instruction for self-correction in the second attempt"},
)

first_attempt_prefix: str = field(
default="First attempt: ", metadata={"help": "Prefix for the first attempt in the second attempt prompt"}
)

second_attempt_prefix: str = field(
default="Improved response: ", metadata={"help": "Prefix for the second attempt in the model output"}
)

def __post_init__(self):
super().__post_init__()

# Ensure that the correction instruction ends with a space
if not self.correction_instruction.endswith(" "):
self.correction_instruction += " "

# Ensure that the prefixes end with a space
if not self.first_attempt_prefix.endswith(" "):
self.first_attempt_prefix += " "
if not self.second_attempt_prefix.endswith(" "):
self.second_attempt_prefix += " "
226 changes: 226 additions & 0 deletions trl/trainer/score_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedTokenizerBase

from .online_dpo_trainer import OnlineDPOTrainer
from .score_config import SCoREConfig
from .utils import get_reward


class SCoRETrainer(OnlineDPOTrainer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.score_config = SCoREConfig()

# Add SCoRE-specific statistics
self.stats.update(
{
"loss/stage1": [],
"kl_div/first_attempt": [],
"reward/second_attempt": [],
}
)

def _generate_completions(self, model, prompts):
with self.accelerator.unwrap_model(model) as unwrapped_model:
# Generate first attempt
first_attempt = unwrapped_model.generate(
input_ids=prompts["input_ids"],
attention_mask=prompts["attention_mask"],
generation_config=self.generation_config,
)

# Prepare input for second attempt
second_attempt_prompt = self._prepare_second_attempt_prompt(prompts, first_attempt)

# Generate second attempt
second_attempt = unwrapped_model.generate(
input_ids=second_attempt_prompt["input_ids"],
attention_mask=second_attempt_prompt["attention_mask"],
generation_config=self.generation_config,
)

return first_attempt, second_attempt

def _prepare_second_attempt_prompt(self, prompts, first_attempt):
context_length = prompts["input_ids"].shape[1]
first_completion = first_attempt[:, context_length:]
correction_instruction = (
self.tokenizer.encode(
self.score_config.correction_instruction, return_tensors="pt", add_special_tokens=False
)
.repeat(prompts["input_ids"].shape[0], 1)
.to(first_attempt.device)
)

second_attempt_input_ids = torch.cat([prompts["input_ids"], first_completion, correction_instruction], dim=1)

second_attempt_attention_mask = torch.ones_like(second_attempt_input_ids)

return {"input_ids": second_attempt_input_ids, "attention_mask": second_attempt_attention_mask}

def _compute_stage1_loss(self, model, ref_model, first_attempt, second_attempt, prompts, ground_truth_completions):
context_length = prompts["input_ids"].shape[1]

# Compute logprobs for first attempt
first_attempt_logits = model(first_attempt["input_ids"], attention_mask=first_attempt["attention_mask"]).logits
first_attempt_logprobs = F.log_softmax(first_attempt_logits[:, context_length - 1 : -1], dim=-1)

# Compute KL divergence for first attempt
with torch.no_grad():
ref_first_attempt_logits = ref_model(
first_attempt["input_ids"], attention_mask=first_attempt["attention_mask"]
).logits
ref_first_attempt_logprobs = F.log_softmax(ref_first_attempt_logits[:, context_length - 1 : -1], dim=-1)

# Create a mask for non-padding tokens
non_padding_mask = (first_attempt["input_ids"][:, context_length:] != self.tokenizer.pad_token_id).float()

kl_div = F.kl_div(ref_first_attempt_logprobs, first_attempt_logprobs, reduction="none", log_target=True).sum(
-1
)
kl_div = (kl_div * non_padding_mask).sum() / non_padding_mask.sum()

# Compute reward for second attempt against ground truth
second_attempt_reward = self._compute_rewards(second_attempt, prompts, ground_truth_completions)

# Compute loss
kl_loss = self.score_config.beta * kl_div
reward_loss = -second_attempt_reward.mean()

# reinforce loss
loss = kl_loss + reward_loss

# Log statistics
self.stats["loss/stage1"].append(loss.item())
self.stats["kl_div/first_attempt"].append(kl_div.mean().item())
self.stats["reward/second_attempt"].append(second_attempt_reward.mean().item())

return loss

def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
model.train()
ref_model = self.ref_model
ref_model.eval()

inputs = self._prepare_inputs(inputs)
prompts = {
"input_ids": inputs["prompt_input_ids"],
"attention_mask": inputs["prompt_attention_mask"],
}
ground_truth_completions = {
"input_ids": inputs["completion_input_ids"],
"attention_mask": inputs["completion_attention_mask"],
}

# Generate completions (both first and second attempts)
first_attempt, second_attempt = self._generate_completions(model, prompts)

# Process completions
first_attempt_data = self._process_completion(first_attempt, prompts)
second_attempt_data = self._process_completion(second_attempt, prompts)

# Compute Stage I loss
loss = self._compute_stage1_loss(
model, ref_model, first_attempt_data, second_attempt_data, prompts, ground_truth_completions
)

if self.args.n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu parallel training

if self.args.gradient_accumulation_steps > 1:
loss = loss / self.args.gradient_accumulation_steps

self.accelerator.backward(loss)

return loss.detach()

def _compute_rewards(self, completions, prompts, ground_truth_completions):
context_length = prompts["input_ids"].shape[1]
with torch.no_grad():
_, generated_scores, _ = get_reward(
self.reward_model, completions["input_ids"], self.tokenizer.pad_token_id, context_length
)

# Compute scores for ground-truth completions
ground_truth_input_ids = torch.cat([prompts["input_ids"], ground_truth_completions["input_ids"]], dim=1)
_, ground_truth_scores, _ = get_reward(
self.reward_model, ground_truth_input_ids, self.tokenizer.pad_token_id, context_length
)

if self.args.missing_eos_penalty is not None:
contain_eos = torch.any(completions["input_ids"] == self.tokenizer.eos_token_id, dim=-1)
generated_scores[~contain_eos] -= self.args.missing_eos_penalty

return generated_scores - ground_truth_scores

def _process_completion(self, completion, prompts):
context_length = prompts["input_ids"].shape[1]
completion_ids = completion[:, context_length:]
completion_ids, completion_mask = self.truncate_right(
completion_ids, self.tokenizer.eos_token_id, self.tokenizer.pad_token_id
)
return {
"input_ids": torch.cat((prompts["input_ids"], completion_ids), dim=1),
"attention_mask": torch.cat((prompts["attention_mask"], completion_mask), dim=1),
}

@staticmethod
def truncate_right(tokens, eos_token_id, pad_token_id):
eos_index = (tokens == eos_token_id).long().argmax(dim=-1)
eos_index = torch.where(eos_index > 0, eos_index, tokens.shape[1])
mask = torch.arange(tokens.shape[1], device=tokens.device)[None, :] < eos_index[:, None]
tokens = tokens.masked_fill(~mask, pad_token_id)
return tokens, mask

@staticmethod
def tokenize_row(feature, is_encoder_decoder: bool, tokenizer: PreTrainedTokenizerBase) -> Dict[str, Any]:
"""Tokenize a single row from a DPO specific dataset."""
if not is_encoder_decoder:
prompt_tokens = tokenizer(feature["prompt"], add_special_tokens=False)
# Add BOS token to head of prompt. Avoid adding if it's already there
if tokenizer.bos_token_id is not None:
prompt_len_input_ids = len(prompt_tokens["input_ids"])
if prompt_len_input_ids == 0 or tokenizer.bos_token_id != prompt_tokens["input_ids"][0]:
prompt_tokens["input_ids"] = [tokenizer.bos_token_id] + prompt_tokens["input_ids"]
prompt_tokens["attention_mask"] = [1] + prompt_tokens["attention_mask"]

# Tokenize the ground-truth completion
completion_tokens = tokenizer(feature["completion"], add_special_tokens=False)

# Combine prompt and completion
batch = {
"prompt_input_ids": prompt_tokens["input_ids"],
"prompt_attention_mask": prompt_tokens["attention_mask"],
"completion_input_ids": completion_tokens["input_ids"],
"completion_attention_mask": completion_tokens["attention_mask"],
}
else:
prompt_tokens = tokenizer(feature["prompt"], add_special_tokens=True)
completion_tokens = tokenizer(feature["completion"], add_special_tokens=False)

batch = {
"prompt_input_ids": prompt_tokens["input_ids"],
"prompt_attention_mask": prompt_tokens["attention_mask"],
"completion_input_ids": completion_tokens["input_ids"],
"completion_attention_mask": completion_tokens["attention_mask"],
}

return batch
Loading