Skip to content

Commit

Permalink
allow extra batch parameters unless specified
Browse files Browse the repository at this point in the history
  • Loading branch information
landoskape committed May 10, 2024
1 parent e43a5fa commit b33e876
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 10 deletions.
14 changes: 8 additions & 6 deletions dominoes/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def get_class_parameters(self):
"""
pass

def _check_parameters(self, reference=None, init=False, **parameters):
def _check_parameters(self, reference=None, init=False, raise_for_extra=False, **parameters):
"""
check if parameters provided in the parameters are valid (and complete)
Expand All @@ -45,17 +45,19 @@ def _check_parameters(self, reference=None, init=False, **parameters):
"""
if reference is None:
reference = self.get_class_parameters()
for param in parameters:
if param not in reference:
raise ValueError(f"parameter {param} not recognized for task {self.task}")
if raise_for_extra:
# check if extra parameters are provided and raise error if they are
for param in parameters:
if param not in reference:
raise ValueError(f"parameter {param} not recognized for task {self.task}")
# if init==True, then this is being called by the constructor's __init__ method and
# we need to check if any required parameters without defaults are set properly
if init:
for param in reference:
if param not in parameters and reference[param] is None:
raise ValueError(f"parameter {param} not provided for task {self.task}")

def parameters(self, **prms):
def parameters(self, raise_for_extra=False, **prms):
"""
Helper method for handling default parameters for each task
Expand All @@ -70,7 +72,7 @@ def parameters(self, **prms):
# get registered parameters
prms_to_use = copy(self.prms)
# check if updates are valid
self._check_parameters(reference=prms_to_use, init=False, **prms)
self._check_parameters(reference=prms_to_use, init=False, raise_for_extra=raise_for_extra, **prms)
# update parameters
prms_to_use.update(prms)
# return to caller function
Expand Down
4 changes: 2 additions & 2 deletions dominoes/datasets/dominoe_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@ def __init__(self, task, device="cpu", **parameters):
self.task = task

# check parameters
self._check_parameters(init=True, **parameters)
self._check_parameters(init=True, raise_for_extra=True, **parameters)

# set parameters to required defaults first, then update
self.prms = self.get_class_parameters()
self.prms = self.parameters(**parameters)
self.prms = self.parameters(raise_for_extra=True, **parameters)

# create base dominoe set
self.dominoe_set = get_dominoes(self.prms["highest_dominoe"], as_torch=True)
Expand Down
4 changes: 2 additions & 2 deletions dominoes/datasets/tsp_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@ def __init__(self, device="cpu", **parameters):
self.set_device(device)

# check parameters
self._check_parameters(init=True, **parameters)
self._check_parameters(init=True, raise_for_extra=True, **parameters)

# set parameters to required defaults first, then update
self.prms = self.get_class_parameters()
self.prms = self.parameters(**parameters)
self.prms = self.parameters(raise_for_extra=True, **parameters)

@classmethod
def get_class_parameters(cls):
Expand Down

0 comments on commit b33e876

Please sign in to comment.