Source code for tsl.nn.layers.ops.concatenate

from typing import List, Tuple, Union

from torch import Tensor, nn

from tsl.nn.functional import expand_then_cat


[docs]class Concatenate(nn.Module): """Concatenate tensors along dimension :attr:`dim`. The tensors dimensions are matched (i.e., broadcasted if necessary) before concatenation. Args: dim (int): The dimension to concatenate on. (default: :obj:`0`) """ def __init__(self, dim: int = 0): super(Concatenate, self).__init__() self.dim = dim
[docs] def forward(self, tensors: Union[Tuple[Tensor, ...], List[Tensor]]) \ -> Tensor: """Returns :func:`~tsl.nn.functional.expand_then_cat` on input tensors.""" return expand_then_cat(tensors, self.dim)