Functional#
- expand_then_cat(tensors: Union[Tuple[Tensor, ...], List[Tensor]], dim: int = -1) Tensor [source]#
Match the dimensions of tensors in the input list and then concatenate.
- gated_tanh(input: Tensor, dim: int = -1) Tensor [source]#
The gated tanh unite. Computes:
\[\text{GatedTanh}(a, b) = \text{tanh}(a) \otimes \sigma(b)\]where
input
is split in half alongdim
to form \(a\) and \(b\), \(\text{tanh}\) is the hyperbolic tangent function, \(\sigma\) is the sigmoid function and \(\otimes\) is the element-wise product between matrices.- Parameters:
input (Tensor) – Input tensor.
dim (int) – Dimension on which the input is split. (default: -1)
- sparse_softmax(src: Tensor, index: Optional[Tensor] = None, ptr: Optional[Tensor] = None, num_nodes: Optional[int] = None, dim: int = -2) Tensor [source]#
Extension of
softmax()
with index broadcasting to compute a sparsely evaluated softmax over multiple broadcast dimensions.Given a value tensor
src
, this function first groups the values along the first dimension based on the indices specified inindex
, and then proceeds to compute the softmax individually for each group.- Parameters:
src (Tensor) – The source tensor.
index (Tensor, optional) – The indices of elements for applying the softmax. (default:
None
)ptr (Tensor, optional) – If given, computes the softmax based on sorted inputs in CSR representation. (default:
None
)num_nodes (int, optional) – The number of nodes, i.e.,
max_val + 1
ofindex
. (default:None
)dim (int) – The dimension on which to normalize, i.e., the edge dimension. (default:
-2
)
- sparse_multi_head_attention(q: Tensor, k: Tensor, v: Tensor, index: Tensor, dim_size: Optional[int] = None, dropout_p: float = 0.0)[source]#
Computes multi-head, scaled, dot product attention on query, key and value tensors, applying dropout if a probability greater than 0 is specified. Index specifies for each query in q the belonging sequence in the original batched, dense tensor. Returns a tensor pair containing attended values and attention weights.
- Parameters:
q (Tensor) – Query tensor. See Shape section for shape details.
k (Tensor) – Key tensor. See Shape section for shape details.
v (Tensor) – Value tensor. See Shape section for shape details.
index (Tensor) – Tensor containing mask values to be added to calculated attention. May be 2D or 3D; see Shape section for details.
dim_size (int, optional) – The batched target length sequence, i.e.
max_val + 1
ofindex
. (default:None
)dropout_p (float) – dropout probability. If greater than 0, then dropout is applied. (default: 0)
- Shapes:
q – \((S, H, E)\) where S is sparsed dimension, H is the number of heads, and E is embedding dimension.
k – \((S, H, E)\) where S is sparsed dimension, H is the number of heads, and E is embedding dimension.
v – \((S, H, O)\) where S is sparsed dimension, H is the number of heads, and O is output dimension.
index – \((S)\) where S is sparsed dimension.
dim_size – must be \((B \times Nt)\)
Output – attention values have shape \((B, Nt, E)\); attention weights have shape \((S, H)\)