Source code for tsl.nn.layers.graph_convs.graph_conv

import torch
from torch import Tensor, nn
from torch.nn import Parameter
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.typing import Adj, OptTensor
from torch_sparse import SparseTensor, matmul

from tsl.nn.layers.graph_convs.mixin import NormalizedAdjacencyMixin
from tsl.nn.utils import get_functional_activation


[docs]class GraphConv(MessagePassing, NormalizedAdjacencyMixin): r"""A simple graph convolutional operator where the message function is a simple linear projection and aggregation a simple average. In other terms: .. math:: \mathbf{X}^{\prime} = \mathbf{\hat{D}}^{-1} \mathbf{\tilde{A}} \mathbf{X} \boldsymbol{\Theta} + \mathbf{b} . Args: input_size (int): Size of the input features. output_size (int): Size of the output features. bias (bool): If :obj:`False`, then the layer will not learn an additive bias vector. (default: :obj:`True`) norm (str): The normalization used for edges and edge weights. If :obj:`'mean'`, then edge weights are normalized as :math:`a_{j \rightarrow i} = \frac{a_{j \rightarrow i}} {deg_{i}}`, other available options are: :obj:`'gcn'`, :obj:`'asym'` and :obj:`'none'`. (default: :obj:`'mean'`) root_weight (bool): If :obj:`True`, then add a linear layer for the root node itself (a skip connection). (default :obj:`True`) activation (str, optional): Activation function to be used, :obj:`None` for identity function (i.e., no activation). (default: :obj:`None`) cached (bool): If :obj:`True`, then cached the normalized edge weights computed in the first call. (default :obj:`False`) """ def __init__(self, input_size: int, output_size: int, bias: bool = True, norm: str = 'mean', root_weight: bool = True, activation: str = None, cached: bool = False): super(GraphConv, self).__init__(aggr="add", node_dim=-2) self.in_channels = input_size self.out_channels = output_size self.norm = norm self.cached = cached self.activation = get_functional_activation(activation) self.lin = nn.Linear(input_size, output_size, bias=False) if root_weight: self.root_lin = nn.Linear(input_size, output_size, bias=False) else: self.register_parameter('root_lin', None) if bias: self.bias = Parameter(torch.Tensor(output_size)) else: self.register_parameter('bias', None) self.reset_parameters()
[docs] def reset_parameters(self): self.lin.reset_parameters() if self.root_lin is not None: self.root_lin.reset_parameters() if self.bias is not None: torch.nn.init.zeros_(self.bias)
def forward(self, x: Tensor, edge_index: Adj, edge_weight: OptTensor = None) -> Tensor: """""" out = self.lin(x) edge_index, edge_weight = self.normalize_edge_index( x, edge_index=edge_index, edge_weight=edge_weight, use_cached=self.cached) out = self.propagate(edge_index, x=out, edge_weight=edge_weight) if self.root_lin is not None: out += self.root_lin(x) if self.bias is not None: out += self.bias return self.activation(out) def message(self, x_j: Tensor, edge_weight) -> Tensor: """""" return edge_weight.view(-1, 1) * x_j def message_and_aggregate(self, adj_t: SparseTensor, x: Tensor) -> Tensor: """""" # adj_t: SparseTensor [nodes, nodes] # x: [(batch,) nodes, channels] return matmul(adj_t, x, reduce=self.aggr)