Skip to content

Commit

Permalink
Merge branch 'main' into torch/seq_padding
Browse files Browse the repository at this point in the history
  • Loading branch information
edknv authored Jul 4, 2023
2 parents db763cd + 4fbf821 commit bb6add9
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 18 deletions.
2 changes: 1 addition & 1 deletion merlin/models/torch/outputs/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class BinaryOutput(ModelOutput):
The metrics used for evaluation. Default includes Accuracy, AUROC, Precision, and Recall.
"""

DEFAULT_LOSS_CLS = nn.BCEWithLogitsLoss
DEFAULT_LOSS_CLS = nn.BCELoss
DEFAULT_METRICS_CLS = (Accuracy, AUROC, Precision, Recall)

def __init__(
Expand Down
33 changes: 17 additions & 16 deletions tests/unit/torch/models/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def test_training_step_values(self):
loss = model.training_step((features, targets), 0)
(weights, bias) = model.parameters()
expected_outputs = nn.Sigmoid()(torch.matmul(features["feature"], weights.T) + bias)
expected_loss = nn.BCEWithLogitsLoss()(expected_outputs, targets["target"])
expected_loss = nn.BCELoss()(expected_outputs, targets["target"])
assert torch.allclose(loss, expected_loss)

def test_training_step_with_dataloader(self):
Expand Down Expand Up @@ -228,11 +228,11 @@ def test_train_classification_with_lightning_trainer(self, music_streaming_data,

class TestComputeLoss:
def test_tensor_inputs(self):
predictions = torch.randn(2, 1)
predictions = torch.sigmoid(torch.randn(2, 1))
targets = torch.randint(2, (2, 1), dtype=torch.float32)
model_outputs = [mm.BinaryOutput(ColumnSchema("a"))]
results = compute_loss(predictions, targets, model_outputs)
expected_loss = nn.BCEWithLogitsLoss()(predictions, targets)
expected_loss = nn.BCELoss()(predictions, targets)
expected_auroc = AUROC(task="binary")(predictions, targets)
expected_acc = Accuracy(task="binary")(predictions, targets)
expected_prec = Precision(task="binary")(predictions, targets)
Expand All @@ -253,48 +253,49 @@ def test_tensor_inputs(self):
assert torch.allclose(results["binary_recall"], expected_rec)

def test_no_metrics(self):
predictions = torch.randn(2, 1)
predictions = torch.sigmoid(torch.randn(2, 1))
targets = torch.randint(2, (2, 1), dtype=torch.float32)
model_outputs = [mm.BinaryOutput(ColumnSchema("a"))]
results = compute_loss(predictions, targets, model_outputs, compute_metrics=False)
assert sorted(results.keys()) == ["loss"]

def test_dict_inputs(self):
predictions = {"a": torch.randn(2, 1)}
outputs = mm.ParallelBlock({"a": mm.BinaryOutput(ColumnSchema("a"))})
predictions = outputs(torch.randn(2, 1))
targets = {"a": torch.randint(2, (2, 1), dtype=torch.float32)}
model_outputs = (mm.BinaryOutput(ColumnSchema("a")),)
results = compute_loss(predictions, targets, model_outputs)
expected_loss = nn.BCEWithLogitsLoss()(predictions["a"], targets["a"])

results = compute_loss(predictions, targets, outputs.find(mm.ModelOutput))
expected_loss = nn.BCELoss()(predictions["a"], targets["a"])
assert torch.allclose(results["loss"], expected_loss)

def test_mixed_inputs(self):
predictions = {"a": torch.randn(2, 1)}
targets = torch.randint(2, (2, 1), dtype=torch.float32)
model_outputs = (mm.BinaryOutput(ColumnSchema("a")),)
model_outputs = (mm.RegressionOutput(ColumnSchema("a")),)
results = compute_loss(predictions, targets, model_outputs)
expected_loss = nn.BCEWithLogitsLoss()(predictions["a"], targets)
expected_loss = nn.MSELoss()(predictions["a"], targets)
assert torch.allclose(results["loss"], expected_loss)

def test_single_model_output(self):
predictions = {"foo": torch.randn(2, 1)}
targets = {"foo": torch.randint(2, (2, 1), dtype=torch.float32)}
model_outputs = [mm.BinaryOutput(ColumnSchema("foo"))]
model_outputs = [mm.RegressionOutput(ColumnSchema("foo"))]
results = compute_loss(predictions, targets, model_outputs)
expected_loss = nn.BCEWithLogitsLoss()(predictions["foo"], targets["foo"])
expected_loss = nn.MSELoss()(predictions["foo"], targets["foo"])
assert torch.allclose(results["loss"], expected_loss)

def test_tensor_input_no_targets(self):
predictions = torch.randn(2, 1)
binary_output = mm.BinaryOutput(ColumnSchema("foo"))
binary_output = mm.RegressionOutput(ColumnSchema("foo"))
results = compute_loss(predictions, None, (binary_output,))
expected_loss = nn.BCEWithLogitsLoss()(predictions, torch.zeros(2, 1))
expected_loss = nn.MSELoss()(predictions, torch.zeros(2, 1))
assert torch.allclose(results["loss"], expected_loss)

def test_dict_input_no_targets(self):
predictions = {"foo": torch.randn(2, 1)}
binary_output = mm.BinaryOutput(ColumnSchema("foo"))
binary_output = mm.RegressionOutput(ColumnSchema("foo"))
results = compute_loss(predictions, None, (binary_output,))
expected_loss = nn.BCEWithLogitsLoss()(predictions["foo"], torch.zeros(2, 1))
expected_loss = nn.MSELoss()(predictions["foo"], torch.zeros(2, 1))
assert torch.allclose(results["loss"], expected_loss)

def test_no_target_raises_error(self):
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/torch/outputs/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def test_init(self):
binary_output = mm.BinaryOutput()

assert isinstance(binary_output, mm.BinaryOutput)
assert isinstance(binary_output.loss, nn.BCEWithLogitsLoss)
assert isinstance(binary_output.loss, nn.BCELoss)
assert binary_output.metrics == [
Accuracy(task="binary"),
AUROC(task="binary"),
Expand Down

0 comments on commit bb6add9

Please sign in to comment.