A Gentle Introduction to tsl#


In this tutorial notebook, we will see how to train our custom-made Spatiotemporal Graph Neural Network (STGNN) for traffic forecasting using tsl (Torch Spatiotemporal), a Python library for neural spatiotemporal data processing, with a focus on Graph Neural Networks.

It is built upon the most used libraries of the python scientific computing ecosystem, with the final objective of providing a straightforward process that goes from data preprocessing to model prototyping.

In particular, tsl offers a wide range of utilities to develop neural networks in PyTorch and PyTorch Geometric (PyG) for processing spatiotemporal graph signals.

Open In Colab Download .ipynb


Quickstart#


Installation#


Let’s start by installing tsl from source and the related dependencies. Installing tsl from GitHub ensures to be up-to-date with the latest version.

# Install required packages.
import os
import torch
os.environ['TORCH'] = torch.__version__
print(torch.__version__)

!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q git+https://github.com/pyg-team/pytorch_geometric.git
!pip install -q git+https://github.com/TorchSpatiotemporal/tsl.git

We refer to tsl and PyG installation guidelines for the setup in other environments.

Let’s check if everything is ok.

import tsl
import torch
import numpy as np
import pandas as pd

print(f"tsl version  : {tsl.__version__}")
print(f"torch version: {torch.__version__}")

pd.options.display.float_format = '{:.2f}'.format
np.set_printoptions(edgeitems=3, precision=3)
torch.set_printoptions(edgeitems=2, precision=3)

# Utility functions ################
def print_matrix(matrix):
    return pd.DataFrame(matrix)

def print_model_size(model):
    tot = sum([p.numel() for p in model.parameters() if p.requires_grad])
    out = f"Number of model ({model.__class__.__name__}) parameters:{tot:10d}"
    print("=" * len(out))
    print(out)

Usage#


tsl is more than a collection of layers. We can classify the library modules into:

  • Data loading modules
    Manage how to store, load, and preprocess spatiotemporal data, providing a simple interface to make data ready for downstream neural models.

  • Inference modules
    Models and engines that take as input spatiotemporal data to make inferences for the task at hand, e.g., forecasting or imputation.

We will go deeper on them in next sections.


Loading and Preprocessing Data#


Loading a tabular dataset#


tsl comes with several datasets used in spatiotemporal processing literature. You can find them inside the submodule tsl.datasets.

As an example, we start by using the MetrLA dataset, a common benchmark for traffic forecasting. The dataset contains traffic readings collected from 207 loop detectors on highways in Los Angeles County, aggregated in 5 minute intervals for 4 months between March 2012 to June 2012. Loading the dataset is as simple as that:

from tsl.datasets import MetrLA

dataset = MetrLA(root='./data')

print(dataset)

All the datasets in tsl are subclass of the root class tsl.datasets.Dataset, exposing useful APIs for spatiotemporal datasets. We can see that data are organized a 3-dimensional array, with:

  • 34.272 time steps (1 each 5 minute for 4 months)

  • 207 nodes (the loop detectors)

  • 1 channels (detected speed)

Nice! Other than storing the data of interest, the dataset comes with useful tools.

print(f"Sampling period: {dataset.freq}")
print(f"Has missing values: {dataset.has_mask}")
print(f"Percentage of missing values: {(1 - dataset.mask.mean()) * 100:.2f}%")
print(f"Has exogenous variables: {dataset.has_covariates}")
print(f"Covariates: {', '.join(dataset.covariates.keys())}")

Let’s look at the output. We know that the dataset has missing entries, with dataset.mask being a binary indicator associated with each timestep, node and channel (with ones indicating valid values).

Also, the dataset has a covariate attribute (i.e., exogenous variables) – the distance matrix – containing the pairwise distances between sensors.

You can access covariates by dataset.{covariate_name}:

print_matrix(dataset.dist)

This matrix stores the pairwise distance between sensors, with inf denoting two non-neighboring sensors.

Let’s now check how the speed readings look like.

dataset.dataframe()

Connecting sensors#

Besides the time series, to properly use graph-based models, we need to connect nodes somehow.

With the method dataset.get_similarity() we can retrieve nodes’ similarities computed with different methods. The available similarity methods for a dataset can be found at dataset.similarity_options, while the default one is at dataset.similarity_score.

print(f"Default similarity: {dataset.similarity_score}")
print(f"Available similarity options: {dataset.similarity_options}")
print("==========================================")

sim = dataset.get_similarity("distance")  # or dataset.compute_similarity()

print("Similarity matrix W:")
print_matrix(sim)

With this method, we compute weight \(w_t^{i,j}\) of the edge connecting \(i\)-th and \(j\)-th node as

\[\begin{split} w^{i,j} = \left\{\begin{array}{cl} \exp \left(-\frac{\operatorname{dist}\left(i, j\right)^{2}}{\gamma}\right) & \operatorname{dist}\left(i, j\right) \leq \delta \\ 0 & \text{otherwise} \end{array}\right. , \end{split}\]

where \(\operatorname{dist}\left(i, j\right)\) is the distance between \(i\)-th and \(j\)-th node, \(\gamma\) controls the kernel width and \(\delta\) is a threshold. Notice that in this case the similarity matrix is not symmetric, since the original preprocessed distance matrix is not symmetric too.

So far so good, now we can build an adjacency matrix out ouf the computed similarity.

The method dataset.get_connectivity() – calling dataset.get_similarity() under-the-hood – provides useful preprocessing options, and, eventually, returns a possibly sparse, possibly weighted, adjacency matrix.

connectivity = dataset.get_connectivity(threshold=0.1,
                                        include_self=False,
                                        normalize_axis=1,
                                        layout="edge_index")

Let’s see what happens with this function call:

  1. compute the similarity matrix as before;

  2. set to 0 values below 0.1 (threshold=0.1);

  3. remove self loops (include_self=False);

  4. normalize edge weights by the in degree of nodes (normalize_axis=1);

  5. request the sparse COO layout of PyG (layout="edge_index")

The connectivity matrix with edge_index layout is provided in COO format, adopting the convention and notation used in PyTorch Geometric. The returned connectivity is a tuple (edge_index, edge_weight), where edge_index lists all edges as pairs of source-target nodes (dimensions [2, E]) and edge_weight (dimension [E]) stores the corresponding weights.

edge_index, edge_weight = connectivity

print(f'edge_index {edge_index.shape}:\n', edge_index)
print(f'edge_weight {edge_weight.shape}:\n', edge_weight)

The "dense" layout instead corresponds to the weighted adjacency matrix \(A \in \mathbb{R}^{N \times N}\). The module tsl.ops.connectivity contains useful operations for connectivities, including methods to change layout.

from tsl.ops.connectivity import edge_index_to_adj

adj = edge_index_to_adj(edge_index, edge_weight)
print(f'A {adj.shape}:')
print_matrix(adj)

From the dense layout, the sparse COO format can be easily retrieved as:

print(f'Sparse edge weights:\n', adj[edge_index[1], edge_index[0]])

Building a PyTorch-ready dataset#


In this section, we will see how to fetch spatiotemporal graph signals that are then given as input to a neural network (e.g., an STGNN) starting from a dataset of this kind.

The first class that comes in help is tsl.data.SpatioTemporalDataset. This class is a subclass of torch.utils.data.Dataset and is in charge of mapping a tabular dataset represented in your preferred format (e.g., numpy array, pandas dataframe or the aforementioned tsl.datasets.Dataset) to a PyTorch-ready implementation.

In particular, a SpatioTemporalDataset object can be used to achieve the following:

  • Perform data manipulation operations required to feed the data to a PyTorch module (e.g., casting data to torch.tensor, handling possibly different shapes, synchronizing temporal data).

  • Create (input, target) samples for supervised learning following the sliding window approach.

  • Define how data should be arranged in a spatiotemporal graph signal (e.g., which are the inputs and targets, how node attributes and covariates variables are mapped into a single graph).

  • Preprocess data before creating a spatiotemporal graph signal by appling transformations or scaling operations.

Let’s see how to go from a Dataset to a SpatioTemporalDataset.

from tsl.data import SpatioTemporalDataset

torch_dataset = SpatioTemporalDataset(target=dataset.dataframe(),
                                      connectivity=connectivity,
                                      mask=dataset.mask,
                                      horizon=12,
                                      window=12,
                                      stride=1)
print(torch_dataset)

As you can see, the number of samples is not the same as the number of steps we have in the dataset. Indeed, we divided the historic time series with a sliding window of 12 time steps for the lockback window (window=12), with a corresponding horizon of 12 time steps (horizon=12). Thus, a single sample spans for a total of 24 time steps. The stride parameters set how many time steps intercurring between two subsequent samples. The following picture helps at visualizing how these (and more) parameters affect the slicing of the original time series in samples.

Spatiotemporal graph signals in tsl#

We now have a PyTorch-based dataset containing a collection of spatiotemporal graph signals. We can fetch samples in the same way we fetch elements of a Python list. Let’s look in details to the layout of a sample:

sample = torch_dataset[0]
print(sample)

A sample is of type tsl.data.Data, the base class for representing spatiotemporal graph signals in tsl. This class extends torch_geometric.data.Data, preserving all its functionalities and adding utilities for spatiotemporal data processing. The main APIs of Data include:

  • Data.input: view on the tensors stored in Data that are meant to serve as input to the model. In the simplest case of a single node-attribute matrix, we could just have Data.input.x.

  • Data.target: view on the tensors stored in Data used as labels to train the model. In the common case of a single label, we could just have Data.input.y.

  • Data.edge_index: graph connectivity in COO format (i.e., as node pairs).

  • Data.edge_weight: weights of the graph connectivity, if any.

  • Data.mask: binary mask indicating the data in Data.target.y to be used as ground-truth for the loss (default is None).

  • Data.transform: mapping of ScalerModule, whose keys must be transformable (or transformed) tensors in Data.

  • Data.pattern: mapping containing the pattern for each tensor in Data. Patterns add information about the dimensions of tensors (e.g., specifying which are the time step and node dimensions).

None of these attributes are required and custom attributes can be seamlessly added.

Let’s check more in details how each of these attributes is composed.

Input and Target#

Data.input and Data.target provide a view on the unique (shared) storage in Data, such that the same key in Data.input and Data.target cannot reference different objects.

sample.input.to_dict()
sample.target.to_dict()

Mask and Transform#

mask and transform are just symbolic links to the corresponding object inside the storage. They also expose properties has_mask and has_transform.

if sample.has_mask:
    print(sample.mask)
else:
    print("Sample has no mask.")
if sample.has_transform:
    print(sample.transform)
else:
    print("Sample has no transformation functions.")

Pattern#

The pattern mapping can be useful to glimpse on how data are arranged. The convention we use is the following:

  • 't' stands for the time steps dimension

  • 'n' stands for a node dimension

  • 'e' stands for the edge dimension

  • 'f' stands for a feature dimension

  • 'b' stands for the batch dimension

print(sample.pattern)
print("==================   Or we can print patterns and shapes together   ==================")
print(sample)

Batching spatiotemporal graph signals#

Getting a batch of spatiotemporal graph signals from a single dataset is as simple as accessing multiple elements from a list:

batch = torch_dataset[:5]
print(batch)

As you can see, we now have an additional dimension for the time-varying elements (i.e., x and y) denoted by pattern b, i.e., the batch dimension. In this new, first dimension we stacked the features of the first 5 spatiotemporal graphs in the dataset.

Note that this is possible only because we are assuming a fixed underlying topology, as also confirmed by the edge_index and edge_weight attributes. The explanation on how Data objects with different graphs are batched together is out of the scope of this notebook.

Preparing the dataset for training#


Usually, before running an experiment there are two quite common preprocessing steps:

  • splitting the dataset into training/validation/test sets;

  • data preprocessing (scaling/normalizing data, detrending).

In tsl, these operations are managed by the tsl.data.SpatioTemporalDataModule, which is based on the LightningDataModule from PyTorch Lightning. A DataModule allows us to standardize and make consistent the training, validation, test splits, data preparation and transformations across different environments and experiments.

Let’s see an example

from tsl.data.datamodule import (SpatioTemporalDataModule,
                                 TemporalSplitter)
from tsl.data.preprocessing import StandardScaler

# Normalize data using mean and std computed over time and node dimensions
scalers = {'target': StandardScaler(axis=(0, 1))}

# Split data sequentially:
#   |------------ dataset -----------|
#   |--- train ---|- val -|-- test --|
splitter = TemporalSplitter(val_len=0.1, test_len=0.2)

dm = SpatioTemporalDataModule(
    dataset=torch_dataset,
    scalers=scalers,
    splitter=splitter,
    batch_size=64,
)

print(dm)

You can consider to extend the base SpatioTemporalDataModule to add further processing to fit your needs.

At this point, the DataModule object has not actually performed any processing yet (lazy approach).

We can execute the preprocessing routines by calling the dm.setup() method.

dm.setup()
print(dm)

During setup the datamodule does the following operations:

  1. Carries out the dataset splitting into training/validation/test sets according to the provided Splitter.

  2. Fits all the Scalers on the training data in torch_dataset corresponding to the scalers’ key.

Splitters#

Splitters in tsl are the objects defining the policy of data splitting. Read more

Scalers#

The tsl.data.preprocessing package offers several of the most common data normalization techniques under the tsl.data.preprocessing.Scaler interface. They adopt an API similar to scikit-learn’s scalers, with fit/transform/fit_transform/inverse_transform methods. Read more


Building Spatiotemporal Graph Neural Networks#


In this section, we will see how to build a very simple Spatiotemporal Graph Neural Network.

All the functions and classes needed to build neural networks in tsl are under the tsl.nn module.

The nn module#


The tsl.nn module is organized as follows:

tsl
└── nn
    ├── base
    ├── blocks
    ├── layers
    └── models

The 3 most important submodules in it are layers, blocks, and models, ordered by increasing level of abstraction.

Layers#

A layer is a basic building block for our neural networks. In simple words, a layer takes an input, performs one (or few) operations, and returns a transformation of the input. Examples of layers are DiffConv, which implements the diffusion convolution operation, or LayerNorm.

Blocks#

blocks perform more complex transformations or combine several operations. We divide blocks into encoders, if they provide a representation of the input in a new space, and decoders, if they produce a meaningful output from a representation.

Models#

We wrap a series of operations, represented by blocks and/or layers, in a model. A model takes as input a spatiotemporal graph signal and returns the desired output, e.g., the forecasted node features at future time steps.

Designing a custom STGNN#


Let’s get the hands dirty and create our first simple STGNN! We will follow the Time-then-Space paradigm. We use a GRU shared among the nodes to process the temporal dimension. This will give us in output a single feature vector for each node, which is then propagated through the underlying graph using a Diffusion Convolutional GNN. Before and after, we add linear transformations to encode the input features and decode the learned representations. We also make use of node embeddings (free parameters learned individually for each node) to make our STGNN a global-local model (Cini et al., 2023).

All the layers that we need are provided inside tsl.nn. We use:

  • RNN from tsl.nn.blocks.encoders for the GRU;

  • DiffConv from tsl.nn.layers.graph_convs for the diffusion convolution;

  • StaticGraphEmbedding from tsl.nn.base for the node embeddings.

import torch.nn as nn

from tsl.nn.blocks.encoders import RNN
from tsl.nn.layers import NodeEmbedding, DiffConv
from einops.layers.torch import Rearrange  # reshape data with Einstein notation


class TimeThenSpaceModel(nn.Module):
    def __init__(self, input_size: int, n_nodes: int, horizon: int,
                 hidden_size: int = 32,
                 rnn_layers: int = 1,
                 gnn_kernel: int = 2):
        super(TimeThenSpaceModel, self).__init__()

        self.encoder = nn.Linear(input_size, hidden_size)

        self.node_embeddings = NodeEmbedding(n_nodes, hidden_size)

        self.time_nn = RNN(input_size=hidden_size,
                           hidden_size=hidden_size,
                           n_layers=rnn_layers,
                           cell='gru',
                           return_only_last_state=True)
        
        self.space_nn = DiffConv(in_channels=hidden_size,
                                 out_channels=hidden_size,
                                 k=gnn_kernel)

        self.decoder = nn.Linear(hidden_size, input_size * horizon)
        self.rearrange = Rearrange('b n (t f) -> b t n f', t=horizon)

    def forward(self, x, edge_index, edge_weight):
        # x: [batch time nodes features]
        x_enc = self.encoder(x)  # linear encoder: x_enc = xΘ + b
        x_emb = x_enc + self.node_embeddings()  # add node-identifier embeddings
        h = self.time_nn(x_emb)  # temporal processing: x=[b t n f] -> h=[b n f]
        z = self.space_nn(h, edge_index, edge_weight)  # spatial processing
        x_out = self.decoder(z)  # linear decoder: z=[b n f] -> x_out=[b n t⋅f]
        x_horizon = self.rearrange(x_out)
        return x_horizon

We can play with hyperparameters and make an instance of our model.

hidden_size = 32   #@param
rnn_layers = 1     #@param
gnn_kernel = 2     #@param

input_size = torch_dataset.n_channels   # 1 channel
n_nodes = torch_dataset.n_nodes         # 207 nodes
horizon = torch_dataset.horizon         # 12 time steps

stgnn = TimeThenSpaceModel(input_size=input_size,
                           n_nodes=n_nodes,
                           horizon=horizon,
                           hidden_size=hidden_size,
                           rnn_layers=rnn_layers,
                           gnn_kernel=gnn_kernel)
print(stgnn)
print_model_size(stgnn)

Fine, we loaded the data and built a model, so let’s train it!

Setting up training#


We are now ready to train our model. We set up the training procedure as we prefer, in the following we will use PyTorch Lightning’s Trainer to reduce the burder of the dirty work. We recall that tsl is highly integrated with widely used PyTorch-based libraries, such as PyTorch Lightning and PyTorch Geometric.

The Predictor#

In tsl, inference engines are implemented as a LightningModule. tsl.engines.Predictor is a base class that can be extended to build more complex forecasting approaches. These modules are meant to wrap deep models in order to ease training and inference phases.

from tsl.metrics.torch import MaskedMAE, MaskedMAPE
from tsl.engines import Predictor

loss_fn = MaskedMAE()

metrics = {'mae': MaskedMAE(),
           'mape': MaskedMAPE(),
           'mae_at_15': MaskedMAE(at=2),  # '2' indicates the third time step,
                                          # which correspond to 15 minutes ahead
           'mae_at_30': MaskedMAE(at=5),
           'mae_at_60': MaskedMAE(at=11)}

# setup predictor
predictor = Predictor(
    model=stgnn,                   # our initialized model
    optim_class=torch.optim.Adam,  # specify optimizer to be used...
    optim_kwargs={'lr': 0.001},    # ...and parameters for its initialization
    loss_fn=loss_fn,               # which loss function to be used
    metrics=metrics                # metrics to be logged during train/val/test
)

Now let’s finalize the last details. We make use of TensorBoard to log and visualize metrics.

from pytorch_lightning.loggers import TensorBoardLogger

logger = TensorBoardLogger(save_dir="logs", name="tsl_intro", version=0)
%load_ext tensorboard
%tensorboard --logdir logs

We let pytorch_lightning.Trainer handle the dirty work for us. We can directly pass the datamodule to the trainer for fitting.

If this is the case, the trainer will call the setup method, and then load train and validation sets.

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint

checkpoint_callback = ModelCheckpoint(
    dirpath='logs',
    save_top_k=1,
    monitor='val_mae',
    mode='min',
)

trainer = pl.Trainer(max_epochs=100,
                     logger=logger,
                     gpus=1 if torch.cuda.is_available() else None,
                     limit_train_batches=100,  # end an epoch after 100 updates
                     callbacks=[checkpoint_callback])

trainer.fit(predictor, datamodule=dm)

Testing#


Now let’s see how the trained model behaves on new unseen data.

predictor.load_model(checkpoint_callback.best_model_path)
predictor.freeze()

trainer.test(predictor, datamodule=dm);

Cool! We succeeded in creating our first simple – yet effective – Spatiotemporal GNN!

🥷 We are now tsl ninjas. We learned how to:

  • Load benchmark datasets

  • Organize data for processing

  • Preprocess the data

  • Build a Spatiotemporal GNN

  • Train and evaluate models

We hope you enjoyed this introduction to tsl, now go and build the game-changer STGNN out there 🌎

🧡 The tsl team