Source code for tsl.nn.blocks.encoders.dcrnn

import torch

from tsl.nn.layers.graph_convs.diff_conv import DiffConv
from tsl.nn.blocks.encoders.gcrnn import _GraphGRUCell, _GraphRNN


[docs]class DCRNNCell(_GraphGRUCell): """ Diffusion Convolutional Recurrent Cell. Args: input_size: Size of the input. output_size: Number of units in the hidden state. k: Size of the diffusion kernel. root_weight: Whether to learn a separate transformation for the central node. """ def __init__(self, input_size, output_size, k=2, root_weight=True): super(DCRNNCell, self).__init__() # instantiate gates self.forget_gate = DiffConv(input_size + output_size, output_size, k=k, root_weight=root_weight) self.update_gate = DiffConv(input_size + output_size, output_size, k=k, root_weight=root_weight) self.candidate_gate = DiffConv(input_size + output_size, output_size, k=k, root_weight=root_weight)
[docs]class DCRNN(_GraphRNN): r"""Diffusion Convolutional Recurrent Network, from the paper `"Diffusion Convolutional Recurrent Neural Network: Data-Driven Traffic Forecasting" <https://arxiv.org/abs/1707.01926>`_. Args: input_size: Size of the input. hidden_size: Number of units in the hidden state. n_layers: Number of layers. k: Size of the diffusion kernel. root_weight: Whether to learn a separate transformation for the central node. """ _n_states = 1 def __init__(self, input_size, hidden_size, n_layers=1, k=2, root_weight=True): super(DCRNN, self).__init__() self.input_size = input_size self.hidden_size = hidden_size self.n_layers = n_layers self.k = k self.rnn_cells = torch.nn.ModuleList() for i in range(self.n_layers): self.rnn_cells.append(DCRNNCell( input_size=self.input_size if i == 0 else self.hidden_size, output_size=self.hidden_size, k=self.k, root_weight=root_weight))