Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update loss.py #315

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 45 additions & 8 deletions torchts/nn/loss.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import torch


def masked_mae_loss(y_pred, y_true):
def masked_mae_loss(y_pred: torch.tensor, y_true: torch.tensor) -> torch.tensor:
"""Calculate masked mean absolute error loss

Args:
y_pred (torch.Tensor): Predicted values
y_true (torch.Tensor): True values

Returns:
torch.Tensor: Loss
torch.Tensor: output loss
"""
mask = (y_true != 0).float()
mask /= mask.mean()
Expand All @@ -21,9 +21,7 @@ def masked_mae_loss(y_pred, y_true):
return loss.mean()


def mis_loss(
y_pred: torch.tensor, y_true: torch.tensor, interval: float
) -> torch.tensor:
def mis_loss(y_pred: torch.tensor, y_true: torch.tensor, interval: float) -> torch.tensor:
"""Calculate MIS loss

Args:
Expand All @@ -32,7 +30,7 @@ def mis_loss(
interval (float): confidence interval (e.g. 0.95 for 95% confidence interval)

Returns:
torch.tensor: output losses
torch.tensor: output loss
"""
alpha = 1 - interval
lower = y_pred[:, 0::2]
Expand All @@ -46,7 +44,7 @@ def mis_loss(
return loss


def quantile_loss(y_pred: torch.tensor, y_true: torch.tensor, quantile: float) -> float:
def quantile_loss(y_pred: torch.tensor, y_true: torch.tensor, quantile: float) -> torch.tensor:
"""Calculate quantile loss

Args:
Expand All @@ -55,9 +53,48 @@ def quantile_loss(y_pred: torch.tensor, y_true: torch.tensor, quantile: float) -
quantile (float): quantile (e.g. 0.5 for median)

Returns:
float: output losses
torch.tensor: output loss
"""
assert 0 < quantile < 1, "Quantile must be in (0, 1)"
errors = y_true - y_pred
loss = torch.max((quantile - 1) * errors, quantile * errors)
loss = torch.mean(loss)
return loss


def log_loss(y_pred: torch.tensor, y_true: torch.tensor) -> torch.tensor:
"""Ensure the predictions are in the range (0, 1)
Args:
y_pred (torch.tensor): Predicted values
y_true (torch.tensor): True values

Returns:
torch.tensor: output loss
"""
y_pred = torch.clamp(y_pred, 1e-7, 1 - 1e-7)
return -torch.mean(y_true * torch.log(y_pred) + (1 - y_true) * torch.log(1 - y_pred))


def mape_loss(y_pred: torch.tensor, y_true: torch.tensor) -> torch.tensor:
"""Calculate mean absolute percentage loss
Args:
y_pred (torch.tensor): Predicted values
y_true (torch.tensor): True values

Returns:
torch.tensor: output loss
"""
return torch.mean(torch.abs((y_true - y_pred) / y_true)) * 100


def smape_loss(y_pred: torch.tensor, y_true: torch.tensor) -> torch.tensor:
"""Calculate symmetric mean absolute percentage loss
Args:
y_pred (torch.tensor): Predicted values
y_true (torch.tensor): True values

Returns:
torch.tensor: output loss
"""
denominator = (torch.abs(y_true) + torch.abs(y_pred)) / 2.0
return torch.mean(torch.abs(y_true - y_pred) / denominator) * 100
Loading