Source code for tsl.nn.layers.base.embedding
import math
from typing import List, Optional, Union
import torch
from torch import Tensor, nn
from torch_geometric.typing import OptTensor
[docs]class NodeEmbedding(nn.Module):
r"""Creates a table of node embeddings with the specified size.
Args:
n_nodes (int): Number of elements for which to store an embedding.
emb_size (int): Size of the embedding.
initializer (str or Tensor): Initialization methods.
(default :obj:`'uniform'`)
requires_grad (bool): Whether to compute gradients for the embeddings.
(default :obj:`True`)
"""
def __init__(self,
n_nodes: int,
emb_size: int,
initializer: Union[str, Tensor] = 'uniform',
requires_grad: bool = True):
super(NodeEmbedding, self).__init__()
self.n_nodes = int(n_nodes)
self.emb_size = int(emb_size)
if isinstance(initializer, Tensor):
self.initializer = "from_values"
self.register_buffer('_default_values', initializer.float())
else:
self.initializer = initializer
self.register_buffer('_default_values', None)
self.emb = nn.Parameter(Tensor(self.n_nodes, self.emb_size),
requires_grad=requires_grad)
self.reset_emb()
def __repr__(self) -> str:
return "{}(n_nodes={}, embedding_size={})".format(
self.__class__.__name__, self.n_nodes, self.emb_size)
def reset_emb(self):
with torch.no_grad():
if self.initializer == 'uniform' or self.initializer is None:
bound = 1.0 / math.sqrt(self.emb.size(-1))
self.emb.data.uniform_(-bound, bound)
elif self.initializer == 'from_values':
self.emb.data.copy_(self._default_values)
else:
raise RuntimeError(
f"Embedding initializer '{self.initializer}'"
" is not supported.")
def reset_parameters(self):
self.reset_emb()
def get_emb(self):
return self.emb
def forward(self,
expand: Optional[List] = None,
node_index: OptTensor = None,
nodes_first: bool = True):
""""""
emb = self.get_emb()
if node_index is not None:
emb = emb[node_index]
if not nodes_first:
emb = emb.T
if expand is None:
return emb
shape = [*emb.size()]
view = [
1 if d > 0 else shape.pop(0 if nodes_first else -1) for d in expand
]
return emb.view(*view).expand(*expand)