From dfab6aa9f036268193bdba4d86b328aaf8a523b5 Mon Sep 17 00:00:00 2001 From: Andrew Landau Date: Fri, 10 May 2024 13:07:35 +0100 Subject: [PATCH] correct name --- dominoes/datasets/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dominoes/datasets/base.py b/dominoes/datasets/base.py index 1dff1ba..aa73a3b 100644 --- a/dominoes/datasets/base.py +++ b/dominoes/datasets/base.py @@ -257,9 +257,9 @@ def get_choice_score(self, choices, scores): """ return torch.gather(scores, 2, choices.unsqueeze(2)).squeeze(2) - def process_reward(self, rewards, scores, choices, gamma_transform): + def process_rewards(self, rewards, scores, choices, gamma_transform): """ - process the reward for performing policy gradient + process the rewards for performing policy gradient args: rewards: list of torch.Tensor, the rewards for each network (precomputed using `reward_function`)