Source code for tsl.nn.layers.graph_convs.spatiotemporal_cross_attention

from typing import Optional, Tuple, Union

import torch
from torch import Tensor
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.typing import Adj, OptPairTensor, OptTensor

from tsl.nn.blocks.encoders import TemporalMLPAttention
from tsl.nn.functional import sparse_softmax
from tsl.nn.layers.norm import LayerNorm


[docs]class SpatiotemporalCrossAttention(MessagePassing): def __init__(self, input_size: Union[int, Tuple[int, int]], output_size: int, msg_size: Optional[int] = None, msg_layers: int = 1, root_weight: bool = True, reweigh: Optional[str] = None, temporal_self_attention: bool = True, mask_temporal: bool = True, mask_spatial: bool = True, norm: bool = True, dropout: float = 0., **kwargs): kwargs.setdefault('aggr', 'add') super(SpatiotemporalCrossAttention, self).__init__(node_dim=-2, **kwargs) # store dimensions if isinstance(input_size, int): self.src_size = self.tgt_size = input_size else: self.src_size, self.tgt_size = input_size self.output_size = output_size self.msg_size = msg_size or self.output_size self.mask_temporal = mask_temporal self.mask_spatial = mask_spatial self.root_weight = root_weight self.dropout = dropout if temporal_self_attention: self.self_attention = TemporalMLPAttention(input_size=input_size, output_size=output_size, msg_size=msg_size, msg_layers=msg_layers, reweigh=reweigh, dropout=dropout, root_weight=False, norm=False) else: self.register_parameter('self_attention', None) self.cross_attention = TemporalMLPAttention( input_size=input_size, output_size=output_size, msg_size=msg_size, msg_layers=msg_layers, reweigh=reweigh, dropout=dropout, root_weight=False, norm=False, ) if self.root_weight: self.lin_skip = Linear(self.tgt_size, self.output_size, bias_initializer='zeros') else: self.register_parameter('lin_skip', None) if norm: self.norm = LayerNorm(output_size) else: self.register_parameter('norm', None) self.reset_parameters()
[docs] def reset_parameters(self): self.cross_attention.reset_parameters() if self.self_attention is not None: self.self_attention.reset_parameters() if self.lin_skip is not None: self.lin_skip.reset_parameters() if self.norm is not None: self.norm.reset_parameters()
[docs] def forward(self, x: OptPairTensor, edge_index: Adj, edge_weight: OptTensor = None, mask: OptTensor = None): # inputs: [batch, steps, nodes, channels] if isinstance(x, Tensor): x_src = x_tgt = x else: x_src, x_tgt = x x_tgt = x_tgt if x_tgt is not None else x_src n_src, n_tgt = x_src.size(-2), x_tgt.size(-2) # propagate query, key and value out = self.propagate(x=(x_src, x_tgt), edge_index=edge_index, edge_weight=edge_weight, mask=mask if self.mask_spatial else None, size=(n_src, n_tgt)) if self.self_attention is not None: s, t = x_src.size(1), x_tgt.size(1) if s == t: attn_mask = ~torch.eye( t, t, dtype=torch.bool, device=x_tgt.device) else: attn_mask = None temp = self.self_attention( x=(x_src, x_tgt), mask=mask if self.mask_temporal else None, temporal_mask=attn_mask, ) out = out + temp # skip connection if self.root_weight: out = out + self.lin_skip(x_tgt) if self.norm is not None: out = self.norm(out) return out
[docs] def message(self, x_i: Tensor, x_j: Tensor, edge_weight: OptTensor, mask_j: OptTensor) -> Tensor: # [batch, steps, edges, channels] out = self.cross_attention((x_j, x_i), mask=mask_j) if edge_weight is not None: out = out * edge_weight.view(-1, 1) return out
[docs]class HierarchicalSpatiotemporalCrossAttention(MessagePassing): def __init__(self, h_size: int, z_size: int, msg_size: Optional[int] = None, msg_layers: int = 1, root_weight: bool = True, reweigh: Optional[str] = None, update_z_cross: bool = True, mask_temporal: bool = True, mask_spatial: bool = True, norm: bool = True, dropout: float = 0., aggr: str = 'add', **kwargs): self.spatial_aggr = aggr if aggr == 'softmax': aggr = 'add' super(HierarchicalSpatiotemporalCrossAttention, self).__init__(node_dim=-2, aggr=aggr, **kwargs) # store dimensions self.h_size = h_size self.z_size = z_size self.msg_size = msg_size or z_size self.mask_temporal = mask_temporal self.mask_spatial = mask_spatial self.root_weight = root_weight self.norm = norm self.dropout = dropout self._z_cross = None self.zh_self = TemporalMLPAttention( input_size=(h_size, z_size), output_size=z_size, msg_size=msg_size, msg_layers=msg_layers, reweigh=reweigh, dropout=dropout, root_weight=True, norm=True, ) self.hz_self = TemporalMLPAttention( input_size=(z_size, h_size), output_size=h_size, msg_size=msg_size, msg_layers=msg_layers, reweigh=reweigh, dropout=dropout, root_weight=True, norm=False, ) if update_z_cross: self.zh_cross = TemporalMLPAttention( input_size=(h_size, z_size), output_size=z_size, msg_size=msg_size, msg_layers=msg_layers, reweigh=reweigh, dropout=dropout, root_weight=True, norm=True, ) else: self.register_parameter('zh_cross', None) self.hz_cross = TemporalMLPAttention( input_size=(z_size, h_size), output_size=h_size, msg_size=msg_size, msg_layers=msg_layers, reweigh=None, dropout=dropout, root_weight=True, norm=False, ) if self.spatial_aggr == 'softmax': self.lin_alpha_h = Linear(h_size, 1, bias=False) self.lin_alpha_z = Linear(z_size, 1, bias=False) else: self.register_parameter('lin_alpha_h', None) self.register_parameter('lin_alpha_z', None) if self.root_weight: self.h_skip = Linear(h_size, h_size, bias_initializer='zeros') self.z_skip = Linear(z_size, z_size, bias_initializer='zeros') else: self.register_parameter('h_skip', None) self.register_parameter('z_skip', None) if self.norm: self.h_norm = LayerNorm(h_size) self.z_norm = LayerNorm(z_size) else: self.register_parameter('h_norm', None) self.register_parameter('z_norm', None) self.reset_parameters()
[docs] def reset_parameters(self): self.zh_self.reset_parameters() self.hz_self.reset_parameters() if self.zh_cross is not None: self.zh_cross.reset_parameters() self.hz_cross.reset_parameters() if self.spatial_aggr == 'softmax': self.lin_alpha_h.reset_parameters() self.lin_alpha_z.reset_parameters() if self.root_weight: self.h_skip.reset_parameters() self.z_skip.reset_parameters() if self.norm: self.h_norm.reset_parameters() self.z_norm.reset_parameters()
[docs] def forward(self, h: Tensor, z: Tensor, edge_index: Adj, mask: OptTensor = None): # inputs: [batch, steps, nodes, channels] z_out = self.zh_self(x=(h, z), mask=mask if self.mask_temporal else None) h_self = self.hz_self(x=(z_out, h)) # propagate query, key and value n_src, n_tgt = h.size(-2), z.size(-2) h_out = self.propagate(h=h_self, z=z_out, edge_index=edge_index, mask=mask if self.mask_spatial else None, size=(n_src, n_tgt)) if self._z_cross is not None: z_out = self.aggregate(self._z_cross, edge_index[1], dim_size=n_tgt) self._z_cross = None # skip connection if self.root_weight: h_out = h_out + self.h_skip(h) z_out = z_out + self.z_skip(z) if self.norm: h_out = self.h_norm(h_out) z_out = self.z_norm(z_out) return h_out, z_out
def h_cross_message(self, h_i: Tensor, z_j: Tensor, index, size_i) -> Tensor: # [batch, steps, edges, channels] h_cross = self.hz_cross((z_j, h_i)) if self.spatial_aggr == 'softmax': alpha_h = self.lin_alpha_h(h_cross) alpha_h = sparse_softmax(alpha_h, index, num_nodes=size_i, dim=self.node_dim) h_cross = alpha_h * h_cross return h_cross def hz_cross_message(self, h_i: Tensor, h_j: Tensor, z_i: Tensor, index, size_i, mask_j: OptTensor) -> Tensor: # [batch, steps, edges, channels] z_cross = self.zh_cross((h_j, z_i), mask=mask_j) h_cross = self.hz_cross((z_cross, h_i)) if self.spatial_aggr == 'softmax': # reweigh z alpha_z = self.lin_alpha_z(z_cross) alpha_z = sparse_softmax(alpha_z, index, num_nodes=size_i, dim=self.node_dim) z_cross = alpha_z * z_cross # reweigh h alpha_h = self.lin_alpha_h(h_cross) alpha_h = sparse_softmax(alpha_h, index, num_nodes=size_i, dim=self.node_dim) h_cross = alpha_h * h_cross self._z_cross = z_cross return h_cross
[docs] def message(self, h_i: Tensor, h_j: Tensor, z_i: Tensor, z_j: Tensor, index, size_i, mask_j: OptTensor) -> Tensor: if self.zh_cross is not None: return self.hz_cross_message(h_i, h_j, z_i, index, size_i, mask_j) return self.h_cross_message(h_i, z_j, index, size_i)