Source code for tsl.data.data

import copy
from typing import (Any, Callable, Dict, Iterable, Iterator, List, Mapping,
                    Optional, Tuple, Union)

import torch
from einops import rearrange
from torch import Tensor
from torch_geometric.data.data import Data as PyGData
from torch_geometric.data.storage import BaseStorage
from torch_geometric.data.view import ItemsView, KeysView, ValuesView
from torch_geometric.typing import Adj
from torch_sparse import SparseTensor

from tsl.ops.connectivity import reduce_graph
from tsl.ops.pattern import take
from tsl.utils.python_utils import ensure_list


def get_size(x: Union[Tensor, SparseTensor]) -> Tuple:
    if isinstance(x, Tensor):
        return tuple(x.size())
    elif isinstance(x, SparseTensor):
        return tuple(x.sizes())


def pattern_size_repr(key: str,
                      x: Union[Tensor, SparseTensor],
                      pattern: str = None):
    if pattern is not None:
        pattern = pattern.replace(' ', '')
        out = str([
            f'{token}={size}' if not token.isnumeric() else str(size)
            for token, size in zip(pattern, get_size(x))
        ])
    else:
        out = str(list(get_size(x)))
    out = f"{key}={out}".replace("'", '')
    return out


class StorageView(BaseStorage):

    def __init__(self, store, keys: Optional[Iterable] = None):
        self.__keys = tuple()
        super(StorageView, self).__init__()
        self._mapping = store
        self._keys = keys  # noqa

    def __len__(self) -> int:
        return len(self._keys)

    def __repr__(self) -> str:
        cls = self.__class__.__name__
        info = [pattern_size_repr(k, v) for k, v in self.items()]
        return '{}({})'.format(cls, ', '.join(info))

    def __setattr__(self, key, value):
        if key == '_keys':
            if value is None:
                keys = []
            else:
                keys = ensure_list(value)
            self.__keys = tuple(keys)
        else:
            super(StorageView, self).__setattr__(key, value)

    def __getitem__(self, item: str) -> Any:
        if item in self._keys:
            return self._mapping[item]
        else:
            raise KeyError(item)

    def __setitem__(self, key, value):
        super(StorageView, self).__setitem__(key, value)
        self.add_keys(key)

    def __delitem__(self, key):
        super(StorageView, self).__delitem__(key)
        self.del_keys(key)

    def __iter__(self) -> Iterator:
        return iter(self.values())

    # Override methods to account for filtering keys  #########################

    def _filter_keys(self, args: Tuple):
        keys = self._keys
        if len(args):
            keys = [arg for arg in args if arg in self._keys]
        return keys

    def keys(self, *args: List[str]) -> KeysView:
        keys = self._filter_keys(args)
        if len(keys) > 0:
            return super(StorageView, self).keys(*keys)
        return KeysView({})

    def values(self, *args: List[str]) -> ValuesView:
        keys = self._filter_keys(args)
        if len(keys) > 0:
            return super(StorageView, self).values(*keys)
        return ValuesView({})

    def items(self, *args: List[str]) -> ItemsView:
        keys = self._filter_keys(args)
        if len(keys) > 0:
            return super(StorageView, self).items(*keys)
        return ItemsView({})

    def apply_(self, func: Callable, *args: List[str]):
        keys = self._filter_keys(args)
        if len(keys) > 0:
            return super(StorageView, self).apply_(func, *keys)
        return self

    def apply(self, func: Callable, *args: List[str]):
        keys = self._filter_keys(args)
        if len(keys) > 0:
            return super(StorageView, self).apply(func, *keys)
        return self

    def to_dict(self) -> Dict[str, Any]:
        return copy.copy({k: self._mapping[k] for k in self._keys})

    def numpy(self, *args: List[str]):
        r"""Transform all tensors to numpy arrays, either for all
        attributes or only the ones given in :obj:`*args`."""
        self.detach().cpu()
        return self.apply(lambda x: x.numpy(), *args)

    @property
    def _keys(self) -> tuple:
        return tuple(k for k in self.__keys if k in self._mapping)

    def add_keys(self, *keys):
        keys = set(keys).difference(self.__keys)
        self.__keys = tuple([*self.__keys, *keys])

    def del_keys(self, *keys):
        keys = tuple(k for k in self.__keys if k not in keys)
        self.__keys = keys


[docs]class Data(PyGData): r"""A data object describing a spatiotemporal graph, i.e., a graph with time series of equal length associated with every node. The data object extends :class:`torch_geometric.data.Data`, thus preserving all its functionalities (see also the `accompanying tutorial <https://pytorch-geometric.readthedocs.io/en/latest/notes/introduction.html #data-handling-of-graphs>`_). Args: input (Mapping, optional): Named mapping of :class:`~torch.Tensor` to be used as input to the model. (default: :obj:`None`) target (Mapping, optional): Named mapping of :class:`~torch.Tensor` to be used as target of the task. (default: :obj:`None`) edge_index (Adj, optional): Graph connectivity either in COO format (a :class:`~torch.Tensor` of shape :obj:`[2, E]`) or as a :class:`torch_sparse.SparseTensor` with shape :obj:`[N, N]`. For dynamic graphs -- with time-varying topology -- can be a Python list of :class:`~torch.Tensor`. (default: :obj:`None`) edge_weight (Tensor, optional): Weights of the edges (if :attr:`edge_index` is not a :class:`torch_sparse.SparseTensor`). (default: :obj:`None`) mask (Tensor, optional): The optional mask associated with the target. (default: :obj:`None`) transform (Mapping, optional): Named mapping of :class:`~tsl.data.preprocessing.Scaler` associated with entries in :attr:`input` or :attr:`output`. (default: :obj:`None`) pattern (Mapping, optional): Map of the pattern of each entry in :attr:`input` or :attr:`output`. (default: :obj:`None`) **kwargs: Any keyword argument for :class:`~torch_geometric.data.Data`. """ input: StorageView target: StorageView pattern: dict def __init__(self, input: Optional[Mapping] = None, target: Optional[Mapping] = None, edge_index: Optional[Adj] = None, edge_weight: Optional[Tensor] = None, mask: Optional[Tensor] = None, transform: Optional[Mapping] = None, pattern: Optional[Mapping] = None, **kwargs): input = input if input is not None else dict() target = target if target is not None else dict() super(Data, self).__init__(**input, **target, edge_index=edge_index, edge_weight=edge_weight, **kwargs) # Set 'input' as view on input keys self.__dict__['input'] = StorageView(self._store, input.keys()) # Set 'target' as view on target keys self.__dict__['target'] = StorageView(self._store, target.keys()) # Add mask self.mask = mask # noqa # Add transform modules transform = transform if transform is not None else dict() self.transform: dict = transform # noqa # Add patterns self.__dict__['pattern'] = dict() if pattern is not None: self.pattern.update(pattern) def __repr__(self) -> str: cls = self.__class__.__name__ inputs = [ pattern_size_repr(k, v, self.pattern.get(k)) for k, v in self.input.items() ] inputs = 'input=({})'.format(', '.join(inputs)) targets = [ pattern_size_repr(k, v, self.pattern.get(k)) for k, v in self.target.items() ] targets = 'target=({})'.format(', '.join(targets)) info = [inputs, targets, "has_mask={}".format(self.has_mask)] if self.has_transform: info += ["transform=[{}]".format(', '.join(self.transform.keys()))] return '{}(\n {}\n)'.format(cls, ',\n '.join(info)) def __cat_dim__(self, key: str, value: Any, *args, **kwargs) -> Any: if key in self.pattern: if 'n' in self.pattern[key]: # cat along node dimension if isinstance(value, SparseTensor): # allow for multi-dim cat for SparseTensor (e.g., adj) return tuple( dim for dim, tkn in enumerate(self.pattern[key].split(' ')) if tkn == 'n') return self.pattern[key].split(' ').index('n') elif 'e' in self.pattern[key]: # cat along edge dimension return self.pattern[key].split(' ').index('e') else: # stack on batch dimension return None return super(Data, self).__cat_dim__(key, value, *args, **kwargs) def stores_as(self, data: 'Data'): # copy input and target keys in self with no check that keys are in self # used when batching Data objects self.input._keys = data.input._keys # noqa self.target._keys = data.target._keys # noqa self.pattern.clear() self.pattern.update(data.pattern) return self @property def edge_weight(self) -> Any: return self['edge_weight'] if 'edge_weight' in self._store else None @property def mask(self) -> Any: return self['mask'] if 'mask' in self._store else None @property def transform(self) -> Any: return self['transform'] if 'transform' in self._store else None @property def has_transform(self): return 'transform' in self._store and len(self.transform) > 0 @property def has_mask(self): return self['mask'] is not None if 'mask' in self._store else False
[docs] def numpy(self, *args: List[str]): r"""Transform all tensors to numpy arrays, either for all attributes or only the ones given in :obj:`*args`.""" self.detach().cpu() return self.apply(lambda x: x.numpy(), *args)
[docs] def rearrange_element(self, key: str, pattern: str, **axes_lengths): r"""Rearrange key in Data according to the provided patter using `einops.rearrange <https://einops.rocks/api/rearrange/>`_.""" key_pattern = self.pattern[key] if '->' in pattern: start_pattern, end_pattern = pattern.split('->') start_pattern = start_pattern.strip() end_pattern = end_pattern.strip() if key_pattern != start_pattern: raise RuntimeError( f"Starting pattern {start_pattern} does not " f"match with key patter {key_pattern}.") else: end_pattern = pattern pattern = key_pattern + ' -> ' + pattern self[key] = rearrange(self[key], pattern, **axes_lengths) self.pattern[key] = end_pattern if key in self.transform: self.transform[key] = self.transform[key].rearrange(end_pattern)
[docs] def rearrange(self, patterns: Mapping): r"""Rearrange all keys in Data according to the provided pattern using `einops.rearrange <https://einops.rocks/api/rearrange/>`_.""" for key, pattern in patterns.items(): self.rearrange_element(key, pattern) return self
def subgraph_(self, subset: Tensor): edge_index, edge_mask = reduce_graph(subset, edge_index=self.edge_index, num_nodes=self.num_nodes) if subset.dtype == torch.bool: num_nodes = int(subset.sum()) else: num_nodes = subset.size(0) for key, value in self: if key == 'edge_index': self.edge_index = edge_index elif key == 'edge_weight': self.edge_weight = self.edge_weight[edge_mask] elif key == 'num_nodes': self.num_nodes = num_nodes # prefer pattern indexing if available elif key in self.pattern: self[key] = take(value, self.pattern[key], node_index=subset, edge_mask=edge_mask) # fallback to PyG indexing (cannot index on multiple node dim) elif isinstance(value, Tensor): if self.is_node_attr(key): node_dim = self.__cat_dim__(key, value, self._store) self[key] = torch.index_select(value, node_dim, subset) elif self.is_edge_attr(key): edge_dim = self.__cat_dim__(key, value, self._store) self[key] = torch.index_select(value, edge_dim, edge_mask) if key in self.transform: scaler = self.transform[key] self.transform[key] = scaler.slice(node_index=subset) return self
[docs] def subgraph(self, subset: Tensor): data = copy.copy(self) return data.subgraph_(subset)