Organizing data#

DataModule#

SpatioTemporalDataModule

Base LightningDataModule for SpatioTemporalDataset.

class SpatioTemporalDataModule(dataset: SpatioTemporalDataset, scalers: Optional[Mapping] = None, mask_scaling: bool = True, splitter: Optional[Splitter] = None, batch_size: int = 32, workers: int = 0, pin_memory: bool = False)[source]#

Base LightningDataModule for SpatioTemporalDataset.

Parameters:
  • dataset (SpatioTemporalDataset) – The complete dataset.

  • scalers (dict, optional) – Named mapping of Scaler to be used for data rescaling after splitting. Every scaler is given as input the attribute of the dataset named as the scaler’s key. If None, no scaling is performed. (default None)

  • mask_scaling (bool) – If True, then compute statistics for dataset.target scaler (if any) by considering only valid values (according to dataset.mask). (default True)

  • splitter (Splitter, optional) – The Splitter to be used for splitting dataset into train/validation/test sets. (default None)

  • batch_size (int) – Size of the mini-batches for the dataloaders. (default 32)

  • workers (int) – Number of workers to use in the dataloaders. (default 0)

  • pin_memory (bool) – If True, then enable pinned GPU memory for train_dataloader(). (default False)

setup(stage: Optional[Literal['fit', 'validate', 'test', 'predict']] = None)[source]#

Called at the beginning of fit (train + validate), validate, test, or predict. This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP.

Parameters:

stage – either 'fit', 'validate', 'test', or 'predict'

Example:

class LitModel(...):
    def __init__(self):
        self.l1 = None

    def prepare_data(self):
        download_data()
        tokenize()

        # don't do this
        self.something = else

    def setup(self, stage):
        data = load_data(...)
        self.l1 = nn.Linear(28, data.num_classes)

Splitters#

Splitter

Base class for splitter module.

CustomSplitter

Create a Splitter using custom validation and test sets splitting functions.

TemporalSplitter

Split the data sequentially with specified lengths.

AtTimeStepSplitter

Split the data at given time steps (only for SpatioTemporalDataset with DatetimeIndex index).

class Splitter(*args, **kwargs)[source]#

Base class for splitter module.

class CustomSplitter(*args, **kwargs)[source]#

Create a Splitter using custom validation and test sets splitting functions.

class TemporalSplitter(*args, **kwargs)[source]#

Split the data sequentially with specified lengths.

class AtTimeStepSplitter(*args, **kwargs)[source]#

Split the data at given time steps (only for SpatioTemporalDataset with DatetimeIndex index).