Source code for tsl.nn.layers.graph_convs.gated_gn

import torch
from torch import Tensor, nn
from torch_geometric.nn import MessagePassing
from torch_geometric.typing import Adj

from tsl.nn.utils import get_layer_activation


[docs]class GatedGraphNetwork(MessagePassing): r"""Gate Graph Neural Network layer (with residual connections) inspired by the FC-GNN model from the paper `"Multivariate Time Series Forecasting with Latent Graph Inference" <https://arxiv.org/abs/2203.03423>`_ (Satorras et al., 2022). Args: input_size (int): Input channels. output_size (int): Output channels. activation (str, optional): Activation function. parametrized_skip_conn (bool, optional): If :obj:`True`, then add a linear layer in the residual connection even if input and output dimensions match. (default: :obj:`False`) """ def __init__(self, input_size: int, output_size: int, activation: str = 'silu', parametrized_skip_conn: bool = False): super(GatedGraphNetwork, self).__init__(aggr="add", node_dim=-2) self.in_channels = input_size self.out_channels = output_size self.msg_mlp = nn.Sequential( nn.Linear(2 * input_size, output_size // 2), get_layer_activation(activation)(), nn.Linear(output_size // 2, output_size), get_layer_activation(activation)(), ) self.gate_mlp = nn.Sequential(nn.Linear(output_size, 1), nn.Sigmoid()) self.update_mlp = nn.Sequential( nn.Linear(input_size + output_size, output_size), get_layer_activation(activation)(), nn.Linear(output_size, output_size)) if (input_size != output_size) or parametrized_skip_conn: self.skip_conn = nn.Linear(input_size, output_size) else: self.skip_conn = nn.Identity() def forward(self, x: Tensor, edge_index: Adj): """""" out = self.propagate(edge_index, x=x) out = self.update_mlp(torch.cat([out, x], -1)) + self.skip_conn(x) return out def message(self, x_i: Tensor, x_j: Tensor): """""" mij = self.msg_mlp(torch.cat([x_i, x_j], -1)) return self.gate_mlp(mij) * mij