Source code for tsl.nn.layers.ops.select
from torch import Tensor, nn
[docs]class Select(nn.Module):
"""Apply :func:`~torch.select` to select one element from a
:class:`~torch.Tensor` along a dimension.
This layer returns a view of the original tensor with the given dimension
removed.
Args:
dim (int): The dimension to slice.
index (int): The index to select with.
"""
def __init__(self, dim: int, index: int):
super(Select, self).__init__()
self.dim = dim
self.index = index
[docs] def forward(self, tensor: Tensor) -> Tensor:
"""Returns :func:`~torch.select` on input tensor."""
return tensor.select(self.dim, self.index)