Source code for tsl.nn.models.base_model

import inspect
from argparse import ArgumentParser
from typing import Optional, Set

from torch import nn

from tsl.typing import ModelReturnOptions
from tsl.utils.python_utils import ensure_list, foo_signature


def _forward_packer(model, input, output):
    if isinstance(output, model.return_type):
        return output
    if model.return_type is list:
        return ensure_list(output)
    raise TypeError(f"return type of forward ({type(output)}) does not "
                    f"match with {model.__class__.__name__}.return_type "
                    f"({model.return_type}).")


[docs]class BaseModel(nn.Module): r"""Base class for creating neural models. This class provides useful utilities for the model designer: * the methods :meth:`~tsl.nn.models.BaseModel.add_model_specific_args` and :meth:`~tsl.nn.models.BaseModel.add_argparse_args` allow to automatically add to an :class:`~argparse.ArgumentParser` the arguments needed to initialize the model (with typing and default values). * the method :meth:`~tsl.nn.models.BaseModel.loss` can be used to compute a custom loss on the provided training target. Inference modules in tsl will call this method for the loss computation, if implemented in the model. * the method :meth:`~tsl.nn.models.BaseModel.predict` can be used to define a variation of the :meth:`~torch.nn.Module.forward` function for only inference purposes (e.g., removing outputs used only for auxiliary tasks during training). * the parameter :attr:`return_type` specifies which the return type of the forward function (:class:`~torch.Tensor`, :obj:`list` or :obj:`dict`). """ return_type: ModelReturnOptions = None def __init__(self): super(BaseModel, self).__init__() if self.return_type is not None: self.register_forward_hook(_forward_packer) model_signature = self.get_model_signature() self.model_signature = model_signature['signature'] self.has_model_args = model_signature['has_args'] self.has_model_kwargs = model_signature['has_kwargs'] forward_signature = self.get_forward_signature() self.forward_signature = forward_signature['signature'] self.has_forward_args = forward_signature['has_args'] self.has_forward_kwargs = forward_signature['has_kwargs'] @property def has_loss(self) -> bool: """Returns :obj:`True` if the model has implemented the :meth:`~tsl.nn.models.BaseModel.loss` method.""" return self.loss.__qualname__.split('.')[0] != 'BaseModel' @property def has_predict(self) -> bool: """Returns :obj:`True` if the model has implemented the :meth:`~tsl.nn.models.BaseModel.predict` method.""" return self.predict.__qualname__.split('.')[0] != 'BaseModel'
[docs] def loss(self, target, *args, **kwargs): """Compute a custom loss w.r.t. :attr:`target`.""" raise NotImplementedError
[docs] def predict(self, *args, **kwargs): """Forward function used only for inference.""" return super(BaseModel, self).__call__(*args, **kwargs)
[docs] def reset_parameters(self): """Reset the parameters of the model.""" raise NotImplementedError
[docs] @classmethod def get_model_signature(cls) -> dict: """Get signature of the model's :class:`~tsl.nn.models.BaseModel`'s :obj:`__init__` function.""" return foo_signature(cls)
[docs] @classmethod def get_forward_signature(cls) -> dict: """Get signature of the model's :meth:`~tsl.nn.models.BaseModel.forward` function.""" return foo_signature(cls.forward)
[docs] @classmethod def filter_model_args_(cls, mapping: dict): """Remove from :attr:`mapping` all the keys that are not in :class:`~tsl.nn.models.BaseModel`'s :obj:`__init__` function.""" model_sign = cls.get_model_signature() if model_sign['has_kwargs']: return model_signature = model_sign['signature'] del_keys = filter(lambda k: k not in model_signature, mapping.keys()) for k in list(del_keys): del mapping[k]
[docs] @classmethod def model_excluded_args(cls) -> Set: """Set of arguments of :meth:`__init__` to be excluded when adding model's args to an :class:`~argparse.ArgumentParser` (see :meth:`~tsl.nn.models.BaseModel.add_model_specific_args`).""" return { 'input_size', 'output_size', 'exog_size', 'n_nodes', 'horizon', 'window' }
[docs] @classmethod def add_model_specific_args(cls, parser: ArgumentParser): """Adds to the :class:`~argparse.ArgumentParser` :attr:`parser` the arguments needed to initialize the model (with typing and default values). The arguments added are all the parameters of the :meth:`__init__` method, excluding the keys returned by :meth:`~tsl.nn.models.BaseModel.model_excluded_args`.""" return cls.add_argparse_args(parser, exclude_args=cls.model_excluded_args())
[docs] @classmethod def add_argparse_args(cls, parser: ArgumentParser, exclude_args: Optional[Set] = None): """Adds to the :class:`~argparse.ArgumentParser` :attr:`parser` all the parameters of the :meth:`__init__` method (with typing and default values).""" sign = inspect.signature(cls.__init__) # filter excluded arguments excluded = {'self'} if exclude_args is not None: excluded.update(exclude_args) # parse signature for name, param in sign.parameters.items(): if name in exclude_args: continue name = '--' + name.replace('_', '-') kwargs = dict() if param.annotation is not inspect._empty: kwargs['type'] = param.annotation if param.default is not inspect._empty: kwargs['default'] = param.default if 'type' not in kwargs: kwargs['type'] = type(param.default) parser.add_argument(name, **kwargs) return parser