Source code for tsl.nn.blocks.encoders.recurrent.base

from typing import List, Optional, Tuple, Union

import torch
from torch import Tensor, nn

from tsl.nn.layers.recurrent.base import RNNCellBase, StateType
from tsl.utils import ensure_list


[docs]class RNNBase(nn.Module): r"""Base class for implementing recurrent neural networks (RNNs).""" def __init__(self, cells: Union[RNNCellBase, List[RNNCellBase], nn.ModuleList], cat_states_layers: bool = False, return_only_last_state: bool = False): super().__init__() self.cat_states_layers = cat_states_layers self.return_only_last_state = return_only_last_state if not isinstance(cells, nn.ModuleList): cells = nn.ModuleList(ensure_list(cells)) self.cells = cells self.n_layers = len(self.cells) def __repr__(self) -> str: args = [ f'cell={self.cells[0].__class__.__name__}', f'return_only_last_state={self.return_only_last_state}' ] return f"{self.__class__.__name__}({', '.join(args)})" def reset_parameters(self): for cell in self.cells: cell.reset_parameters() def initialize_state(self, x: Tensor) -> List[StateType]: return [cell.initialize_state(x) for cell in self.cells] def single_pass(self, x: Tensor, h: List[StateType], *args, **kwargs) -> List[StateType]: # x: [batch, *, channels] # h[i]: [batch, *, channels] out = [] input = x for i, cell in enumerate(self.cells): input = h_new = cell(input, h[i], *args, **kwargs) if isinstance(h_new, (list, tuple)): input = input[0] out.append(h_new) return out def forward(self, x: Tensor, *args, h: Optional[List[StateType]] = None, **kwargs) -> Tuple[Tensor, List[StateType]]: """""" # x: [batch, time, *, features] if h is None: h = self.initialize_state(x) elif not isinstance(h, list): h = [h] # temporal conv out = [] steps = x.size(1) for step in range(steps): h_out = h = self.single_pass(x[:, step], h, *args, **kwargs) # for multi-state rnns (e.g., LSTMs), use first state for readout if not isinstance(h_out[0], torch.Tensor): h_out = [_h[0] for _h in h_out] # append hidden state of the last layer if self.cat_states_layers: h_out = torch.cat(h_out, dim=-1) else: # or take last layer's state h_out = h_out[-1] out.append(h_out) if self.return_only_last_state: # out: [batch, *, features] return out[-1] # out: [batch, time, *, features] out = torch.stack(out, dim=1) return out, h
class RNNIBase(RNNBase): r"""Base class for implementing recurrent neural networks (RNNs) for sequences with missing data.""" def __init__(self, cells: Union[RNNCellBase, List[RNNCellBase], nn.ModuleList], detach_input: bool = False, concat_mask: bool = False, flip_time: bool = False, cat_states_layers: bool = False): super().__init__(cells, cat_states_layers, return_only_last_state=False) self.detach_input = detach_input self.concat_mask = concat_mask self.flip_time = flip_time def state_readout(self, h: List[StateType]): raise NotImplementedError def preprocess_input(self, x: Tensor, x_hat: Tensor, input_mask: Tensor, step: int, *args, h: Optional[List[StateType]] = None, **kwargs): if self.detach_input: x_hat = x_hat.detach() x_t = torch.where(input_mask, x, x_hat) if self.concat_mask: x_t = torch.cat([x_t, input_mask], -1) return x_t def postprocess_state(self, h: List[StateType]): # for multi-state rnns (e.g., LSTMs), use first state for readout if not isinstance(h[0], torch.Tensor): h = [_h[0] for _h in h] if self.cat_states_layers: # concat rnn layers' states h = torch.cat(h, dim=-1) else: # or take last layer's state h = h[-1] return h def forward(self, x: Tensor, input_mask: Tensor, *args, h: Optional[List[StateType]] = None, **kwargs) -> Tuple[Tensor, Tensor, List[StateType]]: """""" # x: [batch, time, *, features] if h is None: h = self.initialize_state(x) elif not isinstance(h, list): h = [h] # temporal conv h_out, x_out = [], [] steps = x.size(1) steps = range(steps) if not self.flip_time else range( steps - 1, -1, -1) for step in steps: # readout phase x_hat = self.state_readout(h) x_out.append(x_hat) h_out.append(self.postprocess_state(h)) # preprocess input to rnn's cell given readout x_t = self.preprocess_input(x[:, step], x_hat, input_mask[:, step], step, *args, **kwargs) h = self.single_pass(x_t, h, *args, **kwargs) if self.flip_time: x_out, h_out = x_out[::-1], h_out[::-1] x_out = torch.stack(x_out, dim=1) # [b t * f] h_out = torch.stack(h_out, dim=1) # [b t * f] return x_out, h_out, h