Source code for tsl.nn.blocks.encoders.rnn

import torch

from torch import nn
from einops import rearrange

from ...utils import maybe_cat_exog

[docs]class RNN(nn.Module): r""" Simple RNN encoder with optional linear readout. Args: input_size (int): Input size. hidden_size (int): Units in the hidden layers. exog_size (int, optional): Size of the optional exogenous variables. output_size (int, optional): Size of the optional readout. n_layers (int, optional): Number of hidden layers. (default: 1) cell (str, optional): Type of cell that should be use (options: [`gru`, `lstm`]). (default: `gru`) dropout (float, optional): Dropout probability. """ def __init__(self, input_size, hidden_size, exog_size=None, output_size=None, n_layers=1, dropout=0., cell='gru'): super(RNN, self).__init__() if cell == 'gru': cell = nn.GRU elif cell == 'lstm': cell = nn.LSTM else: raise NotImplementedError(f'"{cell}" cell not implemented.') if exog_size is not None: input_size += exog_size self.rnn = cell(input_size=input_size, hidden_size=hidden_size, num_layers=n_layers, dropout=dropout) if output_size is not None: self.readout = nn.Linear(hidden_size, output_size) else: self.register_parameter('readout', None)
[docs] def forward(self, x, u=None, return_last_state=False): """ Args: x (torch.Tensor): Input tensor. return_last_state: Whether to return only the state corresponding to the last time step. """ # x: [batches, steps, nodes, features] x = maybe_cat_exog(x, u) b, *_ = x.size() x = rearrange(x, 'b s n f -> s (b n) f') x, *_ = self.rnn(x) # [steps batches * nodes, features] -> [steps batches, nodes, features] x = rearrange(x, 's (b n) f -> b s n f', b=b) if return_last_state: x = x[:, -1] if self.readout is not None: return self.readout(x) return x