Source code for tsl.nn.models.temporal.rnn_imputers_models

from typing import Optional, Union

import torch
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
from torch import Tensor, nn

from tsl import logger
from tsl.nn.blocks.encoders.recurrent import RNNI
from tsl.nn.models.base_model import BaseModel


[docs]class RNNImputerModel(BaseModel): r"""Fill the blanks with 1-step-ahead predictions of a recurrent network. .. math :: \bar{x}_{t} = m_{t} \cdot x_{t} + (1 - m_{t}) \cdot \hat{x}_{t} Args: input_size (int): Number of features of the input sample. hidden_size (int): Number of hidden units. (default: 64) exog_size (int): Number of features of the input covariate, if any. (default: :obj:`0`) cell (str): Type of recurrent cell to be used, one of [:obj:`gru`, :obj:`lstm`]. (default: :obj:`gru`) concat_mask (bool): If :obj:`True`, then the input tensor is concatenated to the mask when fed to the RNN cell. (default: :obj:`True`) fully_connected (bool): If :obj:`True`, then node and feature dimensions are flattened together. (default: :obj:`False`) n_nodes (int, optional): The number of nodes in the input sample, to be provided in case :obj:`fully_connected` is :obj:`True`. (default: :obj:`None`) detach_input (bool): If :obj:`True`, call :meth:`~torch.Tensor.detach` on predictions before they are used to fill the gaps, breaking the error backpropagation. (default: :obj:`False`) n_layers (int, optional): Number of hidden layers. (default: :obj:`1`) cat_states_layers (bool): If :obj:`True`, then the states of the RNN are concatenated together. (default: :obj:`False`) """ def __init__(self, input_size: int, hidden_size: int = 64, exog_size: int = 0, cell: str = 'gru', concat_mask: bool = True, fully_connected: bool = False, n_nodes: Optional[int] = None, detach_input: bool = False, n_layers: int = 1, cat_states_layers: bool = False): super(RNNImputerModel, self).__init__() self.input_size = input_size self.hidden_size = hidden_size self.exog_size = exog_size self.concat_mask = concat_mask self.fully_connected = fully_connected self.detach_input = detach_input self.n_layers = n_layers self.cat_states_layers = cat_states_layers if fully_connected: assert n_nodes is not None input_size = input_size * n_nodes self._to_pattern = 'b t (n f)' else: self._to_pattern = '(b n) t f' if concat_mask and fully_connected: logger.warning("Parameter 'concat_mask' can be True only when " "'fully_connected' is False.") concat_mask = False self.rnn = RNNI(input_size=input_size, hidden_size=hidden_size, exog_size=exog_size, cell=cell, concat_mask=concat_mask, n_layers=n_layers, detach_input=detach_input, cat_states_layers=cat_states_layers)
[docs] def forward(self, x: Tensor, mask: Tensor, u: Optional[Tensor] = None, return_hidden: bool = False) -> Union[Tensor, list]: """""" # x: [batch, time, nodes, features] nodes = x.size(2) x = rearrange(x, f'b t n f -> {self._to_pattern}') mask = rearrange(mask, f'b t n f -> {self._to_pattern}') if u is not None: if self.fully_connected: # fc and 'b t f' assert u.ndim == 3, \ "Only graph-level exogenous are supported in " \ "fully connected mode." elif u.ndim == 3: # no fc and 'b t f' u = repeat(u, f'b t f -> {self._to_pattern}', n=nodes) else: # no fc and 'b t n f' u = rearrange(u, f'b t n f -> {self._to_pattern}') x_hat, h, _ = self.rnn(x, mask, u) x_hat = rearrange(x_hat, f'{self._to_pattern} -> b t n f', n=nodes) if not return_hidden: return x_hat if not self.fully_connected: h = rearrange(h, f'{self._to_pattern} -> b t n f', n=nodes) return [x_hat, h]
[docs] def predict(self, x: Tensor, mask: Tensor, u: Optional[Tensor] = None) -> Tensor: """""" return self.forward(x=x, mask=mask, u=u, return_hidden=False)
[docs]class BiRNNImputerModel(BaseModel): r"""Fill the blanks with 1-step-ahead predictions of a bidirectional recurrent neural network. Args: input_size (int): Number of features of the input sample. hidden_size (int): Number of hidden units. (default: 64) exog_size (int): Number of features of the input covariate, if any. (default: :obj:`0`) cell (str): Type of recurrent cell to be used, one of [:obj:`gru`, :obj:`lstm`]. (default: :obj:`gru`) concat_mask (bool): If :obj:`True`, then the input tensor is concatenated to the mask when fed to the RNN cell. (default: :obj:`True`) fully_connected (bool): If :obj:`True`, then node and feature dimensions are flattened together. (default: :obj:`False`) n_nodes (int, optional): The number of nodes in the input sample, to be provided in case :obj:`fully_connected` is :obj:`True`. (default: :obj:`None`) detach_input (bool): If :obj:`True`, call :meth:`~torch.Tensor.detach` on predictions before they are used to fill the gaps, breaking the error backpropagation. (default: :obj:`False`) n_layers (int, optional): Number of hidden layers. (default: :obj:`1`) return_previous_state (bool): If :obj:`True`, then the returned states are shifted one-step behind the imputations. (default: :obj:`True`) cat_states_layers (bool): If :obj:`True`, then the states of the RNN are concatenated together. (default: :obj:`False`) dropout (float, optional): Dropout probability in the decoder. (default: ``0.``) """ def __init__(self, input_size: int, hidden_size: int = 64, exog_size: int = 0, cell: str = 'gru', concat_mask: bool = True, fully_connected: bool = False, n_nodes: Optional[int] = None, detach_input: bool = False, n_layers: int = 1, cat_states_layers: bool = False, dropout: float = 0.): super(BiRNNImputerModel, self).__init__(return_type=list) self.input_size = input_size self.hidden_size = hidden_size self.exog_size = exog_size self.concat_mask = concat_mask self.fully_connected = fully_connected self.detach_input = detach_input self.n_layers = n_layers self.cat_states_layers = cat_states_layers if fully_connected: assert n_nodes is not None input_size = input_size * n_nodes self._to_pattern = 'b t (n f)' else: self._to_pattern = '(b n) t f' if concat_mask and fully_connected: logger.warning("Parameter 'concat_mask' can be True only when " "'fully_connected' is False.") concat_mask = False self.fwd_rnn = RNNI(input_size=input_size, hidden_size=hidden_size, exog_size=exog_size, cell=cell, concat_mask=concat_mask, n_layers=n_layers, detach_input=detach_input, cat_states_layers=cat_states_layers) self.bwd_rnn = RNNI(input_size=input_size, hidden_size=hidden_size, exog_size=exog_size, cell=cell, concat_mask=concat_mask, flip_time=True, n_layers=n_layers, detach_input=detach_input, cat_states_layers=cat_states_layers) self.dropout = nn.Dropout(dropout) out_size = hidden_size * (n_layers if cat_states_layers else 1) if fully_connected: assert n_nodes is not None self.readout = nn.Sequential( nn.Linear(2 * out_size, input_size), Rearrange('... t (n h) -> ... t n h', n=n_nodes)) else: self.readout = nn.Linear(2 * out_size, self.input_size)
[docs] def forward(self, x: Tensor, mask: Tensor, u: Optional[Tensor] = None, return_hidden: bool = False, return_predictions: bool = True) -> Union[Tensor, list]: """""" # x: [batch, time, nodes, features] nodes = x.size(2) x = rearrange(x, f'b t n f -> {self._to_pattern}') mask = rearrange(mask, f'b t n f -> {self._to_pattern}') if u is not None: if self.fully_connected: # fc and 'b t f' assert u.ndim == 3, \ "Only graph-level exogenous are supported in " \ "fully connected mode." elif u.ndim == 3: # no fc and 'b t f' u = repeat(u, f'b t f -> {self._to_pattern}', n=nodes) else: # no fc and 'b t n f' u = rearrange(u, f'b t n f -> {self._to_pattern}') x_hat_fwd, h_fwd, _ = self.fwd_rnn(x, mask, u) x_hat_bwd, h_bwd, _ = self.bwd_rnn(x, mask, u) if not self.fully_connected: h_fwd = rearrange(h_fwd, f'{self._to_pattern} -> b t n f', n=nodes) h_bwd = rearrange(h_bwd, f'{self._to_pattern} -> b t n f', n=nodes) h = self.dropout(torch.cat([h_fwd, h_bwd], -1)) x_hat = self.readout(h) if not (return_predictions or return_hidden): return x_hat res = [x_hat] if return_predictions: x_hat_fwd = rearrange(x_hat_fwd, f'{self._to_pattern} -> b t n f', n=nodes) x_hat_bwd = rearrange(x_hat_bwd, f'{self._to_pattern} -> b t n f', n=nodes) res.append((x_hat_fwd, x_hat_bwd)) if return_hidden: res.append(h) return res
[docs] def predict(self, x: Tensor, mask: Tensor, u: Optional[Tensor] = None) -> Tensor: """""" return self.forward(x=x, mask=mask, u=u, return_hidden=False, return_predictions=False)