Source code for tsl.nn.blocks.decoders.linear_readout
from einops.layers.torch import Rearrange
from torch import Tensor, nn
[docs]class LinearReadout(nn.Module):
r"""Simple linear readout for multistep forecasting.
If the input representation has a temporal dimension, this model will simply
take the representation corresponding to the last step.
Args:
input_size (int): Input size.
output_size (int): Output size.
horizon (int): Number of steps to predict.
(default: :obj:`1`)
bias (bool): Whether to add a learnable bias.
(default: :obj:`True`)
"""
def __init__(self,
input_size: int,
output_size: int,
horizon: int = 1,
bias: bool = True):
super(LinearReadout, self).__init__()
self.readout = nn.Linear(input_size, output_size * horizon, bias=bias)
self.rearrange = Rearrange('b n (h f) -> b h n f',
f=output_size,
h=horizon)
def reset_parameters(self) -> None:
""""""
self.readout.reset_parameters()
def forward(self, h: Tensor) -> Tensor:
""""""
# h: [batches (steps) nodes features]
if h.dim() == 4:
# take last step representation
h = h[:, -1]
out = self.readout(h)
return self.rearrange(out)