A Gentle Introduction to tsl#


This a tutorial notebook about tsl (Torch Spatiotemporal), a Python library built upon the PyTorch ecosystem and tailored for spatiotemporal data processing.

In this notebook we are going to see what are the necessary steps from data loading to model training.

Open In Colab

Installation#


Let’s start by the installation. If you run the notebook in colab, you can install tsl with these commands:

!git clone https://github.com/TorchSpatiotemporal/tsl.git
!pip install torch==1.10.1+cu111 torchvision==0.11.2+cu111 torchaudio==0.10.1 -f https://download.pytorch.org/whl/torch_stable.html
!pip install torch-scatter torch-sparse torch-geometric -f https://data.pyg.org/whl/torch-1.10.1+cu113.html
!pip install ./tsl

In particular, the second command is installing torch-geometric dependencies for the specific environment you have on colab with GPU runtime. Please refer to PyG installation guidelines for installation in other environments.

We recommend to install tsl from GitHub repository at the moment, to be sure you are up-to-date with latest version.

Let’s check if everything is ok.

import tsl
import torch
import numpy as np

np.set_printoptions(suppress=True)
tsl.logger.disabled = True

print(f"tsl version  : {tsl.__version__}")
print(f"torch version: {torch.__version__}")

Usage#


tsl is more than a collection of layers. We can classify the library modules into:

  • Data loading modules Manage how to store, preprocess and visualize spatio-temporal data

  • Inference modules Methods and models exploiting the data to make inferences

We will go deeper on them in next sections.

Data loading#


tsl comes with several datasets used in spatiotemporal processing literature. You can find them inside the submodule tsl.datasets.

Loading the dataset#

As an example, we start by using the MetrLA dataset, a common benchmark for traffic forecasting. The dataset contains traffic readings collected from 207 loop detectors on highways in Los Angeles County, aggregated in 5 minute intervals for 4 months between March 2012 to June 2012. Loading the dataset is as simple as that:

from tsl.datasets import MetrLA

dataset = MetrLA()

print(dataset)

We can see that data are organized a 3-dimensional array, with:

  • 34.272 temporal steps (1 each 5 minute for 4 months)

  • 207 nodes (the loop detectors)

  • 1 channels (detected speed)

Nice! Other than storing the data of interest, the dataset comes with useful tools.

print(f"Sampling period: {dataset.freq}\n"
      f"Has missing values: {dataset.has_mask}\n"
      f"Percentage of missing values: {(1 - dataset.mask.mean()) * 100:.2f}%\n"
      f"Has dataset exogenous variables: {dataset.has_covariates}\n"
      f"Relevant attributes: {', '.join(dataset.attributes.keys())}")

Let’s look at the output. We know that the dataset has missing entries, with dataset.mask being a binary indicator associated with each timestep, node and channel (with ones indicating valid values).

Also, the dataset has no exogenous variables, i.e., there are no time-varying features paired with the main signal. Instead it has a useful attribute: the distance matrix. We call attributes, features that are static.

You can access exogenous variables and attributes by dataset.name:

print(dataset.dist)

This matrix stores the pairwise distance between sensors, with inf denoting two non-neighboring sensors.

Let’s now check how the speed readings look like.

dataset.dataframe().head(10)

Get connectivity#

Besides the time series, to properly use graph-based models, we need to connect nodes somehow.

With the method dataset.get_similarity() we can retrieve nodes’ similarities computed with different methods. The available similarity methods for a dataset can be found at dataset.similarity_options, while the default one is at dataset.similarity_score.

print(f"Default similarity: {dataset.similarity_score}\n"
      f"Available similarity options: {dataset.similarity_options}\n")

sim = dataset.get_similarity("distance")  # same as dataset.get_similarity()
print(sim[:10, :10])  # just check first 10 nodes for readability

With this method, we compute weight \(w_t^{i,j}\) of the edge connecting \(i\)-th and \(j\)-th node as

\[\begin{split} w^{i,j} = \left\{\begin{array}{cl} \exp \left(-\frac{\operatorname{dist}\left(i, j\right)^{2}}{\gamma}\right) & \operatorname{dist}\left(i, j\right) \leq \delta \\ 0 & \text{otherwise} \end{array}\right. , \end{split}\]
where \(\operatorname{dist}\left(i, j\right)\) is the distance between \(i\)-th and \(j\)-th node, \(\gamma\) controls the width of the kernel and \(\delta\) is a threshold. Notice that in this case the similarity matrix is not symmetric, since the original preprocessed distance matrix is not symmetric too.

So far so good, now we can build an adjacency matrix out ouf the computed similarity.

The method dataset.get_connectivity(), wraps the dataset.get_similarity() module, provides useful preprocessing options, and, eventually, returns a sparse (weighted) adjacency matrix.

adj = dataset.get_connectivity(threshold=0.1,
                               include_self=False,
                               normalize_axis=1,
                               layout="edge_index")

With this function call, under the hood we:

  1. compute the similarity matrix as before

  2. set to \(0\) values below \(0.1\) (threshold=0.1)

  3. remove self loops (include_self=False)

  4. normalize edge weights by the in degree of nodes (normalize_axis=1)

  5. request the sparse COO layout of PyG (layout="edge_index")

The connectivity matrix with edge_index layout is provided in COO format, adopting the convention and notation used in PyTorch Geometric. The returned connectivity is a tuple (edge_index, edge_weight), where edge_index lists all edges as pairs of source-target nodes (dimensions [2, E]) and edge_weight (dimension [E]) stores the corresponding weights.

edge_index, edge_weight = adj

print(edge_index.shape)
print(edge_weight)

The dense layout corresponds to the weighted adjacency matrix \(A \in \mathbb{R}^{N \times N}\). The module tsl.ops.connectivity contains useful operations for connectivities, including methods to change layout.

from tsl.ops.connectivity import edge_index_to_adj

dense = edge_index_to_adj(edge_index, edge_weight)
print(dense.shape)

Data processing#


In this section, we will see how to transfer data from a dataset to an inference model (e.g., a spatiotemporal graph neural network).

The SpatioTemporalDataset#

The first class that comes in help is tsl.data.SpatioTemporalDataset. This class is a subclass of torch.utils.data.Dataset can be considered as wrapper of a tsl dataset providing the interface for further processing.

In particular, a SpatioTemporalDataset object can be used to achieve the following:

  • perform the transformations required to feed the data to a model (e.g., casting to torch.tensor, handling different shapes)

  • handling temporal slicing and windowing for training (e.g., split data in \(\left( \text{window}, \text{horizon} \right)\) samples)

  • defining the layout of inputs and targets (e.g., how node attributes and exogenous variables are arranged)

  • preprocess data before creating a batch

Let’s see how to go from a Dataset to a SpatioTemporalDataset.

from tsl.data import SpatioTemporalDataset

target, idx = dataset.numpy(return_idx=True)

torch_dataset = SpatioTemporalDataset(target=target,
                                      index=idx,
                                      connectivity=adj,
                                      mask=dataset.mask,
                                      horizon=12,
                                      window=12,
                                      stride=1)
print(torch_dataset)

As you can see, the number of samples is not the same as the number of steps we have in the dataset. Indeed, we divided the time series with a sliding window of 12 steps for the input (window=12), with a corresponding horizon of 12 steps (horizon=12). Thus, a sample spans for a total of \(24\) steps. But let’s look in details to the layout of a sample:

sample = torch_dataset[0]
print(sample)
torch_dataset[0].pattern

A sample has 5 main attributes:

  • sample.input is a mapping of data to be forwarded as input to the model.

  • sample.target is a mapping of data to be forwarded as target for the loss function of the model.

  • sample.mask store the mask, if any. It is useful for computing the loss only on valid data.

  • sample.transform is a mapping containing as value a transformation function (e.g., scaling, detrending) and as key the name of the tensor to be transformed.

  • sample.pattern stores the pattern, i.e., a more informative shape representation, of each tensor in sample.

Let’s check more in details how each of these attributes is composed.

Input and Target#

A sample is a tsl.data.Data object which stores all that is needed to support inference. Both input and target are tsl.data.DataView of this storage. This means that they have the same methods, but contain different subsets keys. As a results, you cannot store two tensors using the key in input and target.

sample.input.to_dict()
sample.target.to_dict()

Mask and Transform#

mask and transform are just symbolic links to the corresponding object inside the storage. They also expose properties has_mask and has_transform.

if sample.has_mask:
    print(sample.mask)
else:
    print("Sample has no mask.")
if sample.has_transform:
    print(sample.transform)
else:
    print("Sample has no transform functions.")

Pattern#

The pattern mapping can be useful to glimpse on how data are arranged. The convention we use is the following:

  • 'b' stands for “batch dimension”

  • 'f' stands for “node features dimension”

  • 'e' stands for “edges dimension”

  • 'n' stands for “nodes dimension”

  • 't' stands for “time steps dimension”

sample.pattern

The SpatioTemporalDataModule#

Usually, before running an experiment there are two quite common preprocessing steps:

  • splitting the dataset into training/validation/test sets

  • data preprocessing (scaling/normalizing data, detrending)

In tsl, these operations are carried out in the tsl.data.SpatioTemporalDataModule, which is based on pytorch-lightning’s data modules.

Let’s see an example

from tsl.data import SpatioTemporalDataModule
from tsl.data.preprocessing import StandardScaler

scalers = {'target': StandardScaler(axis=(0, 1))}

splitter = dataset.get_splitter(val_len=0.1, test_len=0.2)

dm = SpatioTemporalDataModule(
    dataset=torch_dataset,
    scalers=scalers,
    splitter=splitter,
    batch_size=64,
)

print(dm)

Eventually one could extend the base datamodule to add further processing in case it is needed.

At this point, the DataModule object has not actually performed any processing yet (lazy approach).

We can execute the preprocessing routines by calling setup method.

Note that

dm.setup()
print(dm)

After setup has been called, the datamodule carries the following operations:

  1. Carries out the dataset splitting into training/validation/test sets according to the splitter argument.

  2. Fits all the Scalers on the training data in dataset corresponding to the scalers’ keys.

Scalers#

The tsl.data.preprocessing package offers several of the most common data normalization techniques under the tsl.data.preprocessing.Scaler interface. They adopt an API similar to scikit-learn’s scalers, with fit/transform/fit_transform/inverse_transform methods. Check the documentation for more details about this.

Building a Model#


In this section, we will see how to build a very simple Graph Neural Network for Spatiotemporal data. All the neural network code inside tsl is under the tsl.nn module.

The NN module#

The tsl.nn module is organized as follows:

tsl
└── nn
    ├── base
    ├── blocks
    ├── layers
    └── models

The 3 most important submodules in it are layers, blocks, and models, ordered by increasing level of abstraction.

Layers#

A layer is a basic building block for our neural networks. In simple words, a layer takes an input, performs one (or few) operations, and return a transformation of the input. Examples of layers are DiffConv, which implements diffusion convolution, or LayerNorm.

Blocks#

blocks perform more complex transformations or combine several operations. We divide blocks into encoders, if they provide a representation of the input in a new space, and decoders, if they produce a meaningful output from a representation.

Models#

We wrap a series of operations, represented by blocks and/or layers, in a model. A model is meant to takes as input a batch SpatioTemporalDataset’s items and return the desired output.

Let’s create a very simple model with a RNN encoder and a nonlinear GCN readout. To do so, we import RNN from the encoders and GCNDecoder from the decoders in the tsl.nn.blocks module.

from tsl.nn.blocks.encoders import RNN
from tsl.nn.blocks.decoders import GCNDecoder


class TimeThenSpaceModel(torch.nn.Module):
    def __init__(self,
                 input_size,
                 hidden_size,
                 rnn_layers,
                 gcn_layers,
                 horizon):
        super(TimeThenSpaceModel, self).__init__()

        self.input_encoder = torch.nn.Linear(input_size, hidden_size)

        self.encoder = RNN(input_size=hidden_size,
                           hidden_size=hidden_size,
                           n_layers=rnn_layers)

        self.decoder = GCNDecoder(
            input_size=hidden_size,
            hidden_size=hidden_size,
            output_size=input_size,
            horizon=horizon,
            n_layers=gcn_layers
        )

    def forward(self, x, edge_index, edge_weight):
        # x: [batches steps nodes channels]
        x = self.input_encoder(x)

        x = self.encoder(x, return_last_state=True)

        return self.decoder(x, edge_index, edge_weight)

Fine, we have a model and we have data, now let’s train it!

Setting up training#


The Predictor#

In tsl, inference engines are implemented as a LightningModule. tsl.predictors.Predictor is a base class that can be extended to build more complex forecasting approaches. These modules are meant to wrap deep models in order to ease training and inference phases.

from tsl.metrics.torch import MaskedMAE, MaskedMAPE
from tsl.engines import Predictor

loss_fn = MaskedMAE()

metrics = {'mae': MaskedMAE(),
           'mape': MaskedMAPE(),
           'mae_at_15': MaskedMAE(at=2),  # '2' indicates the third time step,
                                          # which correspond to 15 minutes ahead
           'mae_at_30': MaskedMAE(at=5),
           'mae_at_60': MaskedMAE(at=11), }

model_kwargs = {
    'input_size': dm.n_channels,  # 1 channel
    'horizon': dm.horizon,  # 12, the number of steps ahead to forecast
    'hidden_size': 16,
    'rnn_layers': 1,
    'gcn_layers': 2
}

# setup predictor
predictor = Predictor(
    model_class=TimeThenSpaceModel,
    model_kwargs=model_kwargs,
    optim_class=torch.optim.Adam,
    optim_kwargs={'lr': 0.003},
    loss_fn=loss_fn,
    metrics=metrics
)
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
/var/folders/m8/d1snxhr91md0hwqh155rvf900000gn/T/ipykernel_1303/969339015.py in <module>
     12 
     13 model_kwargs = {
---> 14     'input_size': dm.n_channels,  # 1 channel
     15     'horizon': dm.horizon,  # 12, the number of steps ahead to forecast
     16     'hidden_size': 16,

NameError: name 'dm' is not defined

Now let’s finalize the last details. We make use of TensorBoard to log and visualize metrics.

from pytorch_lightning.loggers import TensorBoardLogger

logger = TensorBoardLogger(save_dir="logs", name="tsl_intro", version=0)
%load_ext tensorboard
%tensorboard --logdir logs

We let pytorch_lightning.Trainer handle the dirty work for us. We can directly pass the datamodule to the trainer for fitting.

If this is the case, the trainer will call the setup method, and then load train and validation sets.

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint

checkpoint_callback = ModelCheckpoint(
    dirpath='logs',
    save_top_k=1,
    monitor='val_mae',
    mode='min',
)

trainer = pl.Trainer(max_epochs=100,
                     logger=logger,
                     gpus=1 if torch.cuda.is_available() else None,
                     limit_train_batches=100,
                     callbacks=[checkpoint_callback], )

trainer.fit(predictor, datamodule=dm)

Testing#


Now let’s see how the trained model behaves on new unseen data.

predictor.load_model(checkpoint_callback.best_model_path)
predictor.freeze()

performance = trainer.test(predictor, datamodule=dm)

Cool! We succeeded in creating first a simple, yet effective, SpatioTemporal model!

We are now tsl ninjas. We learned how to:

  • Load benchmark datasets

  • Organize data for processing

  • Preprocess the data

  • Build a Spatiotemporal GNN

  • Train and evaluate the model

We hope you enjoyed this introduction to tsl, do not hesitate to contact us if you have any question or problem while using it.

The tsl team.

🧡