Source code for tsl.engines.predictor

import inspect
from typing import Callable, Mapping, Optional, Type

import pytorch_lightning as pl
import torch
from torchmetrics import Metric, MetricCollection

from tsl import logger
from tsl.data import Data
from tsl.metrics.torch import MaskedMetric
from tsl.nn.models import BaseModel
from tsl.utils import foo_signature


[docs]class Predictor(pl.LightningModule): """:class:`~pytorch_lightning.core.LightningModule` to implement predictors. Input data should follow the format [batch, steps, nodes, features]. Args: model (torch.nn.Module, optional): Model implementing the predictor. Ignored if argument `model_class` is not :obj:`None`. This argument should mainly be used for inference. (default: :obj:`None`) model_class (type, optional): Class of :obj:`~torch.nn.Module` implementing the predictor. If not :obj:`None`, argument ``model`` will be ignored. (default: :obj:`None`) model_kwargs (mapping, optional): Dictionary of arguments to be forwarded to :obj:`model_class` at instantiation. (default: :obj:`None`) optim_class (type, optional): Class of :obj:`~torch.optim.Optimizer` implementing the optimizer to be used for training the model. (default: :obj:`None`) optim_kwargs (mapping, optional): Dictionary of arguments to be forwarded to :obj:`optim_class` at instantiation. (default: :obj:`None`) loss_fn (callable, optional): Loss function to be used for training the model. (default: :obj:`None`) scale_target (bool): Whether to scale target before evaluating the loss. The metrics instead will always be evaluated in the original range. (default: :obj:`False`) metrics (mapping, optional): Set of metrics to be logged during train, val and test steps. The metric's name will be automatically prefixed with the loop in which the metric is computed (e.g., metric :obj:`mae` will be logged as :obj:`train_mae` when evaluated during training). (default: :obj:`None`) scheduler_class (type, optional): Class of :obj:`~torch.optim.lr_scheduler._LRScheduler` implementing the learning rate scheduler to be used during training. (default: :obj:`None`) scheduler_kwargs (mapping, optional): Dictionary of arguments to be forwarded to :obj:`scheduler_class` at instantiation. (default: :obj:`None`) """ def __init__(self, model: Optional[torch.nn.Module] = None, loss_fn: Optional[Callable] = None, scale_target: bool = False, metrics: Optional[Mapping[str, Metric]] = None, *, model_class: Optional[Type] = None, model_kwargs: Optional[Mapping] = None, optim_class: Optional[Type] = None, optim_kwargs: Optional[Mapping] = None, scheduler_class: Optional = None, scheduler_kwargs: Optional[Mapping] = None): super(Predictor, self).__init__() self.save_hyperparameters(ignore=['loss_fn', 'model'], logger=False) self.model_cls = model_class self.model_kwargs = model_kwargs or dict() self._model_fwd_signature = None # automatic set on model assignment self.optim_class = optim_class self.optim_kwargs = optim_kwargs or dict() self.scheduler_class = scheduler_class self.scheduler_kwargs = scheduler_kwargs or dict() if loss_fn is not None: self.loss_fn = self._check_metric(loss_fn, on_step=True) else: self.loss_fn = None self.scale_target = scale_target if metrics is None: metrics = dict() self._set_metrics(metrics) if self.model_cls is not None: # instantiate model self.model = self.model_cls(**self.model_kwargs) else: self.model = model def __setattr__(self, key, value): super(Predictor, self).__setattr__(key, value) if key == 'model' and value is not None: self._model_fwd_signature = foo_signature(self.model.forward) self._check_kwargs = True def reset_model(self): """""" if self.model_cls is not None: self.model = self.model_cls(**self.model_kwargs) else: self.model = None
[docs] def load_model(self, filename: str): """Load model's weights from checkpoint at :attr:`filename`. Differently from :meth:`~pytorch_lightning.core.LightningModule.load_from_checkpoint`, this method allows to load the state_dict also for models instantiated outside the predictor, without checking that hyperparameters of the checkpoint's model are the same of the predictor's model. """ storage = torch.load(filename, lambda storage, loc: storage) # if predictor.model has been instantiated inside predictor if self.model_cls is not None: model_cls = storage['hyper_parameters']['model_class'] model_kwargs = storage['hyper_parameters']['model_kwargs'] # check model class and hyperparameters are the same assert model_cls == self.model_cls if model_kwargs is not None: for k, v in model_kwargs.items(): assert v == self.model_kwargs[k] else: logger.warning("Predictor with already instantiated model is " f"loading a state_dict from {filename}. Cannot " " check if model hyperparameters are the same.") self.load_state_dict(storage['state_dict'])
@property def is_tsl_model(self): """""" return self.model is not None and isinstance(self.model, BaseModel) @property def trainable_parameters(self) -> int: """""" return sum(p.numel() for p in self.model.parameters() if p.requires_grad) @property def filter_forward_kwargs(self) -> bool: """""" return (self._model_fwd_signature is not None and not self._model_fwd_signature['has_kwargs']) def _filter_forward_kwargs(self, kwargs: dict) -> dict: """""" if self._check_kwargs: model_args = self._model_fwd_signature['signature'] filtered = set(kwargs).difference(model_args) forwarded = set(kwargs).intersection(model_args) msg = f"Only args {list(forwarded)} are forwarded to the model " \ f"({self.model.__class__.__name__}). " if len(filtered): msg = f"Arguments {list(filtered)} are filtered out. " + msg logger.warning(msg) self._check_kwargs = False return { k: v for k, v in kwargs.items() if k in self._model_fwd_signature['signature'] } def forward(self, *args, **kwargs): """""" if self.filter_forward_kwargs: kwargs = self._filter_forward_kwargs(kwargs) return self.model(*args, **kwargs) def predict(self, *args, **kwargs): """""" predict_fn = self.model.predict if self.is_tsl_model else self.model if self.filter_forward_kwargs: kwargs = self._filter_forward_kwargs(kwargs) return predict_fn(*args, **kwargs) @staticmethod def _check_metric(metric, on_step=False): if not isinstance(metric, MaskedMetric): if 'reduction' in inspect.getfullargspec(metric).args: metric_kwargs = {'reduction': 'none'} else: metric_kwargs = dict() return MaskedMetric(metric, compute_on_step=on_step, metric_fn_kwargs=metric_kwargs) metric = metric.clone() metric.reset() return metric def _set_metrics(self, metrics): self.train_metrics = MetricCollection( metrics={k: self._check_metric(m) for k, m in metrics.items()}, prefix='train_') self.val_metrics = MetricCollection( metrics={k: self._check_metric(m) for k, m in metrics.items()}, prefix='val_') self.test_metrics = MetricCollection( metrics={k: self._check_metric(m) for k, m in metrics.items()}, prefix='test_') def log_metrics(self, metrics, **kwargs): """""" self.log_dict(metrics, on_step=False, on_epoch=True, logger=True, prog_bar=True, **kwargs) def log_loss(self, name, loss, **kwargs): """""" self.log(name + '_loss', loss.detach(), on_step=False, on_epoch=True, logger=True, prog_bar=False, **kwargs) def _unpack_batch(self, batch): """ Unpack a batch into data and preprocessing dictionaries. :param batch: the batch :return: batch_data, batch_preprocessing """ inputs, targets = batch.input, batch.target mask = batch.get('mask') transform = batch.get('transform') return inputs, targets, mask, transform
[docs] def predict_batch(self, batch: Data, preprocess: bool = False, postprocess: bool = True, return_target: bool = False, **forward_kwargs): """This method takes as input a :class:`~tsl.data.Data` object and outputs the predictions. Note that this method works seamlessly for all :class:`~tsl.data.Data` subclasses like :class:`~tsl.data.StaticBatch` and :class:`~tsl.data.DisjointBatch`. Args: batch (Data): The batch to be forwarded to the model. preprocess (bool, optional): If :obj:`True`, then preprocess tensors in :attr:`batch.input` using transformation modules in :attr:`batch.transform`. Note that inputs are preprocessed before creating the batch by default. (default: :obj:`False`) postprocess (bool, optional): If :obj:`True`, then postprocess the model output using transformation modules for :attr:`batch.target` in :attr:`batch.transform`. (default: :obj:`True`) return_target (bool, optional): If :obj:`True`, then returns also the prediction target :attr:`batch.target` and the prediction mask :attr:`batch.mask`, besides the model output. In this case, the order of the arguments in the return is :attr:`batch.target`, :obj:`y_hat`, :attr:`batch.mask`. (default: :obj:`False`) **forward_kwargs: additional keyword arguments passed to the forward method. """ inputs, targets, mask, transform = self._unpack_batch(batch) if preprocess: for key, trans in transform.items(): if key in inputs: inputs[key] = trans.transform(inputs[key]) if forward_kwargs is None: forward_kwargs = dict() y_hat = self.forward(**inputs, **forward_kwargs) # Rescale outputs if postprocess: trans = transform.get('y') if trans is not None: y_hat = trans.inverse_transform(y_hat) if return_target: y = targets.get('y') return y, y_hat, mask return y_hat
def predict_step(self, batch, batch_idx, dataloader_idx=None): """""" # Unpack batch x, y, mask, transform = self._unpack_batch(batch) # Make predictions y_hat = self.predict_batch(batch, preprocess=False, postprocess=True) output = dict(**y, y_hat=y_hat) if mask is not None: output['mask'] = mask return output
[docs] def collate_prediction_outputs(self, outputs): """ Collate the outputs of the :meth:`predict_step` method. Args: outputs: Collated outputs of the :meth:`predict_step` method. Returns: The collated outputs. """ # iterate over results processed_res = dict() keys = set() # iterate over outputs for each batch for res in outputs: for k, v in res.items(): if k in keys: processed_res[k].append(v) else: processed_res[k] = [v] keys.add(k) # concatenate results for k, v in processed_res.items(): processed_res[k] = torch.cat(v, 0) return processed_res
def training_step(self, batch, batch_idx): """""" y = y_loss = batch.y mask = batch.get('mask') # Compute predictions and compute loss y_hat_loss = self.predict_batch(batch, preprocess=False, postprocess=not self.scale_target) y_hat = y_hat_loss.detach() # Scale target and output, eventually if self.scale_target: y_loss = batch.transform['y'].transform(y) y_hat = batch.transform['y'].inverse_transform(y_hat) # Compute loss loss = self.loss_fn(y_hat_loss, y_loss, mask) # Logging self.train_metrics.update(y_hat, y, mask) self.log_metrics(self.train_metrics, batch_size=batch.batch_size) self.log_loss('train', loss, batch_size=batch.batch_size) return loss def validation_step(self, batch, batch_idx): """""" y = y_loss = batch.y mask = batch.get('mask') # Compute predictions y_hat_loss = self.predict_batch(batch, preprocess=False, postprocess=not self.scale_target) y_hat = y_hat_loss.detach() # Scale target and output, eventually if self.scale_target: y_loss = batch.transform['y'].transform(y) y_hat = batch.transform['y'].inverse_transform(y_hat) # Compute loss val_loss = self.loss_fn(y_hat_loss, y_loss, mask) # Logging self.val_metrics.update(y_hat, y, mask) self.log_metrics(self.val_metrics, batch_size=batch.batch_size) self.log_loss('val', val_loss, batch_size=batch.batch_size) return val_loss def test_step(self, batch, batch_idx): """""" # Compute outputs and rescale y_hat = self.predict_batch(batch, preprocess=False, postprocess=True) y, mask = batch.y, batch.get('mask') test_loss = self.loss_fn(y_hat, y, mask) # Logging self.test_metrics.update(y_hat.detach(), y, mask) self.log_metrics(self.test_metrics, batch_size=batch.batch_size) self.log_loss('test', test_loss, batch_size=batch.batch_size) return test_loss def compute_metrics(self, batch, preprocess=False, postprocess=True): """""" # Compute outputs and rescale y_hat = self.predict_batch(batch, preprocess, postprocess) y, mask = batch.y, batch.get('mask') self.test_metrics.update(y_hat.detach(), y, mask) metrics_dict = self.test_metrics.compute() self.test_metrics.reset() return metrics_dict, y_hat def configure_optimizers(self): """""" cfg = dict() optimizer = self.optim_class(self.parameters(), **self.optim_kwargs) cfg['optimizer'] = optimizer if self.scheduler_class is not None: metric = self.scheduler_kwargs.pop('monitor', None) scheduler = self.scheduler_class(optimizer, **self.scheduler_kwargs) cfg['lr_scheduler'] = scheduler if metric is not None: cfg['monitor'] = metric return cfg