import functools
import os
from typing import (Iterable, List, Mapping, Optional, Sequence, Set, Tuple,
Union)
import numpy as np
from numpy import ndarray
from pandas import DataFrame, Series
from scipy.sparse import coo_matrix, csc_matrix, csr_matrix
import tsl
from tsl import config, logger
from ...data.datamodule import Splitter, splitters
from ...typing import ScipySparseMatrix
from ...utils.io import load_pickle, save_pickle
from ...utils.python_utils import ensure_list, files_exist, hash_dict
[docs]class Dataset(object):
"""Base class for Datasets in tsl.
Args:
name (str, optional): Name of the dataset. If :obj:`None`, use name of
the class. (default: :obj:`None`)
spatial_aggregation (str): Function (as string) used for aggregation
along temporal dimension. (default: :obj:`'sum'`)
spatial_aggregation (str): Permutation invariant function (as string)
used for aggregation along nodes' dimension. (default: :obj:`'sum'`)
"""
root: Optional[str] = None
similarity_options: Optional[Set] = None
def __init__(self,
name: Optional[str] = None,
similarity_score: Optional[str] = None,
temporal_aggregation: str = 'sum',
spatial_aggregation: str = 'sum',
default_splitting_method: str = 'temporal'):
# Set name
self.name = name if name is not None else self.__class__.__name__
# Set similarity method
if self.similarity_options is not None:
if similarity_score not in self.similarity_options:
raise ValueError("{} is not a valid similarity method.".format(
similarity_score))
self.similarity_score = similarity_score
# Set aggregation methods
self.temporal_aggregation = temporal_aggregation
self.spatial_aggregation = spatial_aggregation
# Set splitting method
self.default_splitting_method = default_splitting_method
def __new__(cls, *args, **kwargs) -> "Dataset":
obj = super().__new__(cls)
# decorate `get_splitter`
obj.get_splitter = cls._wrap_method(obj, obj.get_splitter)
return obj
@staticmethod
def _wrap_method(obj: "Dataset", fn: callable) -> callable:
"""A decorator that extends functionalities of some methods.
- When ``ds.get_splitter(...)`` is called, if no method is specified or
if the method is not dataset-specific (specified by overriding the
method), the method is looked-up among the ones provided by the library.
Notice that this happens whether or not this method is overridden.
Args:
obj: Object whose function will be tracked.
fn: Function that will be wrapped.
Returns:
Decorated method to extend functionalities.
"""
@functools.wraps(fn)
def get_splitter(method: Optional[str] = None, *args, **kwargs) \
-> Splitter:
if method is None:
method = obj.default_splitting_method
splitter = fn(method, *args, **kwargs)
if splitter is None:
try:
splitter = getattr(splitters, method)(*args, **kwargs)
except AttributeError:
raise NotImplementedError(f'Splitter option "{method}" '
f'does not exists.')
return splitter
if fn.__name__ == 'get_splitter':
return get_splitter
def __getstate__(self) -> dict:
# avoids _pickle.PicklingError: Can't pickle <...>: it's not the same
# object as <...>
d = self.__dict__.copy()
del d['get_splitter']
return d
# Data dimensions
@property
def length(self) -> int:
"""Returns the length -- in terms of time steps -- of the dataset.
Returns:
int: Temporal length of the dataset.
"""
raise NotImplementedError
@property
def n_nodes(self) -> int:
"""Returns the number of nodes in the dataset. In case of dynamic graph,
:obj:`n_nodes` is the total number of nodes present in at least one
time step.
Returns:
int: Total number of nodes in the dataset.
"""
raise NotImplementedError
@property
def n_channels(self) -> int:
"""Returns the number of node-level channels of the main signal in the
dataset.
Returns:
int: Number of channels of the main signal.
"""
raise NotImplementedError
#
def __repr__(self):
return "{}(length={}, n_nodes={}, n_channels={})" \
.format(self.name, self.length, self.n_nodes, self.n_channels)
def __len__(self):
"""Returns the length -- in terms of time steps -- of the dataset.
Returns:
int: Temporal length of the dataset.
"""
return self.length
# Directory information
@property
def root_dir(self) -> str:
if isinstance(self.root, str):
root = os.path.expanduser(os.path.normpath(self.root))
elif self.root is None:
root = os.path.join(config.data_dir, self.__class__.__name__)
else:
raise ValueError
return root
@property
def raw_file_names(self) \
-> Union[str, Sequence[str], Mapping[str, str]]:
"""The name of the files in the :obj:`self.root_dir` folder that must be
present in order to skip downloading."""
return []
@property
def required_file_names(self) \
-> Union[str, Sequence[str], Mapping[str, str]]:
"""The name of the files in the :obj:`self.root_dir` folder that must be
present in order to skip building."""
return self.raw_file_names
@property
def raw_files_paths(self) -> Union[List[str], Mapping[str, str]]:
"""The absolute filepaths that must be present in order to skip
downloading."""
files = self.raw_file_names
if not isinstance(files, Mapping):
files = ensure_list(files)
if isinstance(files, list):
return [os.path.join(self.root_dir, f) for f in files]
else:
return {
k: os.path.join(self.root_dir, f)
for k, f in files.items()
}
@property
def required_files_paths(self) -> Union[List[str], Mapping[str, str]]:
"""The absolute filepaths that must be present in order to skip
building."""
files = self.required_file_names
if not isinstance(files, Mapping):
files = ensure_list(files)
if isinstance(files, list):
return [os.path.join(self.root_dir, f) for f in files]
else:
return {
k: os.path.join(self.root_dir, f)
for k, f in files.items()
}
@property
def raw_files_paths_list(self) -> List[str]:
"""The list of absolute filepaths that must be present in order to skip
downloading."""
files = self.raw_files_paths
if isinstance(files, Mapping):
files = list(files.values())
return files
@property
def required_files_paths_list(self) -> List[str]:
"""The list of absolute filepaths that are required to load the
dataset."""
files = self.required_files_paths
if isinstance(files, Mapping):
files = list(files.values())
return files
# Loading pipeline: load() → load_raw() → build() → download()
def maybe_download(self):
if not files_exist(self.raw_files_paths_list):
os.makedirs(self.root_dir, exist_ok=True)
self.download()
def maybe_build(self):
if not files_exist(self.required_files_paths_list):
os.makedirs(self.root_dir, exist_ok=True)
self.build()
[docs] def download(self) -> None:
"""Downloads dataset's files to the :obj:`self.root_dir` folder."""
raise NotImplementedError
[docs] def build(self) -> None:
"""Eventually build the dataset from raw data to :obj:`self.root_dir`
folder."""
pass
[docs] def load_raw(self, *args, **kwargs):
"""Loads raw dataset without any data preprocessing."""
raise NotImplementedError
[docs] def load(self, *args, **kwargs):
"""Loads raw dataset and preprocess data. Default to :obj:`load_raw`."""
return self.load_raw(*args, **kwargs)
def clean_downloads(self):
for file in self.raw_files_paths_list:
if file not in self.required_files_paths_list:
if os.path.exists(file):
os.unlink(file)
def clean_root_dir(self):
import shutil
total_files = self.required_files_paths_list + self.raw_files_paths_list
for filename in os.listdir(self.root_dir):
file_path = os.path.join(self.root_dir, filename)
if file_path in total_files:
continue
try:
if os.path.isfile(file_path) or os.path.islink(file_path):
os.unlink(file_path)
elif os.path.isdir(file_path):
shutil.rmtree(file_path)
except Exception as e:
print('Failed to delete %s. Reason: %s' % (file_path, e))
# Representations
[docs] def dataframe(self) -> Union[DataFrame, List[DataFrame]]:
"""Returns a pandas representation of the dataset in the form of a
:class:`~pandas.DataFrame`. May be a list of DataFrames if the dataset
has a dynamic structure."""
raise NotImplementedError
[docs] def numpy(
self,
return_idx: bool = False
) -> Union[ndarray, List[ndarray], Tuple[ndarray, Series], Tuple[
List[ndarray], Series]]:
"""Returns a numpy representation of the dataset in the form of a
:class:`~numpy.ndarray`. If :obj:`return_index` is :obj:`True`, it
returns also a :class:`~pandas.Series` that can be used as index. May
be a list of ndarrays (and Series) if the dataset has a dynamic
structure."""
raise NotImplementedError
# IO
[docs] def save_pickle(self, filename: str) -> None:
"""Save :obj:`Dataset` to disk.
Args:
filename (str): path to filename for storage.
"""
save_pickle(self, filename)
[docs] @classmethod
def load_pickle(cls, filename: str) -> "Dataset":
"""Load instance of :obj:`Dataset` from disk.
Args:
filename (str): path of :obj:`Dataset`.
"""
obj = load_pickle(filename)
if not isinstance(obj, cls):
raise TypeError(f"Loaded file is not of class {cls}.")
return obj
# Similarity pipeline: get_adj() → get_similarity() → compute_similarity()
[docs] def compute_similarity(self, method: str,
**kwargs) -> Optional[np.ndarray]:
r"""Implements the options for the similarity matrix :math:`\mathbf{S}
\in \mathbb{R}^{N \times N}` computation, according to :obj:`method`.
Args:
method (str): Method for the similarity computation.
**kwargs (optional): Additional optional keyword arguments.
Returns:
ndarray: The similarity dense matrix.
"""
raise NotImplementedError
[docs] def get_similarity(self,
method: Optional[str] = None,
save: bool = False,
**kwargs) -> ndarray:
r"""Returns the matrix :math:`\mathbf{S} \in \mathbb{R}^{N \\times N}`,
where :math:`N=`:obj:`self.n_nodes`, with the pairwise similarity
scores between nodes.
Args:
method (str, optional): Method for the similarity computation. If
:obj:`None`, defaults to dataset-specific default method.
(default: :obj:`None`)
save (bool): Whether to save similarity matrix in dataset's
directory after computation.
(default: :obj:`True`)
**kwargs (optional): Additional optional keyword arguments.
Returns:
ndarray: The similarity dense matrix.
Raises:
ValueError: If the similarity method is not valid.
"""
if method is None:
method = self.similarity_score
if method not in self.similarity_options:
raise ValueError("Similarity method '{}' not valid".format(method))
if save:
enc = hash_dict(
dict(method=method,
class_name=self.__class__.__name__,
name=self.name,
**kwargs))
name = "sim_{}.npy".format(enc)
path = os.path.join(self.root_dir, name)
if os.path.exists(path):
logger.warning("Loading cached similarity matrix.")
return np.load(path)
# get similarity method
sim = self.compute_similarity(method, **kwargs)
if save:
np.save(path, sim)
logger.info(f"Similarity matrix saved at {path}.")
return sim
[docs] def get_connectivity(self,
method: Optional[str] = None,
threshold: Optional[float] = None,
knn: Optional[int] = None,
binary_weights: bool = False,
include_self: bool = True,
force_symmetric: bool = False,
normalize_axis: Optional[int] = None,
layout: str = 'edge_index',
**kwargs) -> Union[ndarray, Tuple, ScipySparseMatrix]:
r"""Returns the weighted adjacency matrix :math:`\mathbf{A} \in
\mathbb{R}^{N \times N}`, where :math:`N=`:obj:`self.n_nodes`. The
element :math:`a_{i,j} \in \mathbf{A}` is 0 if there not exists an edge
connecting node :math:`i` to node :math:`j`. The return type depends on
the specified :obj:`layout` (default: :obj:`edge_index`).
Args:
method (str, optional): Method for the similarity computation. If
:obj:`None`, defaults to dataset-specific default method.
(default: :obj:`None`)
threshold (float, optional): If not :obj:`None`, set to 0 the values
below the threshold.
(default: :obj:`None`)
knn (int, optional): If not :obj:`None`, keep only :math:`k=`
:obj:`knn` nearest incoming neighbors.
(default: :obj:`None`)
binary_weights (bool): If :obj:`True`, the positive weights of the
adjacency matrix are set to 1.
(default: :obj:`False`)
include_self (bool): If :obj:`False`, self-loops are never taken
into account. (default: :obj:`True`)
force_symmetric (bool): Force adjacency matrix to be symmetric by
taking the maximum value between the two directions for each
edge. (default: :obj:`False`)
normalize_axis (int, optional): Divide edge weight :math:`a_{i, j}`
by :math:`\sum_k a_{i, k}`, if :obj:`normalize_axis=0` or
:math:`\sum_k a_{k, j}`, if :obj:`normalize_axis=1`. :obj:`None`
for no normalization.
(default: :obj:`None`)
layout (str): Convert matrix to a dense/sparse format. Available
options are:
- :obj:`dense`: keep matrix dense :math:`\mathbf{A} \in
\mathbb{R}^{N \times N}`.
- :obj:`edge_index`: convert to (edge_index, edge_weight) tuple,
where edge_index has shape :math:`[2, E]` and edge_weight has
shape :math:`[E]`, being :math:`E` the number of edges.
- :obj:`coo`/:obj:`csr`/:obj:`csc`: convert to specified scipy
sparse matrix type.
(default: :obj:`edge_index`)
**kwargs (optional): Additional optional keyword arguments for
similarity computation.
Returns:
The similarity dense matrix.
"""
if 'sparse' in kwargs:
import warnings
warnings.warn("The argument 'sparse' is deprecated and will be "
"removed in future version of tsl. Please use "
"the argument `layout` instead.")
layout = 'edge_index' if kwargs['sparse'] else 'dense'
if method == 'full':
adj = np.ones((self.n_nodes, self.n_nodes))
elif method == 'identity':
adj = np.eye(self.n_nodes)
else:
adj = self.get_similarity(method, **kwargs)
if knn is not None:
from tsl.ops.similarities import top_k
adj = top_k(adj,
knn,
include_self=include_self,
keep_values=not binary_weights)
elif binary_weights:
adj = (adj > 0).astype(adj.dtype)
if threshold is not None:
adj[adj < threshold] = 0
if not include_self:
np.fill_diagonal(adj, 0)
if force_symmetric:
adj = np.maximum.reduce([adj, adj.T])
if normalize_axis:
adj = adj / (adj.sum(normalize_axis, keepdims=True) + tsl.epsilon)
if layout == 'dense':
return adj
elif layout == 'edge_index':
from tsl.ops.connectivity import adj_to_edge_index
return adj_to_edge_index(adj)
elif layout in ['coo', 'sparse_matrix']:
return coo_matrix(adj)
elif layout == 'csr':
return csr_matrix(adj)
elif layout == 'csc':
return csc_matrix(adj)
else:
raise ValueError(
f"Invalid format for connectivity: {layout}. Valid"
" options are [dense, edge_index, coo, csr, csc].")
# Cross-validation splitting options
[docs] def get_splitter(self,
method: Optional[str] = None,
*args,
**kwargs) -> Splitter:
"""Returns the splitter for a :class:`~tsl.data.SpatioTemporalDataset`.
A :class:`~tsl.data.preprocessing.Splitter` provides the splits of the
dataset -- in terms of indices -- for cross validation."""
# Data methods
[docs] def aggregate(self, node_index: Optional[Iterable[Iterable]] = None):
"""Aggregates nodes given an index of cluster assignments (spatial
aggregation).
Args:
node_index: Sequence of grouped node ids.
"""
raise NotImplementedError
# Getters for SpatioTemporalDataset
[docs] def get_config(self) -> dict:
"""Returns the keywords arguments (as dict) for instantiating a
:class:`~tsl.data.SpatioTemporalDataset`."""
raise NotImplementedError