Source code for tsl.nn.layers.recurrent.dense_dcrnn

from tsl.nn.layers.graph_convs.dense_graph_conv import DenseGraphConvOrderK
from tsl.nn.layers.recurrent.base import GraphGRUCellBase


[docs]class DenseDCRNNCell(GraphGRUCellBase): r"""Dense implementation of the Diffusion Convolutional Recurrent cell from the paper `"Diffusion Convolutional Recurrent Neural Network: Data-Driven Traffic Forecasting" <https://arxiv.org/abs/1707.01926>`_ (Li et al., ICLR 2018). In this implementation, the adjacency matrix is dense and the convolution is performed with matrix multiplication. Args: input_size: Size of the input. hidden_size: Number of units in the hidden state. k: Size of the diffusion kernel. root_weight (bool): Whether to learn a separate transformation for the central node. """ def __init__(self, input_size: int, hidden_size: int, k: int = 2, root_weight: bool = False): # instantiate gates forget_gate = DenseGraphConvOrderK(input_size + hidden_size, hidden_size, support_len=2, order=k, include_self=root_weight, channel_last=True) update_gate = DenseGraphConvOrderK(input_size + hidden_size, hidden_size, support_len=2, order=k, include_self=root_weight, channel_last=True) candidate_gate = DenseGraphConvOrderK(input_size + hidden_size, hidden_size, support_len=2, order=k, include_self=root_weight, channel_last=True) super(DenseDCRNNCell, self).__init__(hidden_size=hidden_size, forget_gate=forget_gate, update_gate=update_gate, candidate_gate=candidate_gate)