Source code for tsl.ops.pattern

import re
from collections import Counter
from types import ModuleType
from typing import Iterable, List, Optional, Union

import numpy as np
import torch
from numpy import ndarray
from torch import Tensor

_PATTERNS = {
    'tnef': re.compile('^[1-2]?t?(n{0,2}|e?)f*$'),
    'btnef': re.compile('^b?t?(n{0,2}|e?)f*$'),
}

#  PATTERN PARSING ############################################################


[docs]def check_pattern(pattern: str, split: bool = False, ndim: int = None, include_batch: bool = False) -> Union[str, list]: r"""Check that :attr:`pattern` is allowed. A pattern is a string of tokens interleaved with blank spaces, where each token specifies what an axis in a tensor refers to. The supported tokens are: * 't', for the time dimension * 'n', for the node dimension * 'e', for the edge dimension * 'f' or 'c', for the feature/channel dimension ('c' token is automatically converted to 'f') In order to be valid, a pattern must have: 1. at most one 't' dimension, as the first token; 2. at most two (consecutive) 'n' dimensions, right after the 't' token or at the beginning of the pattern; 3. at most one 'e' dimension, either as the first token or after a 't'; 4. either 'n' or 'e' dimensions, but not both together; 5. all further tokens must be 'c' or 'f'. Args: pattern (str): The input pattern, specifying with a token what an axis in a tensor refers to. The supported tokens are: * 't', for the time dimension * 'n', for the node dimension * 'e', for the edge dimension * 'f' or 'c', for the feature/channel dimension ('c' token is automatically converted to 'f') split (bool): If :obj:`True`, then return an ordered list of the tokens in the sanitized pattern. (default: :obj:`False`) ndim (int, optional): If it is not :obj:`None`, then check that :attr:`pattern` has :attr:`ndim` tokens. (default: :obj:`None`) include_batch (bool): If :obj:`True`, then allows the token :obj:`b`. (default: :obj:`False`) Returns: str or list: The sanitized pattern as a string, or a list of the tokens in the pattern. """ pattern_squeezed = pattern.replace(' ', '').replace('c', 'f') # check 'c'/'f' follows 'n', 'n' follows 't' # allow for duplicate 'n' dims (e.g., 'n n', 't n n f') # allow for limitless 'c'/'f' dims (e.g., 't n f f') # if include_batch, then allow for batch dimension match_with = _PATTERNS['btnef' if include_batch else 'tnef'] if not match_with.match(pattern_squeezed): raise RuntimeError(f'Pattern "{pattern}" not allowed.') elif ndim is not None and len(pattern_squeezed) != ndim: raise RuntimeError(f'Pattern "{pattern}" has not {ndim} dimensions.') if split: return list(pattern_squeezed) return ' '.join(pattern_squeezed)
def infer_pattern(shape: tuple, t: Optional[int] = None, n: Optional[int] = None, e: Optional[int] = None) -> str: out = [] for dim in shape: if t is not None and dim == t: out.append('t') elif n is not None and dim == n: out.append('n') elif e is not None and dim == e: out.append('e') else: out.append('f') pattern = ' '.join(out) try: pattern = check_pattern(pattern) except RuntimeError: raise RuntimeError(f"Cannot infer pattern from shape: {shape}.") return pattern def outer_pattern(patterns: Iterable[str]): dims = dict(t=0, n=0, e=0, f=0) for pattern in patterns: dim_count = Counter(check_pattern(pattern, split=True)) for dim, count in dim_count.items(): dims[dim] = max(dims[dim], count) dims = [d for dim, count in dims.items() for d in [dim] * count] if 'n' in dims and 'e' in dims: raise RuntimeError("Cannot join node-level and edge-level tensors.") return ' '.join(dims) # PATTERN-BASED OPERATIONS ################################################### def _infer_backend(obj, backend: ModuleType = None): if backend is not None: return backend elif isinstance(obj, Tensor): return torch elif isinstance(obj, np.ndarray): return np raise RuntimeError(f"Cannot infer valid backed from {type(obj)}. " "Expected backends are 'torch' and 'numpy'.") def _parse_indices(backend, time_index: Union[List, ndarray, Tensor] = None, node_index: Union[List, ndarray, Tensor] = None, edge_mask: Union[List, ndarray, Tensor] = None): indices = [time_index, node_index, edge_mask] if backend is torch: for i, index in enumerate(indices): if index is not None: if not isinstance(index, Tensor): index = torch.as_tensor(index) if index.ndim == 1 and index.dtype is torch.bool: index = index.nonzero(as_tuple=True)[0] indices[i] = index elif backend is np: for i, index in enumerate(indices): if index is not None: index = np.asarray(index) if index.ndim == 1 and index.dtype == bool: index = index.nonzero()[0] indices[i] = index return indices def _get_select_fn(backend): select = None if backend is np: def select(obj, index, dim): return obj.take(index, dim) elif backend is torch: def select(obj, index, dim): return obj.index_select(dim, index) return select def _get_expand_fn(backend): expand = None if backend is np: def expand(obj, size, dim): obj = np.expand_dims(obj, dim) if size > 1: obj = obj.repeat(size, dim) return obj elif backend is torch: def expand(obj, size, dim): obj = obj.unsqueeze(dim) shape = [size if i == dim else -1 for i in range(obj.ndim)] return obj.expand(shape) return expand def take(x: Union[np.ndarray, torch.Tensor], pattern: str, time_index: Union[List, ndarray, Tensor] = None, node_index: Union[List, ndarray, Tensor] = None, edge_mask: Union[List, ndarray, Tensor] = None, *, backend: ModuleType = None): backend = _infer_backend(x, backend) dims = check_pattern(pattern, split=True) select = _get_select_fn(backend) time_index, node_index, edge_index = _parse_indices(backend, time_index=time_index, node_index=node_index, edge_mask=edge_mask) # assume that 't' can only be first dimension, then allow multidimensional # temporal indexing pad_dim = 0 if dims[0] == 't': pad_dim = 1 if time_index is not None: # time_index can be multidimensional pad_dim = time_index.ndim # pad pattern with 'batch' dimensions dims = ['b'] * (pad_dim - 1) + dims x = x[time_index] # broadcast array/tensor to pattern according to backend for pos, dim in list(enumerate(dims))[pad_dim:]: if dim == 'n' and node_index is not None: x = select(x, node_index, pos) elif dim == 'e' and edge_index is not None: x = select(x, edge_index, pos) return x def broadcast(x: Union[np.ndarray, torch.Tensor], pattern: str, time_index: Union[List, ndarray, Tensor] = None, node_index: Union[List, ndarray, Tensor] = None, edge_mask: Union[List, ndarray, Tensor] = None, *, t: int = 1, n: int = 1, e: int = 1, f: int = 1, backend: ModuleType = None): # check patterns left, rght = pattern.split('->') left_dims = check_pattern(left, split=True) rght_dims = check_pattern(rght, split=True) if not set(left_dims).issubset(rght_dims): raise RuntimeError(f"Shape {left_dims} cannot be " f"broadcasted to {rght.strip()}.") select = _get_select_fn(backend) expand = _get_expand_fn(backend) time_index, node_index, edge_index = _parse_indices(backend, time_index=time_index, node_index=node_index, edge_mask=edge_mask) # build indices and default values for broadcasting dim_map = dict(t=t if time_index is None else len(time_index), n=n if node_index is None else len(node_index), e=e if edge_index is None else len(edge_index), f=f) # assume that 't' can only be first dimension, then allow multidimensional # temporal indexing pad_dim = 1 if rght_dims[0] == 't' else 0 if time_index is None: # n f -> t n f ==> dim_map['t'] n f if rght_dims[0] == 't' and left_dims[0] != 't': x = expand(x, dim_map['t'], 0) left_dims = ['t'] + left_dims else: pad_dim = time_index.ndim # t n f -> t n f ==> t[time_index] n f if left_dims[0] == 't': x = x[time_index] # n f -> t n f ==> t[time_index] n f elif rght_dims[0] == 't' and left_dims[0] != 't': for p in range(pad_dim): x = expand(x, time_index.size(p), p) left_dims = ['t'] + left_dims left_dims = ['b'] * (pad_dim - 1) + left_dims rght_dims = ['b'] * (pad_dim - 1) + rght_dims # broadcast array/tensor to pattern according to backend for pos, rght_dim in list(enumerate(rght_dims))[pad_dim:]: left_dim = left_dims[pos] if pos < len(left_dims) else None if left_dim != rght_dim: x = expand(x, dim_map[rght_dim], pos) left_dims.insert(pos, rght_dim) elif rght_dim == 'n' and node_index is not None: x = select(x, node_index, pos) elif rght_dim == 'e' and edge_index is not None: x = select(x, edge_index, pos) return x