Source code for tsl.nn.models.stgn.spin_model

from typing import List, Optional, Tuple

import torch
from torch import Tensor, nn
from torch.nn import LayerNorm
from torch_geometric.nn import inits
from torch_geometric.typing import OptTensor

from tsl.nn.blocks.encoders import MLP
from tsl.nn.layers.base import NodeEmbedding, PositionalEncoding
from tsl.nn.layers.graph_convs import (
    HierarchicalSpatiotemporalCrossAttention, SpatiotemporalCrossAttention)
from tsl.nn.models.base_model import BaseModel


class SPINPositionalEncoder(nn.Module):

    def __init__(self,
                 in_channels,
                 out_channels,
                 n_layers: int = 1,
                 n_nodes: Optional[int] = None):
        super(SPINPositionalEncoder, self).__init__()
        self.lin = nn.Linear(in_channels, out_channels)
        self.activation = nn.LeakyReLU()
        self.mlp = MLP(out_channels,
                       out_channels,
                       out_channels,
                       n_layers=n_layers,
                       activation='relu')
        self.positional = PositionalEncoding(out_channels)
        if n_nodes is not None:
            self.node_emb = NodeEmbedding(n_nodes, out_channels)
        else:
            self.register_parameter('node_emb', None)

    def forward(self,
                x: Tensor,
                node_emb: OptTensor = None,
                node_index: OptTensor = None) -> Tensor:
        if node_emb is None:
            node_emb = self.node_emb(node_index=node_index)
        # x: [b t (n) f], node_emb: [n f] -> [b t n f] (broadcasting)
        x = self.lin(x)
        if x.ndim == 3:
            x = x.unsqueeze(-2)  # x: [b t f] -> [b t 1 f]
        x = self.activation(x + node_emb)
        out = self.mlp(x)
        out = self.positional(out)
        return out


[docs]class SPINModel(BaseModel): r"""The Spatiotemporal Point Inference Network (SPIN) from the paper `"Learning to Reconstruct Missing Data from Spatiotemporal Graphs with Sparse Observations" <https://arxiv.org/abs/2205.13479>`_ (Marisca et al., NeurIPS 2022). """ return_type = tuple def __init__(self, input_size: int, hidden_size: int, n_nodes: int, exog_size: Optional[int] = None, output_size: Optional[int] = None, temporal_self_attention: bool = True, reweigh: Optional[str] = 'softmax', n_layers: int = 4, eta: int = 3, message_layers: int = 1): super(SPINModel, self).__init__() exog_size = exog_size or input_size output_size = output_size or input_size self.n_nodes = n_nodes self.n_layers = n_layers self.eta = eta self.temporal_self_attention = temporal_self_attention self.u_enc = SPINPositionalEncoder(in_channels=exog_size, out_channels=hidden_size, n_layers=2, n_nodes=n_nodes) self.h_enc = MLP(input_size, hidden_size, n_layers=2) self.h_norm = LayerNorm(hidden_size) self.valid_emb = NodeEmbedding(n_nodes, hidden_size) self.mask_emb = NodeEmbedding(n_nodes, hidden_size) self.x_skip = nn.ModuleList() self.encoder, self.readout = nn.ModuleList(), nn.ModuleList() for layer in range(n_layers): x_skip = nn.Linear(input_size, hidden_size) encoder = SpatiotemporalCrossAttention( input_size=hidden_size, output_size=hidden_size, msg_size=hidden_size, msg_layers=message_layers, temporal_self_attention=temporal_self_attention, reweigh=reweigh, mask_temporal=True, mask_spatial=layer < eta, norm=True, root_weight=True, dropout=0.0) readout = MLP(hidden_size, hidden_size, output_size, n_layers=2) self.x_skip.append(x_skip) self.encoder.append(encoder) self.readout.append(readout)
[docs] def forward(self, x: Tensor, u: Tensor, mask: Tensor, edge_index: Tensor, node_index: OptTensor = None, target_nodes: OptTensor = None) -> Tuple[Tensor, List[Tensor]]: """""" if target_nodes is None: target_nodes = slice(None) # Whiten missing values x = x * mask # POSITIONAL ENCODING ################################################# # Obtain spatio-temporal positional encoding for every node-step pair # # in both observed and target sets. Encoding are obtained by jointly # # processing node and time positional encoding. # # Build (node, timestamp) encoding q = self.u_enc(u, node_index=node_index) # Condition value on key h = self.h_enc(x) + q # ENCODER ############################################################# # Obtain representations h^i_t for every (i, t) node-step pair by # # only taking into account valid data in representation set. # # Replace H in missing entries with queries Q h = torch.where(mask.bool(), h, q) # Normalize features h = self.h_norm(h) imputations = [] for layer in range(self.n_layers): if layer == self.eta: # Condition H on two different embeddings to distinguish # valid values from masked ones valid = self.valid_emb(node_index=node_index) masked = self.mask_emb(node_index=node_index) h = torch.where(mask.bool(), h + valid, h + masked) # Masked Temporal GAT for encoding representation h = h + self.x_skip[layer](x) * mask # skip connection for valid x h = self.encoder[layer](h, edge_index, mask=mask) # Read from H to get imputations target_readout = self.readout[layer](h[..., target_nodes, :]) imputations.append(target_readout) # Get final layer imputations x_hat = imputations.pop(-1) return x_hat, imputations
[docs] def predict(self, x: Tensor, u: Tensor, mask: Tensor, edge_index: Tensor, node_index: OptTensor = None, target_nodes: OptTensor = None) -> Tensor: """""" imputation = self.forward(x=x, u=u, mask=mask, edge_index=edge_index, node_index=node_index, target_nodes=target_nodes)[0] return imputation
[docs]class SPINHierarchicalModel(BaseModel): r"""The Hierarchical Spatiotemporal Point Inference Network (SPIN-H) from the paper `"Learning to Reconstruct Missing Data from Spatiotemporal Graphs with Sparse Observations" <https://arxiv.org/abs/2205.13479>`_ (Marisca et al., NeurIPS 2022). """ return_type = tuple def __init__(self, input_size: int, h_size: int, z_size: int, n_nodes: int, z_heads: int = 1, exog_size: Optional[int] = None, output_size: Optional[int] = None, n_layers: int = 5, eta: int = 3, message_layers: int = 1, reweigh: Optional[str] = 'softmax', update_z_cross: bool = True, norm: bool = True, spatial_aggr: str = 'add'): super(SPINHierarchicalModel, self).__init__() exog_size = exog_size or input_size output_size = output_size or input_size self.h_size = h_size self.z_size = z_size self.n_nodes = n_nodes self.z_heads = z_heads self.n_layers = n_layers self.eta = eta self.v = NodeEmbedding(n_nodes, h_size) self.lin_v = nn.Linear(h_size, z_size, bias=False) self.z = nn.Parameter(torch.Tensor(1, z_heads, n_nodes, z_size)) inits.uniform(z_size, self.z) self.z_norm = LayerNorm(z_size) self.u_enc = SPINPositionalEncoder(in_channels=exog_size, out_channels=h_size, n_layers=2) self.h_enc = MLP(input_size, h_size, n_layers=2) self.h_norm = LayerNorm(h_size) self.v1 = NodeEmbedding(n_nodes, h_size) self.m1 = NodeEmbedding(n_nodes, h_size) self.v2 = NodeEmbedding(n_nodes, h_size) self.m2 = NodeEmbedding(n_nodes, h_size) self.x_skip = nn.ModuleList() self.encoder, self.readout = nn.ModuleList(), nn.ModuleList() for layer in range(n_layers): x_skip = nn.Linear(input_size, h_size) encoder = HierarchicalSpatiotemporalCrossAttention( h_size=h_size, z_size=z_size, msg_size=h_size, msg_layers=message_layers, reweigh=reweigh, mask_temporal=True, mask_spatial=layer < eta, update_z_cross=update_z_cross, norm=norm, root_weight=True, aggr=spatial_aggr, dropout=0.0) readout = MLP(h_size, z_size, output_size, n_layers=2) self.x_skip.append(x_skip) self.encoder.append(encoder) self.readout.append(readout)
[docs] def forward(self, x: Tensor, u: Tensor, mask: Tensor, edge_index: Tensor, node_index: OptTensor = None, target_nodes: OptTensor = None) -> Tuple[Tensor, List[Tensor]]: """""" if target_nodes is None: target_nodes = slice(None) if node_index is None: node_index = slice(None) # POSITIONAL ENCODING ################################################# # Obtain spatio-temporal positional encoding for every node-step pair # # in both observed and target sets. Encoding are obtained by jointly # # processing node and time positional encoding. # # Condition also embeddings Z on V. # v_nodes = self.v(node_index=node_index) z = self.z[..., node_index, :] + self.lin_v(v_nodes) # Build (node, timestamp) encoding q = self.u_enc(u, node_index=node_index, node_emb=v_nodes) # Condition value on key h = self.h_enc(x) + q # ENCODER ############################################################# # Obtain representations h^i_t for every (i, t) node-step pair by # # only taking into account valid data in representation set. # # Replace H in missing entries with queries Q. Then, condition H on two # different embeddings to distinguish valid values from masked ones. h = torch.where(mask.bool(), h + self.v1(), q + self.m1()) # Normalize features h, z = self.h_norm(h), self.z_norm(z) imputations = [] for layer in range(self.n_layers): if layer == self.eta: # Condition H on two different embeddings to distinguish # valid values from masked ones h = torch.where(mask.bool(), h + self.v2(), h + self.m2()) # Skip connection from input x h = h + self.x_skip[layer](x) * mask # Masked Temporal GAT for encoding representation h, z = self.encoder[layer](h, z, edge_index, mask=mask) target_readout = self.readout[layer](h[..., target_nodes, :]) imputations.append(target_readout) x_hat = imputations.pop(-1) return x_hat, imputations
[docs] def predict(self, x: Tensor, u: Tensor, mask: Tensor, edge_index: Tensor, node_index: OptTensor = None, target_nodes: OptTensor = None) -> Tensor: """""" imputation = self.forward(x=x, u=u, mask=mask, edge_index=edge_index, node_index=node_index, target_nodes=target_nodes)[0] return imputation