Source code for tsl.nn.blocks.encoders.tcn

import torch.nn as nn
from einops import rearrange

from tsl.nn.layers.base import GatedTemporalConv, TemporalConv
from tsl.nn.utils import get_functional_activation, maybe_cat_exog


[docs]class TemporalConvNet(nn.Module): r"""Simple TCN encoder with optional linear readout. Args: input_channels (int): Input size. hidden_channels (int): Channels in the hidden layers. kernel_size (int): Size of the convolutional kernel. dilation (int): Dilation coefficient of the convolutional kernel. stride (int, optional): Stride of the convolutional kernel. output_channels (int, optional): Channels of the optional exogenous variables. output_channels (int, optional): Channels in the output layer. n_layers (int, optional): Number of hidden layers. (default: 1) gated (bool, optional): Whether to use the GatedTanH activation function. (default: :obj:`False`) dropout (float, optional): Dropout probability. activation (str, optional): Activation function. (default: :obj:`'relu'`) exponential_dilation (bool, optional): Whether to increase exponentially the dilation factor at each layer. weight_norm (bool, optional): Whether to apply weight normalization to the temporal convolutional filters. causal_padding (bool, optional): Whether to pad the input sequence to preserve causality. bias (bool, optional): Whether to add a learnable bias to the output. channel_last (bool, optional): If :obj:`True` input must have layout (b s n c), (b c n s) otherwise. """ def __init__(self, input_channels, hidden_channels, kernel_size, dilation, stride=1, exog_channels=None, output_channels=None, n_layers=1, gated=False, dropout=0., activation='relu', exponential_dilation=False, weight_norm=False, causal_padding=True, bias=True, channel_last=True): super(TemporalConvNet, self).__init__() self.channel_last = channel_last base_conv = TemporalConv if not gated else GatedTemporalConv if exog_channels is not None: input_channels += exog_channels layers = [] d = dilation receptive_field = 1 for i in range(n_layers): if exponential_dilation: d = dilation**i layers.append( base_conv(input_channels=input_channels if i == 0 else hidden_channels, output_channels=hidden_channels, kernel_size=kernel_size, dilation=d, stride=stride, causal_pad=causal_padding, weight_norm=weight_norm, bias=bias)) receptive_field += (kernel_size - 1) * d * stride self.receptive_field = receptive_field self.convs = nn.ModuleList(layers) self.f = get_functional_activation( activation) if not gated else nn.Identity() self.dropout = nn.Dropout(dropout) if dropout > 0. else nn.Identity() if output_channels is not None: self.readout = TemporalConv(input_channels=hidden_channels, output_channels=output_channels, kernel_size=1) else: self.register_parameter('readout', None) def forward(self, x, u=None): """""" if self.channel_last: x = maybe_cat_exog(x, u, -1) x = rearrange(x, 'b s n c -> b c n s') else: x = maybe_cat_exog(x, u, 1) for conv in self.convs: x = self.dropout(self.f(conv(x))) if self.readout is not None: x = self.readout(x) if self.channel_last: x = rearrange(x, 'b c n s -> b s n c') return x