diff --git a/dominoes/train.py b/dominoes/train.py index 71b52d8..209efba 100644 --- a/dominoes/train.py +++ b/dominoes/train.py @@ -52,27 +52,27 @@ def train(nets, optimizers, dataset, **parameters): # get input for batch input = batch["input"] - # add context inputs for batch if requested (use *context_inputs for consistent handling) - context_inputs = [] - if "context" in batch: - context_inputs.append(batch["context"]) - if "multimode" in batch: - context_inputs.append(batch["multimode"]) - # get current max output for batch max_output = batch.get("max_output", max_possible_output) # get kwargs for forward pass net_kwargs = dict( mask=batch.get("mask", None), - context_mask=batch.get("context_mask", None), - mm_mask=batch.get("mm_mask", None), init=batch.get("init", None), temperature=temperature, thompson=thompson, max_output=max_output, ) + # add context inputs for batch if requested (use *context_inputs for consistent handling) + context_inputs = [] + if "context" in batch: + context_inputs.append(batch["context"]) + net_kwargs["context_mask"] = batch.get("context_mask", None) + if "multimode" in batch: + context_inputs.append(batch["multimode"]) + net_kwargs["mm_mask"] = batch.get("mm_mask", None) + # zero gradients for opt in optimizers: opt.zero_grad()