Source code for tsl.nn.blocks.encoders.conditional

from torch import nn as nn
from torch.nn import Module
from torch.nn import functional as F

from tsl.nn.layers.base import GatedTemporalConv, TemporalConv
from tsl.nn.utils import get_layer_activation


[docs]class ConditionalBlock(Module): r"""Simple layer to condition the input on a set of exogenous variables. .. math:: \text{CondBlock}(\mathbf{x}, \mathbf{u}) = \left(\text{MLP}_x(\mathbf{x})\right) + \left(\text{MLP}_u(\mathbf{u})\right) Args: input size (int): Input size. exog_size (int): Size of the covariates. output_size (int): Output size. dropout (float, optional): Dropout probability. skip_connection (bool, optional): Whether to add a parametrized residual connection. (default: `False`). activation (str, optional): Activation function. """ def __init__(self, input_size, exog_size, output_size, dropout=0., skip_connection=False, activation='relu'): super().__init__() self.d_in = input_size self.d_u = exog_size self.d_out = output_size self.activation = get_layer_activation(activation)() self.dropout = nn.Dropout(dropout) # inputs module self.input_affinity = nn.Linear(self.d_in, self.d_out) self.condition_affinity = nn.Linear(self.d_u, self.d_out) self.out_inputs_affinity = nn.Linear(self.d_out, self.d_out) self.out_cond_affinity = nn.Linear(self.d_out, self.d_out, bias=False) if skip_connection: self.skip_conn = nn.Linear(self.d_in, self.d_out) else: self.register_parameter('skip_conn', None) def forward(self, x, u=None): """""" if u is None: x, u = x # *, features # inputs block out = self.activation(self.input_affinity(x)) # conditions block conditions = self.activation(self.condition_affinity(u)) out = self.out_inputs_affinity(out) + self.out_cond_affinity( conditions) out = self.dropout(self.activation(out)) if self.skip_conn is not None: out = self.skip_conn(x) + out return out
[docs]class ConditionalTCNBlock(nn.Module): r"""Mirrors the architecture of :class:`tsl.nn.blocks.encoders.ConditionalBlock` but using temporal convolutions instead of affine transformations. Args: input_size (int): Size of the input. exog_size (int): Size of the exogenous variables. output_size (int): Size of the output. kernel_size (int): Size of the convolution kernel. dilation (int): Spacing between kernel elements. dropout (float): Dropout probability. gated (bool): Whether to use `gated tanh` activations. activation (str, optional): Activation function. weight_norm (bool): Whether to apply weight normalization to the parameters of the filter. channel_last (bool): If :obj:`True` input data must follow the `b t n f` layout, assumes `b f n t` otherwise. skip_connection (bool): If :obj:`True` adds a parametrized skip connection from the input to the output. """ def __init__(self, input_size, exog_size, output_size, kernel_size, dilation=1, dropout=0., gated=False, activation='relu', weight_norm=False, channel_last=True, skip_connection=False): super().__init__() if gated: # inputs module self.inputs_conv = nn.Sequential( GatedTemporalConv(input_channels=input_size, output_channels=output_size, kernel_size=kernel_size, dilation=dilation, weight_norm=weight_norm, channel_last=channel_last), nn.Dropout(dropout)) self.conditions_conv = nn.Sequential( GatedTemporalConv(input_channels=exog_size, output_channels=output_size, kernel_size=kernel_size, dilation=dilation, weight_norm=weight_norm, channel_last=channel_last), nn.Dropout(dropout)) else: # inputs module self.inputs_conv = nn.Sequential( TemporalConv(input_channels=input_size, output_channels=output_size, kernel_size=kernel_size, dilation=dilation, weight_norm=weight_norm), get_layer_activation(activation)(), nn.Dropout(dropout)) self.conditions_conv = nn.Sequential( TemporalConv(input_channels=exog_size, output_channels=output_size, kernel_size=kernel_size, dilation=dilation, weight_norm=weight_norm), get_layer_activation(activation)(), nn.Dropout(dropout)) self.out_input = nn.Linear(output_size, output_size) self.out_cond = nn.Linear(output_size, output_size, bias=False) self.activation = getattr(F, activation) self.dropout = nn.Dropout(dropout) if skip_connection: self.skip_conn = TemporalConv(input_size, output_size, 1, channel_last=channel_last) else: self.register_parameter('skip_conn', None) def forward(self, x, u=None): """""" if u is None: x, u = x # inputs block out = self.inputs_conv(x) # conditions block conditions = self.conditions_conv(u) out = self.out_input(out) + self.out_input(conditions) out = self.dropout(self.activation(out)) if self.skip_conn is not None: out = self.skip_conn(out) return out