Source code for tsl.nn.models.stgn.graph_wavenet_model

from typing import Optional

import torch
from einops import repeat
from torch import Tensor, nn
from torch.nn import functional as F
from torch_geometric.typing import Adj, OptTensor

from tsl.nn.blocks.decoders import MLPDecoder
from tsl.nn.blocks.encoders import TemporalConvNet
from tsl.nn.layers.base import NodeEmbedding
from tsl.nn.layers.graph_convs import DenseGraphConvOrderK, DiffConv
from tsl.nn.layers.norm import Norm
from tsl.nn.models.base_model import BaseModel


[docs]class GraphWaveNetModel(BaseModel): r"""The Graph WaveNet model from the paper `"Graph WaveNet for Deep Spatial-Temporal Graph Modeling" <https://arxiv.org/abs/1906.00121>`_ (Wu et al., IJCAI 2019). Args: input_size (int): Number of features of the input sample. output_size (int): Number of output channels. horizon (int): Number of future time steps to forecast. exog_size (int): Number of features of the input covariate, if any. (default: :obj:`0`) hidden_size (int): Number of hidden units. (default: :obj:`32`) ff_size (int): Number of units in the nonlinear readout. (default: :obj:`256`) n_layers (int): Number of Graph WaveNet blocks. (default: :obj:`8`) temporal_kernel_size (int): Size of the temporal convolution kernel. (default: :obj:`2`) spatial_kernel_size (int): Order of the spatial diffusion process. (default: :obj:`2`) learned_adjacency (bool): If :obj:`True`, then consider an additional learned adjacency matrix. (default: :obj:`True`) n_nodes (int, optional): Number of nodes in the input graph, required only when :attr:`learned_adjacency` is :obj:`True`. (default: :obj:`None`) emb_size (int): Number of features in the node embeddings used for graph learning. (default: :obj:`10`) dilation (int): Dilation of the temporal convolutional kernels. (default: :obj:`2`) dilation_mod (int): Length of the cycle for the dilation coefficient. (default: :obj:`2`) norm (str): Normalization strategy. (default: :obj:`'batch'`) dropout (float): Dropout probability. (default: :obj:`0.3`) """ return_type = Tensor def __init__(self, input_size: int, output_size: int, horizon: int, exog_size: int = 0, hidden_size: int = 32, ff_size: int = 256, n_layers: int = 8, temporal_kernel_size: int = 2, spatial_kernel_size: int = 2, learned_adjacency: bool = True, n_nodes: Optional[int] = None, emb_size: int = 10, dilation: int = 2, dilation_mod: int = 2, norm: str = 'batch', dropout: float = 0.3): super(GraphWaveNetModel, self).__init__() if learned_adjacency: assert n_nodes is not None self.source_embeddings = NodeEmbedding(n_nodes, emb_size) self.target_embeddings = NodeEmbedding(n_nodes, emb_size) else: self.register_parameter('source_embedding', None) self.register_parameter('target_embedding', None) self.input_encoder = nn.Linear(input_size + exog_size, hidden_size) temporal_conv_blocks = [] spatial_convs = [] skip_connections = [] norms = [] receptive_field = 1 for i in range(n_layers): d = dilation**(i % dilation_mod) temporal_conv_blocks.append( TemporalConvNet(input_channels=hidden_size, hidden_channels=hidden_size, kernel_size=temporal_kernel_size, dilation=d, exponential_dilation=False, n_layers=1, causal_padding=False, gated=True)) spatial_convs.append( DiffConv(in_channels=hidden_size, out_channels=hidden_size, k=spatial_kernel_size)) skip_connections.append(nn.Linear(hidden_size, ff_size)) norms.append(Norm(norm, hidden_size)) receptive_field += d * (temporal_kernel_size - 1) self.tconvs = nn.ModuleList(temporal_conv_blocks) self.sconvs = nn.ModuleList(spatial_convs) self.skip_connections = nn.ModuleList(skip_connections) self.norms = nn.ModuleList(norms) self.dropout = nn.Dropout(dropout) self.receptive_field = receptive_field dense_sconvs = [] if learned_adjacency: for _ in range(n_layers): dense_sconvs.append( DenseGraphConvOrderK(input_size=hidden_size, output_size=hidden_size, support_len=1, order=spatial_kernel_size, include_self=False, channel_last=True)) self.dense_sconvs = nn.ModuleList(dense_sconvs) self.readout = nn.Sequential( nn.ReLU(), MLPDecoder(input_size=ff_size, hidden_size=2 * ff_size, output_size=output_size, horizon=horizon, activation='relu'))
[docs] def get_learned_adj(self): logits = F.relu(self.source_embeddings() @ self.target_embeddings().T) adj = torch.softmax(logits, dim=1) return adj
[docs] def forward(self, x: Tensor, edge_index: Adj, edge_weight: OptTensor = None, u: OptTensor = None) -> Tensor: """""" # x: [b t n f] if u is not None: if u.dim() == 3: u = repeat(u, 'b t f -> b t n f', n=x.size(-2)) x = torch.cat([x, u], -1) if self.receptive_field > x.size(1): # pad temporal dimension x = F.pad(x, (0, 0, 0, 0, self.receptive_field - x.size(1), 0)) if len(self.dense_sconvs): adj_z = self.get_learned_adj() x = self.input_encoder(x) out = torch.zeros(1, x.size(1), 1, 1, device=x.device) for i, (tconv, sconv, skip_conn, norm) in enumerate( zip(self.tconvs, self.sconvs, self.skip_connections, self.norms)): res = x # temporal conv x = tconv(x) # residual connection -> out out = skip_conn(x) + out[:, -x.size(1):] # spatial conv xs = sconv(x, edge_index, edge_weight) if len(self.dense_sconvs): x = xs + self.dense_sconvs[i](x, adj_z) else: x = xs x = self.dropout(x) # residual connection -> next layer x = x + res[:, -x.size(1):] x = norm(x) return self.readout(out)