Abstract
OOD (Out-of-distribution) generalization has been an unsolved problem for Deep Learning. While self-attention based variants have taken great strides in being able to recover the underlying structured representations in the data, They lack a key component if they are to achieve arbitrary generalization - adaptive computation and strong priors.
This project presents preliminary results of the attempt to generalize the ideas present in prior literature to attention-based variants. We present a modified (recurrent) architecture that is both effective and parallelizable, that displays interesting properties as well as being able to OOD extrapolate on toy tasks. We believe that incorporating strong inductive biases and priors to traditional DL models may go a long way towards eventually being able to arbitrarily extrapolate.
The biggest challenge this line of approach faces is stability. As the dynamics of the task become harder, the approximation learnt by the NN has to become more stable - otherwise they risk compounding errors with more and more iterations. This report highlights some of the ideas used in prior literature to combat this as well as integrate some novel ideas to attempt in making these models maximally stable - even if we fall short of completely solving the problem.
Related Work
This section is a rapid-fire primer on the work done by Bansal et al. and Kaizer et al. for embedding adaptive-computational properties in recurrent styled models.
Bansal and Schwarzschild in their paper demonstrated how their architecture, dubbed “Deep Thinking Network” was able to OOD generalize, solving 201x201
sized mazes despite being only trained on 9x9
ones. The network learnt the dead-end filling algorithm:
This was achieved by using the same network and applying it multiple times iteratively. Such a formulation is similar to DEQs (Deep-equilibrium Networks), but even simpler and more scalable. This is recursively performed for a maximum bound of max_iters: int
iterations. It is expected the the model learns to extrapolate both iteration-wise and length-wise, if it’s to OOD generalize at all.
This paper however utilizes vanilla CNNs inspired from a ResNet block. While the inductive biases integrated by convolutions are helpful, they are ill-suited for arbitrary seq2seq tasks. This is where ReAct
comes in.
Through this simple idea, they were able to gain OOD length-extrapolation on a variety of tasks. From Maze-solving, to prefix sums and even chess!
Another issue here is that transformers simply can’t handle unseen positional encodings - thus when OOD positional encodings enter the model, much like “SolidGoldMagiKarp” tokens, the circuits aren’t able to handle it effectively. This has been addressed in some works, which interpolate positional encodings or use alternative positional encoding schemes (like RoPe) but they still fall short of actually extrapolating.
We don’t consider “interpolating” positional encodings to be an effective way to combat length-generalization as the model actually doesn’t extrapolate - rather we adjust the inputs to be in-distribution which can be a brittle process. ReAct
on the other hand uses vanilla sinusoidal embedding and can still lenght-extrapolate. This, length extrapolation can be achieved through inductive biases/priors alone instead of resorting to unholy tricks.
Architecture
ReAct
is heavily inspired from DTNet
in its bid to integrate attention into the architecture and operate on arbitrary seq2seq tasks. This is useful as its makes the model flexible to various tasks. The architecture is kept simple to aid in scalability, following from “The Bitter Lesson”
While the original ideas was envisioned as a short weekend project to swap out the nn.Conv1d
layers to nn.MultiHeadAttention
layers, more techniques and tricks were required to get it to work. There were some interesting results - such as self-attention being next-to-useless as well as the model being resistant to overfitting/memorisation. I will elaborate on those changes below.
A Lighter Attention - LiteAttention
One of the most counterintuitive findings of this (short) project was how ineffective Multi-head self-attention is, especially for OOD generalization. The problem with MHSA is simple - it works too well!
MHSA simply can’t length-extrapolate. The circuits it learns are hardcoded to the length of the outputs it has seen during training. As a result, if for example the task is reverse_string
on an OOD input:
# In distribution:
[1, 2, 3, 0, 0, 0 ...] ==> [..0, 0, 0, 3, 2, 1]
# Out of distribution
[1, 2, 3, 4, 0, 0, ...] ==> [..0, 0, 3, 2, 1]
The model simply learns to ignore that OOD token position. This is well studied, and is believed to be because of the sinusoidal positional encodings.
To remedy this, we opted to use a simplified version of “Attention” that acts closer to an information gate. Because its simpler, we suspected it wouldn’t learn to be too overly reliant on spurious features and help towards generalization.
The main benefit of using LiteAttention
is that it works as a more explicit information gating mechanism, handles OOD sinusoidal embeddings well and doesn’t require any other heavy modifications - such as using an explicit decay mechanism like AliBi.
As explained, LiteAttention
simply replaces the matmul used in SelfAttention
with a hadamard product. We use Softmax
to bound the final attention scores, which are derived from a simple Linear
layer to create data-dependency for (hopefully) acting as a diluted proxy to In-context learn.
I doubt LiteAttention
holds up with scale. Because it was intended to explicitly remove a matmul, there is little inter-token information mixing. It’s closer to a gate used in RNNs. However, it’s able to (surprisingly) hold up towards rather complicated tasks given its simplicity. I suspect adding the data-dependency is what makes it so expressive
Implicit error correction using n+k
n+k
sampling refers to how we handle the training loop. Because we have a recurrent architecutre, we want to encourage the model to learn an error-correcting mechanism. The basic idea is simple:
We run the model for
n
times while throwing away the gradients.n
is sampled s.tn ≤ max_iters
Then, run the model for
k
iterations which will be backpropagated to improve on the lastn
iterations. This technique helps the model to be somewhat stable and converge to and equilibrium.k
is sampled s.tk ≤ max_iters - n
This loss is interpolated with the standard loss which just runs the model for max_iters
determined by $\alpha$, the interpolation co-efficient.
The actual torch
code used:
This code however has a subtle bug: notice that both $n \sim U(0, \text{max\_iters})$ and $k \sim U(1, \text{max\_iters} - n + 1)$. The sum of 2 randomly drawn variables from a uniform distribution is NOT uniform - rather its Gaussian. And because of our bounds, its resembled the truncated half of the gaussian:
To fix this, in practice we weigh the loss to roughly fit the distribution and compensate for it.
This is a hack. If someone has a better way to do this, please give me a DM. My contact details will be below.
Recall
The recall
mechanism introduced by Kaiser et al. is a simple modification. As indicated in (Figure 4), it’s a long skip connection which helps in the “Overthinking” problem. i.e, even if you perform more iterations that necessary, the model is able to recover the accuracy and not deviate too much.
That’s because there are no explicit constraint put on the architecture to be Lyapunov stable. However, due to the n + k
training loop, SGD learns to (approximately) be somewhat stable and converge to equilibrium. This may need to be reinforced for future work, but for these tasks it’s adequate.
In code, this long skip connection is represented as a concat
op. x: torch.Tensor
is the tensor obtained after embed
-ing and adding positional encodings to the input_sequence
. This is highlighted in code below:
Adversarial Perturbation
In an attempt to guide the model towards learning a more explicit error correction methodology and aid the n + k
training loop, we introduce “Adversarial Perturbation”. Effectively, when we iterate the recur_block,
we get an intermediate representation which is passed to the next iteration. That iteration is termed interim_thought
.
The “output_sequence” is obtained when we apply the output_head(interim_thought)
, and .softmax().argmax(dim=-1)
the logits.
However, we want to introduce mild errors in the final decoded sequence - such that the sequence is off by a few bits and the model learns to correct errors arising due to corrupted bits.
So the task becomes:
Find the smallest perturbation to
interim_thought
such that the decoded sequence,output_head(interim_thought).softmax().argmax()
has the few corrupted bits that we desire.
This perturbed_thought
is now fed back into the model for a few more iterations (slightly transgressing the upper bound of max_iters
) so that the model is forced to locate and correct the errors.
This can be accomplished through a number of ways. We chose to run a smaller loop were we:
- Obtain the vanilla decoded
output_sequence
and then corrupt it. - Take the corrupted sequence as the targets and make
interim_thought
trainable. These would be the parameters, andout_head
should remain static/untrainable. - Backprop using an adaptive optimizer (
AdamW
) and update for a few steps.
Now because we’re running very few steps ($\leq10$) and a small-ish learning rate, we won’t converge to our target representation. Nevertheless, in the process of gradient descent, errors would still creep in as it tries to converge to the corrupted version. We however terminate the convergence early bceause we don’t care where the errors are - just their number.
This is important for performance. Using this, a corruption of a batch of $300$ samples takes about $\sim 150\text{ms}$ on a T4. This is still a pretty heavy hit, but for the scales we’re operating at - it is next-to-nothing.
The entire code can be found here. A simplified snippet is attached below:
And here is a sample of few of the distributions generated during training, as reported in the WandB. This is for $n = 2$ and $n=4$ respectively. These are the distribution of errors on single batch of sequences as handled by the model. Due to the stochastic nature of the process, the distribution fluctuates around quite a bit: