Source code for tsl.nn.layers.recurrent.agcrn

from tsl.nn.layers.graph_convs.adaptive_graph_conv import AdaptiveGraphConv
from tsl.nn.layers.recurrent.base import GraphGRUCellBase


[docs]class AGCRNCell(GraphGRUCellBase): """The Adaptive Graph Convolutional cell from the paper `"Adaptive Graph Convolutional Recurrent Network for Traffic Forecasting" <https://arxiv.org/abs/2007.02842>`_ (Bai et al., NeurIPS 2020). Args: input_size: Size of the input. emb_size: Size of the input node embeddings. hidden_size: Output size. num_nodes: Number of nodes in the input graph. """ def __init__(self, input_size: int, emb_size: int, hidden_size: int, num_nodes: int, bias: bool = True): self.input_size = input_size self.emb_size = emb_size self.num_nodes = num_nodes # instantiate gates forget_gate = AdaptiveGraphConv(input_size + hidden_size, emb_size, output_size=hidden_size, num_nodes=num_nodes, bias=bias) update_gate = AdaptiveGraphConv(input_size + hidden_size, emb_size, output_size=hidden_size, num_nodes=num_nodes, bias=bias) candidate_gate = AdaptiveGraphConv(input_size + hidden_size, emb_size, output_size=hidden_size, num_nodes=num_nodes, bias=bias) super(AGCRNCell, self).__init__(hidden_size=hidden_size, forget_gate=forget_gate, update_gate=update_gate, candidate_gate=candidate_gate)