Source code for tsl.nn.layers.base.temporal_conv

from typing import Tuple, Union

import torch.nn as nn
from einops import rearrange
from torch import Tensor

from tsl.nn.functional import gated_tanh


[docs]class TemporalConv(nn.Module): """Learns a standard temporal convolutional filter. Args: input_channels (int): Input size. output_channels (int): Output size. kernel_size (int): Size of the convolution kernel. dilation (int, optional): Spacing between kernel elements. stride (int, optional): Stride of the convolution. bias (bool, optional): Whether to add a learnable bias to the output of the convolution. padding (int or tuple, optional): Padding of the input. Used only of `causal_pad` is `False`. causal_pad (bool, optional): Whether to pad the input as to preserve causality. weight_norm (bool, optional): Wheter to apply weight normalization to the parameters of the filter. """ def __init__(self, input_channels: int, output_channels: int, kernel_size: int, dilation: int = 1, stride: int = 1, bias: bool = True, padding: Union[int, Tuple[int]] = 0, causal_pad: bool = True, weight_norm: bool = False, channel_last: bool = False): super().__init__() self.input_channels = input_channels self.output_channels = output_channels self.kernel_size = kernel_size self.dilation = dilation self.stride = stride self.bias = bias self.padding = padding self.causal_pad = causal_pad self.weight_norm = weight_norm self.channel_last = channel_last if causal_pad: padding = ((kernel_size - 1) * dilation, 0, 0, 0) elif isinstance(padding, int): padding = (padding, padding, 0, 0) elif isinstance(padding, (list, tuple)): padding = (padding[0], padding[1], 0, 0) self.pad_layer = nn.ZeroPad2d(padding) # We use Conv2d here to accommodate multiple input sequences self.conv = nn.Conv2d(input_channels, output_channels, kernel_size=(1, kernel_size), stride=(1, stride), padding=(0, 0), dilation=(1, dilation), bias=bias) if self.weight_norm: self.conv = nn.utils.weight_norm(self.conv) def __repr__(self): s = ('{cls}({input_channels}, {output_channels}, ' 'kernel_size={kernel_size}, stride={stride}') if self.causal_pad: s += ', causal_padding={padding}' elif self.padding != 0: s += ', padding={padding}' if self.dilation != 1: s += ', dilation={dilation}' if not self.bias: s += ', bias=False' return (s + ')').format(cls=self.__class__.__name__, **self.__dict__) def forward(self, x: Tensor) -> Tensor: """""" if self.channel_last: x = rearrange(x, 'b t n f -> b f n t') # x: [batch, features, nodes, time] x = self.pad_layer(x) x = self.conv(x) if self.channel_last: x = rearrange(x, 'b f n t -> b t n f') return x
[docs]class GatedTemporalConv(TemporalConv): """Temporal convolutional filter with gated tanh connection.""" def __init__(self, input_channels: int, output_channels: int, kernel_size: int, dilation: int = 1, stride: int = 1, bias: bool = True, padding: Union[int, Tuple[int]] = 0, causal_pad: bool = True, weight_norm: bool = False, channel_last: bool = False): super(GatedTemporalConv, self).__init__( input_channels=input_channels, output_channels=2 * output_channels, kernel_size=kernel_size, dilation=dilation, stride=stride, bias=bias, padding=padding, causal_pad=causal_pad, weight_norm=weight_norm, channel_last=channel_last, ) def forward(self, x: Tensor) -> Tensor: """""" x = super(GatedTemporalConv, self).forward(x) dim = -1 if self.channel_last else 1 return gated_tanh(x, dim=dim)