Ridge Regression
Introduction
In this post, we implement the Ridge regression estimator described in the paper
- Allan Raventós, Mansheej Paul, Feng Chen, and Surya Ganguli. Pretraining task diversity and the emergence of non-Bayesian in-context learning for regression.
- Large Language Models and Transformers Workshop, Simon Institute for the Theory of Computing, August 2023.
Data distribution
For each batch element, we generate the data as follows \begin{equation} y = {\bf{w}}^T {\bf{x}} + \epsilon \end{equation}
- \({\bf{w}} \in \mathbb{R}^D\), \({\bf{x}} \in \mathbb{R}^D\).
- \({\bf{x}} \sim \mathcal{N}({\bf 0}, {\bf{I}}_D)\), which means \(E[\bf{x}]=0\) and \(E[{\bf xx}^T]={\bf I}_D\).
- “Task vector”: \({\bf w} \sim \mathcal{T}_{true} = \mathcal{N}({\bf 0}, {\bf I}_D)\), and the test data is generated according to this model.
- observation noise: \(\epsilon \sim \mathcal{N}(0, \sigma^2)\), which means \(E[\epsilon]=0\), and \(E[\epsilon^2]= \sigma^2\)
- \(K\) = the number of examples \(({\bf x}, y)\) generated for the prompt in each batch element.
- The last one will become a query.
But, for a pre-training dataset, we sample \(\{ {\bf w}^{(1)}, {\bf w}^{(2)}, \cdots, {\bf w}^{(M)} \}\) once from \(\mathcal{T}_{true}\)
- Each batch element is generated using one of the $M$ tasks \(\{ {\bf w}^{(1)}, {\bf w}^{(2)}, \cdots, {\bf w}^{(M)} \}\).
In [Raventos et al. 2023], the distrubtion of \({\bf w}\) in the pretraining data is denoted by \begin{equation} \mathcal{T}_{pretrain} = \mathcal{U} ({\bf w}^{(1)}, {\bf w}^{(2)}, \cdots, {\bf w}^{(M)} ) \end{equation}
For each batch element, a “prompt” consists of a bunch of \(K-1\) examples \(({\bf x}, y)\) generated from one \(\bf{w}\), and a query \({\bf{x}}_{query}\).
Given a prompt, we want to predict what \(\bf{w}\) would be, denoted by \(\hat{\bf{w}}\), and predict the output \(\hat{y}\) corresponding to the query \({\bf{x}}_{query}\), i.e.
\begin{equation} \hat{y} = \hat{\bf{w}}^T {\bf{x}}_{query}. \end{equation}
To quantify the accuracy of the predictions, the MSE is used; the average squared difference between the true \(y\) and its prediction \(\hat{y}\).
Ridge regression estimator
The Ridge regression estimator \({\bf w}_{ridge}\) is derived assuming that
- we know how \(y\) is generated from ${\bf x}$, i.e. \(y={\bf w}^T{\bf x} + \epsilon\),
- we do not know \({\bf w}\) or \(\epsilon\).
- But, we know the stastical distribution that generates the data \(({\bf x},y)\) that we observe
- \[{\bf w} \sim \mathcal{T}_{true}=\mathcal{N}({\bf 0}, {\bf I}_D)\]
- \[\epsilon \sim \mathcal{N}(0, \sigma^2)\]
So, the performance of ridge regression is what we would achieve, had we known that \({\bf w} \sim \mathcal{T}_{true}=\mathcal{N}({\bf 0}, {\bf I}_D)\) and \(\epsilon \sim \mathcal{N}(0, \sigma^2)\).
Since we generate the test data according to \({\bf w} \sim \mathcal{T}_{true}=\mathcal{N}({\bf 0}, {\bf I}_D)\) and \(\epsilon \sim \mathcal{N}(0, \sigma^2)\), the ridge regressor is the best predictor in terms of minimizing the MSE.
We will evaluate the accuracy of predicted \(y_{K}\) in response to the \({\bf x}_{query}\), given the in-context learning examples in the prompt
\(\{ ({\bf x}_1, y_1), ({\bf x}_2, y_2), \cdots, ({\bf x}_{K-1}, y_{K-1}) \}\).
Implementation
For each batch element, we want to compute the equation (3) of [Raventos et al. 2023], \begin{equation} {\bf w}_{ridge} = (X^T X + \sigma^2 {\bf I}_D)^{-1} X^T {\bf y} \end{equation}
- $X$ is a \((K-1) \times D\) matrix:
- \({\bf y}\) is a \((K-1) \times 1\) vector:
In the PyTorch code, we use 3D-tensors of the following dimensions
- \(X\) = (batch_size, K-1, D)
- \({\bf y}\) = (batch_size, K-1, 1)
Now, \(X\) denotes a 3D-tensor, and \(X_{b,:,:}\) a matrix corresponding to the batch element \(b\) of the 3D tensor.
We will write some equations that help us implement the ridge estimator using einsum function.
To compute the Gram matrix for \(X_{b,:,:}\), where \(b\) is the batch element index,
\[[X_{b,:,:}^TX_{b,:,:}]_{i,k} = \sum_j (X_{b,:,:}^T)_{i,j} X_{b,j,k}\] \[C_{b, :, :} = X_{b, :, :}^T X_{b, :, :} + \sigma^2 {\bf I}_D\]To compute \((X^T X + \sigma^2 {\bf I}_D)^{-1} X^T\) for each batch element,
\[[C_{b,:,:}^{-1} X_{b,:,:}^T]_{i,k} = \sum_j [C_{b,:,:}^{-1}]_{i,j} [X_{b,:,:}^T]_{j,k}\]Finally, to compute the \({\bf w}_{ridge}\) for each batch element,
\[w_{ridge}[b, i] = [C_{b,:,:}^{-1} X_{b,:,:}^T {\bf y}_{b,:, 0}]_{i} = \sum_{j}[C_{b,:,:}^{-1} X_{b,:,:}^T]_{i,j} {\bf y}_{b,j, 0}\]Predict \(y\) using the ridge regressor \({\bf w}_{ridge}\) for each batch element:
\[\hat{y} = {\bf w}_{ridge}^T {\bf x}_{query}\]"""
x = (batch_size, K, D)
y = (batch_size, K, 1)
"""
def ridge(x, y, M, D, K, sigma_sq):
X = x[:, 0:K-1,:] # batch_size, K-1, D
I = sigma_sq * torch.eye(D).cuda()
Cxx = torch.einsum('bij, bjk -> bik', X.permute([0,2,1]), X) + I[None,:,:] # X'*X
Cxx_inv = torch.linalg.inv(Cxx) # batch_size, D, D
Cxx_inv_X = torch.einsum('bij, bjk -> bik', Cxx_inv, X.permute([0,2,1])) # batch_size, D, K-1
# w_ridge = (X'X + sigma_sq*I)^-1 X'y
w_ridge = torch.einsum('bij, bj -> bi', Cxx_inv_X, y[:,0:K-1,0]) # batch_size, D
query = x[:, K-1,:]# batch_size, D
yhat = torch.sum( w_ridge*query, dim=1).unsqueeze(1) # batch_size, 1
return yhat
Jupyter Notebook
Some portion of a Basic MNIST Example code is used.
Observations
We observe that when \(K=2\), the MSE is larger than when \(K=16\). This is because when \(K-1 < D\), we have fewer equations than the number of unknowns, which is the same as the dimensionality of \({\bf w}\).
\[y_1 = {\bf w}^T {\bf x}_1 + \epsilon_1\] \[y_2 = {\bf w}^T {\bf x}_2 + \epsilon_2\] \[\ \ \ \ \ \ \ \vdots\] \[y_{K-1} = {\bf w}^T {\bf x}_{K-1} + \epsilon_{K-1}\]