Source code for tsl.nn.blocks.encoders.recurrent.rnn

from typing import List, Optional, Tuple

import torch
from einops import rearrange
from torch import Tensor, nn

from tsl.nn.layers.recurrent import GRUCell, LSTMCell, StateType
from tsl.nn.utils import maybe_cat_exog

from .base import RNNIBase


[docs]class RNN(nn.Module): """Simple RNN encoder with optional linear readout. Args: input_size (int): Input size. hidden_size (int): Units in the hidden layers. exog_size (int, optional): Size of the optional exogenous variables. output_size (int, optional): Size of the optional readout. n_layers (int, optional): Number of hidden layers. (default: ``1``) cell (str, optional): Type of cell that should be use (options: ``'gru'``, ``'lstm'``). (default: ``'gru'``) dropout (float, optional): Dropout probability. (default: ``0.``) """ def __init__(self, input_size: int, hidden_size: int, exog_size: int = None, output_size: int = None, n_layers: int = 1, return_only_last_state: bool = False, cell: str = 'gru', bias: bool = True, dropout: float = 0., **kwargs): super(RNN, self).__init__() self.return_only_last_state = return_only_last_state if cell == 'gru': cell = nn.GRU elif cell == 'lstm': cell = nn.LSTM else: raise NotImplementedError(f'"{cell}" cell not implemented.') if exog_size is not None: input_size += exog_size self.rnn = cell(input_size=input_size, hidden_size=hidden_size, num_layers=n_layers, bias=bias, dropout=dropout) if output_size is not None: self.readout = nn.Linear(hidden_size, output_size) else: self.register_parameter('readout', None)
[docs] def forward(self, x: Tensor, u: Optional[Tensor] = None): """Process the input sequence :obj:`x` with optional exogenous variables :obj:`u`. Args: x (Tensor): Input data. u (Tensor): Exogenous data. Shapes: x: :math:`(B, T, N, F_x)` where :math:`B` is the batch dimension, :math:`T` is the number of time steps, :math:`N` is the number of nodes, and :math:`F_x` is the number of input features. u: :math:`(B, T, N, F_u)` or :math:`(B, T, F_u)` where :math:`B` is the batch dimension, :math:`T` is the number of time steps, :math:`N` is the number of nodes (optional), and :math:`F_u` is the number of exogenous features. """ # x: [batches, steps, nodes, features] x = maybe_cat_exog(x, u) b, *_ = x.size() x = rearrange(x, 'b s n f -> s (b n) f') x, *_ = self.rnn(x) # [steps batches * nodes, features] -> [steps batches, nodes, features] x = rearrange(x, 's (b n) f -> b s n f', b=b) if self.return_only_last_state: x = x[:, -1] if self.readout is not None: return self.readout(x) return x
class RNNI(RNNIBase): """RNN encoder for sequences with missing data. Args: input_size (int): Input size. hidden_size (int): Units in the hidden layers. exog_size (int): Size of the optional exogenous variables. (default: ``0.``) cell (str): Type of recurrent cell to be used, one of [:obj:`gru`, :obj:`lstm`]. (default: :obj:`gru`) concat_mask (bool): If :obj:`True`, then the input tensor is concatenated to the mask when fed to the RNN cell. (default: :obj:`True`) unitary_mask (bool): If :obj:`True`, then the mask is a single value and applies to all features. (default: :obj:`False`) flip_time (bool): If :obj:`True`, then the time is folded in the backward direction. (default: :obj:`False`) n_layers (int, optional): Number of hidden layers. (default: :obj:`1`) detach_input (bool): If :obj:`True`, call :meth:`~torch.Tensor.detach` on predictions before they are used to fill the gaps, breaking the error backpropagation. (default: :obj:`False`) cat_states_layers (bool): If :obj:`True`, then the states of the RNN are concatenated together. (default: :obj:`False`) """ def __init__(self, input_size: int, hidden_size: int, exog_size: int = 0, cell: str = 'gru', concat_mask: bool = True, unitary_mask: bool = False, flip_time: bool = False, n_layers: int = 1, detach_input: bool = False, cat_states_layers: bool = False): if cell == 'gru': cell = GRUCell elif cell == 'lstm': cell = LSTMCell else: raise NotImplementedError(f'"{cell}" cell not implemented.') self.input_size = input_size self.hidden_size = hidden_size self.exog_size = exog_size self.mask_size = 1 if unitary_mask else input_size if concat_mask: input_size = input_size + self.mask_size input_size = input_size + exog_size cells = [ cell(input_size if i == 0 else hidden_size, hidden_size) for i in range(n_layers) ] super(RNNI, self).__init__(cells, detach_input, concat_mask, flip_time, cat_states_layers) self.readout = nn.Linear(hidden_size, self.input_size) def state_readout(self, h: List[StateType]): return self.readout(h[-1]) def preprocess_input(self, x: Tensor, x_hat: Tensor, input_mask: Tensor, step: int, u: Optional[Tensor] = None, h: Optional[List[StateType]] = None): x_t = super().preprocess_input(x, x_hat, input_mask, step) if u is not None: x_t = torch.cat([x_t, u[:, step]], -1) return x_t def single_pass(self, x: Tensor, h: List[StateType], *args, **kwargs) -> List[StateType]: return super().single_pass(x, h) def forward(self, x: Tensor, input_mask: Tensor, u: Optional[Tensor] = None, h: Optional[List[StateType]] = None) \ -> Tuple[Tensor, Tensor, List[StateType]]: """Process the input sequence :obj:`x` with optional exogenous variables :obj:`u`. Args: x (Tensor): Input data. u (Tensor): Exogenous data. Shapes: x: :math:`(B, T, N, F_x)` where :math:`B` is the batch dimension, :math:`T` is the number of time steps, :math:`N` is the number of nodes, and :math:`F_x` is the number of input features. u: :math:`(B, T, N, F_u)` or :math:`(B, T, F_u)` where :math:`B` is the batch dimension, :math:`T` is the number of time steps, :math:`N` is the number of nodes (optional), and :math:`F_u` is the number of exogenous features. """ return super().forward(x, input_mask, u=u, h=h)