from typing import Literal, Mapping, Optional
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader, Dataset, Subset
import tsl
from ...typing import Index
from ..loader import StaticGraphLoader
from ..spatiotemporal_dataset import SpatioTemporalDataset
from .splitters import Splitter
StageOptions = Literal['fit', 'validate', 'test', 'predict']
[docs]class SpatioTemporalDataModule(LightningDataModule):
r"""Base :class:`~pytorch_lightning.core.LightningDataModule` for
:class:`~tsl.data.SpatioTemporalDataset`.
Args:
dataset (SpatioTemporalDataset): The complete dataset.
scalers (dict, optional): Named mapping of
:class:`~tsl.data.preprocessing.scalers.Scaler`
to be used for data rescaling after splitting. Every scaler is given
as input the attribute of the dataset named as the scaler's key.
If :obj:`None`, no scaling is performed.
(default :obj:`None`)
mask_scaling (bool): If :obj:`True`, then compute statistics for
:obj:`dataset.target` scaler (if any) by considering only valid
values (according to :obj:`dataset.mask`).
(default :obj:`True`)
splitter (Splitter, optional): The
:class:`~tsl.data.datamodule.splitters.Splitter` to be used for
splitting :obj:`dataset` into train/validation/test sets.
(default :obj:`None`)
batch_size (int): Size of the mini-batches for the dataloaders.
(default :obj:`32`)
workers (int): Number of workers to use in the dataloaders.
(default :obj:`0`)
pin_memory (bool): If :obj:`True`, then enable pinned GPU memory for
:meth:`~tsl.data.datamodule.SpatioTemporalDataModule.train_dataloader`.
(default :obj:`False`)
"""
def __init__(self,
dataset: SpatioTemporalDataset,
scalers: Optional[Mapping] = None,
mask_scaling: bool = True,
splitter: Optional[Splitter] = None,
batch_size: int = 32,
workers: int = 0,
pin_memory: bool = False):
super(SpatioTemporalDataModule, self).__init__()
self.torch_dataset = dataset
# splitting
self.splitter = splitter
self.trainset = self.valset = self.testset = None
# scaling
if scalers is None:
self.scalers = dict()
else:
self.scalers = scalers
self.mask_scaling = mask_scaling
# data loaders
self.batch_size = batch_size
self.workers = workers
self.pin_memory = pin_memory
def __getattr__(self, item):
ds = self.__dict__.get('torch_dataset')
if ds is not None and hasattr(ds, item):
return getattr(ds, item)
else:
raise AttributeError(item)
def __repr__(self):
return "{}(train_len={}, val_len={}, test_len={}, " \
"scalers=[{}], batch_size={})" \
.format(self.__class__.__name__,
self.train_len, self.val_len, self.test_len,
', '.join(self.scalers.keys()), self.batch_size)
@property
def trainset(self):
return self._trainset
@property
def valset(self):
return self._valset
@property
def testset(self):
return self._testset
@trainset.setter
def trainset(self, value):
self._add_set('train', value)
@valset.setter
def valset(self, value):
self._add_set('val', value)
@testset.setter
def testset(self, value):
self._add_set('test', value)
@property
def train_len(self):
return len(self.trainset) if self.trainset is not None else None
@property
def val_len(self):
return len(self.valset) if self.valset is not None else None
@property
def test_len(self):
return len(self.testset) if self.testset is not None else None
@property
def train_slice(self):
return self._train_slice if hasattr(self, '_train_slice') else None
@property
def val_slice(self):
return self._val_slice if hasattr(self, '_val_slice') else None
@property
def test_slice(self):
return self._test_slice if hasattr(self, '_test_slice') else None
def _add_set(self, split_type, _set):
assert split_type in ['train', 'val', 'test']
split_type = '_' + split_type
name = split_type + 'set'
if _set is None or isinstance(_set, Dataset):
setattr(self, name, _set)
else:
indices = _set
assert isinstance(indices, Index.__args__), \
f"type {type(indices)} of `{name}` is not a valid type. " \
"It must be a dataset or a sequence of indices."
_set = Subset(self.torch_dataset, indices)
_slice = self.torch_dataset.expand_indices(_set.indices,
merge=True)
setattr(self, name, _set)
slice_name = split_type + '_slice' # e.g. trainset > _train_slice
setattr(self, slice_name, _slice)
[docs] def setup(self, stage: StageOptions = None):
# splitting
if self.splitter is not None:
self.splitter.split(self.torch_dataset)
self.trainset = self.splitter.train_idxs
self.valset = self.splitter.val_idxs
self.testset = self.splitter.test_idxs
for key, scaler, in self.scalers.items():
if key not in self.torch_dataset:
raise RuntimeError("Cannot find a tensor to scale matching "
f"key '{key}'.")
# set scalers
if stage == 'predict':
tsl.logger.info(f'Set scaler for {key}: {scaler}')
else: # fit scalers before training
data = getattr(self.torch_dataset, key)
# get only training slice
if 't' in self.torch_dataset.patterns[key]:
data = data[self.train_slice]
mask = None
if key == 'target' and self.mask_scaling:
if self.torch_dataset.mask is not None:
mask = self.torch_dataset.get_mask()[self.train_slice]
scaler = scaler.fit(data, mask=mask, keepdims=True)
tsl.logger.info(f'Fit and set scaler for {key}: {scaler}')
self.torch_dataset.add_scaler(key, scaler)
def get_dataloader(self, split: Literal['train', 'val', 'test'] = None,
shuffle: bool = False,
batch_size: Optional[int] = None) \
-> Optional[DataLoader]:
if split is None:
dataset = self.torch_dataset
elif split in ['train', 'val', 'test']:
dataset = getattr(self, f'{split}set')
else:
raise ValueError("Argument `split` must be one of "
"'train', 'val', or 'test'.")
if dataset is None:
return None
pin_memory = self.pin_memory if split == 'train' else None
return StaticGraphLoader(dataset,
batch_size=batch_size or self.batch_size,
shuffle=shuffle,
drop_last=split == 'train',
num_workers=self.workers,
pin_memory=pin_memory)
def train_dataloader(self, shuffle: bool = True,
batch_size: Optional[int] = None) \
-> Optional[DataLoader]:
""""""
return self.get_dataloader('train', shuffle, batch_size)
def val_dataloader(self, shuffle: bool = False,
batch_size: Optional[int] = None) \
-> Optional[DataLoader]:
""""""
return self.get_dataloader('val', shuffle, batch_size)
def test_dataloader(self, shuffle: bool = False,
batch_size: Optional[int] = None) \
-> Optional[DataLoader]:
""""""
return self.get_dataloader('test', shuffle, batch_size)