Source code for tsl.nn.blocks.encoders.transformer

from functools import partial
from typing import Optional

import torch.nn.functional as F
from torch import Tensor, nn

from tsl.nn import utils
from tsl.nn.layers.base import MultiHeadAttention
from tsl.nn.layers.norm import LayerNorm


[docs]class TransformerLayer(nn.Module): r"""A Transformer layer from the paper `"Attention Is All You Need" <https://arxiv.org/abs/1706.03762>`_ (Vaswani et al., NeurIPS 2017). This layer can be instantiated to attend the temporal or spatial dimension. Args: input_size (int): Input size. hidden_size (int): Dimension of the learned representations. ff_size (int): Units in the MLP after self attention. n_heads (int, optional): Number of parallel attention heads. axis (str, optional): Dimension on which to apply attention to update the representations. Can be either, 'time' or 'nodes'. (default: :obj:`'time'`) causal (bool, optional): If :obj:`True`, then causally mask attention scores in temporal attention (has an effect only if :attr:`axis` is :obj:`'time'`). (default: :obj:`True`) activation (str, optional): Activation function. dropout (float, optional): Dropout probability. """ def __init__(self, input_size, hidden_size, ff_size=None, n_heads=1, axis='time', causal=True, activation='elu', dropout=0.): super(TransformerLayer, self).__init__() self.att = MultiHeadAttention(embed_dim=hidden_size, qdim=input_size, kdim=input_size, vdim=input_size, heads=n_heads, axis=axis, causal=causal) if input_size != hidden_size: self.skip_conn = nn.Linear(input_size, hidden_size) else: self.skip_conn = nn.Identity() self.norm1 = LayerNorm(input_size) self.mlp = nn.Sequential(LayerNorm(hidden_size), nn.Linear(hidden_size, ff_size), utils.get_layer_activation(activation)(), nn.Dropout(dropout), nn.Linear(ff_size, hidden_size), nn.Dropout(dropout)) self.dropout = nn.Dropout(dropout) self.activation = utils.get_functional_activation(activation) def forward(self, x: Tensor, mask: Optional[Tensor] = None): """""" # x: [batch, steps, nodes, features] x = self.skip_conn(x) + self.dropout( self.att(self.norm1(x), attn_mask=mask)[0]) x = x + self.mlp(x) return x
[docs]class SpatioTemporalTransformerLayer(nn.Module): r"""A :class:`~tsl.nn.blocks.encoders.TransformerLayer` which attend both the spatial and temporal dimensions by stacking two :class:`~tsl.nn.layers.base.MultiHeadAttention` layers. Args: input_size (int): Input size. hidden_size (int): Dimension of the learned representations. ff_size (int): Units in the MLP after self attention. n_heads (int, optional): Number of parallel attention heads. causal (bool, optional): If :obj:`True`, then causally mask attention scores in temporal attention. (default: :obj:`True`) activation (str, optional): Activation function. dropout (float, optional): Dropout probability. """ def __init__(self, input_size, hidden_size, ff_size=None, n_heads=1, causal=True, activation='elu', dropout=0.): super(SpatioTemporalTransformerLayer, self).__init__() self.temporal_att = MultiHeadAttention(embed_dim=hidden_size, qdim=input_size, kdim=input_size, vdim=input_size, heads=n_heads, axis='time', causal=causal) self.spatial_att = MultiHeadAttention(embed_dim=hidden_size, qdim=hidden_size, kdim=hidden_size, vdim=hidden_size, heads=n_heads, axis='nodes', causal=False) self.skip_conn = nn.Linear(input_size, hidden_size) self.norm1 = LayerNorm(input_size) self.norm2 = LayerNorm(hidden_size) self.mlp = nn.Sequential(LayerNorm(hidden_size), nn.Linear(hidden_size, ff_size), utils.get_layer_activation(activation)(), nn.Dropout(dropout), nn.Linear(ff_size, hidden_size), nn.Dropout(dropout)) self.dropout = nn.Dropout(dropout) def forward(self, x: Tensor, mask: Optional[Tensor] = None): """""" # x: [batch, steps, nodes, features] x = self.skip_conn(x) + self.dropout( self.temporal_att(self.norm1(x), attn_mask=mask)[0]) x = x + self.dropout( self.spatial_att(self.norm2(x), attn_mask=mask)[0]) x = x + self.mlp(x) return x
[docs]class Transformer(nn.Module): r"""A stack of Transformer layers. Args: input_size (int): Input size. hidden_size (int): Dimension of the learned representations. ff_size (int): Units in the MLP after self attention. output_size (int, optional): Size of an optional linear readout. n_layers (int, optional): Number of Transformer layers. n_heads (int, optional): Number of parallel attention heads. axis (str, optional): Dimension on which to apply attention to update the representations. Can be either, 'time', 'nodes', or 'both'. (default: :obj:`'time'`) causal (bool, optional): If :obj:`True`, then causally mask attention scores in temporal attention (has an effect only if :attr:`axis` is :obj:`'time'` or :obj:`'both'`). (default: :obj:`True`) activation (str, optional): Activation function. dropout (float, optional): Dropout probability. """ def __init__(self, input_size, hidden_size, ff_size=None, output_size=None, n_layers=1, n_heads=1, axis='time', causal=True, activation='elu', dropout=0.): super(Transformer, self).__init__() self.f = getattr(F, activation) if ff_size is None: ff_size = hidden_size if axis in ['time', 'nodes']: transformer_layer = partial(TransformerLayer, axis=axis) elif axis == 'both': transformer_layer = SpatioTemporalTransformerLayer else: raise ValueError(f'"{axis}" is not a valid axis.') layers = [] for i in range(n_layers): layers.append( transformer_layer( input_size=input_size if i == 0 else hidden_size, hidden_size=hidden_size, ff_size=ff_size, n_heads=n_heads, causal=causal, activation=activation, dropout=dropout)) self.net = nn.Sequential(*layers) if output_size is not None: self.readout = nn.Linear(hidden_size, output_size) else: self.register_parameter('readout', None) def forward(self, x: Tensor): """""" x = self.net(x) if self.readout is not None: return self.readout(x) return x