Source code for tsl.nn.layers.norm.norm
import torch
from torch import Tensor, nn
from .batch_norm import BatchNorm
from .instance_norm import InstanceNorm
from .layer_norm import LayerNorm
[docs]class Norm(torch.nn.Module):
r"""Applies a normalization of the specified type.
Args:
in_channels (int): Size of each input sample.
"""
def __init__(self, norm_type, in_channels, **kwargs):
super().__init__()
self.norm_type = norm_type
self.in_channels = in_channels
if norm_type == 'instance':
norm_layer = InstanceNorm
elif norm_type == 'batch':
norm_layer = BatchNorm
elif norm_type == 'layer':
norm_layer = LayerNorm
elif norm_type == 'none':
norm_layer = nn.Identity
else:
raise NotImplementedError(
f'"{norm_type}" is not a valid normalization option.')
self.norm = norm_layer(in_channels, **kwargs)
def forward(self, x: Tensor) -> Tensor:
""""""
return self.norm(x)
def __repr__(self):
return (f'{self.__class__.__name__}({self.norm_type},'
f' {self.in_channels})')