Skip to content

Commit

Permalink
handle return_target more intelligently
Browse files Browse the repository at this point in the history
  • Loading branch information
landoskape committed May 20, 2024
1 parent 2002ce5 commit e53389f
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
1 change: 0 additions & 1 deletion dominoes/experiments/arglib.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,6 @@ def add_checkpointing(parser):

def add_dataset_parameters(parser):
"""add generic dataset parameters"""
parser.add_argument("--return_target", type=argbool, default=False, help="whether to return the target (default=False, True if supervised)")
parser.add_argument("--threads", type=int, default=1, help="the number of threads to use for generating batches (default=1)")
parser.add_argument("--ignore_index", type=int, default=-100, help="the index to ignore in the loss function (default=-100)")
return parser
Expand Down
10 changes: 7 additions & 3 deletions dominoes/experiments/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,11 +348,9 @@ def prepare_dataset(self):
dataset_parameters = vars(self.args).copy()
dataset_parameters.pop("task", None)
dataset_parameters.pop("device", None)
if self.args.learning_mode == "supervised":
dataset_parameters["return_target"] = True
return get_dataset(self.args.task, build=True, device=self.device, **dataset_parameters)

def make_train_parameters(self, dataset, train=True):
def make_train_parameters(self, dataset, train=True, **parameter_updates):
"""simple method for getting training parameters"""
# get the training parameters
parameters = {}
Expand All @@ -372,6 +370,12 @@ def make_train_parameters(self, dataset, train=True):
parameters["gamma"] = self.args.gamma
parameters["save_loss"] = self.args.save_loss
parameters["save_reward"] = self.args.save_reward
# Handle return_target
# (If required for supervised learning or saving loss, then return_target=True, otherwise False)
# (When testing but not using supervised or saving loss during training, use parameter_updates)
parameters["return_target"] = self.args.learning_mode == "supervised" or self.args.save_loss
# additionally update parameters directly if requested
parameters.update(parameter_updates)
return parameters

def plot_ready(self, name):
Expand Down

0 comments on commit e53389f

Please sign in to comment.