Source code for tsl.data.batch

from typing import List, Mapping, Optional, Sequence, Union

import numpy as np
import torch
from torch import Tensor
from torch.utils.data.dataloader import default_collate
from torch_geometric.data import Batch
from torch_geometric.data.collate import collate
from torch_geometric.data.separate import separate
from torch_geometric.typing import Adj

from ..typing import IndexSlice
from ..utils import ensure_list
from .data import Data
from .preprocessing import ScalerModule

# STATIC BATCH ################################################################


def static_scaler_collate(transform_list: List[Mapping[str, ScalerModule]]):
    transform = transform_list[0]
    for key, scaler in transform.items():
        params, pattern = scaler.params(), scaler.pattern
        # collate scalers only if time-varying, otherwise keep first one
        if scaler.t is not None:
            for p_name, param in params.items():
                # params can be broadcastable in some dimension
                if param.size(0) == scaler.t:
                    # batch time-varying param on new axis
                    params[p_name] = default_collate([
                        getattr(scl_dict[key], p_name)
                        for scl_dict in transform_list
                    ])
                else:
                    params[p_name] = param[None]  # unsqueeze first dimension
            # add batch dim in pattern
            pattern = 'b ' + pattern
        # create new scaler
        transform[key] = ScalerModule(**params, pattern=pattern)
    return transform


def get_static_scaler(transform: Mapping[str, ScalerModule], idx: int):
    out = dict()
    for key, scaler in transform.items():
        params, pattern = scaler.params(), scaler.pattern
        # index scalers only if time-varying
        if pattern.startswith('b'):
            for p_name, param in params.items():
                # params can have different shapes
                if param.size(0) > 1:  # check if this param has b > 1
                    # batch time-varying param on new axis
                    params[p_name] = param[idx]
                else:
                    params[p_name] = param[0]  # squeeze first dimension
            # remove batch dim from pattern
            pattern = pattern[2:]
        # create new scaler
        out[key] = ScalerModule(**params, pattern=pattern)
    return out


def static_graph_collate(data_list: List[Data],
                         cls: Optional[type] = None) -> Data:
    data_list = ensure_list(data_list)

    # collate all sample-wise elements
    elem = data_list[0]
    if cls is None:
        cls = elem.__class__
    out = cls()
    out = out.stores_as(elem)

    pattern = elem.pattern

    for key in elem.keys:
        if key == 'transform':
            out[key] = static_scaler_collate([data[key] for data in data_list])
        elif key in pattern:
            if 't' in pattern[key]:
                out[key] = default_collate([data[key] for data in data_list])
                out.pattern[key] = 'b ' + pattern[key]
            else:
                out[key] = elem[key]
        else:
            # add warning ?
            out[key] = default_collate([data[key] for data in data_list])

    out._batch_size = len(data_list)
    return out


[docs]class StaticBatch(Data): r"""A batch of :class:`tsl.data.Data` objects for multiple spatiotemporal graphs sharing the same topology. The batch object extends :class:`~tsl.data.Data`, thus preserving all its functionalities. Args: input (Mapping, optional): Named mapping of :class:`~torch.Tensor` to be used as input to the model. (default: :obj:`None`) target (Mapping, optional): Named mapping of :class:`~torch.Tensor` to be used as target of the task. (default: :obj:`None`) edge_index (Adj, optional): Shared graph connectivity, either in COO format (a :class:`~torch.Tensor` of shape :obj:`[2, E]`) or as a :class:`torch_sparse.SparseTensor` with shape :obj:`[N, N]`. (default: :obj:`None`) edge_weight (Tensor, optional): Weights of the edges (if :attr:`edge_index` is not a :class:`torch_sparse.SparseTensor`). (default: :obj:`None`) mask (Tensor, optional): The optional mask associated with the target. (default: :obj:`None`) transform (Mapping, optional): Named mapping of :class:`~tsl.data.preprocessing.Scaler` associated with entries in :attr:`input` or :attr:`output`. (default: :obj:`None`) pattern (Mapping, optional): Map of the pattern of each entry in :attr:`input` or :attr:`output`. (default: :obj:`None`) size (int, optional): The batch size, i.e., the number of spatiotemporal graphs in the batch. The different samples in the batch share all the same topology, such that there is (at most) only one :obj:`edge_index` and :obj:`edge_weight`. If :obj:`None`, then the batch size is inferred from data (if possible). (default: :obj:`None`) **kwargs: Any keyword argument for :class:`~torch_geometric.data.Data`. """ def __init__(self, input: Optional[Mapping] = None, target: Optional[Mapping] = None, edge_index: Optional[Adj] = None, edge_weight: Optional[Tensor] = None, mask: Optional[Tensor] = None, transform: Optional[Mapping] = None, pattern: Optional[Mapping] = None, size: Optional[int] = None, **kwargs): super(StaticBatch, self).__init__(input=input, target=target, edge_index=edge_index, edge_weight=edge_weight, mask=mask, transform=transform, pattern=pattern, **kwargs) self._batch_size = size
[docs] @classmethod def from_data_list(cls, data_list: List[Data]): r"""Constructs a :class:`~tsl.data.Batch` object from a Python list of :class:`~tsl.data.Data` representing temporal signals on a static (shared) graph.""" return static_graph_collate(data_list, cls)
[docs] def get_example(self, idx: int) -> Data: r"""Gets the :class:`~torch_geometric.data.Data` or :class:`~torch_geometric.data.HeteroData` object at index :obj:`idx`. The :class:`~torch_geometric.data.Batch` object must have been created via :meth:`from_data_list` in order to be able to reconstruct the initial object.""" out = Data() out.stores_as(self) for key, value in self: pattern = self.pattern.get(key) if key == 'transform': out.transform.update(get_static_scaler(self.transform, idx)) elif pattern is not None and pattern.startswith('b'): out[key] = value[idx] out.pattern[key] = pattern[2:] else: out[key] = value return out
[docs] def index_select(self, idx: IndexSlice) -> List[Data]: r"""Creates a subset of :class:`~tsl.data.Data` objects from specified indices :obj:`idx`. Indices :obj:`idx` can be a slicing object, *e.g.*, :obj:`[2:5]`, a list, a tuple, or a :obj:`torch.Tensor` or :obj:`np.ndarray` of type long or bool.""" if isinstance(idx, slice): idx = list(range(self._batch_size)[idx]) elif isinstance(idx, Tensor) and idx.dtype == torch.long: idx = idx.flatten().tolist() elif isinstance(idx, Tensor) and idx.dtype == torch.bool: idx = idx.flatten().nonzero(as_tuple=False).flatten().tolist() elif isinstance(idx, np.ndarray) and idx.dtype == np.int64: idx = idx.flatten().tolist() elif isinstance(idx, np.ndarray) and idx.dtype == bool: idx = idx.flatten().nonzero()[0].flatten().tolist() elif isinstance(idx, Sequence) and not isinstance(idx, str): pass else: raise IndexError( f"Only slices (':'), list, tuples, torch.tensor and " f"np.ndarray of dtype long or bool are valid indices (got " f"'{type(idx).__name__}')") return [self.get_example(i) for i in idx]
def __getitem__(self, idx: Union[int, np.integer, str, IndexSlice]): if (isinstance(idx, (int, np.integer)) or (isinstance(idx, Tensor) and idx.dim() == 0) or (isinstance(idx, np.ndarray) and np.isscalar(idx))): return self.get_example(idx) elif isinstance(idx, str) or (isinstance(idx, tuple) and isinstance(idx[0], str)): # Accessing attributes or node/edge types: return super().__getitem__(idx) else: return self.index_select(idx)
[docs] def to_data_list(self) -> List[Data]: r"""Reconstructs the list of :class:`~tsl.data.Data` objects from the :class:`~tsl.data.StaticBatch` object.""" return [self.get_example(i) for i in range(self.batch_size)]
@property def batch_size(self) -> int: """The batch size, i.e., the number of spatiotemporal graphs in the batch.""" if self._batch_size is not None: return self._batch_size if self.pattern is not None: for key, pattern in self.pattern.items(): if pattern.startswith('b'): return self[key].size(0) @property def num_graphs(self) -> int: return self.batch_size
# DISJOINT BATCH ############################################################## def collate_scaler_params(values: List[Tensor], cat_dim: Optional[int] = None, batch_index: Optional[Tensor] = None): elem = values[0] # Stack scaler params on new batch dimension. if cat_dim is None: value = torch.stack(values, dim=0) # Concatenate a list of `torch.Tensor` along the `cat_dim`. else: value = torch.cat(values, dim=cat_dim) if batch_index is not None and (cat_dim is None or elem.size(cat_dim) == 1): value = value.index_select(cat_dim or 0, batch_index) return value, True return value, False def separate_scaler_params(value: Tensor, slices, idx: int, is_repeated: bool, cat_dim: Optional[int] = None): if not is_repeated: start, end = slices[idx], slices[idx + 1] out = value.narrow(cat_dim or 0, start, end - start) else: out = value.index_select(cat_dim or 0, slices[idx]) return out def disjoint_scaler_collate(data_list: List[Data], batch_index: Optional[Tensor] = None, force_batch: bool = False): elem = data_list[0] transform = data_list[0].transform out = dict() for key, scaler in transform.items(): cat_dim = elem.__cat_dim__(key, elem[key]) # collate each param in the scaler params, rep_params = dict(), dict() for param in scaler.params(): param_list = [ getattr(item.transform[key], param) for item in data_list ] value, is_repeated = collate_scaler_params(param_list, cat_dim=cat_dim, batch_index=batch_index) params[param] = value rep_params[param] = is_repeated # set pattern of new collated scaler if scaler.pattern is not None: pattern = scaler.pattern else: pattern = elem.pattern.get(key) if pattern is not None: if cat_dim is None: pattern = 'n ' + pattern if 't' in pattern and force_batch: params = {k: param[None] for k, param in params.items()} pattern = 'b ' + pattern out[key] = ScalerModule(**params, pattern=pattern) out[key]._repeated_params = rep_params return out def get_disjoint_scaler(batch, idx): transform = batch.transform out = dict() for key, scaler in transform.items(): if not hasattr(scaler, '_repeated_params'): raise RuntimeError("Cannot separate 'ScalerModule' because " "was not created via 'collate_scalers()'.") cat_dim = batch.__cat_dim__(key, batch[key]) # collate each param in the scaler params = { param: separate_scaler_params(getattr(scaler, param), slices=batch._slice_dict[key], idx=idx, cat_dim=cat_dim, is_repeated=scaler._repeated_params[param]) for param in scaler.params() } # set pattern of new collated scaler if scaler.pattern is not None: pattern = scaler.pattern else: pattern = batch.pattern.get(key) if cat_dim is None and pattern is not None: pattern = pattern[2:] out[key] = ScalerModule(**params, pattern=pattern) return out
[docs]class DisjointBatch(Batch): r"""A data object describing a batch of graphs as one big (disconnected) graph. Inherits from :class:`tsl.data.Data`. In addition, single graphs can be identified via the assignment vector :obj:`batch`, which maps each node to its respective graph identifier. """
[docs] @classmethod def from_data_list(cls, data_list: List[Data], force_batch: bool = False, follow_batch: Optional[List[str]] = None, exclude_keys: Optional[List[str]] = None, graph_attributes: Optional[List[str]] = None): r"""Constructs a :class:`~tsl.data.DisjointBatch` object from a list of :class:`~tsl.data.Data` objects. The assignment vector :obj:`batch` is created on the fly. In addition, creates assignment vectors for each key in :obj:`follow_batch`. Will exclude any keys given in :obj:`exclude_keys`. Args: data_list (list): The list of :class:`tsl.data.Data` objects. force_batch (bool): If :obj:`True`, then add add dummy batch dimension for time-varying elements. (default: :obj:`False`) follow_batch (list, optional): Create an assignment vector for each key in :attr:`follow_batch`. (default: :obj:`None`) exclude_keys (list, optional): Exclude the keys in :attr:`exclude_keys` from collate. (default: :obj:`None`) graph_attributes: Keys in :attr:`graph_attributes` with no node dimension will be added to the batch as graph attributes, i.e., the tensors will be stacked on a new dimension (the first one). Note that all graph attributes indexed by a key which is not in this list are repeated along a new node dimension (the second one, if the attribute is time-varying, otherwise the first one). (default: :obj:`None`) """ if graph_attributes is None: graph_attributes = [] if exclude_keys is None: exclude_keys = [] exclude_keys.append('transform') batch, slice_dict, inc_dict = collate( cls, data_list=data_list, increment=True, add_batch=not isinstance(data_list[0], StaticBatch), follow_batch=follow_batch, exclude_keys=exclude_keys, ) scalers = disjoint_scaler_collate(data_list, batch_index=batch.batch, force_batch=force_batch) batch.transform.update(scalers) # repeat node-invariant item element along node dimension to not lose # coupling with item in the batch repeated_keys = dict() for key, value in batch.items(): if key in ['transform', 'batch', 'ptr']: continue if batch.__cat_dim__(key, value) is None: if key not in graph_attributes: batch[key] = value[batch.batch] slice_dict[key] = batch.ptr repeated_keys[key] = 0 if key in batch.pattern: dims = batch.pattern[key].split(' ') # '... -> b ...' if key in graph_attributes: dims.insert(0, 'b') # 'n t ... -> t n ...' elif 't' in dims: # invert batch dimension (after collate) with time to # follow time_then_node convention batch[key] = torch.transpose(batch[key], 0, 1). \ contiguous() # adjust pattern and repeated axis dims.insert(1, 'n') repeated_keys[key] = 1 # '... -> n ...' else: dims.insert(0, 'n') batch.pattern[key] = ' '.join(dims) # if `force_batch` add dummy batch dimension for time-varying elems if force_batch and key in batch.pattern: if 't' in batch.pattern[key]: batch[key] = batch[key][None] batch.pattern[key] = 'b ' + batch.pattern[key] if key in repeated_keys: repeated_keys[key] = [0, repeated_keys[key] + 1] else: repeated_keys[key] = [0] batch._num_graphs = len(data_list) batch._slice_dict = slice_dict batch._inc_dict = inc_dict batch._repeated_keys = repeated_keys return batch
[docs] def get_example(self, idx: int) -> Data: r"""Gets the :class:`~torch_geometric.data.Data` or :class:`~torch_geometric.data.HeteroData` object at index :obj:`idx`. The :class:`~torch_geometric.data.Batch` object must have been created via :meth:`from_data_list` in order to be able to reconstruct the initial object.""" if not hasattr(self, '_slice_dict') and \ not hasattr(self, '_repeated_keys'): raise RuntimeError(( "Cannot reconstruct 'Data' object from 'DynamicBatch' because " "'Batch' was not created via 'DynamicBatch.from_data_list()'")) data = separate( cls=self.__class__.__bases__[-1], batch=self, idx=idx, slice_dict=self._slice_dict, inc_dict=self._inc_dict, decrement=True, ) scalers = get_disjoint_scaler(self, idx=idx) data.transform.update(scalers) for key, dims in self._repeated_keys.items(): idx = [ 0 if i in dims else slice(None) # e.g. dims=[1] -> [:, 0, :] for i in range(data[key].ndim) ] data[key] = data[key][idx] if key in data.pattern: pattern_dims = data.pattern[key].split(' ') pattern_dims = [ tkn for i, tkn in enumerate(pattern_dims) if i not in dims ] data.pattern[key] = ' '.join(pattern_dims) return data
@property def batch_size(self) -> int: return self.num_graphs