Source code for tsl.nn.layers.multi.conv

import math
from typing import Union

import torch
from einops import rearrange
from torch import nn
from torch.nn import functional as F
from torch.nn import init


[docs]class MultiConv1d(nn.Module): """Applies convolutions with different weights to the different instances in the input data.""" def __init__(self, in_channels: int, out_channels: int, n_instances: int, kernel_size: int, stride: int = 1, padding: Union[str, int] = 0, dilation: int = 1, bias: bool = True, device=None, dtype=None): factory_kwargs = {'device': device, 'dtype': dtype} super(MultiConv1d, self).__init__() self.in_channels = in_channels self.out_channels = out_channels self.n_instances = n_instances self.kernel_size = kernel_size self.stride = stride self.padding = padding self.dilation = dilation self.weight = nn.Parameter( torch.empty((n_instances * out_channels, in_channels, kernel_size), **factory_kwargs)) if bias: self.bias = nn.Parameter( torch.empty(n_instances, out_channels, **factory_kwargs)) else: self.register_parameter('bias', None) self.reset_parameters() def extra_repr(self) -> str: """""" return f'{self.in_channels}, {self.out_channels}, ' \ f'kernel_size={self.kernel_size}, n_instances={self.n_instances}' def reset_parameters(self) -> None: bound = 1 / math.sqrt(self.in_channels * self.kernel_size) init.uniform_(self.weight.data, -bound, bound) if self.bias is not None: init.uniform_(self.bias.data, -bound, bound) def forward(self, x): """""" x = rearrange(x, '... t n f -> ... (n f) t') out = F.conv1d(x, weight=self.weight, bias=None, stride=self.stride, dilation=self.dilation, groups=self.n_instances, padding=self.padding) out = rearrange(out, '... (n f) t -> ... t n f', f=self.out_channels) if self.bias is not None: out = out + self.bias return out
class MultiTemporalConv(nn.Module): """ Multi temporal convolutional filters. Inputs are expected to be of shape (batch, steps, n_instances, channels). Args: input_channels (int): Input size. output_channels (int): Output size. kernel_size (int): Size of the convolution kernel. dilation (int, optional): Spacing between kernel elements. stride (int, optional): Stride of the convolution. bias (bool, optional): Whether to add a learnable bias to the output of the convolution. padding (int, optional): Padding of the input. Used only of `causal_padding` is `False`. causal_padding (bool, optional): Whether to pad the input as to preserve causality. """ def __init__(self, input_channels: int, output_channels: int, kernel_size: int, n_instances: int, dilation: int = 1, stride: int = 1, bias: bool = True, padding: tuple = 0, causal_padding: bool = True): super().__init__() self._causal_pad_sizes = (0, 0, 0, 0, (kernel_size - 1) * dilation, 0) if causal_padding: assert padding == 0 self.causal_padding = True self.conv = MultiConv1d(in_channels=input_channels, out_channels=output_channels, n_instances=n_instances, kernel_size=kernel_size, stride=stride, padding=0, dilation=dilation, bias=bias) def forward(self, x): """""" if self.causal_padding: x = F.pad(x, self._causal_pad_sizes, mode='constant', value=0.) x = self.conv(x) return x