Source code for tsl.metrics.torch.metric_wrappers

from typing import Any

from torch.nn import Identity

from .metric_base import MaskedMetric
from ...nn.layers import Select


[docs]class MaskedMetricWrapper(MaskedMetric): def __init__(self, metric: MaskedMetric, input_preprocessing=None, target_preprocessing=None, mask_preprocessing=None): super(MaskedMetricWrapper, self).__init__(None) self.metric = metric if input_preprocessing is None: input_preprocessing = Identity if target_preprocessing is None: target_preprocessing = Identity if mask_preprocessing is None: mask_preprocessing = Identity self.input_preprocessing = input_preprocessing self.target_preprocessing = target_preprocessing self.mask_preprocessing = mask_preprocessing
[docs] def forward(self, *args: Any, **kwargs: Any) -> Any: return self.metric(*args, **kwargs)
[docs] def update(self, y_hat, y, mask=None): y_hat = self.input_preprocessing(y_hat) y = self.target_preprocessing(y) if mask is not None: mask = self.mask_preprocessing(mask) return self.metric.update(y_hat, y, mask)
[docs] def compute(self): return self.metric.compute()
[docs] def reset(self): self.metric.reset()
[docs]class SelectMetricWrapper(MaskedMetricWrapper): def __init__(self, metric, dim, input_idx=None, target_idx=None, mask_idx=None): if input_idx is not None: input_preprocessing = Select(dim, input_idx) else: input_preprocessing = None if target_idx is not None: target_preprocessing = Select(dim, target_idx) else: target_preprocessing = None if mask_idx is not None: mask_preprocessing = Select(dim, mask_idx) else: mask_preprocessing = None super(SelectMetricWrapper, self).__init__(metric, input_preprocessing=input_preprocessing, target_preprocessing=target_preprocessing, mask_preprocessing=mask_preprocessing)