Source code for tsl.nn.models.stgn.agcrn_model

from torch import nn, Tensor
from torch_geometric.typing import OptTensor

from tsl.nn.blocks.encoders.agcrn import AGCRN
from ..base_model import BaseModel
from ...blocks.decoders import LinearReadout
from ... import utils


[docs]class AGCRNModel(BaseModel): r"""The Adaptive Graph Convolutional Recurrent Network from the paper `"Adaptive Graph Convolutional Recurrent Network for Traffic Forecasting" <https://arxiv.org/abs/2007.02842>`_ (Bai et al., NeurIPS 2020). Args: input_size (int): Number of features of the input sample. output_size (int): Number of output channels. horizon (int): Number of future time steps to forecast. exog_size (int): Number of features of the input covariate, if any. hidden_size (int): Number of hidden units. hidden_size (int): Size of the learned node embeddings. n_nodes (int): Number of nodes in the input (static) graph. n_layers (int): Number of AGCRN cells. (default: :obj:`1`) """ def __init__(self, input_size: int, output_size: int, horizon: int, exog_size: int, n_nodes: int, emb_size: int, hidden_size: int, n_layers: int = 1): super(AGCRNModel, self).__init__(return_type=Tensor) self.input_encoder = nn.Linear(input_size + exog_size, hidden_size) self.agrn = AGCRN(input_size=hidden_size, emb_size=emb_size, num_nodes=n_nodes, hidden_size=hidden_size, n_layers=n_layers) self.readout = LinearReadout(input_size=hidden_size, output_size=output_size, horizon=horizon)
[docs] def forward(self, x: Tensor, u: OptTensor = None) -> Tensor: """""" x = utils.maybe_cat_exog(x, u) x = self.input_encoder(x) h, _ = self.agrn(x) return self.readout(h)