from typing import Callable, Mapping, Optional, Tuple, Union
import torch
from tsl.typing import DataArray, SparseTensArray, TemporalIndex
from .batch_map import BatchMap, BatchMapItem
from .preprocessing import Scaler
from .spatiotemporal_dataset import SpatioTemporalDataset
from .synch_mode import HORIZON, WINDOW
[docs]class ImputationDataset(SpatioTemporalDataset):
r"""A dataset for imputation tasks. It is a subclass of
:class:`~tsl.data.SpatioTemporalDataset` and most of its attributes. The
main difference is the addition of a :obj:`eval_mask` attribute which is a
boolean mask denoting if values to evaluate imputations.
Args:
target (DataArray): Data relative to the primary channels.
eval_mask (DataArray): Boolean mask denoting values that can be used for
evaluating imputations. The mask is :obj:`True` if the corresponding
value must be used for evaluation and :obj:`False` otherwise.
index (TemporalIndex, optional): Temporal indices for the data.
(default: :obj:`None`)
mask (DataArray, optional): Boolean mask denoting if signal in data is
valid (:obj:`True`) or not (:obj:`False`).
(default: :obj:`None`)
connectivity (SparseTensArray, tuple, optional): The adjacency matrix
defining nodes' relational information. It can be either a
dense/sparse matrix :math:`\mathbf{A} \in \mathbb{R}^{N \times N}`
or an (:obj:`edge_index` :math:`\in \mathbb{N}^{2 \times E}`,
:obj:`edge_weight` :math:`\in \mathbb{R}^{E})` tuple. The input
layout will be preserved (e.g., a sparse matrix will be stored as a
:class:`torch_sparse.SparseTensor`). In any case, the connectivity
will be stored in the attribute :obj:`edge_index`, and the weights
will be eventually stored as :obj:`edge_weight`.
(default: :obj:`None`)
covariates (dict, optional): Dictionary of exogenous channels with
label. An :obj:`exogenous` element is a temporal array with node- or
graph-level channels which are covariates to the main signal. The
temporal dimension must be equal to the temporal dimension of data,
as well as the number of nodes if the exogenous is node-level.
(default: :obj:`None`)
input_map (BatchMap or dict, optional): Defines how data (i.e., the
target and the covariates) are mapped to dataset sample input. Keys
in the mapping are keys in both :obj:`item` and :obj:`item.input`,
while values are :obj:`~tsl.data.new.BatchMapItem`.
(default: :obj:`None`)
target_map (BatchMap or dict, optional): Defines how data (i.e., the
target and the covariates) are mapped to dataset sample target. Keys
in the mapping are keys in both :obj:`item` and :obj:`item.target`,
while values are :obj:`~tsl.data.new.BatchMapItem`.
(default: :obj:`None`)
auxiliary_map (BatchMap or dict, optional): Defines how data (i.e., the
target and the covariates) are added as additional attributes to the
dataset sample. Keys in the mapping are keys only in :obj:`item`,
while values are :obj:`~tsl.data.new.BatchMapItem`.
(default: :obj:`None`)
scalers (Mapping or None): Dictionary of scalers that must be used for
data preprocessing.
(default: :obj:`None`)
trend (DataArray, optional): Trend paired with main signal. Must be of
the same shape of `data`.
(default: :obj:`None`)
transform (callable, optional): A function/transform that takes in a
:class:`tsl.data.Data` object and returns a transformed version.
The data object will be transformed before every access.
(default: :obj:`None`)
window (int): Length (in number of steps) of the lookback window.
(default: 12)
stride (int): Offset (in number of steps) between a sample and the next
one.
(default: 1)
window_lag (int): Sampling frequency (in number of steps) in lookback
window.
(default: 1)
precision (int or str, optional): The float precision to store the data.
Can be expressed as number (16, 32, or 64) or string ("half",
"full", "double").
(default: 32)
name (str, optional): The (optional) name of the dataset.
"""
def __init__(self,
target: DataArray,
eval_mask: DataArray,
index: Optional[TemporalIndex] = None,
mask: Optional[DataArray] = None,
connectivity: Optional[Union[SparseTensArray,
Tuple[DataArray]]] = None,
covariates: Optional[Mapping[str, DataArray]] = None,
input_map: Optional[Union[Mapping, BatchMap]] = None,
target_map: Optional[Union[Mapping, BatchMap]] = None,
auxiliary_map: Optional[Union[Mapping, BatchMap]] = None,
scalers: Optional[Mapping[str, Scaler]] = None,
trend: Optional[DataArray] = None,
transform: Optional[Callable] = None,
window: int = 12,
stride: int = 1,
window_lag: int = 1,
precision: Union[int, str] = 32,
name: Optional[str] = None):
horizon = window
delay = -window
horizon_lag = window_lag
super(ImputationDataset, self).__init__(target,
index=index,
mask=None,
connectivity=connectivity,
covariates=covariates,
input_map=input_map,
target_map=target_map,
auxiliary_map=auxiliary_map,
trend=trend,
transform=transform,
scalers=scalers,
window=window,
horizon=horizon,
delay=delay,
stride=stride,
window_lag=window_lag,
horizon_lag=horizon_lag,
precision=precision,
name=name)
# add eval_mask as covariate
self.add_covariate(
name='eval_mask',
value=eval_mask,
pattern='t n f',
add_to_input_map=False, # NB
synch_mode=HORIZON,
preprocess=False)
# add eval_mask to auxiliary map
self.auxiliary_map['eval_mask'] = BatchMapItem('eval_mask',
synch_mode=HORIZON,
pattern='t n f',
preprocess=False)
# ensure evaluation datapoints are removed from input
if mask is None:
mask = ~torch.isnan(self.target)
mask = torch.logical_not(self.eval_mask) & mask
# set mask and add to input map
self.set_mask(mask, add_to_input_map=True)
def reset_auxiliary_map(self):
self._clear_batch_map('auxiliary')
self.auxiliary_map['eval_mask'] = BatchMapItem('eval_mask',
synch_mode=HORIZON,
pattern='t n f',
preprocess=False)
def reset_input_map(self):
super().reset_input_map()
if self.mask is not None:
self.input_map['mask'] = BatchMapItem('mask',
synch_mode=WINDOW,
pattern='t n f',
preprocess=False)
[docs] def set_mask(self,
mask: Optional[DataArray],
add_to_input_map: bool = True):
super().set_mask(mask, add_to_auxiliary_map=False)
if mask is not None and add_to_input_map:
self.input_map['mask'] = BatchMapItem('mask',
synch_mode=WINDOW,
pattern='t n f',
preprocess=False,
shape=self.mask.shape)