Source code for tsl.data.spatiotemporal_dataset

from contextlib import contextmanager
from copy import deepcopy
from typing import (Callable, Iterable, List, Literal, Mapping, Optional,
                    Tuple, Union)

import numpy as np
import pandas as pd
import torch
from pandas import DatetimeIndex
from torch import Tensor
from torch.utils.data import Dataset
from torch_geometric.typing import Adj
from torch_sparse import SparseTensor

from tsl.typing import (DataArray, IndexSlice, SparseTensArray, TemporalIndex,
                        TensArray)

from ..ops.connectivity import reduce_graph
from ..ops.pattern import broadcast, check_pattern, outer_pattern, take
from ..utils.casting import parse_index
from . import StaticBatch
from .batch_map import BatchMap, BatchMapItem
from .data import Data
from .mixin import DataParsingMixin
from .preprocessing.scalers import Scaler, ScalerModule
from .synch_mode import HORIZON, STATIC, WINDOW, SynchMode

_WINDOWING_ = {
    'window': ['window', 'window_lag'],
    'horizon': ['window', 'delay', 'horizon', 'horizon_lag'],
    'indices': ['target', 'window', 'delay', 'horizon', 'stride'],
    'all': [
        'target', 'window', 'delay', 'horizon', 'stride', 'horizon_lag',
        'window_lag'
    ]
}


[docs]class SpatioTemporalDataset(Dataset, DataParsingMixin): r"""Base class for structures that are bridges between Datasets and Models. A :class:`SpatioTemporalDataset` takes as input a :class:`~tsl.datasets.prototypes.Dataset` and build a proper structure to feed deep models. Args: target (DataArray): Data relative to the primary channels. index (TemporalIndex, optional): Temporal indices for the data. (default: :obj:`None`) mask (DataArray, optional): Boolean mask denoting if signal in data is valid (1) or not (0). (default: :obj:`None`) connectivity (SparseTensArray, tuple, optional): The adjacency matrix defining nodes' relational information. It can be either a dense/sparse matrix :math:`\mathbf{A} \in \mathbb{R}^{N \times N}` or an (:obj:`edge_index` :math:`\in \mathbb{N}^{2 \times E}`, :obj:`edge_weight` :math:`\in \mathbb{R}^{E})` tuple. The input layout will be preserved (e.g., a sparse matrix will be stored as a :class:`torch_sparse.SparseTensor`). In any case, the connectivity will be stored in the attribute :obj:`edge_index`, and the weights will be eventually stored as :obj:`edge_weight`. (default: :obj:`None`) covariates (dict, optional): Dictionary of exogenous channels with label. An :obj:`exogenous` element is a temporal array with node- or graph-level channels which are covariates to the main signal. The temporal dimension must be equal to the temporal dimension of data, as well as the number of nodes if the exogenous is node-level. (default: :obj:`None`) input_map (BatchMap or dict, optional): Defines how data (i.e., the target and the covariates) are mapped to dataset sample input. Keys in the mapping are keys in both :obj:`item` and :obj:`item.input`, while values are :obj:`~tsl.data.new.BatchMapItem`. (default: :obj:`None`) target_map (BatchMap or dict, optional): Defines how data (i.e., the target and the covariates) are mapped to dataset sample target. Keys in the mapping are keys in both :obj:`item` and :obj:`item.target`, while values are :obj:`~tsl.data.new.BatchMapItem`. (default: :obj:`None`) auxiliary_map (BatchMap or dict, optional): Defines how data (i.e., the target and the covariates) are added as additional attributes to the dataset sample. Keys in the mapping are keys only in :obj:`item`, while values are :obj:`~tsl.data.new.BatchMapItem`. (default: :obj:`None`) scalers (Mapping or None): Dictionary of scalers that must be used for data preprocessing. (default: :obj:`None`) trend (DataArray, optional): Trend paired with main signal. Must be of the same shape of `data`. (default: :obj:`None`) transform (callable, optional): A function/transform that takes in a :class:`tsl.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) window (int): Length (in number of steps) of the lookback window. (default: 12) horizon (int): Length (in number of steps) of the prediction horizon. (default: 1) delay (int): Offset (in number of steps) between end of window and start of horizon. (default: 0) stride (int): Offset (in number of steps) between a sample and the next one. (default: 1) window_lag (int): Sampling frequency (in number of steps) in lookback window. (default: 1) horizon_lag (int): Sampling frequency (in number of steps) in prediction horizon. (default: 1) precision (int or str, optional): The float precision to store the data. Can be expressed as number (16, 32, or 64) or string ("half", "full", "double"). (default: 32) name (str, optional): The (optional) name of the dataset. """ def __init__(self, target: DataArray, index: Optional[TemporalIndex] = None, mask: Optional[DataArray] = None, connectivity: Optional[Union[SparseTensArray, Tuple[DataArray]]] = None, covariates: Optional[Mapping[str, DataArray]] = None, input_map: Optional[Union[Mapping, BatchMap]] = None, target_map: Optional[Union[Mapping, BatchMap]] = None, auxiliary_map: Optional[Union[Mapping, BatchMap]] = None, scalers: Optional[Mapping[str, Scaler]] = None, trend: Optional[DataArray] = None, transform: Optional[Callable] = None, window: int = 12, horizon: int = 1, delay: int = 0, stride: int = 1, window_lag: int = 1, horizon_lag: int = 1, precision: Union[int, str] = 32, name: Optional[str] = None): super(SpatioTemporalDataset, self).__init__() # Set info self.name = name if name is not None else self.__class__.__name__ self.precision = precision # Store windowing information self.window = window self.delay = delay self.horizon = horizon self.stride = stride self.window_lag = window_lag self.horizon_lag = horizon_lag # Initialize private data holders self._covariates = dict() self._indices = None # Initialize batch maps self.input_map = BatchMap() self.target_map = BatchMap() self.auxiliary_map = BatchMap() # Store preprocessing modules self.scalers: dict = dict() self.trend: Optional[Tensor] = None self.transform = transform # cache scalers for composed batch items self._batch_scalers: dict = dict() # handle trend as bias in target scaler, so cache actual one (if any) # to restore after trend update/deletion self.__target_bias: Optional[Tensor] = None # Store time information # if index is not explicitly passed and data is a dataframe, defaults to # data's index if index is None and isinstance(target, pd.DataFrame): if isinstance(target.index, DatetimeIndex): index = target.index self.index = index # Set dataset's target signals self.target: Tensor = self._parse_target(target) # Store mask self.mask: Optional[Tensor] = None self.set_mask(mask) # Store adj self.edge_index: Optional[Adj] = None self.edge_weight: Optional[Tensor] = None self.set_connectivity(connectivity) # Store covariates (e.g., exogenous and attributes) self.reset_input_map() if covariates is not None: for name, value in covariates.items(): self.add_covariate(name, **self._value_to_kwargs(value)) # Updated input map (i.e., how to map data, exogenous and attribute # inside item.input) if input_map is not None: self.set_input_map(input_map) # Updated target map (i.e., how to map data, exogenous and attribute # inside item.target) self.reset_target_map() if target_map is not None: self.set_target_map(target_map) # Updated auxiliary map (i.e., how to map data, exogenous and attribute # inside item) if auxiliary_map is not None: self.set_auxiliary_map(auxiliary_map) # A scaler is a module that transforms data with a linear operation if scalers is not None: for k, v in scalers.items(): self.add_scaler(k, v) # Target's trend (i.e., 't n f' tensor to be removed when target # is preprocessed) if trend is not None: self.set_trend(trend) def __repr__(self): return "{}(n_samples={}, n_nodes={}, n_channels={})" \ .format(self.name, self.n_samples, self.n_nodes, self.n_channels) def __getitem__(self, item) -> Data: if isinstance(item, int) and item < 0: # compute item's actual index item = self._indices.size(0) + item else: # convert slice to indexes item = parse_index(item, length=self.n_samples, layout='index') item = self.get(item) if self.transform is not None: item = self.transform(item) return item def __contains__(self, item) -> bool: return item in self.keys def __len__(self) -> int: return self.n_samples def __getattr__(self, item): if '_covariates' in self.__dict__ and item in self._covariates: return self._covariates[item]['value'] if 'input_map' in self.__dict__ and item in self.input_map: return self.collate_item_elem(item)[0] raise AttributeError(f"'{self.__class__.__name__}' object has no " f"attribute '{item}'") def __setattr__(self, key, value): super(SpatioTemporalDataset, self).__setattr__(key, value) if key in _WINDOWING_['window'] and \ all([hasattr(self, attr) for attr in _WINDOWING_['window']]): self._window_range = torch.arange(0, self.window, self.window_lag) if key in _WINDOWING_['horizon'] and \ all([hasattr(self, attr) for attr in _WINDOWING_['horizon']]): self._horizon_range = torch.arange( self.horizon_offset, self.horizon_offset + self.horizon, self.horizon_lag) if key in _WINDOWING_['indices'] and \ all([hasattr(self, attr) for attr in _WINDOWING_['indices']]): self._indices = torch.arange(0, self.n_steps - self.sample_span + 1, self.stride) def __delattr__(self, item): if item in _WINDOWING_['all']: raise AttributeError(f"Cannot delete attribute '{item}'.") elif item == 'mask': self.set_mask(None) elif item == 'trend': self.set_trend(None) elif item in self._covariates: self.remove_covariate(item) else: super(SpatioTemporalDataset, self).__delattr__(item) # Dataset properties ###################################################### @property def n_steps(self) -> int: """Total number of time steps in the dataset.""" return self.target.size(0) @property def n_nodes(self) -> int: """Number of nodes in the dataset.""" return self.target.size(1) @property def n_channels(self) -> int: """Number of channels in dataset's target.""" return self.target.size(-1) @property def n_edges(self) -> Optional[int]: """Number of edges in the dataset, if a connectivity is set.""" if isinstance(self.edge_index, Tensor): return self.edge_index.size(1) elif isinstance(self.edge_index, SparseTensor): return self.edge_index.numel() return None @property def shape(self) -> tuple: """Shape of the target tensor.""" return tuple(self.target.size()) @property def patterns(self) -> dict: """Shows the dimension of dataset's tensors in a more informative way. The pattern mapping can be useful to glimpse on how data are arranged. The convention we use is the following: * 't' stands for “number of time steps” * 'n' stands for “number of nodes” * 'f' stands for “number of features” (per node) * 'e' stands for “number edges” """ patterns = {'target': 't n f'} # add mask pattern if self.mask is not None: patterns['mask'] = 't n f' # add connectivity patterns if self.edge_index is not None: patterns['edge_index'] = '2 e' if isinstance( self.edge_index, Tensor) else 'n n' if self.edge_weight is not None: patterns['edge_weight'] = 'e' # add covariates patterns patterns.update( {name: attr['pattern'] for name, attr in self._covariates.items()}) return patterns @property def batch_patterns(self) -> dict: """Shows the dimension of dataset's tensors in a more informative way. The pattern mapping can be useful to glimpse on how data are arranged. The convention we use is the following: * 't' stands for “number of time steps” * 'n' stands for “number of nodes” * 'f' stands for “number of features” (per node) * 'e' stands for “number edges” """ # add input map patterns patterns = { name: attr.pattern for name, attr in self.input_map.items() } # add mask pattern if self.mask is not None: patterns['mask'] = 't n f' # add connectivity patterns if self.edge_index is not None: patterns['edge_index'] = '2 e' if isinstance( self.edge_index, Tensor) else 'n n' if self.edge_weight is not None: patterns['edge_weight'] = 'e' # add target map patterns patterns.update( {name: attr.pattern for name, attr in self.target_map.items()}) return patterns @property def keys(self) -> list: """Keys in dataset.""" return list(self.patterns.keys()) @property def batch_keys(self) -> list: """Keys in dataset item.""" return list(self.batch_patterns.keys()) # Indexing @property def horizon_offset(self) -> int: """Lag of starting step of horizon.""" return self.window + self.delay @property def sample_span(self) -> int: """Total number of steps of an item, including window and horizon.""" return max(self.horizon_offset + self.horizon, self.window) @property def samples_offset(self) -> int: """Difference (in number of steps) between two items.""" return int(np.ceil(self.window / self.stride)) @property def indices(self) -> Tensor: """Indices of the dataset. The :obj:`i`-th item is mapped to :obj:`indices[i]`""" return self._indices @property def n_samples(self) -> int: """Number of samples (i.e., items) in the dataset.""" return len(self._indices) # Covariates properties @property def covariates(self) -> dict: """Mapping of dataset's covariates.""" return {name: attr['value'] for name, attr in self._covariates.items()} @property def exogenous(self) -> dict: """Time-varying covariates of the dataset's target.""" return { name: attr['value'] for name, attr in self._covariates.items() if 't' in attr['pattern'] } @property def attributes(self) -> dict: """Static features related to the dataset.""" return { name: attr['value'] for name, attr in self._covariates.items() if 't' not in attr['pattern'] } @property def n_covariates(self) -> int: """Number of covariates in the dataset.""" return len(self._covariates) # flags @property def has_connectivity(self) -> bool: """Whether the dataset has a connectivity.""" return self.edge_index is not None @property def has_mask(self) -> bool: """Whether the dataset has a mask denoting valid values in target.""" return self.mask is not None @property def has_covariates(self) -> bool: """Whether the dataset has covariates to the target tensor.""" return self.n_covariates > 0 # Map Dataset to item ##################################################### @property def targets(self) -> BatchMap: return self.target_map def reset_input_map(self): self._clear_batch_map('input') self.input_map['x'] = BatchMapItem('target', SynchMode.WINDOW, preprocess=True, cat_dim=None, pattern='t n f', shape=self.shape) for name, attr in self._covariates.items(): self.input_map[name] = BatchMapItem(name, SynchMode.WINDOW, preprocess=True, cat_dim=None, pattern=attr['pattern'], shape=attr['value'].size()) def reset_target_map(self): self._clear_batch_map('target') self.target_map['y'] = BatchMapItem('target', SynchMode.HORIZON, preprocess=False, cat_dim=None, pattern='t n f', shape=self.shape) def reset_auxiliary_map(self): self._clear_batch_map('auxiliary') if self.mask is not None: self.auxiliary_map['mask'] = BatchMapItem('mask', SynchMode.HORIZON, preprocess=False, cat_dim=None, pattern='t n f', shape=self.mask.shape) def set_input_map(self, input_map=None, **kwargs): self._clear_batch_map('input') self.update_input_map(input_map, **kwargs) def set_target_map(self, target_map=None, **kwargs): self._clear_batch_map('target') self.update_target_map(target_map, **kwargs) def set_auxiliary_map(self, auxiliary_map=None, **kwargs): self._clear_batch_map('auxiliary') self.update_auxiliary_map(auxiliary_map, **kwargs) def update_input_map(self, input_map=None, **kwargs): self._update_batch_map('input', input_map, **kwargs) def update_target_map(self, target_map=None, **kwargs): self._update_batch_map('target', target_map, **kwargs) def update_auxiliary_map(self, auxiliary_map=None, **kwargs): self._update_batch_map('auxiliary', auxiliary_map, **kwargs) def _clear_batch_map(self, endpoint): batch_map = getattr(self, f"{endpoint}_map") for key in batch_map: if key in self._batch_scalers: del self._batch_scalers[key] setattr(self, f"{endpoint}_map", BatchMap()) def _update_batch_map(self, endpoint, batch_map=None, **kwargs): # check batch map if batch_map is None: batch_map = BatchMap() elif not isinstance(batch_map, Mapping): raise TypeError(f"Type {type(batch_map)} is not valid for " f"`{endpoint}_map` (must be a mapping).") # update from kwargs batch_map.update(**kwargs) # update endpoint_map endpoint_map = getattr(self, f"{endpoint}_map") endpoint_map.update(**batch_map) # update 'pattern' and 'shape' to added/updated keys for name in batch_map: item: BatchMapItem = endpoint_map[name] # keys sanity check and compute pattern and shape keys = item.keys if len(keys) > 1: tensor, scaler, pattern = self.collate_keys( keys, cat_dim=item.cat_dim, return_pattern=True) item.shape, item.pattern = tuple(tensor.size()), pattern if scaler is not None: self._batch_scalers[name] = scaler else: item.shape = tuple(getattr(self, keys[0]).size()) item.pattern = self.patterns[keys[0]] # Getters ################################################################# def get(self, item): # check if item is scalar or vector ndim = item.ndim if isinstance(item, Tensor) else 0 if ndim == 0: # get a single item sample = Data(pattern=self.batch_patterns) elif ndim == 1: # get batch of items pattern = { name: ('b ' + pattern) if 't' in pattern else pattern for name, pattern in self.batch_patterns.items() } sample = StaticBatch(pattern=pattern, size=item.size(0)) else: raise RuntimeError(f"Too many dimensions for index ({ndim}).") # get input synchronized with window if self.window > 0: wdw_idxs = self.get_window_indices(item) self._add_to_sample(sample, WINDOW, 'input', time_index=wdw_idxs) self._add_to_sample(sample, WINDOW, 'target', time_index=wdw_idxs) self._add_to_sample(sample, WINDOW, 'auxiliary', time_index=wdw_idxs) # get input synchronized with horizon hrz_idxs = self.get_horizon_indices(item) self._add_to_sample(sample, HORIZON, 'input', time_index=hrz_idxs) self._add_to_sample(sample, HORIZON, 'target', time_index=hrz_idxs) self._add_to_sample(sample, HORIZON, 'auxiliary', time_index=hrz_idxs) # get static data self._add_to_sample(sample, STATIC, 'input') self._add_to_sample(sample, STATIC, 'target') self._add_to_sample(sample, STATIC, 'auxiliary') # get connectivity if self.edge_index is not None: sample.input['edge_index'] = self.edge_index if self.edge_weight is not None: sample.input['edge_weight'] = self.edge_weight return sample def expand_scaler(self, key: str, pattern: Optional[str] = None, time_index: Union[List, Tensor] = None, node_index: Union[List, Tensor] = None) \ -> Optional[ScalerModule]: # check if there is a scaler if key not in self.keys: raise KeyError(f"{key} not in {self.name}.") elif key not in self.scalers: return None # convert indices time_index = self._get_time_index(time_index, layout='index') node_index = self._get_time_index(node_index, layout='index') # get params if pattern is None: return self.scalers[key] # if there is an out-pattern, create new scaler scaler = ScalerModule(self.scalers[key], pattern=pattern) pattern = self.patterns[key] + ' -> ' + pattern scaler.bias = broadcast(scaler.bias, pattern, backend=torch, time_index=time_index, node_index=node_index) scaler.scale = broadcast(scaler.scale, pattern, backend=torch, time_index=time_index, node_index=node_index) return scaler def expand_tensor(self, key: str, pattern: str, time_index: Union[List, Tensor] = None, node_index: Union[List, Tensor] = None): x = getattr(self, key) pattern = self.patterns[key] + ' -> ' + pattern x = broadcast(x, pattern, t=self.n_steps, n=self.n_nodes, backend=torch, time_index=time_index, node_index=node_index) return x def get_tensor(self, key: str, preprocess: bool = False, time_index: Union[List, Tensor] = None, node_index: Union[List, Tensor] = None) \ -> Tuple[Tensor, Optional[ScalerModule]]: # get dataset item if key not in self.keys: raise KeyError(f"{key} not in dataset {self.name}.") # convert indices time_index = self._get_time_index(time_index, layout='index') node_index = self._get_time_index(node_index, layout='index') x = take(getattr(self, key), self.patterns[key], backend=torch, time_index=time_index, node_index=node_index) # get scaler (if any) scaler = None if key in self.scalers is not None: scaler = self.scalers[key].slice(time_index=time_index, node_index=node_index) if preprocess: # transform tensor x = scaler.transform(x) return x, scaler def collate_item_elem(self, key: str, time_index: Union[List, Tensor] = None, node_index: Union[List, Tensor] = None) \ -> Tuple[Tensor, Optional[ScalerModule]]: # get batch item if key in self.input_map: itm = self.input_map[key] elif key in self.target_map: itm = self.target_map[key] else: raise KeyError(f"{key} not in any batch map of {self.name}.") # expand and concatenate tensors x = torch.cat([ self.expand_tensor(k, itm.pattern, time_index, node_index) for k in itm.keys ], dim=itm.cat_dim) # get scaler (if any) scaler = None if key in self._batch_scalers: scaler = self._batch_scalers[key].slice(time_index=time_index, node_index=node_index) if itm.preprocess: # transform tensor x = scaler.transform(x) return x, scaler def collate_keys(self, keys: Iterable, preprocess: bool = False, time_index: Union[List, Tensor] = None, node_index: Union[List, Tensor] = None, cat_dim: Optional[int] = None, return_pattern: bool = False): if any([key not in self.keys for key in keys]): unmatch = set(keys).difference(self.keys) raise KeyError(f"{unmatch} not in {self.name}.") pattern = outer_pattern([self.patterns[key] for key in keys]) tensors, scalers = list(), list() for key in keys: tensor = self.expand_tensor(key, pattern, time_index, node_index) scaler = self.expand_scaler(key, pattern, time_index, node_index) if preprocess and scaler is not None: tensor = scaler(tensor) tensors.append(tensor) scalers.append(scaler) if len(tensors) == 1: if return_pattern: return tensors[0], scalers[0], pattern return tensors[0], scalers[0] if cat_dim is not None: scalers = ScalerModule.cat(scalers, dim=cat_dim, sizes=[t.size() for t in tensors]) tensors = torch.cat(tensors, dim=cat_dim) if return_pattern: return tensors, scalers, pattern return tensors, scalers def get_mask(self, dtype: Union[type, str, np.dtype] = None) -> Tensor: mask = self.mask if self.has_mask else ~torch.isnan(self.target) if dtype is not None: assert dtype in ['bool', 'uint8', bool, torch.bool, torch.uint8] mask = mask.to(dtype) return mask # Getters helpers def _get_time_index(self, time_index: IndexSlice = None, layout: Literal['index', 'slice', 'mask'] = 'index'): return parse_index(time_index, length=self.n_steps, layout=layout) def _get_node_index(self, node_index: IndexSlice = None, layout: Literal['index', 'slice', 'mask'] = 'index'): return parse_index(node_index, length=self.n_nodes, layout=layout) def _add_to_sample(self, out, synch_mode, endpoint='input', time_index=None, node_index=None): batch_map = getattr(self, f"{endpoint}_map") for key, item in batch_map.by_synch_mode(synch_mode).items(): if len(item.keys) > 1: tensor, scaler = self.collate_item_elem(key, time_index=time_index, node_index=node_index) else: tensor, scaler = self.get_tensor(item.keys[0], preprocess=item.preprocess, time_index=time_index, node_index=node_index) if endpoint == 'auxiliary': out[key] = tensor else: getattr(out, endpoint)[key] = tensor if scaler is not None: out.transform[key] = scaler # Setters #################################################################
[docs] def set_data(self, data: DataArray): r"""Set sequence of primary channels at :obj:`self.data`.""" self.target = self._parse_target(data)
[docs] def set_mask(self, mask: Optional[DataArray], add_to_auxiliary_map: bool = True): r"""Set mask of target channels, i.e., a bool for each (node, time step, channel) triplet denoting if corresponding value in target is observed (obj:`True`) or not (obj:`False`).""" if mask is not None: mask = self._parse_target(mask) if mask.dtype not in [torch.bool, torch.uint8]: raise RuntimeError("Mask is not boolean, accepted types are " f"{[torch.bool, torch.uint8]}.") # check mask length is equal to target's length self._check_same_dim(mask.size(0), 'n_steps', 'mask') # check mask nodes/features are broadcastable to # target's nodes/features self._check_same_dim(mask.size(1), 'n_nodes', 'mask', allow_broadcasting=True) self._check_same_dim(mask.size(-1), 'n_channels', 'mask', allow_broadcasting=True) if add_to_auxiliary_map: self.auxiliary_map['mask'] = BatchMapItem('mask', SynchMode.HORIZON, preprocess=False, cat_dim=None, pattern='t n f', shape=mask.shape) self.mask = mask
[docs] def set_connectivity(self, connectivity: Union[SparseTensArray, Tuple[DataArray]], target_layout: Optional[str] = None): r"""Set dataset connectivity. The input can be either a dense/sparse matrix :math:`\mathbf{A} \in \mathbb{R}^{N \times N}` or an (:obj:`edge_index` :math:`\in \mathbb{N}^{2 \times E}`, :obj:`edge_weight` :math:`\in \mathbb{R}^{E})` tuple. If :obj:`target_layout` is :obj:`None`, the input layout will be preserved (e.g., a sparse matrix will be stored as a :class:`torch_sparse.SparseTensor`), otherwise the connectivity is converted to the specified layout. In any case, the connectivity will be stored in the attribute :obj:`edge_index`, and the weights will be eventually stored as :obj:`edge_weight`. Args: connectivity (SparseTensArray, tuple, optional): The connectivity target_layout (str, optional): If specified, the input connectivity is converted to this layout. Possible options are [dense, sparse, edge_index]. If :obj:`None`, the target layout is inferred from the input. (default: :obj:`None`) """ self.edge_index, self.edge_weight = self._parse_connectivity( connectivity, target_layout)
# Setter for covariates
[docs] def add_covariate(self, name: str, value: DataArray, pattern: Optional[str] = None, add_to_input_map: bool = True, synch_mode: Optional[SynchMode] = None, preprocess: bool = True, convert_precision: bool = True): r"""Add covariate to the dataset. Examples of covariate are exogenous signals (in the form of dynamic multidimensional data) or static attributes (e.g., graph/node metadata). Parameter :obj:`pattern` specifies what each axis refers to: - 't': temporal dimension; - 'n': node dimension; - 'c'/'f': channels/features dimension. For instance, the pattern of a node-level covariate is 't n f', while a pairwise metric between nodes has pattern 'n n'. Args: name (str): the name of the object. You can then access the added object as :obj:`dataset.{name}`. value (DataArray): the object to be added. Can be a :class:`~pandas.DataFrame`, a :class:`~numpy.ndarray` or a :class:`~torch.Tensor`. pattern (str, optional): the pattern of the object. A pattern specifies what each axis refers to: - 't': temporal dimension; - 'n': node dimension; - 'c'/'f': channels/features dimension. If :obj:`None`, the pattern is inferred from the shape. (default :obj:`None`) add_to_input_map (bool): Whether to map the covariate to dataset item when calling :obj:`get` methods. (default: :obj:`True`) synch_mode (SynchMode): How to synchronize the exogenous variable inside dataset item, i.e., with the window slice (:obj:`SynchMode.WINDOW`) or horizon (:obj:`SynchMode.HORIZON`). It applies only for time-varying covariates. (default: :obj:`SynchMode.WINDOW`) preprocess (bool): If :obj:`True` and the dataset has a scaler with same key, then data are scaled when calling :obj:`get` methods. (default: :obj:`True`) convert_precision (bool): If :obj:`True`, then cast :attr:`value` with dataset's precision. (default: :obj:`True`) """ # validate name. name cannot be an attribute of self, but allow override self._check_name(name) value, pattern = self._parse_covariate( value, pattern, name=name, convert_precision=convert_precision) self._covariates[name] = dict(value=value, pattern=pattern) if add_to_input_map: self.input_map[name] = BatchMapItem(name, synch_mode, preprocess, cat_dim=None, pattern=pattern, shape=value.size())
[docs] def update_covariate(self, name: str, value: Optional[DataArray] = None, pattern: Optional[str] = None, add_to_input_map: bool = True, synch_mode: Optional[SynchMode] = None, preprocess: bool = None): r"""Update a covariate already in the dataset. Args: name (str): the name of the object. You can then access the added object as :obj:`dataset.{name}`. value (DataArray, optional): the object to be added. Can be a :class:`~pandas.DataFrame`, a :class:`~numpy.ndarray` or a :class:`~torch.Tensor`. pattern (str, optional): the pattern of the object. A pattern specifies what each axis refers to: - 't': temporal dimension; - 'n': node dimension; - 'c'/'f': channels/features dimension. If :obj:`None`, the pattern is inferred from the shape. (default :obj:`None`) add_to_input_map (bool): Whether to map the covariate to dataset item when calling :obj:`get` methods. (default: :obj:`True`) synch_mode (SynchMode): How to synchronize the exogenous variable inside dataset item, i.e., with the window slice (:obj:`SynchMode.WINDOW`) or horizon (:obj:`SynchMode.HORIZON`). It applies only for time-varying covariates. (default: :obj:`SynchMode.WINDOW`) preprocess (bool): If :obj:`True` and the dataset has a scaler with same key, then data are scaled when calling :obj:`get` methods. (default: :obj:`True`) """ # validate name. name cannot be an attribute of self, but allow override if name not in self._covariates: raise RuntimeError(f"There is not a covariate named {name} in " f"{self.__class__.__name__}") # override value (with or without explicit pattern) if value is not None: value, pattern = self._parse_covariate(value, pattern, name=name) # override pattern, rearranging the stored covariate elif pattern is not None: pattern = check_pattern(pattern, ndim=value.ndim) from einops import rearrange old_value = self._covariates[name]['value'] old_pattern = self._covariates[name]['pattern'] value = rearrange(old_value, f'{old_pattern} -> {pattern}') # update the covariate self._covariates[name] = dict(value=value, pattern=pattern) # update input map if add_to_input_map: if name in self.input_map: kwargs = dict(synch_mode=self.input_map[name].synch_mode, preprocess=self.input_map[name].preprocess) else: kwargs = dict(synch_mode=None, preprocess=True) if preprocess is not None: kwargs['preprocess'] = preprocess if synch_mode is not None: kwargs['synch_mode'] = synch_mode shape = tuple(self._covariates[name]['value'].size()) pattern = self._covariates[name]['pattern'] self.input_map[name] = BatchMapItem(name, **kwargs, cat_dim=None, pattern=pattern, shape=shape)
[docs] def remove_covariate(self, name: str): r"""Delete covariate from the dataset. Args: name (str): the name of the covariate to be deleted. """ try: # remove covariate del self._covariates[name] # remove associated scaler if name in self.scalers: del self.scalers[name] # ATTENTION! remove entirely map item with covariate in keys for _map in [self.input_map, self.target_map, self.auxiliary_map]: for _map_item in _map: if name in _map_item.keys: del _map[_map_item] except Exception as e: raise e
[docs] def add_exogenous(self, name: str, value: DataArray, node_level: bool = True, add_to_input_map: bool = True, synch_mode: SynchMode = WINDOW, preprocess: bool = True): r"""Shortcut method to add a time-varying covariate. Exogenous variables are time-varying covariates of the dataset's target. They can either be graph-level (i.e., with same temporal length as :obj:`target` but with no node dimension) or node-level (i.e., with same temporal and node size as :obj:`target`). Args: name (str): The name of the exogenous variable. If the name starts with :obj:`"global_"`, the variable is assumed to be graph-level (overriding parameter :obj:`node_level`), and the :obj:`"global_"` prefix is removed from the name. value (DataArray): The data sequence. Can be a :class:`~pandas.DataFrame`, a :class:`~numpy.ndarray` or a :class:`~torch.Tensor`. node_level (bool): Whether the input variable is node- or graph-level. If a 2-dimensional array is given and node-level is :obj:`True`, it is assumed that the input has one channel. (default: :obj:`True`) add_to_input_map (bool): Whether to map the exogenous variable to dataset item when calling :obj:`get` methods. (default: :obj:`True`) synch_mode (SynchMode): How to synchronize the exogenous variable inside dataset item, i.e., with the window slice (:obj:`SynchMode.WINDOW`) or horizon (:obj:`SynchMode.HORIZON`). (default: :obj:`SynchMode.WINDOW`) preprocess (bool): If :obj:`True` and the dataset has a scaler with same key, then data are scaled when calling :obj:`get` methods. (default: :obj:`True`) Returns: SpatioTemporalDataset: the dataset with added exogenous. """ if name.startswith('global_'): name = name[7:] node_level = False pattern = 't n f' if node_level else 't f' self.add_covariate(name, value, pattern, synch_mode=synch_mode, add_to_input_map=add_to_input_map, preprocess=preprocess)
# Setters for preprocessing
[docs] def set_trend(self, trend: Optional[DataArray]): r"""Set trend of dataset's target data.""" if trend is not None: trend = self._parse_target(trend) # check trend length is equal to target's length self._check_same_dim(trend.size(0), 'n_steps', 'mask') # check trend nodes/features are broadcastable to # target's nodes/features self._check_same_dim(trend.size(1), 'n_nodes', 'mask', allow_broadcasting=True) self._check_same_dim(trend.size(-1), 'n_channels', 'mask', allow_broadcasting=True) if 'target' in self.scalers: if self.__target_bias is None: self.__target_bias = self.scalers['target'].bias self.scalers['target'].bias = trend else: if 'target' in self.scalers: self.scalers['target'].bias = self.__target_bias # refresh maps self.update_input_map(self.input_map) self.update_target_map(self.target_map) self.trend = trend
[docs] def add_scaler(self, key: str, scaler: Union[Scaler, ScalerModule]): r"""Add a :class:`tsl.data.preprocessing.Scaler` for the object indexed by :obj:`key` in the dataset. Args: key (str): The name of the variable associated to the scaler. It must be a temporal variable, i.e., :obj:`data` or an exogenous. scaler (Scaler): The :class:`~tsl.data.preprocessing.Scaler`. """ if key not in self.keys: raise KeyError(f"{key} not in {self.name}.") # copy to ScalerModule scaler = ScalerModule(scaler) pattern = self.patterns[key] self._check_pattern(scaler.bias, pattern, name=f"scaler ({key})", allow_broadcasting=True) self._check_pattern(scaler.scale, pattern, name=f"scaler ({key})", allow_broadcasting=True) if key == 'target' and self.trend is not None: self.__target_bias = scaler.bias scaler.bias = scaler.bias + self.trend scaler.pattern = pattern self.scalers[key] = scaler # cache batch scaler if target tensor is in a multi-key batch item for bm in [self.input_map, self.target_map, self.auxiliary_map]: for bm_key, bm_item in bm.items(): if key in bm_item.keys and len(bm_item.keys) > 1: tensor, scaler = self.collate_keys(bm_item.keys, cat_dim=bm_item.cat_dim, return_pattern=False) self._batch_scalers[bm_key] = scaler
# Dataset trimming ########################################################
[docs] def reduce(self, time_index: Optional[IndexSlice] = None, node_index: Optional[IndexSlice] = None): """Reduce the dataset in terms of number of steps and nodes. Returns a copy of the reduced dataset. If dataset has a connectivity, edges ending to or starting from removed nodes will be removed as well. Args: time_index (IndexSlice, optional): index or mask of the nodes to keep after reduction. (default: :obj:`None`) node_index (IndexSlice, optional): index or mask of the nodes to keep after reduction. (default: :obj:`None`) """ return deepcopy(self).reduce_(time_index, node_index)
[docs] def reduce_(self, time_index: Optional[IndexSlice] = None, node_index: Optional[IndexSlice] = None): """Reduce the dataset in terms of number of steps and nodes. This is an inplace operation. If dataset has a connectivity, edges ending to or starting from removed nodes will be removed as well. Args: time_index (IndexSlice, optional): index or mask of the nodes to keep after reduction. (default: :obj:`None`) node_index (IndexSlice, optional): index or mask of the nodes to keep after reduction. (default: :obj:`None`) """ # use slice to reduce known tensor time_slice = self._get_time_index(time_index, layout='slice') node_slice = self._get_node_index(node_index, layout='slice') # use index to reduce using index-fed functions time_index = self._get_time_index(time_index, layout='index') node_index = self._get_node_index(node_index, layout='index') try: if self.edge_index is not None and node_index is not None: self.edge_index, self.edge_weight = reduce_graph( node_index, self.edge_index, self.edge_weight, num_nodes=self.n_nodes) self.target = self.target[time_slice, node_slice] if self.index is not None: self.index = self.index[time_index.numpy()] if self.mask is not None: self.mask = self.mask[time_slice, node_slice] if self.trend is not None: self.trend = self.trend[time_slice, node_slice] for name, attr in self._covariates.items(): x, scaler = self.get_tensor(name, time_index=time_index, node_index=node_index) attr['value'] = x if scaler is not None: self.scalers[name] = scaler except Exception as e: raise e return self
@contextmanager def change_windowing(self, **kwargs): default, indices = dict(), self._indices try: assert all([k in _WINDOWING_['all'][1:] for k in kwargs]) for k, v in kwargs.items(): default[k] = getattr(self, k) setattr(self, k, v) yield self finally: for k, v in default.items(): setattr(self, k, v) self._indices = indices # Indexing ################################################################ def get_window_indices(self, item): idx = self._indices[item] if idx.dim(): # idx is list of indices return idx[:, None] + self._window_range[None] else: # idx is one index return idx + self._window_range def get_horizon_indices(self, item): idx = self._indices[item] if idx.dim(): # idx is list of indices return idx[:, None] + self._horizon_range[None] else: # idx is one index return idx + self._horizon_range def set_indices(self, indices: TensArray): indices = torch.as_tensor(indices, dtype=torch.long) max_index = self.n_steps - self.sample_span assert all((indices >= 0) & (indices <= max_index)), \ f"indices must be in the range [0, {max_index}] for {self.name}." self._indices = indices def expand_indices(self, indices=None, unique=False, merge=False): indices = torch.arange(self.n_samples) if indices is None else indices ds_indices = dict() if self.window > 0: ds_indices['window'] = self.get_window_indices(indices) ds_indices['horizon'] = self.get_horizon_indices(indices) if merge: return torch.unique( torch.cat( [v.contiguous().view(-1) for v in ds_indices.values()])) if unique: ds_indices = {k: torch.unique(v) for k, v in ds_indices.items()} return ds_indices def overlapping_indices(self, idxs1, idxs2, synch_mode: Union[SynchMode, str] = WINDOW, as_mask=False): if isinstance(synch_mode, SynchMode): synch_mode = synch_mode.name idxs1, idxs2 = np.asarray(idxs1), np.asarray(idxs2) ts1 = self.expand_indices(idxs1)[synch_mode.lower()].numpy() ts2 = self.expand_indices(idxs2)[synch_mode.lower()].numpy() common_ts = np.intersect1d(ts1, ts2) def is_overlapping(sample): return np.any(np.in1d(sample, common_ts)) m1 = np.apply_along_axis(is_overlapping, 1, ts1) m2 = np.apply_along_axis(is_overlapping, 1, ts2) if as_mask: return m1, m2 return idxs1[m1], idxs2[m2] def data_timestamps(self, indices=None, unique=False) -> Optional[dict]: if self.index is None: return None ds_indices = self.expand_indices(indices, unique=unique) index = self.index if unique else self.index.to_numpy() ds_timestamps = {k: index[v.numpy()] for k, v in ds_indices.items()} return ds_timestamps # Representation ########################################################## def numpy(self): return np.asarray(self.target) def dataframe(self): columns = pd.MultiIndex.from_product( [range(self.n_nodes), range(self.n_channels)], names=['nodes', 'channels']) data = self.numpy().reshape((-1, self.n_nodes * self.n_channels)) return pd.DataFrame(data=data, index=self.index, columns=columns) # Utilities ###############################################################
[docs] def save(self, filename: str) -> None: """Save :obj:`SpatioTemporalDataset` to disk. Args: filename (str): path to filename for storage. """ torch.save(self, filename)
[docs] @classmethod def load(cls, filename: str) -> "SpatioTemporalDataset": """Load instance of :obj:`SpatioTemporalDataset` from disk. Args: filename (str): path of :obj:`SpatioTemporalDataset`. """ obj = torch.load(filename) if not isinstance(obj, cls): raise TypeError(f"Loaded file is not of class {cls}.") return obj
[docs] @classmethod def from_dataset(cls, dataset, connectivity: Optional[Union[SparseTensArray, Tuple[DataArray]]] = None, covariate_keys: List[str] = None, input_map: Optional[Union[Mapping, BatchMap]] = None, target_map: Optional[Union[Mapping, BatchMap]] = None, auxiliary_map: Optional[Union[Mapping, BatchMap]] = None, scalers: Optional[Mapping[str, Scaler]] = None, trend: Optional[DataArray] = None, window: int = 12, horizon: int = 1, delay: int = 0, stride: int = 1, window_lag: int = 1, horizon_lag: int = 1) -> "SpatioTemporalDataset": """Create a :class:`~tsl.data.SpatioTemporalDataset` from a :class:`~tsl.datasets.prototypes.TabularDataset`. """ covariates = dataset._covariates if covariate_keys is not None: covariates = {k: v for k, v in covariates if k in covariate_keys} return cls(target=dataset.target, index=dataset.index, mask=dataset.mask, covariates=covariates, name=dataset.name, precision=dataset.precision, connectivity=connectivity, input_map=input_map, target_map=target_map, auxiliary_map=auxiliary_map, scalers=scalers, trend=trend, window=window, horizon=horizon, delay=delay, stride=stride, window_lag=window_lag, horizon_lag=horizon_lag)