Source code for tsl.nn.blocks.encoders.recurrent.dcrnn
from tsl.nn.layers.recurrent import DCRNNCell
from .base import RNNBase
[docs]class DCRNN(RNNBase):
"""The Diffusion Convolutional Recurrent Neural Network from the paper
`"Diffusion Convolutional Recurrent Neural Network: Data-Driven Traffic
Forecasting" <https://arxiv.org/abs/1707.01926>`_ (Li et al., ICLR 2018).
Args:
input_size: Size of the input.
hidden_size: Number of units in the hidden state.
n_layers: Number of layers.
k: Size of the diffusion kernel.
root_weight: Whether to learn a separate transformation for the central
node.
"""
def __init__(self,
input_size: int,
hidden_size: int,
n_layers: int = 1,
cat_states_layers: bool = False,
return_only_last_state: bool = False,
k: int = 2,
root_weight: bool = True,
add_backward: bool = True,
bias: bool = True):
self.input_size = input_size
self.hidden_size = hidden_size
self.k = k
rnn_cells = [
DCRNNCell(input_size if i == 0 else hidden_size,
hidden_size,
k=k,
root_weight=root_weight,
add_backward=add_backward,
bias=bias) for i in range(n_layers)
]
super(DCRNN, self).__init__(rnn_cells, cat_states_layers,
return_only_last_state)