Source code for tsl.nn.models.stgn.agcrn_model

from torch import Tensor, nn
from torch_geometric.typing import OptTensor

from tsl.nn.blocks.decoders import LinearReadout
from tsl.nn.blocks.encoders import AGCRN
from tsl.nn.models.base_model import BaseModel
from tsl.nn.utils import maybe_cat_exog


[docs]class AGCRNModel(BaseModel): r"""The Adaptive Graph Convolutional Recurrent Network from the paper `"Adaptive Graph Convolutional Recurrent Network for Traffic Forecasting" <https://arxiv.org/abs/2007.02842>`_ (Bai et al., NeurIPS 2020). Args: input_size (int): Number of features of the input sample. output_size (int): Number of output channels. horizon (int): Number of future time steps to forecast. exog_size (int): Number of features of the input covariate, if any. hidden_size (int): Number of hidden units. hidden_size (int): Size of the learned node embeddings. n_nodes (int): Number of nodes in the input (static) graph. n_layers (int): Number of AGCRN cells. (default: :obj:`1`) """ return_type = Tensor def __init__(self, input_size: int, output_size: int, horizon: int, n_nodes: int, hidden_size: int = 64, emb_size: int = 10, exog_size: int = 0, n_layers: int = 1): super(AGCRNModel, self).__init__() self.input_encoder = nn.Linear(input_size + exog_size, hidden_size) self.agrn = AGCRN(input_size=hidden_size, emb_size=emb_size, num_nodes=n_nodes, hidden_size=hidden_size, n_layers=n_layers, return_only_last_state=True) self.readout = LinearReadout(input_size=hidden_size, output_size=output_size, horizon=horizon)
[docs] def forward(self, x: Tensor, u: OptTensor = None) -> Tensor: """""" x = maybe_cat_exog(x, u) x = self.input_encoder(x) out = self.agrn(x) return self.readout(out)