Source code for tsl.nn.layers.multi.recurrent

from typing import Tuple

import torch
from torch import Tensor

from ..recurrent.base import GRUCellBase, LSTMCellBase
from .linear import MultiLinear


[docs]class MultiGRUCell(GRUCellBase): r"""Multiple parallel gated recurrent unit (GRU) cells. .. math:: \begin{array}{ll} r = \sigma(W_{ir} x + b_{ir} + W_{hr} h + b_{hr}) \\ z = \sigma(W_{iz} x + b_{iz} + W_{hz} h + b_{hz}) \\ n = \tanh(W_{in} x + b_{in} + r * (W_{hn} h + b_{hn})) \\ h' = (1 - z) * n + z * h \end{array} where :math:`\sigma` is the sigmoid function, and :math:`*` is the Hadamard product. Args: input_size (int): The number of features in the instance input sample. hidden_size (int): The number of features in the instance hidden state. n_instances (int): The number of parallel GRU cells. Each cell has different weights. bias (bool): If :obj:`True`, then the layer will learn an additive bias for each instance gate. (default: :obj:`True`) device (optional): The device of the parameters. (default: :obj:`None`) dtype (optional): The data type of the parameters. (default: :obj:`None`) Examples:: >>> rnn = MultiGRUCell(20, 32, 10) >>> input = torch.randn(64, 12, 10, 20) >>> h = None >>> output = [] >>> for i in range(12): ... h = rnn(input[:, i], h) ... output.append(h) >>> output = torch.stack(output, dim=1) >>> print(output.size()) torch.Size([64, 12, 10, 32]) """ def __init__(self, input_size: int, hidden_size: int, n_instances: int, bias: bool = True, device=None, dtype=None) -> None: factory_kwargs = {'device': device, 'dtype': dtype} self.input_size = input_size self.n_instances = n_instances # instantiate gates in_size = input_size + hidden_size forget_gate = MultiLinear(in_size, hidden_size, n_instances, bias=bias, **factory_kwargs) update_gate = MultiLinear(in_size, hidden_size, n_instances, bias=bias, **factory_kwargs) candidate_gate = MultiLinear(in_size, hidden_size, n_instances, bias=bias, **factory_kwargs) super().__init__(hidden_size, forget_gate, update_gate, candidate_gate) def initialize_state(self, x) -> Tensor: return torch.zeros(x.size(0), self.n_instances, self.hidden_size, dtype=x.dtype, device=x.device)
[docs]class MultiLSTMCell(LSTMCellBase): r"""Multiple parallel long short-term memory (LSTM) cells. .. math:: \begin{array}{ll} i = \sigma(W_{ii} x + b_{ii} + W_{hi} h + b_{hi}) \\ f = \sigma(W_{if} x + b_{if} + W_{hf} h + b_{hf}) \\ g = \tanh(W_{ig} x + b_{ig} + W_{hg} h + b_{hg}) \\ o = \sigma(W_{io} x + b_{io} + W_{ho} h + b_{ho}) \\ c' = f * c + i * g \\ h' = o * \tanh(c') \\ \end{array} where :math:`\sigma` is the sigmoid function, and :math:`*` is the Hadamard product. Args: input_size (int): The number of features in the instance input sample. hidden_size (int): The number of features in the instance hidden state. n_instances (int): The number of parallel LSTM cells. Each cell has different weights. bias (bool): If :obj:`True`, then the layer will learn an additive bias for each instance gate. (default: :obj:`True`) device (optional): The device of the parameters. (default: :obj:`None`) dtype (optional): The data type of the parameters. (default: :obj:`None`) Examples:: >>> rnn = MultiLSTMCell(20, 32, 10) >>> input = torch.randn(64, 12, 10, 20) >>> h = None >>> output = [] >>> for i in range(12): ... h = rnn(input[:, i], h) # h = h, c ... output.append(h[0]) # i-th output is h_i >>> output = torch.stack(output, dim=1) >>> print(output.size()) torch.Size([64, 12, 10, 32]) """ def __init__(self, input_size: int, hidden_size: int, n_instances: int, bias: bool = True, device=None, dtype=None) -> None: factory_kwargs = {'device': device, 'dtype': dtype} self.input_size = input_size self.n_instances = n_instances # instantiate gates in_size = input_size + hidden_size input_gate = MultiLinear(in_size, hidden_size, n_instances, bias=bias, **factory_kwargs) forget_gate = MultiLinear(in_size, hidden_size, n_instances, bias=bias, **factory_kwargs) cell_gate = MultiLinear(in_size, hidden_size, n_instances, bias=bias, **factory_kwargs) output_gate = MultiLinear(in_size, hidden_size, n_instances, bias=bias, **factory_kwargs) super().__init__(hidden_size, input_gate, forget_gate, cell_gate, output_gate) def initialize_state(self, x) -> Tuple[Tensor, Tensor]: return (torch.zeros(x.size(0), self.n_instances, self.hidden_size, dtype=x.dtype, device=x.device), torch.zeros(x.size(0), self.n_instances, self.hidden_size, dtype=x.dtype, device=x.device))