Source code for tsl.metrics.torch.metrics

from typing import Any

import torch
from torch.nn import functional as F
from torchmetrics.utilities.checks import _check_same_shape

import tsl

from .functional import mape, smape
from .metric_base import MaskedMetric


[docs]class MaskedMAE(MaskedMetric): """Mean Absolute Error Metric. Args: mask_nans (bool, optional): Whether to automatically mask nan values. mask_inf (bool, optional): Whether to automatically mask infinite values. at (int, optional): Whether to compute the metric only w.r.t. a certain time step. """ is_differentiable: bool = True higher_is_better: bool = False full_state_update: bool = False def __init__(self, mask_nans=False, mask_inf=False, at=None, **kwargs: Any): super(MaskedMAE, self).__init__( metric_fn=F.l1_loss, mask_nans=mask_nans, mask_inf=mask_inf, metric_fn_kwargs={'reduction': 'none'}, at=at, **kwargs, )
[docs]class MaskedMAPE(MaskedMetric): """Mean Absolute Percentage Error Metric. Args: mask_nans (bool, optional): Whether to automatically mask nan values. at (int, optional): Whether to compute the metric only w.r.t. a certain time step. """ is_differentiable: bool = True higher_is_better: bool = False full_state_update: bool = False def __init__(self, mask_nans=False, at=None, **kwargs: Any): super(MaskedMAPE, self).__init__( metric_fn=mape, mask_nans=mask_nans, mask_inf=True, metric_fn_kwargs={'reduction': 'none'}, at=at, **kwargs, )
[docs]class MaskedSMAPE(MaskedMetric): """Symmetric Mean Absolute Percentage Error Metric. Args: mask_nans (bool, optional): Whether to automatically mask nan values. at (int, optional): Whether to compute the metric only w.r.t. a certain time step. """ is_differentiable: bool = True higher_is_better: bool = False full_state_update: bool = False def __init__(self, mask_nans=False, at=None, **kwargs: Any): super(MaskedSMAPE, self).__init__( metric_fn=smape, mask_nans=mask_nans, mask_inf=True, metric_fn_kwargs={'reduction': 'none'}, at=at, **kwargs, )
[docs]class MaskedMSE(MaskedMetric): """Mean Squared Error Metric. Args: mask_nans (bool, optional): Whether to automatically mask nan values. mask_inf (bool, optional): Whether to automatically mask infinite values. at (int, optional): Whether to compute the metric only w.r.t. a certain time step. """ is_differentiable: bool = True higher_is_better: bool = False full_state_update: bool = False def __init__(self, mask_nans=False, mask_inf=False, at=None, **kwargs: Any): super(MaskedMSE, self).__init__( metric_fn=F.mse_loss, mask_nans=mask_nans, mask_inf=mask_inf, metric_fn_kwargs={'reduction': 'none'}, at=at, **kwargs, )
[docs]class MaskedMRE(MaskedMetric): """Mean Relative Error Metric. Args: mask_nans (bool, optional): Whether to automatically mask nan values. mask_inf (bool, optional): Whether to automatically mask infinite values. at (int, optional): Whether to compute the metric only w.r.t. a certain time step. """ is_differentiable: bool = True higher_is_better: bool = False full_state_update: bool = False def __init__(self, mask_nans=False, mask_inf=False, at=None, **kwargs: Any): super(MaskedMRE, self).__init__( metric_fn=F.l1_loss, mask_nans=mask_nans, mask_inf=mask_inf, metric_fn_kwargs={'reduction': 'none'}, at=at, **kwargs, ) self.add_state( 'tot', dist_reduce_fx='sum', default=torch.tensor(0.0, dtype=torch.float) ) def _compute_masked(self, y_hat, y, mask): _check_same_shape(y_hat, y) val = self.metric_fn(y_hat, y) mask = self._check_mask(mask, val) val = torch.where(mask, val, torch.zeros_like(val)) y_masked = torch.where(mask, y, torch.zeros_like(y)) return val.sum(), mask.sum(), y_masked.sum() def _compute_std(self, y_hat, y): _check_same_shape(y_hat, y) val = self.metric_fn(y_hat, y) return val.sum(), val.numel(), y.sum()
[docs] def compute(self): if self.tot > tsl.epsilon: return self.value / self.tot return self.value
[docs] def update(self, y_hat, y, mask=None): y_hat = y_hat[:, self.at] y = y[:, self.at] if mask is not None: mask = mask[:, self.at] if self.is_masked(mask): val, numel, tot = self._compute_masked(y_hat, y, mask) else: val, numel, tot = self._compute_std(y_hat, y) self.value += val self.numel += numel self.tot += tot