Source code for tsl.nn.base.attention.attention

from typing import Optional

from einops import rearrange
from torch import nn, Tensor
import torch

from torch_geometric.nn.dense import Linear
from torch_geometric.typing import OptTensor

from tsl.nn.layers.positional_encoding import PositionalEncoding
from tsl.nn.utils import get_functional_activation

@torch.jit.script
def _get_causal_mask(seq_len: int, diagonal: int = 0,
                     device: Optional[torch.device] = None):
    # mask keeping only previous steps
    ones = torch.ones((seq_len, seq_len), dtype=torch.bool, device=device)
    causal_mask = torch.triu(ones, diagonal)
    return causal_mask


[docs]class AttentionEncoder(nn.Module): def __init__(self, embed_dim, qdim: Optional[int] = None, kdim: Optional[int] = None, vdim: Optional[int] = None, add_positional_encoding: bool = False, bias: bool = True, activation: Optional[str] = None) -> None: super(AttentionEncoder, self).__init__() self.embed_dim = embed_dim self.qdim = qdim self.kdim = kdim self.vdim = vdim self.lin_query = Linear(qdim, self.embed_dim, bias) \ if qdim is not None else nn.Identity() self.lin_key = Linear(kdim, self.embed_dim, bias) \ if qdim is not None else nn.Identity() self.lin_value = Linear(vdim, self.embed_dim, bias) \ if qdim is not None else nn.Identity() self.activation = get_functional_activation(activation) self.pe = PositionalEncoding(self.embed_dim) \ if add_positional_encoding else nn.Identity()
[docs] def forward(self, query: Tensor, key: OptTensor = None, value: OptTensor = None): # inputs: [batches, steps, nodes, channels] if key is None: key = query if value is None: value = key query = self.pe(self.activation(self.lin_query(query))) key = self.pe(self.activation(self.lin_key(key))) value = self.activation(self.lin_value(value)) return query, key, value
[docs]class MultiHeadAttention(nn.MultiheadAttention): def __init__(self, embed_dim, heads, qdim: Optional[int] = None, kdim: Optional[int] = None, vdim: Optional[int] = None, axis='steps', dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, device=None, dtype=None, causal=False) -> None: if axis in ['steps', 0]: shape = 's (b n) c' elif axis in ['nodes', 1]: if causal: raise ValueError(f'Cannot use causal attention for axis "{axis}".') shape = 'n (b s) c' else: raise ValueError("Axis can either be 'steps' (0) or 'nodes' (1), " f"not '{axis}'.") self._in_pattern = f'b s n c -> {shape}' self._out_pattern = f'{shape} -> b s n c' self.causal = causal # Impose batch dimension as the second one super(MultiHeadAttention, self).__init__(embed_dim, heads, dropout=dropout, bias=bias, add_bias_kv=add_bias_kv, add_zero_attn=add_zero_attn, kdim=kdim, vdim=vdim, batch_first=False, device=device, dtype=dtype) # change projections if qdim is not None and qdim != embed_dim: self.qdim = qdim self.q_proj = Linear(self.qdim, embed_dim) else: self.qdim = embed_dim self.q_proj = nn.Identity()
[docs] def forward(self, query: Tensor, key: OptTensor = None, value: OptTensor = None, key_padding_mask: OptTensor = None, need_weights: bool = True, attn_mask: OptTensor = None): # inputs: [batches, steps, nodes, channels] -> [s (b n) c] if key is None: key = query if value is None: value = key batch = value.shape[0] query, key, value = [rearrange(x, self._in_pattern) for x in (query, key, value)] if self.causal: causal_mask = _get_causal_mask(key.size(0), diagonal=1, device=query.device) if attn_mask is None: attn_mask = causal_mask else: attn_mask = torch.logical_and(attn_mask, causal_mask) attn_output, attn_weights = super(MultiHeadAttention, self).forward(self.q_proj(query), key, value, key_padding_mask, need_weights, attn_mask) attn_output = rearrange(attn_output, self._out_pattern, b=batch)\ .contiguous() if attn_weights is not None: attn_weights = rearrange(attn_weights, '(b d) l m -> b d l m', b=batch).contiguous() return attn_output, attn_weights