Source code for tsl.nn.blocks.decoders.multi_step_mlp_decoder

import torch
from einops import rearrange, repeat
from torch import nn

from tsl.nn.blocks.encoders.mlp import MLP


[docs]class MultiHorizonMLPDecoder(nn.Module): r"""Decoder for multistep forecasting based on the paper `"A Multi-Horizon Quantile Recurrent Forecaster" <https://arxiv.org/abs/1711.11053>`_ (Wen et al., 2018). It requires exogenous variables synchronized with the forecasting horizon. Args: input_size (int): Size of the input. exog_size (int): Size of the horizon exogenous variables. hidden_size (int): Number of hidden units. context_size (int): Number of units used to condition the forecasting of each step. output_size (int): Output channels. n_layers (int): Number of hidden layers. horizon (int): Forecasting horizon. activation (str, optional): Activation function. dropout (float, optional): Dropout probability. """ def __init__(self, input_size, exog_size, hidden_size, context_size, output_size, n_layers, horizon, activation='relu', dropout=0.): super(MultiHorizonMLPDecoder, self).__init__() global_d_out = horizon * context_size + context_size self.d_context = context_size self.horizon = horizon self.global_mlp = MLP(input_size=input_size, hidden_size=hidden_size, output_size=global_d_out, n_layers=n_layers, activation=activation, dropout=dropout) self.local_mlp = MLP(input_size=exog_size + 2 * context_size, hidden_size=hidden_size, output_size=output_size, n_layers=n_layers, activation=activation, dropout=dropout) def forward(self, x: torch.Tensor, u: torch.Tensor): """""" # x: [batch, (steps), nodes, channels] # u: [batch, horizon, (nodes), channels] # out: [batch, steps, nodes, channels] if x.dim() == 4: x = x[:, -1] n = x.size(1) if u.dim() == 3: u = repeat(u, 'b h f -> b h n f', n=n) u = rearrange(u, 'b h n f -> b n h f') out = self.global_mlp(x) global_context, contexts = torch.split( out, [self.d_context, self.horizon * self.d_context], -1) global_context = repeat(global_context, 'b n f -> b n h f', h=self.horizon) contexts = rearrange(contexts, 'b n (h f) -> b n h f', f=self.d_context, h=self.horizon) x_dec = torch.cat([contexts, global_context, u], -1) x_dec = self.local_mlp(x_dec) return rearrange(x_dec, 'b n h f -> b h n f') def reset_parameters(self): self.global_mlp.reset_parameters() self.local_mlp.reset_parameters()