Source code for tsl.nn.models.temporal.rnn_imputers_models

from typing import Optional, Union

import torch
from einops import rearrange
from einops.layers.torch import Rearrange
from torch import nn, Tensor

from tsl.nn.functional import reverse_tensor
from tsl.nn.models.base_model import BaseModel


[docs]class RNNImputerModel(BaseModel): r"""Fill the blanks with 1-step-ahead predictions of a recurrent network. .. math :: \bar{x}_{t} = m_{t} \cdot x_{t} + (1 - m_{t}) \cdot \hat{x}_{t} Args: input_size (int): Number of features of the input sample. hidden_size (int): Number of hidden units. (default: 64) exog_size (int): Number of features of the input covariate, if any. (default: :obj:`0`) cell (str): Type of recurrent cell to be used, one of [:obj:`gru`, :obj:`lstm`]. (default: :obj:`gru`) concat_mask (bool): If :obj:`True`, then the input tensor is concatenated to the mask when fed to the RNN cell. (default: :obj:`True`) fully_connected (bool): If :obj:`True`, then the node and feature dimensions are flattened together. (default: :obj:`False`) n_nodes (int, optional): The number of nodes in the input sample, to be provided in case :obj:`fully_connected` is :obj:`True`. (default: :obj:`None`) detach_input (bool): If :obj:`True`, call :meth:`~torch.Tensor.detach` on predictions before they are used to fill the gaps, breaking the error backpropagation. (default: :obj:`False`) state_init (str): How to initialize the state of the recurrent cell, one of [:obj:`zero`, :obj:`noise`]. With :obj:`noise`, the states are drawn from a normal distribution. (default: :obj:`zero`) """ def __init__(self, input_size: int, hidden_size: int = 64, exog_size: int = 0, cell: str = 'gru', concat_mask: bool = True, fully_connected: bool = False, n_nodes: Optional[int] = None, detach_input: bool = False, state_init: str = 'zero'): super(RNNImputerModel, self).__init__() if fully_connected: self._to_pattern = '(b n) t f' else: assert n_nodes is not None input_size = input_size * n_nodes self._to_pattern = 'b t (n f)' self.input_size = input_size self.hidden_size = hidden_size self.exog_size = exog_size self.concat_mask = concat_mask self.fully_connected = fully_connected self.detach_input = detach_input self.state_init = state_init if cell == 'gru': cell = nn.GRUCell elif cell == 'lstm': cell = nn.LSTMCell else: raise NotImplementedError(f'"{cell}" cell not implemented.') if concat_mask: input_size = 2 * input_size input_size = input_size + exog_size self.rnn_cell = cell(input_size=input_size, hidden_size=hidden_size) self.readout = nn.Linear(hidden_size, self.input_size)
[docs] def init_hidden_state(self, x: Tensor): if self.state_init == 'zero': return torch.zeros((x.size(0), self.hidden_size), device=x.device, dtype=x.dtype) if self.state_init == 'noise': return torch.randn(x.size(0), self.hidden_size, device=x.device, dtype=x.dtype)
def _preprocess_input(self, x: Tensor, x_hat: Tensor, m: Tensor, u: Optional[Tensor] = None): if self.detach_input: x_p = torch.where(m, x, x_hat.detach()) else: x_p = torch.where(m, x, x_hat) if u is not None: x_p = torch.cat([x_p, u], -1) if self.concat_mask: x_p = torch.cat([x_p, m], -1) return x_p
[docs] def forward(self, x: Tensor, input_mask: Tensor, u: Optional[Tensor] = None, return_hidden: bool = False) -> Union[Tensor, list]: """""" # x: [batch, time, nodes, features] steps, nodes = x.size(1), x.size(2) x = rearrange(x, f'b t n f -> {self._to_pattern}') input_mask = rearrange(input_mask, f'b t n f -> {self._to_pattern}') if u is not None: u = rearrange(u, f'b t n f -> {self._to_pattern}') h = self.init_hidden_state(x) x_hat = self.readout(h) hs = [h] preds = [x_hat] for s in range(steps - 1): u_t = None if u is None else u[:, s] x_t = self._preprocess_input(x[:, s], x_hat, input_mask[:, s], u_t) h = self.rnn_cell(x_t, h) x_hat = self.readout(h) hs.append(h) preds.append(x_hat) x_hat = torch.stack(preds, 1) # [b t (n f)] or [(b n) t f] h = torch.stack(hs, 1) # [b t h] or [(b n) t h] x_hat = rearrange(x_hat, f'{self._to_pattern} -> b t n f', n=nodes) if not return_hidden: return x_hat if self.fully_connected: h = rearrange(h, f'{self._to_pattern} -> b t n f', n=nodes) return [x_hat, h]
[docs] def predict(self, x: Tensor, input_mask: Tensor, u: Optional[Tensor] = None) -> Tensor: """""" return self.forward(x=x, input_mask=input_mask, u=u, return_hidden=False)
[docs]class BiRNNImputerModel(BaseModel): r"""Fill the blanks with a bidirectional GRU 1-step-ahead predictor.""" def __init__(self, input_size: int, hidden_size: int = 64, exog_size: int = 0, cell: str = 'gru', dropout=0., concat_mask: bool = True, fully_connected: bool = False, n_nodes: Optional[int] = None, detach_input: bool = False, state_init: str = 'zero'): super(BiRNNImputerModel, self).__init__(return_type=list) self.fwd_rnn = RNNImputerModel(input_size, hidden_size, exog_size=exog_size, cell=cell, concat_mask=concat_mask, n_nodes=n_nodes, fully_connected=fully_connected, detach_input=detach_input, state_init=state_init) self.bwd_rnn = RNNImputerModel(input_size, hidden_size, exog_size=exog_size, cell=cell, concat_mask=concat_mask, n_nodes=n_nodes, fully_connected=fully_connected, detach_input=detach_input, state_init=state_init) self.dropout = nn.Dropout(dropout) if fully_connected: self.read_out = nn.Linear(2 * hidden_size, input_size) else: assert n_nodes is not None self.read_out = nn.Sequential( nn.Linear(2 * hidden_size, input_size * n_nodes), Rearrange('... t (n h) -> ... t n h', n=n_nodes) )
[docs] def forward(self, x: Tensor, input_mask: Tensor, u: Optional[Tensor] = None, return_hidden: bool = False) -> list: """""" # x: [batches, steps, nodes, features] x_hat_fwd, h_fwd = self.fwd_rnn(x, input_mask, u=u, return_hidden=True) u_rev = reverse_tensor(u, 1) if u is not None else None x_hat_bwd, h_bwd = self.bwd_rnn(reverse_tensor(x, 1), reverse_tensor(input_mask, 1), u=u_rev, return_hidden=True) x_hat_bwd = reverse_tensor(x_hat_bwd, 1) h_bwd = reverse_tensor(h_bwd, 1) h = self.dropout(torch.cat([h_fwd, h_bwd], -1)) x_hat = self.read_out(h) if return_hidden: return [x_hat, (x_hat_fwd, x_hat_bwd), h] return [x_hat, (x_hat_fwd, x_hat_bwd)]
[docs] def predict(self, x: Tensor, input_mask: Tensor, u: Optional[Tensor] = None) -> Tensor: """""" return self.forward(x=x, input_mask=input_mask, u=u, return_hidden=False)[0]