Source code for tsl.nn.layers.recurrent.grin

from typing import List, Optional, Union

import torch
import torch.nn as nn
from torch import LongTensor, Tensor
from torch_geometric.typing import Adj, OptTensor
from torch_geometric.utils import (from_scipy_sparse_matrix, remove_self_loops,
                                   to_scipy_sparse_matrix)
from torch_geometric.utils.num_nodes import maybe_num_nodes
from torch_sparse import SparseTensor, remove_diag

from tsl.nn.layers.base import NodeEmbedding
from tsl.nn.layers.graph_convs import DiffConv
from tsl.nn.layers.norm import LayerNorm
from tsl.ops.connectivity import asymmetric_norm, power_series, transpose

from .dcrnn import DCRNNCell


def compute_support(edge_index: Adj,
                    edge_weight: OptTensor = None,
                    order: int = 1,
                    num_nodes: Optional[int] = None,
                    add_backward: bool = True):
    num_nodes = maybe_num_nodes(edge_index, num_nodes)
    edge_index, edge_weight = remove_self_loops(edge_index, edge_weight)
    ei, ew = asymmetric_norm(edge_index,
                             edge_weight,
                             dim=1,
                             num_nodes=num_nodes)
    a = to_scipy_sparse_matrix(ei, ew, num_nodes)
    support = []
    ak = a
    for i in range(order - 1):
        ak = ak * a
        ak.setdiag(0.)
        ak.eliminate_zeros()
        support.append(ak)
    support = [(ei, ew)] + [from_scipy_sparse_matrix(ak) for ak in support]
    if add_backward:
        ei_t, ew_t = transpose(edge_index, edge_weight)
        return support + compute_support(ei_t, ew_t, order, num_nodes, False)
    return support


class SpatialDecoder(nn.Module):

    def __init__(self,
                 input_size: int,
                 hidden_size: int,
                 output_size: Optional[int] = None,
                 exog_size: int = 0,
                 order: int = 1):
        super(SpatialDecoder, self).__init__()

        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size or input_size
        self.order = order

        exog_size = exog_size
        in_channels = input_size * 2 + hidden_size + exog_size

        self.lin_in = nn.Linear(in_channels, hidden_size)
        self.graph_conv = DiffConv(in_channels=hidden_size,
                                   out_channels=hidden_size,
                                   root_weight=False,
                                   k=1)
        self.lin_out = nn.Linear(2 * hidden_size, hidden_size)
        self.read_out = nn.Linear(2 * hidden_size, self.output_size)
        self.activation = nn.PReLU()

    def __repr__(self):
        attrs = ['input_size', 'hidden_size', 'output_size', 'order']
        attrs = ', '.join([f'{attr}={getattr(self, attr)}' for attr in attrs])
        return f"{self.__class__.__name__}({attrs})"

    def compute_support(self,
                        edge_index: LongTensor,
                        edge_weight: OptTensor = None,
                        num_nodes: Optional[int] = None,
                        add_backward: bool = True):
        ei, ew = asymmetric_norm(edge_index,
                                 edge_weight,
                                 dim=1,
                                 num_nodes=num_nodes)
        ei, ew = power_series(ei, ew, self.order)
        ei, ew = remove_self_loops(ei, ew)
        if add_backward:
            ei_t, ew_t = transpose(edge_index, edge_weight)
            return (ei, ew), self.compute_support(ei_t, ew_t, num_nodes, False)
        return ei, ew

    def forward(self,
                x: Tensor,
                mask: Tensor,
                h: Tensor,
                edge_index: Adj,
                edge_weight: OptTensor = None,
                u: OptTensor = None):
        # x: [batch, nodes, channels]
        x_in = [x, mask, h]
        if u is not None:
            x_in += [u]
        x_in = torch.cat(x_in, -1)
        x_in = self.lin_in(x_in)
        if self.order > 1:
            support = self.compute_support(edge_index,
                                           edge_weight,
                                           x.size(1),
                                           add_backward=True)
            self.graph_conv._support = support
            out = self.graph_conv(x_in, edge_index=None)
            self.graph_conv._support = None
        else:
            if isinstance(edge_index, Tensor):
                edge_index, edge_weight = remove_self_loops(
                    edge_index, edge_weight)
            elif isinstance(edge_index, SparseTensor):
                edge_index = remove_diag(edge_index)
            out = self.graph_conv(x_in, edge_index, edge_weight)
        # out MLP
        out = torch.cat([out, h], -1)
        out = self.activation(self.lin_out(out))
        out = torch.cat([out, h], -1)
        return self.read_out(out), out


[docs]class GRINCell(nn.Module): r"""The Graph Recurrent Imputation cell with `Diffusion Convolution <https://arxiv.org/abs/1707.01926>`_ from the paper `"Filling the G_ap_s: Multivariate Time Series Imputation by Graph Neural Networks" <https://arxiv.org/abs/2108.00298>`_ (Cini et al., ICLR 2022). Args: input_size (int): Size of the input. hidden_size (int): Number of units in the DCRNN hidden layer. (default: :obj:`64`) exog_size (int): Number of channels in the exogenous variables, if any. (default: :obj:`0`) n_layers (int): Number of stacked DCRNN cells. (default: :obj:`1`) n_nodes (int, optional): Number of nodes in the input graph. (default: :obj:`None`) kernel_size (int): Order of the spatial diffusion process in the DCRNN cells. (default: :obj:`2`) decoder_order (int): Order of the spatial diffusion process in the spatial decoder. (default: :obj:`1`) layer_norm (bool, optional): If :obj:`True`, then use layer normalization. (default: :obj:`False`) dropout (float, optional): Dropout probability in the DCRNN cells. (default: :obj:`0`) """ def __init__(self, input_size: int, hidden_size: int, exog_size: int = 0, n_layers: int = 1, n_nodes: Optional[int] = None, kernel_size: int = 2, decoder_order: int = 1, layer_norm: bool = False, dropout: float = 0.): super(GRINCell, self).__init__() self.input_size = input_size self.hidden_size = hidden_size self.u_size = exog_size self.n_layers = n_layers self.kernel_size = kernel_size # input + mask + (eventually) exogenous rnn_input_size = 2 * self.input_size + exog_size # Spatio-temporal encoder (rnn_input_size -> hidden_size) self.cells = nn.ModuleList() self.norms = nn.ModuleList() for i in range(self.n_layers): in_channels = rnn_input_size if i == 0 else self.hidden_size cell = DCRNNCell(input_size=in_channels, hidden_size=self.hidden_size, k=kernel_size, root_weight=True) self.cells.append(cell) norm = LayerNorm(self.hidden_size) if layer_norm else nn.Identity() self.norms.append(norm) self.dropout = nn.Dropout(dropout) if dropout > 0. else None # Fist stage readout self.first_stage = nn.Linear(self.hidden_size, self.input_size) # Spatial decoder (rnn_input_size + hidden_size -> hidden_size) self.spatial_decoder = SpatialDecoder(input_size=input_size, hidden_size=hidden_size, exog_size=exog_size, order=decoder_order) # Hidden state initialization embedding if n_nodes is not None: self.h0 = nn.ModuleList() for _ in range(self.n_layers): self.h0.append(NodeEmbedding(n_nodes, self.hidden_size)) else: self.register_parameter('h0', None) def __repr__(self): attrs = ['input_size', 'hidden_size', 'kernel_size', 'n_layers'] attrs = ', '.join([f'{attr}={getattr(self, attr)}' for attr in attrs]) return f"{self.__class__.__name__}({attrs})" def get_h0(self, x): if self.h0 is not None: return [h(expand=(x.shape[0], -1, -1)) for h in self.h0] size = (self.n_layers, x.shape[0], x.shape[2], self.hidden_size) return [*torch.zeros(size, device=x.device)] def update_state(self, x, h, edge_index, edge_weight): # x: [batch, nodes, channels] rnn_in = x for layer, (cell, norm) in enumerate(zip(self.cells, self.norms)): h[layer] = norm(cell(rnn_in, h[layer], edge_index, edge_weight)) rnn_in = h[layer] if self.dropout is not None and layer < (self.n_layers - 1): rnn_in = self.dropout(rnn_in) return h def forward(self, x: Tensor, edge_index: Adj, edge_weight: OptTensor = None, mask: OptTensor = None, u: OptTensor = None, h: Union[List[Tensor], Tensor] = None): """""" # x: [batch, steps, nodes, channels] steps = x.size(1) # infer all valid if mask is None if mask is None: mask = torch.ones_like(x, dtype=torch.uint8) # init hidden state using node embedding or the empty state if h is None: h = self.get_h0(x) # [[b n h] * n_layers] elif not isinstance(h, list): h = [*h] # Temporal conv predictions, imputations, states = [], [], [] representations = [] for step in range(steps): x_s = x[:, step] m_s = mask[:, step] h_s = h[-1] u_s = u[:, step] if u is not None else None # firstly impute missing values with predictions from state xs_hat_1 = self.first_stage(h_s) # fill missing values in input with prediction x_s = torch.where(m_s.bool(), x_s, xs_hat_1) # prepare inputs # retrieve maximum information from neighbors xs_hat_2, repr_s = self.spatial_decoder(x_s, m_s, h_s, u=u_s, edge_index=edge_index, edge_weight=edge_weight) # readout of imputation state + mask to retrieve imputations # prepare inputs x_s = torch.where(m_s.bool(), x_s, xs_hat_2) inputs = [x_s, m_s] if u_s is not None: inputs.append(u_s) inputs = torch.cat(inputs, dim=-1) # x_hat_2 + mask + exogenous # update state with original sequence filled using imputations h = self.update_state(inputs, h, edge_index, edge_weight) # store imputations and states imputations.append(xs_hat_2) predictions.append(xs_hat_1) states.append(torch.stack(h, dim=0)) representations.append(repr_s) # Aggregate outputs -> [batch, steps, nodes, channels] imputations = torch.stack(imputations, dim=1) predictions = torch.stack(predictions, dim=1) states = torch.stack(states, dim=2) representations = torch.stack(representations, dim=1) return imputations, predictions, representations, states