Source code for tsl.datasets.prototypes.dataset

import functools
import os
from typing import (Iterable, List, Mapping, Optional, Sequence, Set, Tuple,
                    Union)

import numpy as np
from numpy import ndarray
from pandas import DataFrame, Series
from scipy.sparse import coo_matrix, csc_matrix, csr_matrix

import tsl
from tsl import config, logger

from ...data.datamodule import Splitter, splitters
from ...typing import ScipySparseMatrix
from ...utils.io import load_pickle, save_pickle
from ...utils.python_utils import ensure_list, files_exist, hash_dict


[docs]class Dataset(object): """Base class for Datasets in tsl. Args: name (str, optional): Name of the dataset. If :obj:`None`, use name of the class. (default: :obj:`None`) spatial_aggregation (str): Function (as string) used for aggregation along temporal dimension. (default: :obj:`'sum'`) spatial_aggregation (str): Permutation invariant function (as string) used for aggregation along nodes' dimension. (default: :obj:`'sum'`) """ root: Optional[str] = None similarity_options: Optional[Set] = None def __init__(self, name: Optional[str] = None, similarity_score: Optional[str] = None, temporal_aggregation: str = 'sum', spatial_aggregation: str = 'sum', default_splitting_method: str = 'temporal'): # Set name self.name = name if name is not None else self.__class__.__name__ # Set similarity method if self.similarity_options is not None: if similarity_score not in self.similarity_options: raise ValueError("{} is not a valid similarity method.".format( similarity_score)) self.similarity_score = similarity_score # Set aggregation methods self.temporal_aggregation = temporal_aggregation self.spatial_aggregation = spatial_aggregation # Set splitting method self.default_splitting_method = default_splitting_method def __new__(cls, *args, **kwargs) -> "Dataset": obj = super().__new__(cls) # decorate `get_splitter` obj.get_splitter = cls._wrap_method(obj, obj.get_splitter) return obj @staticmethod def _wrap_method(obj: "Dataset", fn: callable) -> callable: """A decorator that extends functionalities of some methods. - When ``ds.get_splitter(...)`` is called, if no method is specified or if the method is not dataset-specific (specified by overriding the method), the method is looked-up among the ones provided by the library. Notice that this happens whether or not this method is overridden. Args: obj: Object whose function will be tracked. fn: Function that will be wrapped. Returns: Decorated method to extend functionalities. """ @functools.wraps(fn) def get_splitter(method: Optional[str] = None, *args, **kwargs) \ -> Splitter: if method is None: method = obj.default_splitting_method splitter = fn(method, *args, **kwargs) if splitter is None: try: splitter = getattr(splitters, method)(*args, **kwargs) except AttributeError: raise NotImplementedError(f'Splitter option "{method}" ' f'does not exists.') return splitter if fn.__name__ == 'get_splitter': return get_splitter def __getstate__(self) -> dict: # avoids _pickle.PicklingError: Can't pickle <...>: it's not the same # object as <...> d = self.__dict__.copy() del d['get_splitter'] return d # Data dimensions @property def length(self) -> int: """Returns the length -- in terms of time steps -- of the dataset. Returns: int: Temporal length of the dataset. """ raise NotImplementedError @property def n_nodes(self) -> int: """Returns the number of nodes in the dataset. In case of dynamic graph, :obj:`n_nodes` is the total number of nodes present in at least one time step. Returns: int: Total number of nodes in the dataset. """ raise NotImplementedError @property def n_channels(self) -> int: """Returns the number of node-level channels of the main signal in the dataset. Returns: int: Number of channels of the main signal. """ raise NotImplementedError # def __repr__(self): return "{}(length={}, n_nodes={}, n_channels={})" \ .format(self.name, self.length, self.n_nodes, self.n_channels) def __len__(self): """Returns the length -- in terms of time steps -- of the dataset. Returns: int: Temporal length of the dataset. """ return self.length # Directory information @property def root_dir(self) -> str: if isinstance(self.root, str): root = os.path.expanduser(os.path.normpath(self.root)) elif self.root is None: root = os.path.join(config.data_dir, self.__class__.__name__) else: raise ValueError return root @property def raw_file_names(self) \ -> Union[str, Sequence[str], Mapping[str, str]]: """The name of the files in the :obj:`self.root_dir` folder that must be present in order to skip downloading.""" return [] @property def required_file_names(self) \ -> Union[str, Sequence[str], Mapping[str, str]]: """The name of the files in the :obj:`self.root_dir` folder that must be present in order to skip building.""" return self.raw_file_names @property def raw_files_paths(self) -> Union[List[str], Mapping[str, str]]: """The absolute filepaths that must be present in order to skip downloading.""" files = self.raw_file_names if not isinstance(files, Mapping): files = ensure_list(files) if isinstance(files, list): return [os.path.join(self.root_dir, f) for f in files] else: return { k: os.path.join(self.root_dir, f) for k, f in files.items() } @property def required_files_paths(self) -> Union[List[str], Mapping[str, str]]: """The absolute filepaths that must be present in order to skip building.""" files = self.required_file_names if not isinstance(files, Mapping): files = ensure_list(files) if isinstance(files, list): return [os.path.join(self.root_dir, f) for f in files] else: return { k: os.path.join(self.root_dir, f) for k, f in files.items() } @property def raw_files_paths_list(self) -> List[str]: """The list of absolute filepaths that must be present in order to skip downloading.""" files = self.raw_files_paths if isinstance(files, Mapping): files = list(files.values()) return files @property def required_files_paths_list(self) -> List[str]: """The list of absolute filepaths that are required to load the dataset.""" files = self.required_files_paths if isinstance(files, Mapping): files = list(files.values()) return files # Loading pipeline: load() → load_raw() → build() → download() def maybe_download(self): if not files_exist(self.raw_files_paths_list): os.makedirs(self.root_dir, exist_ok=True) self.download() def maybe_build(self): if not files_exist(self.required_files_paths_list): os.makedirs(self.root_dir, exist_ok=True) self.build()
[docs] def download(self) -> None: """Downloads dataset's files to the :obj:`self.root_dir` folder.""" raise NotImplementedError
[docs] def build(self) -> None: """Eventually build the dataset from raw data to :obj:`self.root_dir` folder.""" pass
[docs] def load_raw(self, *args, **kwargs): """Loads raw dataset without any data preprocessing.""" raise NotImplementedError
[docs] def load(self, *args, **kwargs): """Loads raw dataset and preprocess data. Default to :obj:`load_raw`.""" return self.load_raw(*args, **kwargs)
def clean_downloads(self): for file in self.raw_files_paths_list: if file not in self.required_files_paths_list: if os.path.exists(file): os.unlink(file) def clean_root_dir(self): import shutil total_files = self.required_files_paths_list + self.raw_files_paths_list for filename in os.listdir(self.root_dir): file_path = os.path.join(self.root_dir, filename) if file_path in total_files: continue try: if os.path.isfile(file_path) or os.path.islink(file_path): os.unlink(file_path) elif os.path.isdir(file_path): shutil.rmtree(file_path) except Exception as e: print('Failed to delete %s. Reason: %s' % (file_path, e)) # Representations
[docs] def dataframe(self) -> Union[DataFrame, List[DataFrame]]: """Returns a pandas representation of the dataset in the form of a :class:`~pandas.DataFrame`. May be a list of DataFrames if the dataset has a dynamic structure.""" raise NotImplementedError
[docs] def numpy( self, return_idx: bool = False ) -> Union[ndarray, List[ndarray], Tuple[ndarray, Series], Tuple[ List[ndarray], Series]]: """Returns a numpy representation of the dataset in the form of a :class:`~numpy.ndarray`. If :obj:`return_index` is :obj:`True`, it returns also a :class:`~pandas.Series` that can be used as index. May be a list of ndarrays (and Series) if the dataset has a dynamic structure.""" raise NotImplementedError
# IO
[docs] def save_pickle(self, filename: str) -> None: """Save :obj:`Dataset` to disk. Args: filename (str): path to filename for storage. """ save_pickle(self, filename)
[docs] @classmethod def load_pickle(cls, filename: str) -> "Dataset": """Load instance of :obj:`Dataset` from disk. Args: filename (str): path of :obj:`Dataset`. """ obj = load_pickle(filename) if not isinstance(obj, cls): raise TypeError(f"Loaded file is not of class {cls}.") return obj
# Similarity pipeline: get_adj() → get_similarity() → compute_similarity()
[docs] def compute_similarity(self, method: str, **kwargs) -> Optional[np.ndarray]: r"""Implements the options for the similarity matrix :math:`\mathbf{S} \in \mathbb{R}^{N \times N}` computation, according to :obj:`method`. Args: method (str): Method for the similarity computation. **kwargs (optional): Additional optional keyword arguments. Returns: ndarray: The similarity dense matrix. """ raise NotImplementedError
[docs] def get_similarity(self, method: Optional[str] = None, save: bool = False, **kwargs) -> ndarray: r"""Returns the matrix :math:`\mathbf{S} \in \mathbb{R}^{N \\times N}`, where :math:`N=`:obj:`self.n_nodes`, with the pairwise similarity scores between nodes. Args: method (str, optional): Method for the similarity computation. If :obj:`None`, defaults to dataset-specific default method. (default: :obj:`None`) save (bool): Whether to save similarity matrix in dataset's directory after computation. (default: :obj:`True`) **kwargs (optional): Additional optional keyword arguments. Returns: ndarray: The similarity dense matrix. Raises: ValueError: If the similarity method is not valid. """ if method is None: method = self.similarity_score if method not in self.similarity_options: raise ValueError("Similarity method '{}' not valid".format(method)) if save: enc = hash_dict( dict(method=method, class_name=self.__class__.__name__, name=self.name, **kwargs)) name = "sim_{}.npy".format(enc) path = os.path.join(self.root_dir, name) if os.path.exists(path): logger.warning("Loading cached similarity matrix.") return np.load(path) # get similarity method sim = self.compute_similarity(method, **kwargs) if save: np.save(path, sim) logger.info(f"Similarity matrix saved at {path}.") return sim
[docs] def get_connectivity(self, method: Optional[str] = None, threshold: Optional[float] = None, knn: Optional[int] = None, binary_weights: bool = False, include_self: bool = True, force_symmetric: bool = False, normalize_axis: Optional[int] = None, layout: str = 'edge_index', **kwargs) -> Union[ndarray, Tuple, ScipySparseMatrix]: r"""Returns the weighted adjacency matrix :math:`\mathbf{A} \in \mathbb{R}^{N \times N}`, where :math:`N=`:obj:`self.n_nodes`. The element :math:`a_{i,j} \in \mathbf{A}` is 0 if there not exists an edge connecting node :math:`i` to node :math:`j`. The return type depends on the specified :obj:`layout` (default: :obj:`edge_index`). Args: method (str, optional): Method for the similarity computation. If :obj:`None`, defaults to dataset-specific default method. (default: :obj:`None`) threshold (float, optional): If not :obj:`None`, set to 0 the values below the threshold. (default: :obj:`None`) knn (int, optional): If not :obj:`None`, keep only :math:`k=` :obj:`knn` nearest incoming neighbors. (default: :obj:`None`) binary_weights (bool): If :obj:`True`, the positive weights of the adjacency matrix are set to 1. (default: :obj:`False`) include_self (bool): If :obj:`False`, self-loops are never taken into account. (default: :obj:`True`) force_symmetric (bool): Force adjacency matrix to be symmetric by taking the maximum value between the two directions for each edge. (default: :obj:`False`) normalize_axis (int, optional): Divide edge weight :math:`a_{i, j}` by :math:`\sum_k a_{i, k}`, if :obj:`normalize_axis=0` or :math:`\sum_k a_{k, j}`, if :obj:`normalize_axis=1`. :obj:`None` for no normalization. (default: :obj:`None`) layout (str): Convert matrix to a dense/sparse format. Available options are: - :obj:`dense`: keep matrix dense :math:`\mathbf{A} \in \mathbb{R}^{N \times N}`. - :obj:`edge_index`: convert to (edge_index, edge_weight) tuple, where edge_index has shape :math:`[2, E]` and edge_weight has shape :math:`[E]`, being :math:`E` the number of edges. - :obj:`coo`/:obj:`csr`/:obj:`csc`: convert to specified scipy sparse matrix type. (default: :obj:`edge_index`) **kwargs (optional): Additional optional keyword arguments for similarity computation. Returns: The similarity dense matrix. """ if 'sparse' in kwargs: import warnings warnings.warn("The argument 'sparse' is deprecated and will be " "removed in future version of tsl. Please use " "the argument `layout` instead.") layout = 'edge_index' if kwargs['sparse'] else 'dense' if method == 'full': adj = np.ones((self.n_nodes, self.n_nodes)) elif method == 'identity': adj = np.eye(self.n_nodes) else: adj = self.get_similarity(method, **kwargs) if knn is not None: from tsl.ops.similarities import top_k adj = top_k(adj, knn, include_self=include_self, keep_values=not binary_weights) elif binary_weights: adj = (adj > 0).astype(adj.dtype) if threshold is not None: adj[adj < threshold] = 0 if not include_self: np.fill_diagonal(adj, 0) if force_symmetric: adj = np.maximum.reduce([adj, adj.T]) if normalize_axis: adj = adj / (adj.sum(normalize_axis, keepdims=True) + tsl.epsilon) if layout == 'dense': return adj elif layout == 'edge_index': from tsl.ops.connectivity import adj_to_edge_index return adj_to_edge_index(adj) elif layout in ['coo', 'sparse_matrix']: return coo_matrix(adj) elif layout == 'csr': return csr_matrix(adj) elif layout == 'csc': return csc_matrix(adj) else: raise ValueError( f"Invalid format for connectivity: {layout}. Valid" " options are [dense, edge_index, coo, csr, csc].")
# Cross-validation splitting options
[docs] def get_splitter(self, method: Optional[str] = None, *args, **kwargs) -> Splitter: """Returns the splitter for a :class:`~tsl.data.SpatioTemporalDataset`. A :class:`~tsl.data.preprocessing.Splitter` provides the splits of the dataset -- in terms of indices -- for cross validation."""
# Data methods
[docs] def aggregate(self, node_index: Optional[Iterable[Iterable]] = None): """Aggregates nodes given an index of cluster assignments (spatial aggregation). Args: node_index: Sequence of grouped node ids. """ raise NotImplementedError
# Getters for SpatioTemporalDataset
[docs] def get_config(self) -> dict: """Returns the keywords arguments (as dict) for instantiating a :class:`~tsl.data.SpatioTemporalDataset`.""" raise NotImplementedError