Source code for tsl.metrics.torch.pinball_loss

from tsl.metrics.torch import pinball_loss
from tsl.metrics.torch.metric_base import MaskedMetric


[docs]class MaskedPinballLoss(MaskedMetric): """Quantile loss. Args: q (float): Target quantile. mask_nans (bool, optional): Whether to automatically mask nan values. mask_inf (bool, optional): Whether to automatically mask infinite values. compute_on_step (bool, optional): Whether to compute the metric right-away or if accumulate the results. This should be :obj:`True` when using the metric to compute a loss function, :obj:`False` if the metric is used for logging the aggregate error across different mini-batches. at (int, optional): Whether to compute the metric only w.r.t. a certain time step. """ is_differentiable: bool = True higher_is_better: bool = False full_state_update: bool = False def __init__(self, q, mask_nans=False, mask_inf=False, at=None, **kwargs): super(MaskedPinballLoss, self).__init__(metric_fn=pinball_loss, mask_nans=mask_nans, mask_inf=mask_inf, metric_fn_kwargs={'q': q}, at=at, **kwargs)