Source code for tsl.nn.layers.recurrent.evolvegcn

import math
from typing import Optional, Tuple

import torch
import torch.nn.functional as F
from einops import rearrange, repeat
from torch import Tensor, nn
from torch.nn import Parameter
from torch_geometric.nn import MessagePassing
from torch_sparse import SparseTensor, matmul

from tsl.nn.layers.graph_convs.mixin import NormalizedAdjacencyMixin
from tsl.nn.utils import get_functional_activation


class _TopK(torch.nn.Module):

    def __init__(self, input_size, k):
        super().__init__()
        self.k = k

        self.p = Parameter(torch.Tensor(input_size, 1))
        self.reset_parameters()

    def reset_parameters(self):
        # Initialize based on the number of rows
        stdv = 1. / math.sqrt(self.p.size(0))
        self.p.data.uniform_(-stdv, stdv)

    def forward(self, x):
        """"""
        scores = x @ self.p / self.p.norm()
        vals, topk_idxs = scores.topk(self.k, dim=-2)  # b k 1 -> b k f
        idxs = topk_idxs.expand(-1, self.k, x.size(-1))
        x = torch.gather(x, -2, idxs)
        out = x * F.tanh(vals)

        return out


class _EvolveGCNCell(MessagePassing, NormalizedAdjacencyMixin):
    r"""
    """

    def __init__(self,
                 in_size,
                 out_size,
                 norm,
                 activation='relu',
                 root_weight=False,
                 bias=True,
                 cached=False):
        super(_EvolveGCNCell, self).__init__(aggr='add')
        self.in_size = in_size
        self.out_size = out_size
        self.norm = norm
        self.cached = cached

        self.activation_fn = get_functional_activation(activation)

        self.W0 = nn.Parameter(torch.Tensor(in_size, out_size))

        if bias:
            self.bias = Parameter(torch.Tensor(out_size))
        else:
            self.register_parameter('bias', None)

        if root_weight:
            self.skip_con = nn.Linear(in_size, out_size, bias=False)
        else:
            self.register_parameter('skip_con', None)

    def reset_parameters(self):
        std = 1. / math.sqrt(self.out_size)
        self.W0.data.uniform_(-std, std)
        if self.bias is not None:
            torch.nn.init.zeros_(self.bias)
        if self.skip_con is not None:
            self.skip_con.reset_parameters()


[docs]class EvolveGCNHCell(_EvolveGCNCell): r"""EvolveGCNH cell from the paper `"EvolveGCN: Evolving Graph Convolutional Networks for Dynamic Graphs" <https://arxiv.org/abs/1902.10191>`_ (Pereja et al., AAAI 2020). This variant of the model adapts the weights of the graph convolution by looking at node features. Args: in_size (int): Size of the input. out_size (int): Number of units in the hidden state. norm (bool): Methods used to normalize the adjacency matrix. activation (str): Activation function after the GCN layer. root_weight (bool): Whether to add a parametrized skip connection. bias (bool): Whether to learn a bias. cached (bool): Whether to cache normalized edge_weights. """ _cached_edge_index: Optional[Tuple[Tensor, Tensor]] _cached_adj_t: Optional[SparseTensor] def __init__(self, in_size, out_size, norm, activation='relu', root_weight=False, bias=True, cached=False): super(EvolveGCNHCell, self).__init__(in_size, out_size, norm=norm, activation=activation, root_weight=root_weight, bias=bias, cached=cached) self.gru_cell = nn.GRUCell(input_size=in_size, hidden_size=in_size) self.pooling_layer = _TopK(input_size=in_size, k=out_size) self.reset_parameters()
[docs] def reset_parameters(self): super(EvolveGCNHCell, self).reset_parameters() self.gru_cell.reset_parameters() self.pooling_layer.reset_parameters()
def forward(self, x, h, edge_index, edge_weight=None): """""" edge_index, edge_weight = self.normalize_edge_index( x, edge_index, edge_weight, use_cached=self.cached) if h is None: W = repeat(self.W0, 'din dout -> b din dout', b=x.size(0)) else: W = h h_gru = rearrange(W, 'b din dout -> (b dout) din') x_gru = rearrange(self.pooling_layer(x), 'b dout din -> (b dout) din') h_gru = self.gru_cell(x_gru, h_gru) W = rearrange(h_gru, '(b dout) din -> b din dout', dout=self.out_size) out = torch.matmul(x, W) out = self.propagate(edge_index, x=out, edge_weight=edge_weight) if self.bias is not None: out += self.bias if self.skip_con is not None: out += self.skip_con(x) return self.activation_fn(out), W def message(self, x_j: Tensor, edge_weight) -> Tensor: """""" return edge_weight.view(-1, 1) * x_j def message_and_aggregate(self, adj_t: SparseTensor, x: Tensor) -> Tensor: """""" # adj_t: SparseTensor [nodes, nodes] # x: [(batch,) nodes, channels] return matmul(adj_t, x, reduce=self.aggr)
[docs]class EvolveGCNOCell(_EvolveGCNCell): r"""EvolveGCNO cell from the paper `"EvolveGCN: Evolving Graph Convolutional Networks for Dynamic Graphs" <https://arxiv.org/abs/1902.10191>`_ (Pereja et al., AAAI 2020). This variant of the model simply updates the weights of the graph convolution. Args: in_size (int): Size of the input. out_size (int): Number of units in the hidden state. norm (str): Method used to normalize the adjacency matrix. activation (str): Activation function after the GCN layer. root_weight (bool): Whether to add a parametrized skip connection. bias (bool): Whether to learn a bias. cached (bool): Whether to cache normalized edge_weights. """ _cached_edge_index: Optional[Tuple[Tensor, Tensor]] _cached_adj_t: Optional[SparseTensor] def __init__(self, in_size, out_size, norm, activation='relu', root_weight=False, bias=True, cached=False): super(EvolveGCNOCell, self).__init__(in_size, out_size, norm=norm, activation=activation, root_weight=root_weight, bias=bias, cached=cached) self.lstm_cell = nn.LSTMCell(input_size=in_size, hidden_size=in_size) self.reset_parameters()
[docs] def reset_parameters(self): super(EvolveGCNOCell, self).reset_parameters() self.lstm_cell.reset_parameters()
def forward(self, x, hs, edge_index, edge_weight=None): """""" edge_index, edge_weight = self.normalize_edge_index( x, edge_index, edge_weight, use_cached=self.cached) if hs is None: W = repeat(self.W0, 'din dout -> b din dout', b=x.size(0)) hc = None else: W, hc = hs x_lstm = rearrange(W, 'b din dout -> (b dout) din') hc = self.lstm_cell(x_lstm, hc) W = rearrange(hc[0], '(b dout) din -> b din dout', dout=self.out_size) out = torch.matmul(x, W) out = self.propagate(edge_index, x=out, edge_weight=edge_weight) if self.bias is not None: out += self.bias if self.skip_con is not None: out += self.skip_con(x) return self.activation_fn(out), (W, hc)