In [1]:
import torch

Matrix multiplication¶

In [2]:
X = torch.randn(4,3)
Y = torch.randn(3,2)
print("X=", X, "\nY=",Y)
X= tensor([[ 0.1742,  0.8217,  0.8242],
        [ 0.0720,  1.7807,  1.5268],
        [ 0.5374,  0.6169, -0.7690],
        [ 0.1421,  0.7180, -0.3730]]) 
Y= tensor([[ 0.1103, -0.9541],
        [-0.1714,  0.6822],
        [ 0.0751,  1.2649]])
In [3]:
Z=torch.einsum('ij, jk -> ik', X, Y)
print("Z=", Z)
Z= tensor([[-0.0597,  1.4370],
        [-0.1826,  3.0774],
        [-0.1042, -1.0646],
        [-0.1354, -0.1175]])

Gram matrix¶

In [4]:
# Gram matrix computation
# G=X'X
batch_size=16
X = torch.randn(batch_size, 4, 2) 

G = torch.einsum('bji, bjk -> bik', X, X) 
G_alt = torch.einsum('bij, bjk -> bik', X.permute([0,2,1]), X)

G.shape
Out[4]:
torch.Size([16, 2, 2])
In [5]:
torch.sum( (G-G_alt)**2 )
Out[5]:
tensor(0.)

The innerproduct of the two columns of the $4\times 2$ matrix $X_{3,:,:}$ is $$G_{3,0,1} = X_{3,:,0}^T X_{3,:,1}$$

In [6]:
print( torch.sum( X[3, :, 0]*X[3, :, 1] ), '\t', G[3,0,1] )
tensor(3.0322) 	 tensor(3.0322)