from typing import Any
import torch
from torch.nn import functional as F
from torchmetrics.utilities.checks import _check_same_shape
import tsl
from .functional import mape, smape
from .metric_base import MaskedMetric
[docs]class MaskedMAE(MaskedMetric):
"""Mean Absolute Error Metric.
Args:
mask_nans (bool, optional): Whether to automatically mask nan values.
mask_inf (bool, optional): Whether to automatically mask infinite
values.
at (int, optional): Whether to compute the metric only w.r.t. a certain
time step.
"""
is_differentiable: bool = True
higher_is_better: bool = False
full_state_update: bool = False
def __init__(self, mask_nans=False, mask_inf=False, at=None, **kwargs: Any):
super(MaskedMAE, self).__init__(
metric_fn=F.l1_loss,
mask_nans=mask_nans,
mask_inf=mask_inf,
metric_fn_kwargs={'reduction': 'none'},
at=at,
**kwargs,
)
[docs]class MaskedMAPE(MaskedMetric):
"""Mean Absolute Percentage Error Metric.
Args:
mask_nans (bool, optional): Whether to automatically mask nan values.
at (int, optional): Whether to compute the metric only w.r.t. a certain
time step.
"""
is_differentiable: bool = True
higher_is_better: bool = False
full_state_update: bool = False
def __init__(self, mask_nans=False, at=None, **kwargs: Any):
super(MaskedMAPE, self).__init__(
metric_fn=mape,
mask_nans=mask_nans,
mask_inf=True,
metric_fn_kwargs={'reduction': 'none'},
at=at,
**kwargs,
)
[docs]class MaskedSMAPE(MaskedMetric):
"""Symmetric Mean Absolute Percentage Error Metric.
Args:
mask_nans (bool, optional): Whether to automatically mask nan values.
at (int, optional): Whether to compute the metric only w.r.t. a certain
time step.
"""
is_differentiable: bool = True
higher_is_better: bool = False
full_state_update: bool = False
def __init__(self, mask_nans=False, at=None, **kwargs: Any):
super(MaskedSMAPE, self).__init__(
metric_fn=smape,
mask_nans=mask_nans,
mask_inf=True,
metric_fn_kwargs={'reduction': 'none'},
at=at,
**kwargs,
)
[docs]class MaskedMSE(MaskedMetric):
"""Mean Squared Error Metric.
Args:
mask_nans (bool, optional): Whether to automatically mask nan values.
mask_inf (bool, optional): Whether to automatically mask infinite
values.
at (int, optional): Whether to compute the metric only w.r.t. a certain
time step.
"""
is_differentiable: bool = True
higher_is_better: bool = False
full_state_update: bool = False
def __init__(self, mask_nans=False, mask_inf=False, at=None, **kwargs: Any):
super(MaskedMSE, self).__init__(
metric_fn=F.mse_loss,
mask_nans=mask_nans,
mask_inf=mask_inf,
metric_fn_kwargs={'reduction': 'none'},
at=at,
**kwargs,
)
[docs]class MaskedMRE(MaskedMetric):
"""Mean Relative Error Metric.
Args:
mask_nans (bool, optional): Whether to automatically mask nan values.
mask_inf (bool, optional): Whether to automatically mask infinite
values.
at (int, optional): Whether to compute the metric only w.r.t. a certain
time step.
"""
is_differentiable: bool = True
higher_is_better: bool = False
full_state_update: bool = False
def __init__(self, mask_nans=False, mask_inf=False, at=None, **kwargs: Any):
super(MaskedMRE, self).__init__(
metric_fn=F.l1_loss,
mask_nans=mask_nans,
mask_inf=mask_inf,
metric_fn_kwargs={'reduction': 'none'},
at=at,
**kwargs,
)
self.add_state(
'tot', dist_reduce_fx='sum', default=torch.tensor(0.0, dtype=torch.float)
)
def _compute_masked(self, y_hat, y, mask):
_check_same_shape(y_hat, y)
val = self.metric_fn(y_hat, y)
mask = self._check_mask(mask, val)
val = torch.where(mask, val, torch.zeros_like(val))
y_masked = torch.where(mask, y, torch.zeros_like(y))
return val.sum(), mask.sum(), y_masked.sum()
def _compute_std(self, y_hat, y):
_check_same_shape(y_hat, y)
val = self.metric_fn(y_hat, y)
return val.sum(), val.numel(), y.sum()
[docs] def compute(self):
if self.tot > tsl.epsilon:
return self.value / self.tot
return self.value
[docs] def update(self, y_hat, y, mask=None):
y_hat = y_hat[:, self.at]
y = y[:, self.at]
if mask is not None:
mask = mask[:, self.at]
if self.is_masked(mask):
val, numel, tot = self._compute_masked(y_hat, y, mask)
else:
val, numel, tot = self._compute_std(y_hat, y)
self.value += val
self.numel += numel
self.tot += tot