Source code for tsl.transforms.imputation

from torch_geometric.transforms import BaseTransform

from tsl.data import Data


[docs]class MaskInput(BaseTransform): """Whiten masked values in :attr:`input_key` according to mask in :attr:`mask_key`. Args: input_key (str): The key in ``Data`` to be masked. (default: :obj:`'input_key'`) mask_key (str): The key in ``Data`` to serve as mask. (default: :obj:`'mask_key'`) """ def __init__(self, input_key: str = 'x', mask_key: str = 'mask'): self.input_key = input_key self.mask_key = mask_key def __call__(self, data: Data) -> Data: data[self.input_key] *= data[self.mask_key] return data