Source code for tsl.nn.models.stgn.gated_gn_model

import torch
from einops import rearrange
from einops.layers.torch import Rearrange
from torch import Tensor, nn

from tsl.nn.layers.base import NodeEmbedding
from tsl.nn.layers.graph_convs import GatedGraphNetwork
from tsl.nn.models import BaseModel
from tsl.nn.utils import get_layer_activation, maybe_cat_exog


[docs]class GatedGraphNetworkModel(BaseModel): r"""Simple time-then-space model with an MLP with residual connections as encoder (flattened time dimension) and a gated GN decoder with node identification. Inspired by the FC-GNN model from the paper `"Multivariate Time Series Forecasting with Latent Graph Inference" <https://arxiv.org/abs/2203.03423>`_ (Satorras et al., 2022). Args: input_size (int): Size of the input. input_window_size (int): Size of the input window (this model cannot process sequences of variable length). hidden_size (int): Number of hidden units in each hidden layer. output_size (int): Size of the output. horizon (int): Forecasting steps. n_nodes (int): Number of nodes. exog_size (int): Size of the optional exogenous variables. enc_layers (int): Number of layers in the MLP encoder. gnn_layers (int): Number of GNN layers in the decoder. full_graph (int): Whether to use a full graph for the GNN. In that case, the model turns into a dense spatial attention layer. """ return_type = Tensor def __init__(self, input_size: int, input_window_size: int, horizon: int, n_nodes: int, hidden_size: int, output_size: int = None, exog_size: int = 0, enc_layers: int = 1, gnn_layers: int = 1, full_graph: bool = True, activation: str = 'silu'): super(GatedGraphNetworkModel, self).__init__() self.input_size = input_size self.output_size = output_size or input_size self.input_window_size = input_window_size self.full_graph = full_graph input_size += exog_size self.input_encoder = nn.Sequential( nn.Linear(input_size * input_window_size, hidden_size), ) self.encoder_layers = nn.ModuleList([ nn.Sequential(nn.Linear(hidden_size, hidden_size), get_layer_activation(activation)(), nn.Linear(hidden_size, hidden_size)) for _ in range(enc_layers) ]) self.emb = NodeEmbedding(n_nodes=n_nodes, emb_size=hidden_size) self.gcn_layers = nn.ModuleList([ GatedGraphNetwork(hidden_size, hidden_size, activation=activation) for _ in range(gnn_layers) ]) self.decoder = nn.Sequential(nn.Linear(hidden_size, hidden_size), get_layer_activation(activation)()) self.readout = nn.Sequential( nn.Linear(hidden_size, horizon * self.output_size), Rearrange('b n (h f) -> b h n f', h=horizon, f=self.output_size))
[docs] def forward(self, x, edge_index=None, u=None): """""" # x: [batches steps nodes features] x = maybe_cat_exog(x, u) if self.full_graph or edge_index is None: num_nodes = x.size(-2) nodes = torch.arange(num_nodes, device=x.device) edge_index = torch.cartesian_prod(nodes, nodes).T # flat time dimension x = rearrange(x[:, -self.input_window_size:], 'b s n f -> b n (s f)') x = self.input_encoder(x) for layer in self.encoder_layers: x = layer(x) + x # add encoding x = self.emb() + x for layer in self.gcn_layers: x = layer(x, edge_index) x = self.decoder(x) + x return self.readout(x)