Source code for tsl.nn.blocks.encoders.multi.rnn
from tsl.nn.blocks.encoders.recurrent import RNNBase
from tsl.nn.layers.multi import MultiGRUCell, MultiLSTMCell
[docs]class MultiRNN(RNNBase):
"""A Recurrent Neural Network whose cells' weights are not shared among
the different instances."""
def __init__(self,
input_size: int,
hidden_size: int,
n_instances: int,
n_layers: int = 1,
cat_states_layers: bool = False,
return_only_last_state: bool = False,
cell: str = 'gru',
bias: bool = True,
**kwargs):
if cell == 'gru':
cell = MultiGRUCell
elif cell == 'lstm':
cell = MultiLSTMCell
else:
raise NotImplementedError(f'"{cell}" cell not implemented.')
rnn_cells = [
cell(input_size if i == 0 else hidden_size,
hidden_size,
n_instances,
bias=bias,
**kwargs) for i in range(n_layers)
]
super(MultiRNN, self).__init__(rnn_cells, cat_states_layers,
return_only_last_state)