Source code for tsl.nn.layers.graph_convs.graph_attention

import math
from typing import Optional

import torch
import torch.nn.functional as F
from einops import rearrange
from torch import Tensor
from torch import nn as nn
from torch_geometric.data import Batch, Data
from torch_geometric.nn.conv import GATConv, MessagePassing
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.typing import Adj, OptTensor
from torch_geometric.utils import add_remaining_self_loops

from tsl.nn.functional import sparse_softmax


class AttentionScores(MessagePassing):

    def __init__(self,
                 embed_dim,
                 heads: int = 1,
                 qdim: Optional[int] = None,
                 kdim: Optional[int] = None,
                 edge_dim: Optional[int] = None,
                 add_self_loops: bool = True,
                 **kwargs):
        kwargs.setdefault('aggr', 'add')
        super(AttentionScores, self).__init__(node_dim=-3, **kwargs)

        self.embed_dim = int(embed_dim)
        self.qdim = int(qdim) if qdim is not None else self.embed_dim
        self.kdim = int(kdim) if kdim is not None else self.embed_dim

        self.heads = heads
        self.edge_dim = edge_dim
        self.add_self_loops = add_self_loops
        self._alpha = None

        self.lin_key = Linear(self.kdim, heads * embed_dim)
        self.lin_query = Linear(self.qdim, heads * embed_dim)
        if edge_dim is not None:
            self.lin_edge = Linear(edge_dim, heads * embed_dim, bias=False)
        else:
            self.lin_edge = self.register_parameter('lin_edge', None)

        self.reset_parameters()

    def reset_parameters(self):
        self.lin_key.reset_parameters()
        self.lin_query.reset_parameters()
        if self.edge_dim:
            self.lin_edge.reset_parameters()

    def forward(self,
                query: Tensor,
                key: Tensor,
                edge_index: Adj,
                edge_attr: OptTensor = None):
        n = (key if key.size(-2) > query.size(-2) else query).size(-2)
        h_and_c = self.heads, self.embed_dim
        # project and split heads
        query = self.lin_query(query).view(query.shape[:-1] + h_and_c)
        key = self.lin_key(key).view(key.shape[:-1] + h_and_c)
        # add self loops (i.e., attend also to self)
        if self.add_self_loops:
            edge_index, edge_attr = add_remaining_self_loops(edge_index,
                                                             edge_attr,
                                                             num_nodes=n)
        # project edge_attr
        if self.lin_edge is not None:
            assert edge_attr is not None
            edge_attr = self.lin_edge(edge_attr).view(edge_attr.shape[:-1] +
                                                      h_and_c)
        # propagate and attend
        self.propagate(edge_index,
                       q=query,
                       k=key,
                       edge_attr=edge_attr,
                       size=(n, n))
        # retrieve scores
        alpha = self._alpha.mean(-1)
        self._alpha = None

        return alpha

    def message(self, q_i: Tensor, k_j: OptTensor, edge_attr: OptTensor,
                index: Tensor, size_i: int) -> Tensor:

        # cat edge_attr to query and key
        if edge_attr is not None:
            q_i = q_i + edge_attr
            k_j = k_j + edge_attr

        # compute scores
        alpha = (q_i * k_j).sum(dim=-1) / math.sqrt(self.embed_dim)
        alpha = sparse_softmax(alpha, index, num_nodes=size_i, dim=-2)
        self._alpha = alpha

        return alpha


[docs]class MultiHeadGraphAttention(MessagePassing): """The multi-head attention from the paper `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_ (Vaswani et al., NeurIPS 2017) for graph-structured data. Args: embed_dim (int): Size of the embedding dimension. num_heads (int): Number of attention heads. (default: :obj:`1`) qdim (int, optional): Number of features of the query. If :obj:`None`, then defaults to :attr:`embed_dim`. (default: :obj:`None`) kdim (int, optional): Number of features of the key. If :obj:`None`, then defaults to :attr:`embed_dim`. (default: :obj:`None`) vdim (int, optional): Number of features of the value. If :obj:`None`, then defaults to :attr:`embed_dim`. (default: :obj:`None`) edge_dim (int, optional): Number of edge features (:obj:`None` if there are no edge features). (default: :obj:`None`) concat (bool): If :obj:`True`, then the heads' outputs are concatenated along the feature dimension, and the dimension of each head's output is :obj:`embed_dim / num_heads`. Note that the total number of features in output is :attr:`embed_dim` in both cases. (default: :obj:`True`) dropout (float, optional): The dropout rate. (default: :obj:`0`) root_weight (bool): If :obj:`True`, then add a skip connection from the input with a linear transformation. (default :obj:`True`) bias (bool, optional): If :obj:`True`, then add a bias vector in output. (default: :obj:`True`) **kwargs: keyword arguments for the ``super(MessagePassing)`` call. """ def __init__(self, embed_dim: int, num_heads: int = 1, qdim: Optional[int] = None, kdim: Optional[int] = None, vdim: Optional[int] = None, edge_dim: Optional[int] = None, concat: bool = True, dropout: float = 0., root_weight: bool = True, bias: bool = True, **kwargs): kwargs.setdefault('aggr', 'add') super(MultiHeadGraphAttention, self).__init__(node_dim=-3, **kwargs) self.embed_dim = int(embed_dim) self.qdim = int(qdim) if qdim is not None else self.embed_dim self.kdim = int(kdim) if kdim is not None else self.embed_dim self.vdim = int(vdim) if vdim is not None else self.embed_dim self.edge_dim = edge_dim self.num_heads = num_heads self.concat = concat self.dropout = dropout self.root_weight = root_weight self._alpha = None if self.concat: self.head_dim = self.embed_dim // self.num_heads assert self.head_dim * self.num_heads == self.embed_dim, \ "embed_dim must be divisible by heads" else: self.head_dim = self.embed_dim # key bias is discarded in softmax self.lin_key = Linear(self.kdim, num_heads * self.head_dim, bias=False) self.lin_query = Linear(self.qdim, num_heads * self.head_dim, bias_initializer='zeros') self.lin_value = Linear(self.vdim, num_heads * self.head_dim, bias_initializer='zeros') if edge_dim is not None: self.lin_edge = Linear(edge_dim, num_heads * self.head_dim, bias=False) else: self.lin_edge = self.register_parameter('lin_edge', None) if self.root_weight: self.lin_skip = Linear(self.vdim, self.embed_dim, bias=bias) else: self.lin_skip = self.register_parameter('lin_skip', None) self.reset_parameters()
[docs] def reset_parameters(self): self.lin_key.reset_parameters() self.lin_query.reset_parameters() self.lin_value.reset_parameters() if self.edge_dim: self.lin_edge.reset_parameters() if self.lin_skip is not None: self.lin_skip.reset_parameters()
def forward(self, query: Tensor, key: Tensor, value: Tensor, edge_index: Adj, edge_attr: OptTensor = None, return_attention_weights: Optional[bool] = False, return_attention_matrix: Optional[bool] = False): """""" # inputs: [*batch, nodes, channels] n = value.size(-2) x = value # save original input for skip connection # project and split heads h_and_c = self.num_heads, self.head_dim query = self.lin_query(query).view(query.shape[:-1] + h_and_c) key = self.lin_key(key).view(key.shape[:-1] + h_and_c) value = self.lin_value(value).view(value.shape[:-1] + h_and_c) # project edge_attr if self.lin_edge is not None: assert edge_attr is not None edge_attr = self.lin_edge(edge_attr).view(edge_attr.shape[:-1] + h_and_c) # propagate and attend out = self.propagate(edge_index, q=query, k=key, v=value, edge_attr=edge_attr, size=(n, n)) # Concatenate or average heads if self.concat: out = out.view(*out.shape[:-2], self.embed_dim) else: out = out.mean(dim=-2) # skip connection if self.root_weight: out = out + self.lin_skip(x) # retrieve scores alpha = self._alpha self._alpha = None if return_attention_weights: return out, alpha elif return_attention_matrix: # arrange weights as (batched) dense matrix W = torch.zeros((*alpha.shape[:-2], n, n), dtype=alpha.dtype, device=alpha.device) W[..., edge_index[0], edge_index[1]] = alpha.mean(-1) return out, W else: return out def message(self, q_i: Tensor, k_j: OptTensor, v_j: OptTensor, edge_attr: OptTensor, index: Tensor, size_i: int) -> Tensor: """""" # cat edge_attr to query and key if edge_attr is not None: q_i = q_i + edge_attr k_j = k_j + edge_attr # compute scores alpha = (q_i * k_j).sum(dim=-1) / math.sqrt(self.embed_dim) alpha = sparse_softmax(alpha, index, num_nodes=size_i, dim=-2) self._alpha = alpha alpha = F.dropout(alpha, p=self.dropout, training=self.training) # attend out = v_j * alpha[..., None] return out
class GATLayer(nn.Module): def __init__(self, d_model, n_heads, concat=False, dropout=0.1): super(GATLayer, self).__init__() self.gat_conv = GATConv(in_channels=d_model, out_channels=d_model, heads=n_heads, dropout=dropout, concat=concat) def forward(self, x, edge_index): # x:[b, s, n, f] b, s, n, _ = x.size() x = rearrange(x, 'b s n f -> (b s n) f') rep = b * s batch_edge_index = Batch.from_data_list([ Data(edge_index=edge_index, num_nodes=n), ] * rep) x = self.gat_conv(x, batch_edge_index.edge_index) x = rearrange(x, '(b s n) f -> b s n f', b=b, s=s) return x