Source code for tsl.nn.layers.multi.linear

import math
from typing import Union

import torch
from torch import Tensor, nn
from torch.nn import init


[docs]class MultiLinear(nn.Module): r"""Applies linear transformations with different weights to the different instances in the input data. .. math:: \mathbf{X}^{\prime} = [\boldsymbol{\Theta}_i \mathbf{x}_i + \mathbf{b}_i]_{i=0,\ldots,N} Args: in_channels (int): Size of instance input sample. out_channels (int): Size of instance output sample. n_instances (int): The number :math:`N` of parallel linear operations. Each operation has different weights and biases. instance_dim (int or str): Dimension of the instances (must match :attr:`n_instances` at runtime). (default: :obj:`-2`) channel_dim (int or str): Dimension of the input channels. (default: :obj:`-1`) bias (bool): If :obj:`True`, then the layer will learn an additive bias for each instance. (default: :obj:`True`) device (optional): The device of the parameters. (default: :obj:`None`) dtype (optional): The data type of the parameters. (default: :obj:`None`) Examples: >>> m = MultiLinear(20, 32, 10, pattern='t n f', instance_dim='n') >>> input = torch.randn(64, 12, 10, 20) # shape: [b t n f] >>> output = m(input) >>> print(output.size()) torch.Size([64, 24, 10, 32]) """ def __init__(self, in_channels: int, out_channels: int, n_instances: int, *, ndim: int = None, pattern: str = None, instance_dim: Union[int, str] = -2, channel_dim: Union[int, str] = -1, bias: bool = True, device=None, dtype=None) -> None: factory_kwargs = {'device': device, 'dtype': dtype} super(MultiLinear, self).__init__() self.in_channels = in_channels self.out_channels = out_channels self.n_instances = n_instances self.ndim = ndim self.instance_dim = instance_dim self.channel_dim = channel_dim # initialize by pattern, e.g.: # pattern='t n f', instance_dim='n', instance_dim=-1 # pattern='t n f', instance_dim=-2, instance_dim=-1 if pattern is not None: pattern = pattern.replace(' ', '') self.instance_dim = instance_dim if isinstance(instance_dim, str) \ else pattern[instance_dim] self.channel_dim = channel_dim if isinstance(channel_dim, str) \ else pattern[channel_dim] self.einsum_pattern = self._compute_einsum_pattern(pattern=pattern) self.bias_shape = self._compute_bias_shape(pattern=pattern) self.reshape_bias = False # initialize negative dim indexing (default), e.g.: # instance_dim=-2, instance_dim=-1 (pattern=None, ndim=None) elif ndim is None and instance_dim < 0 and channel_dim < 0: ndim = abs(min(instance_dim, channel_dim)) self.einsum_pattern = self._compute_einsum_pattern(ndim) self.bias_shape = self._compute_bias_shape(ndim) self.reshape_bias = False # initialize with ndim and dim (positive/negative) indexing, e.g.: # ndim=3, instance_dim=1, instance_dim=-1 elif ndim is not None: # initialize with lazy ndim calculation, e.g.: # ndim=-1, instance_dim=1, instance_dim=-1 if ndim < 0: # lazy initialize einsum pattern self.einsum_pattern = None self.bias_shape = (n_instances, out_channels) self.reshape_bias = True self._hook = self.register_forward_pre_hook( self.initialize_module) else: self.einsum_pattern = self._compute_einsum_pattern(ndim) self.bias_shape = self._compute_bias_shape(ndim) self.reshape_bias = False # cannot initialize if all: # 1. pattern is None # 2. ndim is None and instance_dim >= 0 or channel_dim >= 0 else: raise ValueError("One of 'pattern' or 'ndim' must be given if one " "of 'instance_dim' or 'channel_dim' is positive.") self.weight: nn.Parameter = nn.Parameter( torch.empty((n_instances, in_channels, out_channels), **factory_kwargs)) if bias: self.bias: nn.Parameter = nn.Parameter( torch.empty(*self.bias_shape, **factory_kwargs)) else: self.register_parameter('bias', None) self.reset_parameters() def extra_repr(self) -> str: """""" return 'in_channels={}, out_channels={}, n_instances={}'.format( self.in_channels, self.out_channels, self.n_instances) def reset_parameters(self) -> None: bound = 1 / math.sqrt(self.in_channels) init.uniform_(self.weight.data, -bound, bound) if self.bias is not None: init.uniform_(self.bias.data, -bound, bound) def _compute_bias_shape(self, ndim: int = None, pattern: str = None): if ndim is not None: bias_shape = [1] * ndim bias_shape[self.instance_dim] = self.n_instances bias_shape[self.channel_dim] = self.out_channels elif pattern is not None: pattern = pattern.replace(' ', '') bias_shape = [] for token in pattern: if token == self.channel_dim: bias_shape.append(self.out_channels) elif token == self.instance_dim: bias_shape.append(self.n_instances) else: bias_shape.append(1) else: raise ValueError("One of 'pattern' or 'ndim' must be given.") return tuple(bias_shape) def _compute_einsum_pattern(self, ndim: int = None, pattern: str = None): if ndim is not None: pattern = [chr(s + 97) for s in range(ndim)] # 'a', 'b', ... pattern[self.instance_dim] = 'x' pattern[self.channel_dim] = 'y' input_pattern = ''.join(pattern) pattern[self.channel_dim] = 'z' output_pattern = ''.join(pattern) weight_pattern = 'xyz' elif pattern is not None: input_pattern = pattern.replace(' ', '') output_pattern = input_pattern.replace(self.channel_dim, 'z') weight_pattern = f'{self.instance_dim}{self.channel_dim}z' else: raise ValueError("One of 'pattern' or 'ndim' must be given.") return f"...{input_pattern},{weight_pattern}->...{output_pattern}" @torch.no_grad() def initialize_module(self, module, input): self.ndim = input[0].ndim self.einsum_pattern = self._compute_einsum_pattern(self.ndim) self.bias_shape = self._compute_bias_shape(self.ndim) self._hook.remove() delattr(self, '_hook')
[docs] def forward(self, input: Tensor) -> Tensor: r"""Compute :math:`\mathbf{X}^{\prime} = [\boldsymbol{\Theta}_i \mathbf{x}_i + \mathbf{b}_i]_{i=0,\ldots,N}`""" out = torch.einsum(self.einsum_pattern, input, self.weight) if self.bias is not None: if self.reshape_bias: out = out + self.bias.view(*self.bias_shape).contiguous() else: out = out + self.bias return out