Source code for tsl.nn.layers.norm.batch_norm

import torch
from einops import rearrange
from torch import Tensor


[docs]class BatchNorm(torch.nn.Module): r"""Applies graph-wise batch normalization. Args: in_channels (int): Size of each input sample. eps (float, optional): A value added to the denominator for numerical stability. (default: :obj:`1e-5`) momentum (float, bool): Running stats momentum. affine (bool, optional): If set to :obj:`True`, this module has learnable affine parameters :math:`\gamma` and :math:`\beta`. (default: :obj:`True`) track_running_stats (bool, optional): Whether to track stats to perform batch norm. (default: :obj:`True`) """ def __init__(self, in_channels, eps: float = 1e-5, momentum: float = 0.1, affine: bool = True, track_running_stats: bool = True): super().__init__() self.module = torch.nn.BatchNorm1d(in_channels, eps, momentum, affine, track_running_stats) def reset_parameters(self): self.module.reset_parameters() def forward(self, x: Tensor) -> Tensor: """""" b, *_ = x.size() x = rearrange(x, 'b ... n c -> (b n) c ...') x = self.module(x) return rearrange(x, '(b n) c ... -> b ... n c', b=b) def __repr__(self): return f'{self.__class__.__name__}({self.module.num_features})'