Source code for tsl.nn.layers.recurrent.dcrnn

from tsl.nn.layers.graph_convs.diff_conv import DiffConv
from tsl.nn.layers.recurrent.base import GraphGRUCellBase


[docs]class DCRNNCell(GraphGRUCellBase): """The Diffusion Convolutional Recurrent cell from the paper `"Diffusion Convolutional Recurrent Neural Network: Data-Driven Traffic Forecasting" <https://arxiv.org/abs/1707.01926>`_ (Li et al., ICLR 2018). Args: input_size: Size of the input. hidden_size: Number of units in the hidden state. k: Size of the diffusion kernel. root_weight: Whether to learn a separate transformation for the central node. """ def __init__(self, input_size: int, hidden_size: int, k: int = 2, root_weight: bool = True, add_backward: bool = True, bias: bool = True): # instantiate gates forget_gate = DiffConv(input_size + hidden_size, hidden_size, k=k, root_weight=root_weight, add_backward=add_backward, bias=bias) update_gate = DiffConv(input_size + hidden_size, hidden_size, k=k, root_weight=root_weight, add_backward=add_backward, bias=bias) candidate_gate = DiffConv(input_size + hidden_size, hidden_size, k=k, root_weight=root_weight, add_backward=add_backward, bias=bias) super(DCRNNCell, self).__init__(hidden_size=hidden_size, forget_gate=forget_gate, update_gate=update_gate, candidate_gate=candidate_gate)