Source code for tsl.nn.base.attention.linear_attention

from typing import Optional

import torch
    from fast_transformers.attention import CausalLinearAttention as CLAttention
    from fast_transformers.masking import TriangularCausalMask, LengthMask
except ModuleNotFoundError:
    CLAttention = None
from torch import Tensor
from torch import nn
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.typing import OptTensor

[docs]class CausalLinearAttention(nn.Module): def __init__(self, embed_dim, heads, qdim: Optional[int] = None, kdim: Optional[int] = None, vdim: Optional[int] = None, out_channels: Optional[int] = None, concat: bool = True, dim: int = 1) -> None: super(CausalLinearAttention, self).__init__() if CLAttention is None: raise RuntimeError("Install optional dependency 'fast_transformers'" " to use CausalLinearAttention.") # store dimensions self.embed_dim = int(embed_dim) self.qdim = int(qdim) if qdim is not None else self.embed_dim self.kdim = int(kdim) if kdim is not None else self.embed_dim self.vdim = int(vdim) if vdim is not None else self.embed_dim self.out_channels = int(out_channels) if out_channels is not None \ else self.embed_dim self.heads = heads self.concat = concat self.dim = dim if self.concat: self.head_dim = self.embed_dim // self.heads out_dim = self.out_channels // self.heads assert self.head_dim * self.heads == self.embed_dim, \ "embed_dim must be divisible by heads" assert out_dim * self.heads == self.out_channels, \ "out_channels must be divisible by heads" else: self.head_dim, out_dim = self.embed_dim, self.out_channels self.lin_key = Linear(self.kdim, self.heads * self.head_dim, bias_initializer='zeros') self.lin_query = Linear(self.qdim, self.heads * self.head_dim, bias_initializer='zeros') self.lin_value = Linear(self.vdim, self.heads * out_dim, bias_initializer='zeros') self.attention = CLAttention(self.head_dim) self.reset_parameters() def reset_parameters(self): self.lin_key.reset_parameters() self.lin_query.reset_parameters() self.lin_value.reset_parameters()
[docs] def forward(self, query: Tensor, key: OptTensor = None, value: OptTensor = None): # If key and value not provided, self attention if key is None: key = query if value is None: value = key L, H, E = query.size(self.dim), self.heads, self.head_dim # move sequence dimension to penultimate dimension -> [*b s c] query = query.transpose(self.dim, -2) key = key.transpose(self.dim, -2) value = value.transpose(self.dim, -2) # B = value.shape[:-2] if not (torch.tensor(query.shape[:-2] + key.shape[:-2]) == 1).all(): query = query.expand(*B, *query.shape[-2:]) key = key.expand(*B, *key.shape[-2:]) # project and split heads query = self.lin_query(query).view(-1, L, H, E) key = self.lin_key(key).view(-1, L, H, E) value = self.lin_value(value).view(-1, L, H, E) attn_mask = TriangularCausalMask(L, device=query.device) key_lengths = LengthMask(torch.LongTensor([1]), 1, device=query.device) out = self.attention(query.float(), key.float(), value.float(), attn_mask, query_lengths=None, key_lengths=key_lengths) # reshape out to [*b, s, *n, c] if not self.concat: out = out.view(*B, L, H, E).mean(-2)\ .transpose(self.dim, -2).contiguous() else: out = out.view(*B, L, -1).transpose(self.dim, -2).contiguous() return out