Source code for tsl.nn.layers.ops.lambda_module

from typing import Callable

from torch import Tensor, nn


[docs]class Lambda(nn.Module): """Call a generic function on the input. Args: function (callable): The function to call in :obj:`forward(input)`. """ def __init__(self, function: Callable): super(Lambda, self).__init__() self.function = function
[docs] def forward(self, input: Tensor) -> Tensor: """Returns :obj:`self.function(input)`.""" return self.function(input)