Source code for tsl.experiment.experiment

import inspect
import os
import os.path as osp
import sys
from functools import wraps
from typing import Optional, Callable, List, Union

import torch
from pytorch_lightning import seed_everything

from tsl import logger, config
from tsl.imports import _HYDRA_AVAILABLE
from tsl.utils.python_utils import ensure_list

if _HYDRA_AVAILABLE:
    import hydra
    from hydra.core.hydra_config import HydraConfig
    from omegaconf import DictConfig, OmegaConf, flag_override
    from omegaconf.errors import ConfigAttributeError
else:
    hydra = DictConfig = None


def get_hydra_cli_arg(key: str, delete: bool = False):
    try:
        key_idx = [arg.split("=")[0] for arg in sys.argv].index(key)
        arg = sys.argv[key_idx].split("=")[1]
        if delete:
            del sys.argv[key_idx]
        return arg
    except ValueError:
        return None


def _pre_experiment_routine(cfg: DictConfig):
    hconf = HydraConfig.get()

    # set the seed for the run
    seed = cfg.get('seed', None)
    seed = seed_everything(seed)

    # add run args to cfg
    run_args = dict(seed=seed,
                    # name=hconf.job.name,
                    dir=hconf.runtime.output_dir)
    if hconf.get('output_subdir') is not None:
        run_args['tsl_subdir'] = osp.join(cfg.run.dir, hconf.output_subdir)
        # remove hydra conf from logging
        os.unlink(osp.join(run_args['tsl_subdir'], 'hydra.yaml'))
    # set run name
    run_args['name'] = "${now:%Y-%m-%d_%H-%M-%S}_${run.seed}"
    with flag_override(cfg, 'struct', False):
        cfg.run = DictConfig(run_args)

    # override data_dir in tsl config
    config.data_dir = cfg.get('data_dir', config.data_dir)

    # if True, then allow for adding new args to the cfg
    if cfg.get('allow_config_extension', False):
        OmegaConf.set_struct(cfg, False)

    # set the PyTorch num_threads from here
    if 'num_threads' in cfg:
        torch.set_num_threads(cfg.num_threads)

    logger.info("\n**** Experiment config ****\n" +
                OmegaConf.to_yaml(cfg, resolve=True))

    return cfg


[docs]class Experiment: r"""Simple class to handle the routines used to run experiments. This class relies heavily on the `Hydra <https://hydra.cc/>`_ framework, check `Hydra docs <https://hydra.cc/docs/intro/>`_ for usage information. Hydra is an optional dependency of tsl, to install it using pip: .. code-block:: bash pip install hydra-core Args: run_fn (callable): Python function that actually runs the experiment when called. The run function must accept a single argument, being the experiment configuration. config_path (str, optional): Path to configuration files. If not specified the default will be used. config_name (str, optional): Name of the configuration file in :attr:`config_path` to be used. The :obj:`.yaml` extension can be omitted. pre_run_hooks (list): Ordered list of functions to call on :meth:`~tsl.experiment.Experiment.run` before the :attr:`run_fn`. Every hook must accept a single argument, being the experiment configuration, and act in-place on the configuration. """ def __init__(self, run_fn: Callable, config_path: Optional[str] = None, config_name: Optional[str] = None, pre_run_hooks: Union[Callable, List[Callable]] = None): if not _HYDRA_AVAILABLE: raise RuntimeError("Install optional dependency 'hydra-core'" f" to use {self.__class__.__name__}.") # store the run configuration self.cfg: Optional[DictConfig] = None # default config is cd/config if config_path is None: config_path = config.config_dir # allow override of config_path as Hydra cli arg: # config_path={config_path} same as --config-path {config_path} override_config_path = get_hydra_cli_arg('config_path', delete=True) config_path = override_config_path or config_path if not osp.isabs(config_path): root_path = osp.dirname(inspect.getfile(run_fn)) config_path = osp.abspath(osp.join(root_path, config_path)) self.config_path = config_path # store config_dir in tsl config config.config_dir = self.config_path # allow override of config_name as Hydra cli arg: # config={config_name} same as --config-name {config_name} override_config_name = get_hydra_cli_arg('config', delete=True) self.config_name = override_config_name or config_name sys.argv.insert(1, 'hydra.output_subdir=null') self._pre_run_hooks = [_pre_experiment_routine] if pre_run_hooks is not None: pre_run_hooks = ensure_list(pre_run_hooks) for hook in pre_run_hooks: self.register_pre_run_hook(hook) self.run_fn = self.register_run_function(run_fn) self.run_output = None def register_pre_run_hook(self, hook: Callable): self._pre_run_hooks.append(hook) def register_run_function(self, run_fn: Callable) -> Callable: args = inspect.getfullargspec(run_fn).args if len(args) > 1: raise RuntimeError("run_fn must have a single 'cfg' parameter.") def run_fn_decorator(func: Callable) -> Callable: @wraps(func) def decorated_run_fn(cfg: DictConfig): # execute pre-run hooks for hook in self._pre_run_hooks: hook(cfg) # store final config self.cfg = cfg self.log_config() self.run_output = func(cfg) return self.run_output return decorated_run_fn return hydra.main(config_path=self.config_path, config_name=self.config_name, version_base=None)(run_fn_decorator(run_fn))
[docs] def log_config(self) -> None: """Save config as ``.yaml`` file in :meth:`~tsl.experiment.Experiment.run_dir`.""" with open(osp.join(self.run_dir, 'config.yaml'), 'w') as fp: fp.write(OmegaConf.to_yaml(self.cfg, resolve=True))
def __repr__(self): return "{}(config_path={}, config_name={}, run_fn={})".format( self.__class__.__name__, self.config_path, self.config_name, self.run_fn.__name__ ) @property def run_dir(self): """Directory of the current run, where logs and artifacts are stored.""" if self.cfg is not None: try: return self.cfg.run.dir except ConfigAttributeError: return None return None
[docs] def run(self): """Run the experiment routine.""" self.run_fn() return self.run_output