Source code for tsl.metrics.torch.pinball_loss

from tsl.metrics.torch import pinball_loss
from tsl.metrics.torch.metric_base import MaskedMetric


[docs]class MaskedPinballLoss(MaskedMetric): """Quantile loss. Args: q (float): Target quantile. mask_nans (bool, optional): Whether to automatically mask nan values. mask_inf (bool, optional): Whether to automatically mask infinite values. compute_on_step (bool, optional): Whether to compute the metric right-away or if accumulate the results. This should be :obj:`True` when using the metric to compute a loss function, :obj:`False` if the metric is used for logging the aggregate error across different mini-batches. at (int, optional): Whether to compute the metric only w.r.t. a certain time step. """ is_differentiable: bool = True higher_is_better: bool = False full_state_update: bool = False def __init__(self, q, mask_nans=False, mask_inf=False, compute_on_step=True, dist_sync_on_step=False, process_group=None, dist_sync_fn=None, at=None): super(MaskedPinballLoss, self).__init__(metric_fn=pinball_loss, mask_nans=mask_nans, mask_inf=mask_inf, compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, process_group=process_group, dist_sync_fn=dist_sync_fn, metric_fn_kwargs={'q': q}, at=at)