Source code for tsl.nn.blocks.encoders.mlp_attention

from typing import Optional, Tuple, Union

import torch
from torch import Tensor, nn
from torch.nn import LayerNorm
from torch.nn import functional as F
from torch_geometric.nn import Linear, MessagePassing
from torch_geometric.typing import Adj, OptTensor, PairTensor
from torch_scatter import scatter
from torch_scatter.utils import broadcast

from tsl.nn.functional import sparse_softmax

from .mlp import MLP


[docs]class MLPAttention(MessagePassing): def __init__(self, input_size: Union[int, Tuple[int, int]], output_size: int, msg_size: Optional[int] = None, msg_layers: int = 1, root_weight: bool = True, reweigh: Optional[str] = None, norm: bool = True, dropout: float = 0.0, dim: int = -2, **kwargs): kwargs.setdefault('aggr', 'add') super().__init__(node_dim=dim, **kwargs) self.output_size = output_size if isinstance(input_size, int): self.src_size = self.tgt_size = input_size else: self.src_size, self.tgt_size = input_size self.msg_size = msg_size or self.output_size self.msg_layers = msg_layers assert reweigh in ['softmax', 'l1', None] self.reweigh = reweigh self.root_weight = root_weight self.dropout = dropout # key bias is discarded in softmax self.lin_src = Linear(self.src_size, self.output_size, weight_initializer='glorot', bias_initializer='zeros') self.lin_tgt = Linear(self.tgt_size, self.output_size, weight_initializer='glorot', bias=False) if self.root_weight: self.lin_skip = Linear(self.tgt_size, self.output_size, bias=False) else: self.register_parameter('lin_skip', None) self.msg_nn = nn.Sequential( nn.PReLU(init=0.2), MLP(self.output_size, self.msg_size, self.output_size, n_layers=self.msg_layers, dropout=self.dropout, activation='prelu')) if self.reweigh == 'softmax': self.msg_gate = nn.Linear(self.output_size, 1, bias=False) else: self.msg_gate = nn.Sequential(nn.Linear(self.output_size, 1), nn.Sigmoid()) if norm: self.norm = LayerNorm(self.output_size) else: self.register_parameter('norm', None) self.reset_parameters()
[docs] def reset_parameters(self): self.lin_src.reset_parameters() self.lin_tgt.reset_parameters() if self.lin_skip is not None: self.lin_skip.reset_parameters()
[docs] def forward(self, x: PairTensor, edge_index: Adj, mask: OptTensor = None): # if query/key not provided, defaults to x (e.g., for self-attention) if isinstance(x, Tensor): x_src = x_tgt = x else: x_src, x_tgt = x x_tgt = x_tgt if x_tgt is not None else x_src N_src, N_tgt = x_src.size(self.node_dim), x_tgt.size(self.node_dim) msg_src = self.lin_src(x_src) msg_tgt = self.lin_tgt(x_tgt) msg = (msg_src, msg_tgt) # propagate_type: (msg: PairTensor, mask: OptTensor) out = self.propagate(edge_index, msg=msg, mask=mask, size=(N_src, N_tgt)) # skip connection if self.root_weight: out = out + self.lin_skip(x_tgt) if self.norm is not None: out = self.norm(out) return out
def normalize_weights(self, weights, index, num_nodes, mask=None): # mask weights if mask is not None: fill_value = float("-inf") if self.reweigh == 'softmax' else 0. weights = weights.masked_fill(torch.logical_not(mask), fill_value) # optional reweighing if self.reweigh == 'l1': expanded_index = broadcast(index, weights, self.node_dim) weights_sum = scatter(weights, expanded_index, self.node_dim, dim_size=num_nodes, reduce='sum') weights_sum = weights_sum.index_select(self.node_dim, index) weights = weights / (weights_sum + 1e-5) elif self.reweigh == 'softmax': weights = sparse_softmax(weights, index, num_nodes=num_nodes, dim=self.node_dim) return weights
[docs] def message(self, msg_j: Tensor, msg_i: Tensor, index, size_i, mask_j: OptTensor = None) -> Tensor: msg = self.msg_nn(msg_j + msg_i) gate = self.msg_gate(msg) alpha = self.normalize_weights(gate, index, size_i, mask_j) alpha = F.dropout(alpha, p=self.dropout, training=self.training) out = alpha * msg return out
def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.output_size}, ' f'dim={self.node_dim}, ' f'root_weight={self.root_weight})')
[docs]class TemporalMLPAttention(MLPAttention): def __init__(self, input_size: Union[int, Tuple[int, int]], output_size: int, msg_size: Optional[int] = None, msg_layers: int = 1, root_weight: bool = True, reweigh: Optional[str] = None, norm: bool = True, dropout: float = 0.0, **kwargs): kwargs.setdefault('dim', 1) super().__init__(input_size=input_size, output_size=output_size, msg_size=msg_size, msg_layers=msg_layers, root_weight=root_weight, reweigh=reweigh, dropout=dropout, norm=norm, **kwargs)
[docs] def forward(self, x: PairTensor, mask: OptTensor = None, temporal_mask: OptTensor = None, causal_lag: Optional[int] = None): # x: [b s * c] query: [b l * c] key: [b s * c] # mask: [b s * c] temporal_mask: [l s] if isinstance(x, Tensor): x_src = x_tgt = x else: x_src, x_tgt = x x_tgt = x_tgt if x_tgt is not None else x_src l, s = x_tgt.size(self.node_dim), x_src.size(self.node_dim) i = torch.arange(l, dtype=torch.long, device=x_src.device) j = torch.arange(s, dtype=torch.long, device=x_src.device) # compute temporal index, from j to i if temporal_mask is None and isinstance(causal_lag, int): temporal_mask = tuple( torch.tril_indices(l, l, offset=-causal_lag, device=x_src.device)) if temporal_mask is not None: assert temporal_mask.size() == (l, s) i, j = torch.meshgrid(i, j) edge_index = torch.stack((j[temporal_mask], i[temporal_mask])) else: edge_index = torch.cartesian_prod(j, i).T return super(TemporalMLPAttention, self).forward(x, edge_index, mask=mask)