from typing import Optional, Tuple, Union
import torch
import torch.nn.functional as F
from torch import Tensor
from torch.nn import Parameter
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.nn.inits import glorot, zeros
from torch_geometric.typing import Adj, OptPairTensor, OptTensor
from torch_geometric.utils import add_self_loops, remove_self_loops
from torch_sparse import SparseTensor, set_diag
from tsl.nn.functional import sparse_softmax
[docs]class GATConv(MessagePassing):
r"""Extension of :class:`~torch_geometric.nn.conv.GATConv` for static graphs
with multidimensional features.
The graph attentional operator from the `"Graph Attention Networks"
<https://arxiv.org/abs/1710.10903>`_ paper
.. math::
\mathbf{x}^{\prime}_i = \alpha_{i,i}\mathbf{\Theta}\mathbf{x}_{i} +
\sum_{j \in \mathcal{N}(i)} \alpha_{i,j}\mathbf{\Theta}\mathbf{x}_{j},
where the attention coefficients :math:`\alpha_{i,j}` are computed as
.. math::
\alpha_{i,j} =
\frac{
\exp\left(\mathrm{LeakyReLU}\left(\mathbf{a}^{\top}
[\mathbf{\Theta}\mathbf{x}_i \, \Vert \, \mathbf{\Theta}\mathbf{x}_j]
\right)\right)}
{\sum_{k \in \mathcal{N}(i) \cup \{ i \}}
\exp\left(\mathrm{LeakyReLU}\left(\mathbf{a}^{\top}
[\mathbf{\Theta}\mathbf{x}_i \, \Vert \, \mathbf{\Theta}\mathbf{x}_k]
\right)\right)}.
If the graph has multi-dimensional edge features :math:`\mathbf{e}_{i,j}`,
the attention coefficients :math:`\alpha_{i,j}` are computed as
.. math::
\alpha_{i,j} =
\frac{
\exp\left(\mathrm{LeakyReLU}\left(\mathbf{a}^{\top}
[\mathbf{\Theta}\mathbf{x}_i \, \Vert \, \mathbf{\Theta}\mathbf{x}_j
\, \Vert \, \mathbf{\Theta}_{e} \mathbf{e}_{i,j}]\right)\right)}
{\sum_{k \in \mathcal{N}(i) \cup \{ i \}}
\exp\left(\mathrm{LeakyReLU}\left(\mathbf{a}^{\top}
[\mathbf{\Theta}\mathbf{x}_i \, \Vert \, \mathbf{\Theta}\mathbf{x}_k
\, \Vert \, \mathbf{\Theta}_{e} \mathbf{e}_{i,k}]\right)\right)}.
Args:
in_channels (int or tuple): Size of each input sample, or :obj:`-1` to
derive the size from the first input(s) to the forward method.
A tuple corresponds to the sizes of source and target
dimensionalities.
out_channels (int): Size of each output sample.
heads (int, optional): Number of multi-head-attentions.
(default: :obj:`1`)
concat (bool, optional): If set to :obj:`True`, the output dimension of
each attention head is :obj:`out_channels/heads` and all heads'
output are concatenated, resulting in :obj:`out_channels` number of
features. If set to :obj:`False`, the multi-head attentions are
averaged instead of concatenated.
(default: :obj:`True`)
dim (int): The axis along which to propagate. (default: :obj:`-2`)
negative_slope (float, optional): LeakyReLU angle of the negative
slope. (default: :obj:`0.2`)
dropout (float, optional): Dropout probability of the normalized
attention coefficients which exposes each node to a stochastically
sampled neighborhood during training. (default: :obj:`0`)
add_self_loops (bool, optional): If set to :obj:`False`, will not add
self-loops to the input graph. (default: :obj:`True`)
edge_dim (int, optional): Edge feature dimensionality (in case
there are any). (default: :obj:`None`)
fill_value (float or Tensor or str, optional): The way to generate
edge features of self-loops (in case :obj:`edge_dim != None`).
If given as :obj:`float` or :class:`torch.Tensor`, edge features of
self-loops will be directly given by :obj:`fill_value`.
If given as :obj:`str`, edge features of self-loops are computed by
aggregating all features of edges that point to the specific node,
according to a reduce operation. (:obj:`"add"`, :obj:`"mean"`,
:obj:`"min"`, :obj:`"max"`, :obj:`"mul"`). (default: :obj:`"mean"`)
bias (bool, optional): If set to :obj:`False`, the layer will not learn
an additive bias. (default: :obj:`True`)
**kwargs (optional): Additional arguments of
:class:`torch_geometric.nn.conv.MessagePassing`.
Shapes:
- **input:**
node features :math:`(*, |\mathcal{V}|, *, F_{in})` or
:math:`((*, |\mathcal{V_s}|, *, F_s), (*, |\mathcal{V_t}|, *, F_t))`
if bipartite,
edge indices :math:`(2, |\mathcal{E}|)`,
edge features :math:`(|\mathcal{E}|, D)` *(optional)*
- **output:**
node features :math:`(*, |\mathcal{V}|, *, F_{out})` or
:math:`((*, |\mathcal{V}_t|, *, F_{out})` if bipartite
attention_weights :math:`((2, |\mathcal{E}|), (|\mathcal{E}|, H)))` if
:obj:`need_weights` is :obj:`True` else :obj:`None`
"""
def __init__(
self,
in_channels: Union[int, Tuple[int, int]],
out_channels: int,
heads: int = 1,
concat: bool = True,
dim: int = -2,
negative_slope: float = 0.2,
dropout: float = 0.0,
add_self_loops: bool = True,
edge_dim: Optional[int] = None,
fill_value: Union[float, Tensor, str] = 'mean',
bias: bool = True,
**kwargs,
):
kwargs.setdefault('aggr', 'add')
super().__init__(node_dim=dim, **kwargs)
self.in_channels = in_channels
self.out_channels = out_channels
self.heads = heads
self.concat = concat
self.negative_slope = negative_slope
self.dropout = dropout
self.add_self_loops = add_self_loops
self.edge_dim = edge_dim
self.fill_value = fill_value
if self.concat:
self.head_channels = self.out_channels // self.heads
assert self.head_channels * self.heads == self.out_channels, \
"`out_channels` must be divisible by `heads`."
else:
self.head_channels = self.out_channels
# In case we are operating in bipartite graphs, we apply separate
# transformations 'lin_src' and 'lin_dst' to source and target nodes:
if isinstance(in_channels, int):
self.lin_src = Linear(in_channels,
heads * self.head_channels,
bias=False,
weight_initializer='glorot')
self.lin_dst = self.lin_src
else:
self.lin_src = Linear(in_channels[0],
heads * self.head_channels,
False,
weight_initializer='glorot')
self.lin_dst = Linear(in_channels[1],
heads * self.head_channels,
False,
weight_initializer='glorot')
# The learnable parameters to compute attention coefficients:
self.att_src = Parameter(torch.Tensor(1, heads, self.head_channels))
self.att_dst = Parameter(torch.Tensor(1, heads, self.head_channels))
if edge_dim is not None:
self.lin_edge = Linear(edge_dim,
heads * self.head_channels,
bias=False,
weight_initializer='glorot')
self.att_edge = Parameter(
torch.Tensor(1, heads, self.head_channels))
else:
self.lin_edge = None
self.register_parameter('att_edge', None)
if bias and concat:
self.bias = Parameter(torch.Tensor(heads * self.head_channels))
elif bias and not concat:
self.bias = Parameter(torch.Tensor(out_channels))
else:
self.register_parameter('bias', None)
self.reset_parameters()
[docs] def reset_parameters(self):
self.lin_src.reset_parameters()
self.lin_dst.reset_parameters()
if self.lin_edge is not None:
self.lin_edge.reset_parameters()
glorot(self.att_src)
glorot(self.att_dst)
glorot(self.att_edge)
zeros(self.bias)
def forward(self,
x: Union[Tensor, OptPairTensor],
edge_index: Adj,
edge_attr: OptTensor = None,
need_weights: bool = False):
""""""
node_dim = self.node_dim
self.node_dim = (node_dim + x.dim()) if node_dim < 0 else node_dim
N, H, C = x.size(self.node_dim), self.heads, self.head_channels
# We first transform the input node features. If a tuple is passed, we
# transform source and target node features via separate weights:
if isinstance(x, Tensor):
x_src = x_dst = self.lin_src(x).view(*x.shape[:-1], H, C)
else: # Tuple of source and target node features:
x_src, x_dst = x
x_src = self.lin_src(x_src).view(*x_src.shape[:-1], H, C)
if x_dst is not None:
x_dst = self.lin_dst(x_dst).view(*x_dst.shape[:-1], H, C)
x = (x_src, x_dst)
# Next, we compute node-level attention coefficients, both for source
# and target nodes (if present):
alpha_src = (x_src * self.att_src).sum(dim=-1)
alpha_dst = None if x_dst is None else (x_dst * self.att_dst).sum(-1)
alpha = (alpha_src, alpha_dst)
if self.add_self_loops:
if isinstance(edge_index, Tensor):
edge_index, edge_attr = remove_self_loops(
edge_index, edge_attr)
edge_index, edge_attr = add_self_loops(
edge_index,
edge_attr,
fill_value=self.fill_value,
num_nodes=N)
elif isinstance(edge_index, SparseTensor):
if self.edge_dim is None:
edge_index = set_diag(edge_index)
else:
raise NotImplementedError(
"The usage of 'edge_attr' and 'add_self_loops' "
"simultaneously is currently not yet supported for "
"'edge_index' in a 'SparseTensor' form")
# edge_updater_type: (alpha: OptPairTensor, edge_attr: OptTensor)
alpha = self.edge_updater(edge_index, alpha=alpha, edge_attr=edge_attr)
# propagate_type: (x: OptPairTensor, alpha: Tensor)
out = self.propagate(edge_index, x=x, alpha=alpha, size=(N, N))
if self.concat:
out = out.view(*out.shape[:-2], self.out_channels)
else:
out = out.mean(dim=-2)
if self.bias is not None:
out += self.bias
if need_weights:
# alpha rearrange: [... e ... h] -> [e ... h]
alpha = torch.movedim(alpha, self.node_dim, 0)
if isinstance(edge_index, Tensor):
alpha = (edge_index, alpha)
elif isinstance(edge_index, SparseTensor):
alpha = edge_index.set_value(alpha, layout='coo')
else:
alpha = None
self.node_dim = node_dim
return out, alpha
def edge_update(self, alpha_j: Tensor, alpha_i: OptTensor,
edge_attr: OptTensor, index: Tensor, ptr: OptTensor,
size_i: Optional[int]) -> Tensor:
""""""
# Given edge-level attention coefficients for source and target nodes,
# we simply need to sum them up to "emulate" concatenation:
alpha = alpha_j if alpha_i is None else alpha_j + alpha_i
if edge_attr is not None:
if edge_attr.dim() == 1:
edge_attr = edge_attr.view(-1, 1)
assert self.lin_edge is not None
edge_attr = self.lin_edge(edge_attr)
edge_attr = edge_attr.view(-1, self.heads, self.head_channels)
alpha_edge = (edge_attr * self.att_edge).sum(dim=-1)
shape = [1] * (alpha.ndim - 1) + [self.heads]
shape[self.node_dim] = alpha_edge.size(0)
alpha = alpha + alpha_edge.view(shape)
alpha = F.leaky_relu(alpha, self.negative_slope)
alpha = sparse_softmax(alpha,
index,
num_nodes=size_i,
ptr=ptr,
dim=self.node_dim)
alpha = F.dropout(alpha, p=self.dropout, training=self.training)
return alpha
def message(self, x_j: Tensor, alpha: Tensor) -> Tensor:
""""""
return alpha.unsqueeze(-1) * x_j
def __repr__(self) -> str:
return (f'{self.__class__.__name__}({self.in_channels}, '
f'{self.out_channels}, heads={self.heads})')