Source code for tsl.metrics.torch.metric_base

import inspect
from copy import deepcopy
from functools import partial
from typing import Any

import torch
from torchmetrics import Metric
from torchmetrics.utilities.checks import _check_same_shape


[docs]def convert_to_masked_metric(metric_fn, **kwargs): """ Simple utility function to transform a callable into a `MaskedMetric`. Args: metric_fn: Callable to be wrapped. **kwargs: Keyword arguments that will be passed to the callable. Returns: """ if not isinstance(metric_fn, MaskedMetric): if 'reduction' in inspect.getfullargspec(metric_fn).args: metric_kwargs = {'reduction': 'none'} else: metric_kwargs = dict() return MaskedMetric(metric_fn, metric_fn_kwargs=metric_kwargs, **kwargs) assert not len(kwargs) return deepcopy(metric_fn)
[docs]class MaskedMetric(Metric): r"""Base class to implement the metrics used in `tsl`. In particular a `MaskedMetric` accounts for missing values in the input sequences by accepting a boolean mask as additional input. Args: metric_fn: Base function to compute the metric point wise. mask_nans (bool, optional): Whether to automatically mask nan values. mask_inf (bool, optional): Whether to automatically mask infinite values. at (int, optional): Whether to compute the metric only w.r.t. a certain time step. """ is_differentiable: bool = None higher_is_better: bool = None full_state_update: bool = None def __init__(self, metric_fn, mask_nans=False, mask_inf=False, metric_fn_kwargs=None, at=None, full_state_update: bool = None, **kwargs: Any): # set 'full_state_update' before Metric instantiation if full_state_update is not None: self.__dict__['full_state_update'] = full_state_update super(MaskedMetric, self).__init__(**kwargs) if metric_fn_kwargs is None: metric_fn_kwargs = dict() self.metric_fn = partial(metric_fn, **metric_fn_kwargs) self.mask_nans = mask_nans self.mask_inf = mask_inf if at is None: self.at = slice(None) else: self.at = slice(at, at + 1) self.add_state('value', dist_reduce_fx='sum', default=torch.tensor(0., dtype=torch.float)) self.add_state('numel', dist_reduce_fx='sum', default=torch.tensor(0., dtype=torch.float)) def _check_mask(self, mask, val): if mask is None: mask = torch.ones_like(val, dtype=torch.bool) else: mask = mask.bool() _check_same_shape(mask, val) if self.mask_nans: mask = mask & ~torch.isnan(val) if self.mask_inf: mask = mask & ~torch.isinf(val) return mask def _compute_masked(self, y_hat, y, mask): _check_same_shape(y_hat, y) val = self.metric_fn(y_hat, y) mask = self._check_mask(mask, val) val = torch.where(mask, val, torch.zeros_like(val)) return val.sum(), mask.sum() def _compute_std(self, y_hat, y): _check_same_shape(y_hat, y) val = self.metric_fn(y_hat, y) return val.sum(), val.numel()
[docs] def is_masked(self, mask): return self.mask_inf or self.mask_nans or (mask is not None)
[docs] def update(self, y_hat, y, mask=None): y_hat = y_hat[:, self.at] y = y[:, self.at] if mask is not None: mask = mask[:, self.at] if self.is_masked(mask): val, numel = self._compute_masked(y_hat, y, mask) else: val, numel = self._compute_std(y_hat, y) self.value += val self.numel += numel
[docs] def compute(self): if self.numel > 0: return self.value / self.numel return self.value