from torch import nn
from tsl.nn.layers.graph_convs import DiffConv
from tsl.nn.layers.norm import Norm
from tsl.nn.utils import get_layer_activation
from .tcn import TemporalConvNet
[docs]class SpatioTemporalConvNet(nn.Module):
r"""SpatioTemporalConvolutional encoder with optional linear readout.
Applies several temporal convolutions followed by diffusion convolution
over a graph.
Args:
input_size (int): Input size.
output_size (int): Channels in the output representation.
temporal_kernel_size (int): Size of the temporal convolutional kernel.
spatial_kernel_size (int): Size of the spatial diffusion kernel.
temporal_convs (int, optional): Number of temporal convolutions.
(default: :obj:`2`)
spatial_convs (int, optional): Number of spatial convolutions.
(default: :obj:`1`)
dilation (int): Dilation coefficient of the temporal convolutional
kernel.
norm (str, optional): Type of normalization applied to the hidden units.
dropout (float, optional): Dropout probability.
gated (bool, optional): Whether to used the GatedTanH activation
function after temporal convolutions.
(default: :obj:`False`)
pad (bool, optional): Whether to pad the input sequence to preserve the
sequence length.
activation (str, optional): Activation function.
(default: :obj:`'relu'`)
"""
def __init__(self,
input_size,
output_size,
temporal_kernel_size,
spatial_kernel_size,
temporal_convs=2,
spatial_convs=1,
dilation=1,
norm='none',
dropout=0.,
gated=False,
pad=True,
activation='relu'):
super(SpatioTemporalConvNet, self).__init__()
self.pad = pad
self.tcn = nn.Sequential(
Norm(norm_type=norm, in_channels=input_size),
TemporalConvNet(input_channels=input_size,
hidden_channels=output_size,
kernel_size=temporal_kernel_size,
dilation=dilation,
exponential_dilation=True,
n_layers=temporal_convs,
activation=activation,
causal_padding=pad,
dropout=dropout,
gated=gated))
self.skip_conn = nn.Linear(input_size, output_size)
self.spatial_convs = nn.ModuleList(
DiffConv(in_channels=output_size,
out_channels=output_size,
k=spatial_kernel_size) for _ in range(spatial_convs))
self.spatial_norms = nn.ModuleList(
Norm(norm_type=norm, in_channels=output_size)
for _ in range(spatial_convs))
self.dropout = nn.Dropout(dropout)
self.activation = get_layer_activation(activation)()
def forward(self, x, edge_index, edge_weight=None):
""""""
# temporal conv
x = self.skip_conn(x) + self.tcn(x)
# spatial conv
for filter, norm in zip(self.spatial_convs, self.spatial_norms):
x_neigh = filter(norm(x), edge_index, edge_weight)
x = x + self.dropout(self.activation(x_neigh))
return x