Source code for tsl.nn.blocks.encoders.recurrent.dense_dcrnn

from typing import Optional

from torch import Tensor

from tsl.nn.layers.graph_convs.dense_graph_conv import DenseGraphConvOrderK
from tsl.nn.layers.recurrent import DenseDCRNNCell

from .base import RNNBase

[docs]class DenseDCRNN(RNNBase): """Dense implementation of the Diffusion Convolutional Recurrent Neural Network from the paper `"Diffusion Convolutional Recurrent Neural Network: Data-Driven Traffic Forecasting" <>`_ (Li et al., ICLR 2018). In this implementation, the adjacency matrix is dense and the convolution is performed with matrix multiplication. Args: input_size: Size of the input. hidden_size: Number of units in the hidden state. n_layers: Number of layers. k: Size of the diffusion kernel. root_weight: Whether to learn a separate transformation for the central node. """ _n_states = 1 def __init__(self, input_size: int, hidden_size: int, n_layers: int = 1, cat_states_layers: bool = False, return_only_last_state: bool = False, k: int = 2, root_weight: bool = False): self.input_size = input_size self.hidden_size = hidden_size self.k = k rnn_cells = [ DenseDCRNNCell(input_size if i == 0 else hidden_size, hidden_size, k=k, root_weight=root_weight) for i in range(n_layers) ] super(DenseDCRNN, self).__init__(rnn_cells, cat_states_layers, return_only_last_state) def forward(self, x: Tensor, adj, h: Optional[Tensor] = None, **kwargs): """""" support = DenseGraphConvOrderK.compute_support(adj) return super(DenseDCRNN, self).forward(x, h=h, support=support)