Source code for tsl.engines.imputer

from typing import Callable, List, Mapping, Optional, Tuple, Type, Union

import torch
from torch import Tensor
from torch_geometric.data.storage import recursive_apply
from torchmetrics import Metric

from .predictor import Predictor


[docs]class Imputer(Predictor): r""":class:`~pytorch_lightning.core.LightningModule` to implement imputers. An imputer is an engines designed to fill out missing values in spatiotemporal data. Args: model (torch.nn.Module, optional): Model implementing the imputer. Ignored if argument `model_class` is not null. This argument should mainly be used for inference. (default: :obj:`None`) model_class (type, optional): Class of :obj:`~torch.nn.Module` implementing the imputer. If not `None`, argument `model` will be ignored. (default: :obj:`None`) model_kwargs (mapping, optional): Dictionary of arguments to be forwarded to :obj:`model_class` at instantiation. (default: :obj:`None`) optim_class (type, optional): Class of :obj:`~torch.optim.Optimizer` implementing the optimizer to be used for training the model. (default: :obj:`None`) optim_kwargs (mapping, optional): Dictionary of arguments to be forwarded to :obj:`optim_class` at instantiation. (default: :obj:`None`) loss_fn (callable, optional): Loss function to be used for training the model. (default: :obj:`None`) scale_target (bool): Whether to scale target before evaluating the loss. The metrics instead will always be evaluated in the original range. (default: :obj:`False`) whiten_prob (float or list): Randomly mask out a valid datapoint during a training step with probability :obj:`whiten_prob`. If a list is passed, :obj:`whiten_prob` is sampled from the list for each batch. (default: :obj:`0.05`) prediction_loss_weight (float): The weight to assign to predictions (if any) in the loss. The loss is computed as .. math:: L = \ell(\bar{y}, y, m) + \lambda \sum_i \ell(\hat{y}_i, y, m) where :math:`\ell(\bar{y}, y, m)` is the imputation loss, :math:`\ell(\bar{y}_i, y, m)` is the forecasting error of prediction :math:`\bar{y}_i`, and :math:`\lambda` is :obj:`prediction_loss_weight`. (default: :obj:`1.0`) impute_only_missing (bool): Whether to impute only missing values in inference or the whole sequence. (default: :obj:`True`) warm_up_steps (int, tuple): Number of steps to be considered as warm up stage at the beginning of the sequence. If a tuple is provided, the padding is applied both at the beginning and the end of the sequence. (default: :obj:`0`) metrics (mapping, optional): Set of metrics to be logged during train, val and test steps. The metric's name will be automatically prefixed with the loop in which the metric is computed (e.g., metric :obj:`mae` will be logged as :obj:`train_mae` when evaluated during training). (default: :obj:`None`) scheduler_class (type): Class of :obj:`~torch.optim.lr_scheduler._LRScheduler` implementing the learning rate scheduler to be used during training. (default: :obj:`None`) scheduler_kwargs (mapping): Dictionary of arguments to be forwarded to :obj:`scheduler_class` at instantiation. (default: :obj:`None`) """ def __init__( self, model: Optional[torch.nn.Module] = None, loss_fn: Optional[Callable] = None, scale_target: bool = False, metrics: Optional[Mapping[str, Metric]] = None, *, whiten_prob: Optional[Union[float, List[float]]] = 0.05, prediction_loss_weight: float = 1.0, impute_only_missing: bool = True, warm_up_steps: Union[int, Tuple[int, int]] = 0, model_class: Optional[Type] = None, model_kwargs: Optional[Mapping] = None, optim_class: Optional[Type] = None, optim_kwargs: Optional[Mapping] = None, scheduler_class: Optional = None, scheduler_kwargs: Optional[Mapping] = None, ): super(Imputer, self).__init__(model=model, model_class=model_class, model_kwargs=model_kwargs, optim_class=optim_class, optim_kwargs=optim_kwargs, loss_fn=loss_fn, scale_target=scale_target, metrics=metrics, scheduler_class=scheduler_class, scheduler_kwargs=scheduler_kwargs) if isinstance(whiten_prob, (list, tuple)): self.whiten_prob = torch.tensor(whiten_prob) else: self.whiten_prob = whiten_prob self.prediction_loss_weight = prediction_loss_weight self.impute_only_missing = impute_only_missing if isinstance(warm_up_steps, int): self.warm_up_steps = (warm_up_steps, 0) elif isinstance(warm_up_steps, (list, tuple)): self.warm_up_steps = tuple(warm_up_steps) if len(self.warm_up_steps) != 2: raise ValueError( "'warm_up_steps' must be an int of time steps to " "be cut at the beginning of the sequence or a " "pair of int if the sequence must be trimmed in a " "bidirectional way.")
[docs] def trim_warm_up(self, *args): """Trim all tensors in :obj:`args` removing a number of first and last steps equals to :obj:`(self.warm_up_steps[0], self.warm_up_steps[1])`, respectively.""" left, right = self.warm_up_steps # assume time in second dimension (after batch dim) trim = lambda s: s[:, left:s.size(1) - right] # noqa args = recursive_apply(args, trim) if len(args) == 1: return args[0] return args
# Imputation data hooks ###################################################
[docs] def on_train_batch_start(self, batch, batch_idx: int) -> None: r"""For every training batch, randomly mask out value with probability :obj:`p = self.whiten_prob`. Then, whiten missing values in :obj:`batch.input.x`.""" super(Imputer, self).on_train_batch_start(batch, batch_idx) batch.original_mask = batch.mask if self.whiten_prob is not None: # randomly mask out value with probability p = whiten_prob mask = batch.mask p = self.whiten_prob if isinstance(p, Tensor) and p.ndim > 0: # broadcast p to mask size p_size = [mask.size(0)] + [1] * (mask.ndim - 1) # sample p for each batch p = p[torch.randint(len(p), p_size)].to(device=mask.device) # set each non-zero element of mask to 0 with probability p whiten_mask = torch.rand(mask.size(), device=mask.device) > p batch.mask = mask & whiten_mask # whiten missing values if 'x' in batch.input: batch.input.x = batch.input.x * batch.mask
def predict_step(self, batch, batch_idx, dataloader_idx=None): # Make predictions y_hat = self.predict(**batch.input) # Rescale outputs trans = batch.transform.get('y') if trans is not None: y_hat = trans.inverse_transform(y_hat) # fill missing values in target data if self.impute_only_missing: y_hat = torch.where(batch.mask.bool(), batch.y, y_hat) # return dict output = dict(**batch.target, y_hat=y_hat, mask=batch.mask, eval_mask=batch.eval_mask) return output def shared_step(self, batch, mask): y = y_loss = batch.y y_hat = y_hat_loss = self.predict_batch( batch, preprocess=False, postprocess=not self.scale_target) if self.scale_target: y_loss = batch.transform['y'].transform(y) y_hat = batch.transform['y'].inverse_transform(y_hat) y_hat_loss, y_loss, mask = self.trim_warm_up(y_hat_loss, y_loss, mask) if isinstance(y_hat_loss, (list, tuple)): imputation, predictions = y_hat_loss y_hat = y_hat[0] else: imputation, predictions = y_hat_loss, [] loss = self.loss_fn(imputation, y_loss, mask) for pred in predictions: pred_loss = self.loss_fn(pred, y_loss, mask) loss += self.prediction_loss_weight * pred_loss return y_hat.detach(), y, loss def training_step(self, batch, batch_idx): y_hat, y, loss = self.shared_step(batch, batch.original_mask) # Logging self.train_metrics.update(y_hat, y, batch.eval_mask) self.log_metrics(self.train_metrics, batch_size=batch.batch_size) self.log_loss('train', loss, batch_size=batch.batch_size) return loss def validation_step(self, batch, batch_idx): y_hat, y, val_loss = self.shared_step(batch, batch.mask) # Logging self.val_metrics.update(y_hat, y, batch.eval_mask) self.log_metrics(self.val_metrics, batch_size=batch.batch_size) self.log_loss('val', val_loss, batch_size=batch.batch_size) return val_loss def test_step(self, batch, batch_idx): # Compute outputs and rescale y_hat = self.predict_step(batch, batch_idx)['y_hat'] # reconstruction loss test_loss = self.loss_fn(y_hat, batch.y, batch.mask) # Logging self.test_metrics.update(y_hat.detach(), batch.y, batch.eval_mask) self.log_metrics(self.test_metrics, batch_size=batch.batch_size) self.log_loss('test', test_loss, batch_size=batch.batch_size) return test_loss