Transforms#

In the spirit of torchvision and PyG, this module contains transform operations called on every SpatioTemporalDataset item get. A transform object expects a Data object as input and returns a transformed object of the same type.

MaskedSubgraph

Reduce graph in sample removing masked nodes.

Rearrange

Rearrange all keys in Data according to the provided pattern using einops.rearrange.

NodeThenTime

Rearrange all keys in Data such that the node dimension precedes the temporal one.

MaskInput

Whiten masked values in input_key according to mask in mask_key.

class MaskedSubgraph[source]#

Reduce graph in sample removing masked nodes.

class Rearrange(patterns: dict)[source]#

Rearrange all keys in Data according to the provided pattern using einops.rearrange.

If the objects is of type StaticBatch, then the batch dimension in the output pattern is automatically considered.

class NodeThenTime(original_patterns: dict)[source]#

Rearrange all keys in Data such that the node dimension precedes the temporal one.

For time-variant but node-invariant signals, a new dummy dimension is added.

class MaskInput(input_key: str = 'x', mask_key: str = 'mask')[source]#

Whiten masked values in input_key according to mask in mask_key.

Parameters:
  • input_key (str) – The key in Data to be masked. (default: 'input_key')

  • mask_key (str) – The key in Data to serve as mask. (default: 'mask_key')