Source code for tsl.nn.models.stgn.rnn_gcn_model

from einops import rearrange
from torch import Tensor, nn

from tsl.nn.blocks.decoders.gcn_decoder import GCNDecoder
from tsl.nn.blocks.encoders import RNN, ConditionalBlock
from tsl.nn.models import BaseModel


[docs]class RNNEncGCNDecModel(BaseModel): """ Simple time-then-space model. Input time series are encoded in vectors using an RNN and then decoded using a stack of GCN layers. Args: input_size (int): Input size. hidden_size (int): Units in the hidden layers. output_size (int): Size of the optional readout. exog_size (int): Size of the exogenous variables. rnn_layers (int): Number of recurrent layers in the encoder. gcn_layers (int): Number of graph convolutional layers in the decoder. rnn_dropout (float, optional): Dropout probability in the RNN encoder. gcn_dropout (float, optional): Dropout probability int the GCN decoder. horizon (int): Forecasting horizon. cell_type (str, optional): Type of cell that should be used. (options: [``'gru'``, ``'lstm'``]). (default: ``'gru'``) activation (str, optional): Activation function. """ return_type = Tensor def __init__(self, input_size, hidden_size, output_size, exog_size, rnn_layers, gcn_layers, rnn_dropout, gcn_dropout, horizon, cell_type='gru', activation='relu'): super(RNNEncGCNDecModel, self).__init__() if exog_size: self.input_encoder = ConditionalBlock(input_size=input_size, exog_size=exog_size, output_size=hidden_size, activation=activation) else: self.input_encoder = nn.Sequential( nn.Linear(input_size, hidden_size), ) self.encoder = RNN(input_size=hidden_size, hidden_size=hidden_size, n_layers=rnn_layers, return_only_last_state=True, dropout=rnn_dropout, cell=cell_type) self.decoder = GCNDecoder(input_size=hidden_size, hidden_size=hidden_size, output_size=output_size, horizon=horizon, n_layers=gcn_layers, activation=activation, dropout=gcn_dropout)
[docs] def forward(self, x, edge_index, edge_weight, u=None, **kwargs): """""" # x: [batches steps nodes features] # u: [batches steps (nodes) features] if u is not None: if u.dim() == 3: u = rearrange(u, 'b s f -> b s 1 f') x = self.input_encoder(x, u) else: x = self.input_encoder(x) x = self.encoder(x) return self.decoder(x, edge_index, edge_weight)