Functional#

expand_then_cat(tensors: Union[Tuple[Tensor, ...], List[Tensor]], dim: int = -1) Tensor[source]#

Match the dimensions of tensors in the input list and then concatenate.

Parameters:
  • tensors (list) – Tensors to concatenate.

  • dim (int) – Dimension along which to concatenate. (default: -1)

gated_tanh(input: Tensor, dim: int = -1) Tensor[source]#

The gated tanh unite. Computes:

\[\text{GatedTanh}(a, b) = \text{tanh}(a) \otimes \sigma(b)\]

where input is split in half along dim to form \(a\) and \(b\), \(\text{tanh}\) is the hyperbolic tangent function, \(\sigma\) is the sigmoid function and \(\otimes\) is the element-wise product between matrices.

Parameters:
  • input (Tensor) – Input tensor.

  • dim (int) – Dimension on which the input is split. (default: -1)

sparse_softmax(src: Tensor, index: Optional[Tensor] = None, ptr: Optional[Tensor] = None, num_nodes: Optional[int] = None, dim: int = -2) Tensor[source]#

Extension of softmax() with index broadcasting to compute a sparsely evaluated softmax over multiple broadcast dimensions.

Given a value tensor src, this function first groups the values along the first dimension based on the indices specified in index, and then proceeds to compute the softmax individually for each group.

Parameters:
  • src (Tensor) – The source tensor.

  • index (Tensor, optional) – The indices of elements for applying the softmax. (default: None)

  • ptr (Tensor, optional) – If given, computes the softmax based on sorted inputs in CSR representation. (default: None)

  • num_nodes (int, optional) – The number of nodes, i.e., max_val + 1 of index. (default: None)

  • dim (int) – The dimension on which to normalize, i.e., the edge dimension. (default: -2)

sparse_multi_head_attention(q: Tensor, k: Tensor, v: Tensor, index: Tensor, dim_size: Optional[int] = None, dropout_p: float = 0.0)[source]#

Computes multi-head, scaled, dot product attention on query, key and value tensors, applying dropout if a probability greater than 0 is specified. Index specifies for each query in q the belonging sequence in the original batched, dense tensor. Returns a tensor pair containing attended values and attention weights.

Parameters:
  • q (Tensor) – Query tensor. See Shape section for shape details.

  • k (Tensor) – Key tensor. See Shape section for shape details.

  • v (Tensor) – Value tensor. See Shape section for shape details.

  • index (Tensor) – Tensor containing mask values to be added to calculated attention. May be 2D or 3D; see Shape section for details.

  • dim_size (int, optional) – The batched target length sequence, i.e. max_val + 1 of index. (default: None)

  • dropout_p (float) – dropout probability. If greater than 0, then dropout is applied. (default: 0)

Shapes:
  • q\((S, H, E)\) where S is sparsed dimension, H is the number of heads, and E is embedding dimension.

  • k\((S, H, E)\) where S is sparsed dimension, H is the number of heads, and E is embedding dimension.

  • v\((S, H, O)\) where S is sparsed dimension, H is the number of heads, and O is output dimension.

  • index\((S)\) where S is sparsed dimension.

  • dim_size – must be \((B \times Nt)\)

  • Output – attention values have shape \((B, Nt, E)\); attention weights have shape \((S, H)\)