Source code for tsl.nn.models.temporal.rnn_model

from typing import Optional

from einops import rearrange
from torch import Tensor, nn

from tsl.nn.blocks.decoders.mlp_decoder import MLPDecoder
from tsl.nn.blocks.encoders import RNN, ConditionalBlock
from tsl.nn.models.base_model import BaseModel


[docs]class RNNModel(BaseModel): r"""Simple RNN for multistep forecasting. Args: input_size (int): Size of the input. hidden_size (int): Number of units in the recurrent cell. output_size (int): Number of output channels. ff_size (int): Number of units in the link predictor. exog_size (int): Size of the exogenous variables. rec_layers (int): Number of RNN layers. ff_layers (int): Number of hidden layers in the decoder. rec_dropout (float, optional): Dropout probability in the RNN encoder. ff_dropout (float, optional): Dropout probability int the GCN decoder. horizon (int): Forecasting horizon. cell_type (str, optional): Type of cell that should be use (options: [:obj:`gru`, :obj:`lstm`]). (default: :obj:`gru`) activation (str, optional): Activation function. (default: :obj:`relu`) """ return_type = Tensor def __init__(self, input_size: int, output_size: int, horizon: int, exog_size: int = 0, hidden_size: int = 32, ff_size: int = 64, rec_layers: int = 1, ff_layers: int = 1, rec_dropout: float = 0., ff_dropout: float = 0., cell_type: str = 'gru', activation: str = 'relu'): super(RNNModel, self).__init__() if exog_size: self.input_encoder = ConditionalBlock(input_size=input_size, exog_size=exog_size, output_size=hidden_size, activation=activation) else: self.input_encoder = nn.Sequential( nn.Linear(input_size, hidden_size), nn.ReLU()) self.rnn = RNN(input_size=hidden_size, hidden_size=hidden_size, n_layers=rec_layers, return_only_last_state=True, dropout=rec_dropout, cell=cell_type) self.readout = MLPDecoder(input_size=hidden_size, hidden_size=ff_size, output_size=output_size, horizon=horizon, n_layers=ff_layers, activation=activation, dropout=ff_dropout)
[docs] def forward(self, x: Tensor, u: Optional[Tensor] = None) -> Tensor: """""" # x: [batches steps nodes features] # u: [batches steps (nodes) features] if u is not None: if u.dim() == 3: u = rearrange(u, 'b t f -> b t 1 f') x = self.input_encoder(x, u) else: x = self.input_encoder(x) x = self.rnn(x) return self.readout(x)
[docs]class FCRNNModel(RNNModel): r"""A simple fully connected RNN for multistep forecasting that simply flattens data along the spatial dimension. Args: input_size (int): Size of the input. hidden_size (int): Number of units in the recurrent cell. output_size (int): Number of output channels. ff_size (int): Number of units in the link predictor. exog_size (int): Size of the exogenous variables. rec_layers (int): Number of RNN layers. ff_layers (int): Number of hidden layers in the decoder. rec_dropout (float, optional): Dropout probability in the RNN encoder. ff_dropout (float, optional): Dropout probability int the GCN decoder. horizon (int): Forecasting horizon. cell_type (str, optional): Type of cell that should be use (options: [:obj:`gru`, :obj:`lstm`]). (default: :obj:`gru`) activation (str, optional): Activation function. (default: :obj:`relu`) """ def __init__(self, input_size: int, output_size: int, horizon: int, n_nodes: int, exog_size: Optional[int] = None, hidden_size: int = 32, ff_size: int = 64, rec_layers: int = 1, ff_layers: int = 1, rec_dropout: float = 0., ff_dropout: float = 0., cell_type: str = 'gru', activation: str = 'relu'): super(FCRNNModel, self).__init__(input_size=input_size * n_nodes, output_size=output_size * n_nodes, horizon=horizon, exog_size=exog_size, hidden_size=hidden_size, ff_size=ff_size, rec_layers=rec_layers, ff_layers=ff_layers, rec_dropout=rec_dropout, ff_dropout=ff_dropout, cell_type=cell_type, activation=activation)
[docs] def forward(self, x: Tensor, u: Optional[Tensor] = None) -> Tensor: """""" # x: [batches, steps, nodes, features] # u: [batches, steps, (nodes), features] b, _, n, _ = x.size() x = rearrange(x, 'b t n f -> b t 1 (n f)') if u is not None and u.dim() == 4: u = rearrange(u, 'b t n f -> b t 1 (n f)') x = super(FCRNNModel, self).forward(x, u) # [b, h, 1, (n f)] return rearrange(x, 'b h 1 (n f) -> b h n f', n=n)