Source code for tsl.nn.models.stgn.grin_model

from typing import Optional

import torch
import torch.nn as nn
from torch import Tensor
from torch_geometric.typing import Adj, OptTensor

from tsl.nn.layers.base import NodeEmbedding
from tsl.nn.layers.recurrent import GRINCell
from tsl.nn.models.base_model import BaseModel


[docs]class GRINModel(BaseModel): r"""The Graph Recurrent Imputation Network with DCRNN cells from the paper `"Filling the G_ap_s: Multivariate Time Series Imputation by Graph Neural Networks" <https://arxiv.org/abs/2108.00298>`_ (Cini et al., ICLR 2022). Args: input_size (int): Size of the input. hidden_size (int): Number of units in the DCRNN hidden layer. (default: :obj:`64`) ff_size (int): Number of units in the nonlinear readout. (default: :obj:`128`) embedding_size (int, optional): Number of features in the optional node embeddings. (default: :obj:`None`) exog_size (int): Number of channels in the exogenous variables, if any. (default: :obj:`None`) n_layers (int): Number DCRNN cells. (default: :obj:`1`) n_nodes (int, optional): Number of nodes in the input graph. (default: :obj:`None`) kernel_size (int): Order of the spatial diffusion process in the DCRNN cells. (default: :obj:`2`) decoder_order (int): Order of the spatial diffusion process in the spatial decoder. (default: :obj:`1`) layer_norm (bool, optional): If :obj:`True`, then use layer normalization. (default: :obj:`False`) dropout (float, optional): Dropout probability in the DCRNN cells. (default: :obj:`0`) ff_dropout (float, optional): Dropout probability in the readout. (default: :obj:`0`) merge_mode (str, optional): Strategy used to merge representations coming from the two branches of the bidirectional model. (default: :obj:`mlp`) """ return_type = list def __init__(self, input_size: int, hidden_size: int = 64, ff_size: int = 128, embedding_size: Optional[int] = None, exog_size: int = 0, n_layers: int = 1, n_nodes: Optional[int] = None, kernel_size: int = 2, decoder_order: int = 1, layer_norm: bool = False, dropout: float = 0., ff_dropout: float = 0., merge_mode: str = 'mlp'): super(GRINModel, self).__init__() self.fwd_gril = GRINCell(input_size=input_size, hidden_size=hidden_size, exog_size=exog_size, n_layers=n_layers, dropout=dropout, kernel_size=kernel_size, decoder_order=decoder_order, n_nodes=n_nodes, layer_norm=layer_norm) self.bwd_gril = GRINCell(input_size=input_size, hidden_size=hidden_size, exog_size=exog_size, n_layers=n_layers, dropout=dropout, kernel_size=kernel_size, decoder_order=decoder_order, n_nodes=n_nodes, layer_norm=layer_norm) if embedding_size is not None: assert n_nodes is not None self.emb = NodeEmbedding(n_nodes, embedding_size) else: self.register_parameter('emb', None) self.merge_mode = merge_mode if merge_mode == 'mlp': in_channels = 4 * hidden_size + input_size + embedding_size self.out = nn.Sequential(nn.Linear(in_channels, ff_size), nn.ReLU(), nn.Dropout(ff_dropout), nn.Linear(ff_size, input_size)) elif merge_mode in ['mean', 'sum', 'min', 'max']: self.out = getattr(torch, merge_mode) else: raise ValueError("Merge option %s not allowed." % merge_mode)
[docs] def forward(self, x: Tensor, edge_index: Adj, edge_weight: OptTensor = None, mask: OptTensor = None, u: OptTensor = None) -> list: """""" # x: [batch, steps, nodes, channels] fwd_out, fwd_pred, fwd_repr, _ = self.fwd_gril(x, edge_index, edge_weight, mask=mask, u=u) # Backward rev_x = x.flip(1) rev_mask = mask.flip(1) if mask is not None else None rev_u = u.flip(1) if u is not None else None *bwd, _ = self.bwd_gril(rev_x, edge_index, edge_weight, mask=rev_mask, u=rev_u) bwd_out, bwd_pred, bwd_repr = [res.flip(1) for res in bwd] if self.merge_mode == 'mlp': inputs = [fwd_repr, bwd_repr, mask] if self.emb is not None: b, s, *_ = fwd_repr.size() # fwd_h: [b t n f] inputs += [self.emb(expand=(b, s, -1, -1))] imputation = torch.cat(inputs, dim=-1) imputation = self.out(imputation) else: imputation = torch.stack([fwd_out, bwd_out], dim=-1) imputation = self.out(imputation, dim=-1) return [imputation, (fwd_out, bwd_out, fwd_pred, bwd_pred)]
[docs] def predict(self, x: Tensor, edge_index: Adj, edge_weight: OptTensor = None, mask: OptTensor = None, u: OptTensor = None) -> Tensor: """""" imputation = self.forward(x=x, mask=mask, u=u, edge_index=edge_index, edge_weight=edge_weight)[0] return imputation