Source code for tsl.nn.layers.ops.lambda_module
from typing import Callable
from torch import Tensor, nn
[docs]class Lambda(nn.Module):
"""Call a generic function on the input.
Args:
function (callable): The function to call in :obj:`forward(input)`.
"""
def __init__(self, function: Callable):
super(Lambda, self).__init__()
self.function = function
self.__function_src__ = ''
if function.__name__ == '<lambda>':
import inspect
src = inspect.getsource(function)
src = src[src.find('lambda') + 7:] # cut from 'lambda ' on
src = src[:len(src) - src[::-1].find(')') -
1] # cut until last ')'
self.__function_src__ = src
[docs] def forward(self, input: Tensor) -> Tensor:
"""Returns :obj:`self.function(input)`."""
return self.function(input)