Source code for tsl.nn.layers.multi.dense

from torch import Tensor, nn

from tsl.nn.utils import get_layer_activation

from .linear import MultiLinear

[docs]class MultiDense(MultiLinear): r"""Applies linear transformations with different weights to the different instances in the input data with a final nonlinear activation. .. math:: \mathbf{X}^{\prime} = \left[\sigma\left(\boldsymbol{\Theta}_i \mathbf{x}_i + \mathbf{b}_i \right)\right]_{i=0,\ldots,N} Args: in_channels (int): Size of instance input sample. out_channels (int): Size of instance output sample. n_instances (int): The number :math:`N` of parallel linear operations. Each operation has different weights and biases. activation (str, optional): Activation function to be used. (default: :obj:`'relu'`) dropout (float, optional): Dropout rate. (default: :obj:`0`) instance_dim (int or str): Dimension of the instances (must match :attr:`n_instances` at runtime). (default: :obj:`-2`) channel_dim (int or str): Dimension of the input channels. (default: :obj:`-1`) bias (bool): If :obj:`True`, then the layer will learn an additive bias for each instance. (default: :obj:`True`) device (optional): The device of the parameters. (default: :obj:`None`) dtype (optional): The data type of the parameters. (default: :obj:`None`) """ def __init__(self, in_channels: int, out_channels: int, n_instances: int, activation: str = 'relu', dropout: float = 0., *, ndim: int = None, pattern: str = None, instance_dim: int = -2, channel_dim: int = -1, bias: bool = True, device=None, dtype=None) -> None: super(MultiDense, self).__init__(in_channels, out_channels, n_instances=n_instances, ndim=ndim, pattern=pattern, instance_dim=instance_dim, channel_dim=channel_dim, bias=bias, device=device, dtype=dtype) self.activation = get_layer_activation(activation)() self.dropout = nn.Dropout(dropout) if dropout > 0. else nn.Identity()
[docs] def forward(self, input: Tensor) -> Tensor: r"""Compute :math:`\mathbf{X}^{\prime} = \left[\sigma\left(\boldsymbol{ \Theta}_i\mathbf{x}_i + \mathbf{b}_i \right)\right]_{i=0,\ldots,N}`.""" out = self.activation(super().forward(input)) return self.dropout(out)