from typing import Optional, Sequence, Union
import torch
from einops import rearrange
from torch import Tensor, nn
from tsl.utils import ensure_list
from ...blocks import MLP, LinearReadout
from ...layers import NodeEmbedding
from .. import BaseModel
[docs]class STIDModel(BaseModel):
"""The Spatial-Temporal Identity (STID) model from the paper
`"Spatial-Temporal Identity: A Simple yet Effective Baseline for
Multivariate Time Series Forecasting" <https://arxiv.org/abs/2208.05233>`_
(Shao et al., CIKM 2022 short paper).
Args:
input_size (int): Size of the input.
n_nodes (int): Number of nodes.
window (int): Size of the input window (this model cannot process
sequences of variable length).
horizon (int): Forecasting steps.
n_exog_emb (int or list): Number of embeddings to be set for optional
exogenous variables. Can be an integer or a list of integers where
each element is the cardinality of each related covariate.
(default: :obj:`None`)
output_size (int): Size of the output. If :obj:`None`, then defaults to
:obj:`input_size`.
(default: :obj:`None`)
hidden_size (int): Number of hidden units in each hidden layer.
(default: :obj:`32`)
n_layers (int): Number of layers in the MLP encoder.
(default: :obj:`3`)
dropout (float): Dropout probability in MLP hidden layers.
(default: :obj:`0.15`)
"""
def __init__(self,
input_size: int,
n_nodes: int,
window: int,
horizon: int,
n_exog_emb: Union[Sequence[int], int] = None,
output_size: int = None,
hidden_size: int = 32,
n_layers: int = 3,
dropout: float = 0.15):
super().__init__()
self.input_size = input_size
self.n_nodes = n_nodes
self.window = window
self.horizon = horizon
self.output_size = output_size or input_size
self.node_emb = NodeEmbedding(n_nodes, hidden_size)
mlp_size = 2 * hidden_size
# temporal embeddings
if n_exog_emb is not None:
n_exog_emb = ensure_list(n_exog_emb)
self.exog_embs = nn.ModuleList(
[NodeEmbedding(size, hidden_size) for size in n_exog_emb])
mlp_size += len(n_exog_emb) * hidden_size
self.exog_size = n_exog_emb
# embedding layer
self.input_encoder = nn.Linear(input_size * window, hidden_size)
# encoding
self.mlp_list = nn.ModuleList([
MLP(input_size=mlp_size,
hidden_size=mlp_size,
output_size=mlp_size,
n_layers=1,
activation="relu",
dropout=dropout) for _ in range(n_layers)
])
# regression
self.readout = LinearReadout(mlp_size, self.output_size, horizon)
self.reset_parameters()
[docs] def reset_parameters(self):
with torch.no_grad():
nn.init.xavier_uniform_(self.node_emb.emb)
for exog_emb in self.exog_embs:
nn.init.xavier_uniform_(exog_emb.emb)
self.input_encoder.reset_parameters()
for mlp in self.mlp_list:
mlp.reset_parameters()
self.readout.reset_parameters()
[docs] def forward(self, x: Tensor, u: Optional[Tensor] = None) -> Tensor:
"""
Args:
x (Tensor): The input data.
u (Tensor, optional): Optional index of exogenous variables. Each
channel contains the index for the corresponding embeddings.
If :obj:`u` has a time dimension, then it must be of the same
length as :obj:`x` and equal to :obj:`self.window`; otherwise,
it must be synchronized with the last step of :obj:`x`.
(default: :obj:`None`)
Shapes:
x: :math:`(B, T, N, F)`, where :math:`B` is the batch size,
:math:`T` is the number of time steps in the lookback window,
:math:`N` is the number of nodes, and :math:`F` is the number
of features/channels.
u: :math:`(B, [T,] F)`, where :math:`B` is the batch size,
:math:`T` is the (optional) number of time steps in the
lookback window, :math:`F` is the number of covariates.
"""
# x: [b t n f]
# u: [b t f]
b = x.size(0) # batch size
# flat time dimension
assert x.size(1) == self.window
x = rearrange(x, 'b s n f -> b n (s f)')
h = self.input_encoder(x) # h: b n f
n_emb = self.node_emb(expand=(b, -1, -1)) # emb: b n f
z = [h, n_emb]
if u is not None:
assert (self.exog_size is not None and u.dim() <= 3
and u.size(-1) == len(self.exog_size))
if u.dim() == 3:
assert u.size(1) == self.window
u = u[:, -1] # select only last step: b t f -> b 1 f
for u_idx, u_emb in enumerate(self.exog_embs):
t_emb = u_emb(expand=(-1, self.n_nodes, -1),
node_index=u[..., u_idx])
z.append(t_emb)
z = torch.cat(z, dim=-1)
# encoding
for mlp in self.mlp_list:
z = mlp(z) + z
# regression
out = self.readout(z)
return out