From 92f61e8b5a5b79e3960526012f14d10bde75a62e Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 24 Sep 2024 20:18:09 +0200 Subject: [PATCH 1/4] initial score stage 1 --- trl/trainer/score_config.py | 69 +++++++++++++++ trl/trainer/score_trainer.py | 161 +++++++++++++++++++++++++++++++++++ 2 files changed, 230 insertions(+) create mode 100644 trl/trainer/score_config.py create mode 100644 trl/trainer/score_trainer.py diff --git a/trl/trainer/score_config.py b/trl/trainer/score_config.py new file mode 100644 index 0000000000..2e9cff56a6 --- /dev/null +++ b/trl/trainer/score_config.py @@ -0,0 +1,69 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from typing import List + +from trl.trainer.online_dpo_config import OnlineDPOConfig + + +@dataclass +class SCoREConfig(OnlineDPOConfig): + r""" + Configuration class for the [`SCoRETrainer`]. + + Subclass of [`OnlineDPOConfig`] we can use all its arguments and add the following: + + """ + # Stage I specific parameters + kl_coef: float = field( + default=0.1, + metadata={"help": "Coefficient for KL divergence loss in Stage I"} + ) + + # Prompts + correction_instruction: str = field( + default="The previous response may contain errors. Please review and correct any mistakes: ", + metadata={"help": "Instruction for self-correction in the second attempt"} + ) + + first_attempt_prefix: str = field( + default="First attempt: ", + metadata={"help": "Prefix for the first attempt in the second attempt prompt"} + ) + + second_attempt_prefix: str = field( + default="Improved response: ", + metadata={"help": "Prefix for the second attempt in the model output"} + ) + + # Training stages + num_stage1_epochs: int = field( + default=1, + metadata={"help": "Number of epochs to train in Stage I"} + ) + + + def __post_init__(self): + super().__post_init__() + + # Ensure that the correction instruction ends with a space + if not self.correction_instruction.endswith(" "): + self.correction_instruction += " " + + # Ensure that the prefixes end with a space + if not self.first_attempt_prefix.endswith(" "): + self.first_attempt_prefix += " " + if not self.second_attempt_prefix.endswith(" "): + self.second_attempt_prefix += " " \ No newline at end of file diff --git a/trl/trainer/score_trainer.py b/trl/trainer/score_trainer.py new file mode 100644 index 0000000000..9c39367ed3 --- /dev/null +++ b/trl/trainer/score_trainer.py @@ -0,0 +1,161 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, Union, Any + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .online_dpo_trainer import OnlineDPOTrainer +from .score_config import SCoREConfig + + +class SCoRETrainer(OnlineDPOTrainer): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.score_config = SCoREConfig() + + # Add SCoRE-specific statistics + self.stats.update({ + "loss/stage1": [], + "kl_div/first_attempt": [], + "reward/second_attempt": [], + }) + + def _generate_completions(self, model, prompts): + with self.accelerator.unwrap_model(model) as unwrapped_model: + # Generate first attempt + first_attempt = unwrapped_model.generate( + input_ids=prompts["input_ids"], + attention_mask=prompts["attention_mask"], + generation_config=self.generation_config, + ) + + # Prepare input for second attempt + second_attempt_prompt = self._prepare_second_attempt_prompt(prompts, first_attempt) + + # Generate second attempt + second_attempt = unwrapped_model.generate( + input_ids=second_attempt_prompt["input_ids"], + attention_mask=second_attempt_prompt["attention_mask"], + generation_config=self.generation_config, + ) + + return first_attempt, second_attempt + + def _prepare_second_attempt_prompt(self, prompts, first_attempt): + context_length = prompts["input_ids"].shape[1] + first_completion = first_attempt[:, context_length:] + correction_instruction = self.tokenizer.encode( + self.score_config.correction_instruction, + return_tensors="pt", + add_special_tokens=False + ).repeat(prompts["input_ids"].shape[0], 1).to(first_attempt.device) + + second_attempt_input_ids = torch.cat([ + prompts["input_ids"], + first_completion, + correction_instruction + ], dim=1) + + second_attempt_attention_mask = torch.ones_like(second_attempt_input_ids) + + return { + "input_ids": second_attempt_input_ids, + "attention_mask": second_attempt_attention_mask + } + + def _compute_stage1_loss(self, model, ref_model, first_attempt, second_attempt, prompts): + context_length = prompts["input_ids"].shape[1] + + # Compute logprobs for first attempt + first_attempt_logits = model(first_attempt["input_ids"], attention_mask=first_attempt["attention_mask"]).logits + first_attempt_logprobs = F.log_softmax(first_attempt_logits[:, context_length-1:-1], dim=-1) + first_attempt_token_logprobs = torch.gather( + first_attempt_logprobs, 2, first_attempt["input_ids"][:, context_length:].unsqueeze(-1) + ).squeeze(-1) + + # Compute KL divergence for first attempt + with torch.no_grad(): + ref_first_attempt_logits = ref_model(first_attempt["input_ids"], attention_mask=first_attempt["attention_mask"]).logits + ref_first_attempt_logprobs = F.log_softmax(ref_first_attempt_logits[:, context_length-1:-1], dim=-1) + + kl_div = F.kl_div(first_attempt_logprobs, ref_first_attempt_logprobs.exp(), reduction='none').sum(-1) + + # Compute reward for second attempt + second_attempt_reward = self._compute_rewards(second_attempt, context_length) + + # Compute loss + kl_loss = self.score_config.kl_coef * kl_div.mean() + reward_loss = -second_attempt_reward.mean() + + loss = kl_loss + reward_loss + + # Log statistics + self.stats["loss/stage1"].append(loss.item()) + self.stats["kl_div/first_attempt"].append(kl_div.mean().item()) + self.stats["reward/second_attempt"].append(second_attempt_reward.mean().item()) + + return loss + + def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor: + model.train() + ref_model = self.ref_model + ref_model.eval() + + inputs = self._prepare_inputs(inputs) + prompts = { + "input_ids": inputs["prompt_input_ids"], + "attention_mask": inputs["prompt_attention_mask"], + } + + # Generate completions (both first and second attempts) + first_attempt, second_attempt = self._generate_completions(model, prompts) + + # Process completions + first_attempt_data = self._process_completion(first_attempt, prompts) + second_attempt_data = self._process_completion(second_attempt, prompts) + + # Compute Stage I loss + loss = self._compute_stage1_loss(model, ref_model, first_attempt_data, second_attempt_data, prompts) + + if self.args.n_gpu > 1: + loss = loss.mean() # mean() to average on multi-gpu parallel training + + if self.args.gradient_accumulation_steps > 1: + loss = loss / self.args.gradient_accumulation_steps + + self.accelerator.backward(loss) + + return loss.detach() + + def _process_completion(self, completion, prompts): + context_length = prompts["input_ids"].shape[1] + completion_ids = completion[:, context_length:] + completion_ids, completion_mask = self.truncate_right( + completion_ids, self.tokenizer.eos_token_id, self.tokenizer.pad_token_id + ) + return { + "input_ids": torch.cat((prompts["input_ids"], completion_ids), dim=1), + "attention_mask": torch.cat((prompts["attention_mask"], completion_mask), dim=1), + } + + @staticmethod + def truncate_right(tokens, eos_token_id, pad_token_id): + eos_index = (tokens == eos_token_id).long().argmax(dim=-1) + eos_index = torch.where(eos_index > 0, eos_index, tokens.shape[1]) + mask = torch.arange(tokens.shape[1], device=tokens.device)[None, :] < eos_index[:, None] + tokens = tokens.masked_fill(~mask, pad_token_id) + return tokens, mask From 34eff84dd8b8cf2d556d84d48f3d3ebb96606596 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 24 Sep 2024 21:04:19 +0200 Subject: [PATCH 2/4] formatting --- trl/trainer/score_config.py | 27 +++++-------- trl/trainer/score_trainer.py | 75 +++++++++++++++++++++--------------- 2 files changed, 52 insertions(+), 50 deletions(-) diff --git a/trl/trainer/score_config.py b/trl/trainer/score_config.py index 2e9cff56a6..9fa51f24d0 100644 --- a/trl/trainer/score_config.py +++ b/trl/trainer/score_config.py @@ -13,7 +13,6 @@ # limitations under the License. from dataclasses import dataclass, field -from typing import List from trl.trainer.online_dpo_config import OnlineDPOConfig @@ -26,44 +25,36 @@ class SCoREConfig(OnlineDPOConfig): Subclass of [`OnlineDPOConfig`] we can use all its arguments and add the following: """ + # Stage I specific parameters - kl_coef: float = field( - default=0.1, - metadata={"help": "Coefficient for KL divergence loss in Stage I"} - ) + kl_coef: float = field(default=0.1, metadata={"help": "Coefficient for KL divergence loss in Stage I"}) # Prompts correction_instruction: str = field( default="The previous response may contain errors. Please review and correct any mistakes: ", - metadata={"help": "Instruction for self-correction in the second attempt"} + metadata={"help": "Instruction for self-correction in the second attempt"}, ) first_attempt_prefix: str = field( - default="First attempt: ", - metadata={"help": "Prefix for the first attempt in the second attempt prompt"} + default="First attempt: ", metadata={"help": "Prefix for the first attempt in the second attempt prompt"} ) second_attempt_prefix: str = field( - default="Improved response: ", - metadata={"help": "Prefix for the second attempt in the model output"} + default="Improved response: ", metadata={"help": "Prefix for the second attempt in the model output"} ) # Training stages - num_stage1_epochs: int = field( - default=1, - metadata={"help": "Number of epochs to train in Stage I"} - ) - + num_stage1_epochs: int = field(default=1, metadata={"help": "Number of epochs to train in Stage I"}) def __post_init__(self): super().__post_init__() - + # Ensure that the correction instruction ends with a space if not self.correction_instruction.endswith(" "): self.correction_instruction += " " - + # Ensure that the prefixes end with a space if not self.first_attempt_prefix.endswith(" "): self.first_attempt_prefix += " " if not self.second_attempt_prefix.endswith(" "): - self.second_attempt_prefix += " " \ No newline at end of file + self.second_attempt_prefix += " " diff --git a/trl/trainer/score_trainer.py b/trl/trainer/score_trainer.py index 9c39367ed3..a1882dd8c4 100644 --- a/trl/trainer/score_trainer.py +++ b/trl/trainer/score_trainer.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Union, Any +from typing import Any, Dict, Union import torch import torch.nn as nn @@ -20,19 +20,22 @@ from .online_dpo_trainer import OnlineDPOTrainer from .score_config import SCoREConfig +from .utils import get_reward class SCoRETrainer(OnlineDPOTrainer): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.score_config = SCoREConfig() - + self.score_config = SCoREConfig() + # Add SCoRE-specific statistics - self.stats.update({ - "loss/stage1": [], - "kl_div/first_attempt": [], - "reward/second_attempt": [], - }) + self.stats.update( + { + "loss/stage1": [], + "kl_div/first_attempt": [], + "reward/second_attempt": [], + } + ) def _generate_completions(self, model, prompts): with self.accelerator.unwrap_model(model) as unwrapped_model: @@ -45,7 +48,7 @@ def _generate_completions(self, model, prompts): # Prepare input for second attempt second_attempt_prompt = self._prepare_second_attempt_prompt(prompts, first_attempt) - + # Generate second attempt second_attempt = unwrapped_model.generate( input_ids=second_attempt_prompt["input_ids"], @@ -58,41 +61,37 @@ def _generate_completions(self, model, prompts): def _prepare_second_attempt_prompt(self, prompts, first_attempt): context_length = prompts["input_ids"].shape[1] first_completion = first_attempt[:, context_length:] - correction_instruction = self.tokenizer.encode( - self.score_config.correction_instruction, - return_tensors="pt", - add_special_tokens=False - ).repeat(prompts["input_ids"].shape[0], 1).to(first_attempt.device) - - second_attempt_input_ids = torch.cat([ - prompts["input_ids"], - first_completion, - correction_instruction - ], dim=1) + correction_instruction = ( + self.tokenizer.encode( + self.score_config.correction_instruction, return_tensors="pt", add_special_tokens=False + ) + .repeat(prompts["input_ids"].shape[0], 1) + .to(first_attempt.device) + ) + + second_attempt_input_ids = torch.cat([prompts["input_ids"], first_completion, correction_instruction], dim=1) second_attempt_attention_mask = torch.ones_like(second_attempt_input_ids) - return { - "input_ids": second_attempt_input_ids, - "attention_mask": second_attempt_attention_mask - } + return {"input_ids": second_attempt_input_ids, "attention_mask": second_attempt_attention_mask} def _compute_stage1_loss(self, model, ref_model, first_attempt, second_attempt, prompts): context_length = prompts["input_ids"].shape[1] # Compute logprobs for first attempt first_attempt_logits = model(first_attempt["input_ids"], attention_mask=first_attempt["attention_mask"]).logits - first_attempt_logprobs = F.log_softmax(first_attempt_logits[:, context_length-1:-1], dim=-1) - first_attempt_token_logprobs = torch.gather( - first_attempt_logprobs, 2, first_attempt["input_ids"][:, context_length:].unsqueeze(-1) - ).squeeze(-1) + first_attempt_logprobs = F.log_softmax(first_attempt_logits[:, context_length - 1 : -1], dim=-1) # Compute KL divergence for first attempt with torch.no_grad(): - ref_first_attempt_logits = ref_model(first_attempt["input_ids"], attention_mask=first_attempt["attention_mask"]).logits - ref_first_attempt_logprobs = F.log_softmax(ref_first_attempt_logits[:, context_length-1:-1], dim=-1) + ref_first_attempt_logits = ref_model( + first_attempt["input_ids"], attention_mask=first_attempt["attention_mask"] + ).logits + ref_first_attempt_logprobs = F.log_softmax(ref_first_attempt_logits[:, context_length - 1 : -1], dim=-1) - kl_div = F.kl_div(first_attempt_logprobs, ref_first_attempt_logprobs.exp(), reduction='none').sum(-1) + kl_div = F.kl_div(ref_first_attempt_logprobs, first_attempt_logprobs, reduction="none", log_target=True).sum( + -1 + ) # Compute reward for second attempt second_attempt_reward = self._compute_rewards(second_attempt, context_length) @@ -100,7 +99,7 @@ def _compute_stage1_loss(self, model, ref_model, first_attempt, second_attempt, # Compute loss kl_loss = self.score_config.kl_coef * kl_div.mean() reward_loss = -second_attempt_reward.mean() - + loss = kl_loss + reward_loss # Log statistics @@ -141,6 +140,18 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, return loss.detach() + def _compute_rewards(self, completions, context_length): + with torch.no_grad(): + _, scores, _ = get_reward( + self.reward_model, completions["input_ids"], self.tokenizer.pad_token_id, context_length + ) + + if self.args.missing_eos_penalty is not None: + contain_eos = torch.any(completions["input_ids"] == self.tokenizer.eos_token_id, dim=-1) + scores[~contain_eos] -= self.args.missing_eos_penalty + + return scores + def _process_completion(self, completion, prompts): context_length = prompts["input_ids"].shape[1] completion_ids = completion[:, context_length:] From b7ae4e10b6c3fabf50a641cd2726ebd251b7a976 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Wed, 25 Sep 2024 13:56:35 +0200 Subject: [PATCH 3/4] also return the completion input_ids and use them to calculate score --- trl/trainer/score_trainer.py | 63 ++++++++++++++++++++++++++++++++---- 1 file changed, 56 insertions(+), 7 deletions(-) diff --git a/trl/trainer/score_trainer.py b/trl/trainer/score_trainer.py index a1882dd8c4..dee0700364 100644 --- a/trl/trainer/score_trainer.py +++ b/trl/trainer/score_trainer.py @@ -17,6 +17,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +from transformers import PreTrainedTokenizerBase from .online_dpo_trainer import OnlineDPOTrainer from .score_config import SCoREConfig @@ -75,7 +76,7 @@ def _prepare_second_attempt_prompt(self, prompts, first_attempt): return {"input_ids": second_attempt_input_ids, "attention_mask": second_attempt_attention_mask} - def _compute_stage1_loss(self, model, ref_model, first_attempt, second_attempt, prompts): + def _compute_stage1_loss(self, model, ref_model, first_attempt, second_attempt, prompts, ground_truth_completions): context_length = prompts["input_ids"].shape[1] # Compute logprobs for first attempt @@ -94,7 +95,7 @@ def _compute_stage1_loss(self, model, ref_model, first_attempt, second_attempt, ) # Compute reward for second attempt - second_attempt_reward = self._compute_rewards(second_attempt, context_length) + second_attempt_reward = self._compute_rewards(second_attempt, prompts, ground_truth_completions) # Compute loss kl_loss = self.score_config.kl_coef * kl_div.mean() @@ -119,6 +120,10 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, "input_ids": inputs["prompt_input_ids"], "attention_mask": inputs["prompt_attention_mask"], } + ground_truth_completions = { + "input_ids": inputs["completion_input_ids"], + "attention_mask": inputs["completion_attention_mask"], + } # Generate completions (both first and second attempts) first_attempt, second_attempt = self._generate_completions(model, prompts) @@ -128,7 +133,9 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, second_attempt_data = self._process_completion(second_attempt, prompts) # Compute Stage I loss - loss = self._compute_stage1_loss(model, ref_model, first_attempt_data, second_attempt_data, prompts) + loss = self._compute_stage1_loss( + model, ref_model, first_attempt_data, second_attempt_data, prompts, ground_truth_completions + ) if self.args.n_gpu > 1: loss = loss.mean() # mean() to average on multi-gpu parallel training @@ -140,17 +147,24 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, return loss.detach() - def _compute_rewards(self, completions, context_length): + def _compute_rewards(self, completions, prompts, ground_truth_completions): + context_length = prompts["input_ids"].shape[1] with torch.no_grad(): - _, scores, _ = get_reward( + _, generated_scores, _ = get_reward( self.reward_model, completions["input_ids"], self.tokenizer.pad_token_id, context_length ) + # Compute scores for ground-truth completions + ground_truth_input_ids = torch.cat([prompts["input_ids"], ground_truth_completions["input_ids"]], dim=1) + _, ground_truth_scores, _ = get_reward( + self.reward_model, ground_truth_input_ids, self.tokenizer.pad_token_id, context_length + ) + if self.args.missing_eos_penalty is not None: contain_eos = torch.any(completions["input_ids"] == self.tokenizer.eos_token_id, dim=-1) - scores[~contain_eos] -= self.args.missing_eos_penalty + generated_scores[~contain_eos] -= self.args.missing_eos_penalty - return scores + return generated_scores - ground_truth_scores def _process_completion(self, completion, prompts): context_length = prompts["input_ids"].shape[1] @@ -170,3 +184,38 @@ def truncate_right(tokens, eos_token_id, pad_token_id): mask = torch.arange(tokens.shape[1], device=tokens.device)[None, :] < eos_index[:, None] tokens = tokens.masked_fill(~mask, pad_token_id) return tokens, mask + + @staticmethod + def tokenize_row(feature, is_encoder_decoder: bool, tokenizer: PreTrainedTokenizerBase) -> Dict[str, Any]: + """Tokenize a single row from a DPO specific dataset.""" + if not is_encoder_decoder: + prompt_tokens = tokenizer(feature["prompt"], add_special_tokens=False) + # Add BOS token to head of prompt. Avoid adding if it's already there + if tokenizer.bos_token_id is not None: + prompt_len_input_ids = len(prompt_tokens["input_ids"]) + if prompt_len_input_ids == 0 or tokenizer.bos_token_id != prompt_tokens["input_ids"][0]: + prompt_tokens["input_ids"] = [tokenizer.bos_token_id] + prompt_tokens["input_ids"] + prompt_tokens["attention_mask"] = [1] + prompt_tokens["attention_mask"] + + # Tokenize the ground-truth completion + completion_tokens = tokenizer(feature["completion"], add_special_tokens=False) + + # Combine prompt and completion + batch = { + "prompt_input_ids": prompt_tokens["input_ids"], + "prompt_attention_mask": prompt_tokens["attention_mask"], + "completion_input_ids": completion_tokens["input_ids"], + "completion_attention_mask": completion_tokens["attention_mask"], + } + else: + prompt_tokens = tokenizer(feature["prompt"], add_special_tokens=True) + completion_tokens = tokenizer(feature["completion"], add_special_tokens=False) + + batch = { + "prompt_input_ids": prompt_tokens["input_ids"], + "prompt_attention_mask": prompt_tokens["attention_mask"], + "completion_input_ids": completion_tokens["input_ids"], + "completion_attention_mask": completion_tokens["attention_mask"], + } + + return batch From bc3e1e1363d8b45551308247fd773dd97b33d4a8 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Wed, 25 Sep 2024 14:13:14 +0200 Subject: [PATCH 4/4] use beta and mask padded kl_div --- trl/trainer/score_config.py | 6 ------ trl/trainer/score_trainer.py | 9 +++++++-- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/trl/trainer/score_config.py b/trl/trainer/score_config.py index 9fa51f24d0..0d69d7bb4f 100644 --- a/trl/trainer/score_config.py +++ b/trl/trainer/score_config.py @@ -26,9 +26,6 @@ class SCoREConfig(OnlineDPOConfig): """ - # Stage I specific parameters - kl_coef: float = field(default=0.1, metadata={"help": "Coefficient for KL divergence loss in Stage I"}) - # Prompts correction_instruction: str = field( default="The previous response may contain errors. Please review and correct any mistakes: ", @@ -43,9 +40,6 @@ class SCoREConfig(OnlineDPOConfig): default="Improved response: ", metadata={"help": "Prefix for the second attempt in the model output"} ) - # Training stages - num_stage1_epochs: int = field(default=1, metadata={"help": "Number of epochs to train in Stage I"}) - def __post_init__(self): super().__post_init__() diff --git a/trl/trainer/score_trainer.py b/trl/trainer/score_trainer.py index dee0700364..18996aa438 100644 --- a/trl/trainer/score_trainer.py +++ b/trl/trainer/score_trainer.py @@ -90,17 +90,22 @@ def _compute_stage1_loss(self, model, ref_model, first_attempt, second_attempt, ).logits ref_first_attempt_logprobs = F.log_softmax(ref_first_attempt_logits[:, context_length - 1 : -1], dim=-1) + # Create a mask for non-padding tokens + non_padding_mask = (first_attempt["input_ids"][:, context_length:] != self.tokenizer.pad_token_id).float() + kl_div = F.kl_div(ref_first_attempt_logprobs, first_attempt_logprobs, reduction="none", log_target=True).sum( -1 ) + kl_div = (kl_div * non_padding_mask).sum() / non_padding_mask.sum() - # Compute reward for second attempt + # Compute reward for second attempt against ground truth second_attempt_reward = self._compute_rewards(second_attempt, prompts, ground_truth_completions) # Compute loss - kl_loss = self.score_config.kl_coef * kl_div.mean() + kl_loss = self.score_config.beta * kl_div reward_loss = -second_attempt_reward.mean() + # reinforce loss loss = kl_loss + reward_loss # Log statistics