Source code for tsl.transforms.rearrange

from torch_geometric.transforms import BaseTransform

from tsl.data import Data, StaticBatch


[docs]class Rearrange(BaseTransform): """Rearrange all keys in :class:`~tsl.data.Data` according to the provided pattern using `einops.rearrange <https://einops.rocks/api/rearrange/>`_. If the objects is of type :class:`~tsl.data.StaticBatch`, then the batch dimension in the output pattern is automatically considered.""" def __init__(self, patterns: dict): self.item_patterns = patterns self.batch_patterns = dict() for key, pattern in patterns.items(): if pattern.startswith('t'): self.batch_patterns[key] = 'b ' + pattern else: self.batch_patterns[key] = pattern def __call__(self, data: Data) -> Data: if isinstance(data, StaticBatch): data.rearrange(self.batch_patterns) else: data.rearrange(self.item_patterns) return data
[docs]class NodeThenTime(Rearrange): """Rearrange all keys in :class:`~tsl.data.Data` such that the node dimension precedes the temporal one. For time-variant but node-invariant signals, a new dummy dimension is added.""" def __init__(self, original_patterns: dict): patterns = dict() for key, pattern in original_patterns.items(): if pattern.startswith('t'): if 'n' in pattern: patterns[key] = 'n t' + pattern[3:] else: patterns[key] = '1 ' + pattern super(NodeThenTime, self).__init__(patterns)