import functools
from copy import deepcopy
from datetime import datetime
from typing import Callable, Mapping, Optional, Tuple, Union
import numpy as np
from tsl.utils.python_utils import ensure_list
from ..spatiotemporal_dataset import SpatioTemporalDataset
from ..synch_mode import SynchMode
__all__ = [
'Splitter',
'CustomSplitter',
'TemporalSplitter',
'AtTimeStepSplitter',
]
from ...typing import Index
[docs]class Splitter:
r"""Base class for splitter module."""
def __init__(self):
self.__indices = dict()
self._fitted = False
self.reset()
def __new__(cls, *args, **kwargs) -> "Splitter":
obj = super().__new__(cls)
# track `fit` calls
obj.fit = cls._track_fit(obj, obj.fit)
return obj
@staticmethod
def _track_fit(obj: "Splitter", fn: callable) -> callable:
"""A decorator to track fit calls.
When ``splitter.fit(...)`` is called, :obj:`splitter.fitted` is set to
:obj:`True`.
Args:
obj: Object whose function will be tracked.
fn: Function that will be wrapped.
Returns:
Decorated method to track :obj:`fit` calls.
"""
@functools.wraps(fn)
def fit(dataset: SpatioTemporalDataset) -> dict:
fn(dataset)
obj._fitted = True
return obj.indices
return fit
def __getstate__(self) -> dict:
# avoids _pickle.PicklingError: Can't pickle <...>: it's not the same
# object as <...>
d = self.__dict__.copy()
del d['fit']
return d
def __call__(self, *args, **kwargs):
return self.split(*args, **kwargs)
def __repr__(self):
lens = ", ".join(map(lambda kv: "%s=%s" % kv, self.lens().items()))
return "%s(%s)" % (self.__class__.__name__, lens)
@property
def indices(self):
return self.__indices
@property
def fitted(self):
return self._fitted
@property
def train_idxs(self):
return self.__indices.get('train')
@property
def val_idxs(self):
return self.__indices.get('val')
@property
def test_idxs(self):
return self.__indices.get('test')
@property
def train_len(self):
return len(self.train_idxs) if self.train_idxs is not None else None
@property
def val_len(self):
return len(self.val_idxs) if self.val_idxs is not None else None
@property
def test_len(self):
return len(self.test_idxs) if self.test_idxs is not None else None
def set_indices(self, train=None, val=None, test=None):
if train is not None:
self.__indices['train'] = train
if val is not None:
self.__indices['val'] = val
if test is not None:
self.__indices['test'] = test
def reset(self):
self.__indices = dict(train=None, val=None, test=None)
self._fitted = False
def lens(self) -> dict:
return dict(train_len=self.train_len,
val_len=self.val_len,
test_len=self.test_len)
def copy(self) -> "Splitter":
copy = Splitter()
copy.__dict__ = deepcopy(self.__dict__)
return copy
def fit(self, dataset: SpatioTemporalDataset):
raise NotImplementedError
def split(self, dataset: SpatioTemporalDataset) -> dict:
if self.fitted:
return self.indices
else:
return self.fit(dataset)
[docs]class CustomSplitter(Splitter):
r"""Create a :class:`~tsl.data.datamodule.splitters.Splitter` using custom
validation and test sets splitting functions."""
def __init__(self,
val_split_fn: Callable = None,
test_split_fn: Callable = None,
val_kwargs: Mapping = None,
test_kwargs: Mapping = None,
mask_test_indices_in_val: bool = True):
super(CustomSplitter, self).__init__()
self.val_split_fn = val_split_fn
self.test_split_fn = test_split_fn
self.val_kwargs = val_kwargs or dict()
self.test_kwargs = test_kwargs or dict()
self.mask_test_indices_in_val = mask_test_indices_in_val
@property
def val_policy(self):
return self.val_split_fn.__name__ if callable(
self.val_split_fn) else None
@property
def test_policy(self):
return self.test_split_fn.__name__ if callable(
self.test_split_fn) else None
def fit(self, dataset: SpatioTemporalDataset):
_, test_idxs = self.test_split_fn(dataset, **self.test_kwargs)
val_kwargs = self.val_kwargs
if self.mask_test_indices_in_val and len(test_idxs):
val_kwargs = dict(**self.val_kwargs, mask=test_idxs)
train_idxs, val_idxs = self.val_split_fn(dataset, **val_kwargs)
self.set_indices(train_idxs, val_idxs, test_idxs)
class FixedIndicesSplitter(Splitter):
r"""Create a :class:`~tsl.data.datamodule.splitters.Splitter` using fixed
indices for training, validation and test sets."""
def __init__(self,
train_idxs: Optional[Index] = None,
val_idxs: Optional[Index] = None,
test_idxs: Optional[Index] = None):
super(FixedIndicesSplitter, self).__init__()
self.set_indices(train_idxs, val_idxs, test_idxs)
self._fitted = True
def fit(self, dataset: SpatioTemporalDataset):
pass
[docs]class TemporalSplitter(Splitter):
r"""Split the data sequentially with specified lengths."""
def __init__(self,
val_len: Union[int, float] = None,
test_len: Union[int, float] = None):
super(TemporalSplitter, self).__init__()
self._val_len = val_len
self._test_len = test_len
def fit(self, dataset: SpatioTemporalDataset):
idx = np.arange(len(dataset))
val_len, test_len = self._val_len, self._test_len
if test_len < 1:
test_len = int(test_len * len(idx))
if val_len < 1:
val_len = int(val_len * (len(idx) - test_len))
test_start = len(idx) - test_len
val_start = test_start - val_len
self.set_indices(idx[:val_start - dataset.samples_offset],
idx[val_start:test_start - dataset.samples_offset],
idx[test_start:])
@staticmethod
def add_argparse_args(parser):
parser.add_argument('--val-len', type=float or int, default=0.1)
parser.add_argument('--test-len', type=float or int, default=0.2)
return parser
[docs]class AtTimeStepSplitter(Splitter):
r"""Split the data at given time steps (only for
:class:`~tsl.data.SpatioTemporalDataset` with
:class:`~pandas.DatetimeIndex` index)."""
def __init__(self,
first_val_ts: Union[Tuple, datetime] = None,
first_test_ts: Union[Tuple, datetime] = None,
last_val_ts: Union[Tuple, datetime] = None,
last_test_ts: Union[Tuple, datetime] = None,
drop_following_steps: bool = True):
super(AtTimeStepSplitter, self).__init__()
self.first_val_ts = first_val_ts
self.first_test_ts = first_test_ts
self.last_val_ts = last_val_ts
self.last_test_ts = last_test_ts
self.drop_following_steps = drop_following_steps
def fit(self, dataset: SpatioTemporalDataset):
test_idx = indices_between(dataset,
first_ts=self.first_test_ts,
last_ts=self.last_test_ts)
val_idx = indices_between(dataset,
first_ts=self.first_val_ts,
last_ts=self.last_val_ts)
if self.drop_following_steps:
val_idx = val_idx[val_idx < test_idx.min()]
train_idx = np.arange(val_idx.min())
else:
val_idx = np.setdiff1d(val_idx, test_idx)
train_idx = np.setdiff1d(np.arange(len(dataset)), test_idx)
train_idx = np.setdiff1d(train_idx, val_idx)
return self.set_indices(train_idx, val_idx, test_idx)
@staticmethod
def add_argparse_args(parser):
parser.add_argument('--first-val-ts', type=list or tuple, default=None)
parser.add_argument('--first-test-ts',
type=list or tuple,
default=None)
return parser
###
def indices_between(dataset: SpatioTemporalDataset,
first_ts: Union[Tuple, datetime] = None,
last_ts: Union[Tuple, datetime] = None):
if first_ts is not None:
if not isinstance(first_ts, datetime):
# first_ts must be (tuple, list) and len(first_ts) >= 3
first_ts = datetime(*first_ts, tzinfo=dataset.index.tzinfo)
if last_ts is not None:
if not isinstance(last_ts, datetime):
# last_ts must be (tuple, list) and len(last_ts) >= 3
last_ts = datetime(*last_ts, tzinfo=dataset.index.tzinfo)
first_day_loc, last_day_loc = dataset.index.slice_locs(first_ts, last_ts)
first_sample_loc = first_day_loc - dataset.horizon_offset
last_sample_loc = last_day_loc - dataset.horizon_offset - 1
indices_after = first_sample_loc <= dataset.indices
indices_before = dataset.indices < last_sample_loc
indices = np.nonzero(indices_after & indices_before).ravel()
return indices
def split_at_ts(dataset, ts, mask=None):
from_day_idxs = indices_between(dataset, first_ts=ts)
prev_idxs = np.arange(
from_day_idxs[0] if len(from_day_idxs) else len(dataset))
if mask is not None:
from_day_idxs = np.setdiff1d(from_day_idxs, mask)
prev_idxs = np.setdiff1d(prev_idxs, mask)
return prev_idxs, from_day_idxs
def disjoint_months(dataset, months=None, synch_mode=SynchMode.WINDOW):
idxs = np.arange(len(dataset))
months = ensure_list(months)
# divide indices according to window or horizon
if synch_mode is SynchMode.WINDOW:
start = 0
end = dataset.window - 1
elif synch_mode is SynchMode.HORIZON:
start = dataset.horizon_offset
end = dataset.horizon_offset + dataset.horizon - 1
else:
raise ValueError("synch_mode can only be one of "
f"{[SynchMode.WINDOW, SynchMode.HORIZON]}")
# after idxs
indices = np.asarray(dataset._indices)
start_in_months = np.in1d(dataset.index[indices + start].month, months)
end_in_months = np.in1d(dataset.index[indices + end].month, months)
idxs_in_months = start_in_months & end_in_months
after_idxs = idxs[idxs_in_months]
# previous idxs
months = np.setdiff1d(np.arange(1, 13), months)
start_in_months = np.in1d(dataset.index[indices + start].month, months)
end_in_months = np.in1d(dataset.index[indices + end].month, months)
idxs_in_months = start_in_months & end_in_months
prev_idxs = idxs[idxs_in_months]
return prev_idxs, after_idxs
# SPLIT FUNCTIONS
def split_function_builder(fn, *args, name=None, **kwargs):
def wrapper_split_fn(dataset, length=None, mask=None):
return fn(dataset, length=length, mask=mask, *args, **kwargs)
wrapper_split_fn.__name__ = name or "wrapped__%s" % fn.__name__
return wrapper_split_fn
def subset_len(length, set_size, period=None):
if period is None:
period = set_size
if length is None or length <= 0:
length = 0
if 0. < length < 1.:
length = max(int(length * period), 1)
elif period <= length < set_size:
length = int(length / set_size * period)
elif length > set_size:
raise ValueError("Provided length of %i is greater than set_size %i" %
(length, set_size))
return length
def tail_of_period(iterable, length, mask=None, period=None):
size = len(iterable)
period = period or size
if mask is None:
mask = []
indices = np.arange(size)
length = subset_len(length, size, period)
prev_idxs, after_idxs = [], []
for batch_idxs in [indices[i:i + period] for i in range(0, size, period)]:
batch_idxs = np.setdiff1d(batch_idxs, mask)
prev_idxs.extend(batch_idxs[:len(batch_idxs) - length])
after_idxs.extend(batch_idxs[len(batch_idxs) - length:])
return np.array(prev_idxs), np.array(after_idxs)
def random(iterable, length, mask=None):
size = len(iterable)
if mask is None:
mask = []
indices = np.setdiff1d(np.arange(size), mask)
np.random.shuffle(indices)
split_at = len(indices) - subset_len(length, size)
res = [np.sort(indices[:split_at]), np.sort(indices[split_at:])]
return res
def past_pretest_days(dataset, length, mask):
# get the first day of testing, as the first step of the horizon
keep_until = np.min(mask)
first_testing_day_idx = dataset._indices[keep_until]
first_testing_day = dataset.index[first_testing_day_idx +
dataset.lookback + dataset.delay]
# extract samples before first day of testing through the years
tz_info = dataset.index.tzinfo
years = sorted(set(dataset.index.year))
yearly_testing_loc = []
for year in years:
ftd_year = datetime(year,
first_testing_day.month,
first_testing_day.day,
tzinfo=tz_info)
yearly_testing_loc.append(dataset.index.slice_locs(ftd_year)[0])
yearly_train_samples = [
np.where(dataset._indices < ytl - dataset.lookback - dataset.delay)[0]
for ytl in yearly_testing_loc
]
# filter the years in which there are no such samples
yearly_train_samples = [
yts for yts in yearly_train_samples if len(yts) > 0
]
# for each year but the last take the last "val_len // n_years" samples
yearly_val_len = length // len(yearly_train_samples)
yearly_val_lens = [
min(yearly_val_len, len(yts)) for yts in yearly_train_samples[:-1]
]
# For the last year, take the remaining number of samples needed to reach
# val_len. This value is always greater or equals to the other, so we have
# at least the same number of validation samples coming from the last year
# than the maximum among all the other years.
yearly_val_lens.append(length - sum(yearly_val_lens))
# finally extracts the validation samples
val_idxs = [
idxs[-val_len:]
for idxs, val_len in zip(yearly_train_samples, yearly_val_lens)
]
val_idxs = np.concatenate(val_idxs)
# recompute training and test indices
all_idxs = np.arange(len(dataset))
train_idxs = np.setdiff1d(all_idxs, val_idxs)
return train_idxs, val_idxs
def last_month(dataset, mask=None):
if mask is not None:
keep_until = np.min(mask)
last_day_idx = dataset._indices[keep_until]
last_day = dataset.index[last_day_idx]
else:
last_day = dataset.index[-1]
split_day = (last_day.year, last_day.month, 1)
return split_at_ts(dataset, split_day, mask)
# aliases
temporal = TemporalSplitter
at_ts = AtTimeStepSplitter