Einsum
Einsum can make multi-dimensional linear algebraic array operations simple, with unified syntax. It is simpler to describe it using math than words. I will give a few examples of its usage in PyTorch.
Notation:
- The entry in row \(i\), column \(j\) of matrix \(A\) is denoted by \(A_{i,j}\).
- An entry in a three diemensional tensor \(X\) will be similarly denoted \(X_{b, i, j}\).
- \(X_{b, :, :}\) is a matrix corresponding to the batch element \(b\) of the 3D tensor \(X\).
We have matrices $X$ and $Y$, and want to implement $Z=XY$, where
\begin{equation} Z_{i,k} = \sum_{j} X_{i,j} Y_{j,k} \end{equation}
In PyTorch, this can be implemented as
Z = torch.einsum('ij, jk -> ik', X, Y)
Now, suppose that \(X\) and \(Y\) represent 3-D tensors, where the first dimension represents a batch dimension, and we want to perform matrix multiplication over the last two dimenensions, for each batch element. \begin{equation} Z_{b, i,k} = \sum_{j} X_{b, i,j} Y_{b, j,k} \end{equation}
In PyTorch, this can be implemented as
Z = torch.einsum('bij, bjk -> bik', X, Y)
In the final example, we compute the Gram matrix \(G_{b,:,:}=X_{b,:,:}^T X_{b,:,:}\) over the last two dimensions of 3-D array \(X\). For each batch element \(b\), we have a matrix \(X_{b,:,:}\), and compute its Gram matrix \begin{equation} G_{b,i,j} = \sum_j X_{b,j,i} X_{b,j,k} \end{equation}
So, we can implement either
G = torch.einsum('bji, bjk -> bik', X, X)
or
G = torch.einsum('bij, bjk -> bik', X.permute(0,2,1), X)