Skip to content

Commit

Permalink
revert changes to Model
Browse files Browse the repository at this point in the history
  • Loading branch information
radekosmulski committed Jul 5, 2023
1 parent c7070dc commit bc6825d
Showing 1 changed file with 10 additions and 22 deletions.
32 changes: 10 additions & 22 deletions merlin/models/torch/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,16 +88,19 @@ def forward(
def training_step(self, batch, batch_idx):
"""Performs a training step with a single batch."""
del batch_idx
loss_and_metrics = self._step(batch)
if isinstance(batch, Batch):
features = batch.features
targets = batch.targets
else:
features, targets = batch

return loss_and_metrics["loss"]
predictions = self(features, batch=Batch(features, targets))

def validation_step(self, batch, batch_idx):
"""Performs a validation step with a single batch."""
del batch_idx
loss_and_metrics = self._step(batch)
loss_and_metrics = compute_loss(predictions, targets, self.model_outputs())
for name, value in loss_and_metrics.items():
self.log(f"train_{name}", value)

return loss_and_metrics
return loss_and_metrics["loss"]

def configure_optimizers(self):
"""Configures the optimizer for the model."""
Expand All @@ -115,21 +118,6 @@ def last(self) -> nn.Module:
"""Returns the last block in the model."""
return self.values[-1]

def _step(self, batch):
if isinstance(batch, Batch):
features = batch.features
targets = batch.targets
else:
features, targets = batch
predictions = self(features, batch=Batch(features, targets))

loss_and_metrics = compute_loss(predictions, targets, self.model_outputs())
train_or_eval = 'train' if self.training else 'valid'
for name, value in loss_and_metrics.items():
self.log(f"{train_or_eval}_{name}", value)

return loss_and_metrics


def compute_loss(
predictions: Union[torch.Tensor, Dict[str, torch.Tensor]],
Expand Down

0 comments on commit bc6825d

Please sign in to comment.