Recursive transformers: Aiding DL models to better extrapolate on unseen data

Name of Project: ReAct - Recurrence for Attention (Yes I’m looking for cooler names, suggestions welcome! :rocket:)

Proposal in one sentence: Attempting to explore how recurrence as a prior can help Transformers extrapolate to unseen datapoints during inference/deployment

Description of the project and what problem it is solving: Current DL/AI models can’t extrapolate well to datapoints unseen during training, a capability which arises from scale. By imbuing recurrent priors, I aim to replicate the same ability for much cheaper and smaller models which can be deployed in a practical setting and thus are more feasible for consumers to train and research on without requiring a cluster of expensive GPUs.

It is highly experimental and technical, so please let me know in the comments what parts I should elaborate on - otherwise I’d be filling pages here with papers and explanations.

Grant Deliverables:
choose deliverables you can complete in a month’s worth of part-time work

  • PoC of extrapolation on Arithmetic task (simple addition) outside 1-2 digits from its training dataset.
  • In 3-4 months, being able to scale the above and empirically show model can extrapolate to distinct OOD samples

Spread the Love

If you plan to use some of the funds to reward contributions from other community members, please describe your desired roles/skillsets e.g. looking for a data scientist to work with me to develop a machine learning model. If successful, this role will be advertised on the Algovera opportunities board and DeWork.

Squad

Squad Lead:

  • Twitter handle: N/A
  • discord handle: Awesome_Ruler_007#7922

Additional notes for proposals

  • Here’s the WandB report for progress tracking: WandB Notes Dashboard
    → All Code, Configs, Notes, Ideas, Model Checkpoints, Logs and performance results are synced on WandB automatically

  • Grant Money would go for hardware and GPU-hours for training models. Since every iteration is expensive, I try to work as effectively as possible but there’s a limit here. I believe the grant can help be accelerate my research and hopefully get a PoC out to confirm whether this direction is worth pursuing.


I have a brief writeup here that is semi-technical. As again, I can go in-depth in explanations as warranted - lmk below.

Basically, its exploring how recurrence as a prior can aid in OOD generalization and allowing for dynamic compute and memory at inference time (dataset ratios and tokenization matter too but broadly these are key areas)

I’m basing my exploration/experimentation off this interesting result: https://twitter.com/tomgoldsteincs/status/1596210043019722752?cxt=HHwWgIDQkfek8KYsAAAA
This paper demonstrated how CNNs can OOD extrapolate at inference time just by scaling iterations. They leverage a CNN trained to solve 9x9 mazes, which extrapolates to 801x801 mazes even - while conventional convnets barely generalize to 13x13 sized ones. (While the method looks extremely similar to diffusion, the difference is that we frame the entire process as a single Markovian chain and don’t denoise, thus saving on compute)

Interestingly, I’ve already tried some preliminary experiments. Turns out, learning f(x)=x is much faster for this arch than memorizing a simple sequence! So the hypothesis of recurrence being a strong prior against overfitting and promoting generalization has some support empirically.

(Very) Initial runs on addition seem to generalize decently on a couple of digits outside training distribution. Further experimentation would be needed to see whether I can push that to its limits and generalize for more than a handful of digits.

My hypothesis is that recurrence introduces a prior for generalization and combating memorization; However, seeing results and other works, It seems more a property of employing differential equations. Still, empirical results are a bit lacking here so that’s what I’d be focusing on! :slight_smile:

4 Likes

This seems super interesting, although I cannot access the WandB page and would like more technical details on it. Anyway I can get a preprint or some sort of a summary on what you exactly did and the results?

1 Like

Hey! sorry apparently that’s the wrong link. This is the correct WandB link

Right now, I’m still experimenting and a few insights are jotted down on the above WandB page as notes. I would write-up something about this when its done (I should’ve put that as a deliverable I guess :thinking:). As of right now, I don’t plan to publish anything to conferences/journals yes - but would reconsider if I get some interesting results :wink:

For a more technical read, this is the paper which seeded the whole idea. Let me know concepts that you find foreign and I can provide you a paper list of reading.

(There’s also a list of relevant paper at the end of the WandB if you want to check them out)

1 Like

Very interesting. Do you plan to add in the recurrence prior by cross attention? Also will be interesting to compare vs longer context size transformers since intuitively I feel like transformers having attention over all parts avoids the forgetting issue.

1 Like

Heya! No, we don’t need to incorporate cross-attention between the intermediate latents. The problem here is being reframed as solving a PDE (partial differential equation) and the model simply learns its Jacobian which is applied onto the latent multiple times. This is equivalent to numerically computing a diff eq. discretely with some $\Delta{T}$ - which is $\Delta{T} \propto{\frac{1}{i}}$ where $i \rightarrow \text{iterations}$. This makes intuitive sense in that increasing iterations directly improves the precision, as you take smaller and smaller steps :slight_smile:

Cross-attention could be used here, but its main use is to attend to distinct vector subspaces. :thinking: I could modify the recall mechanism (basically concat(og_input, intermediate_tens)) to cross-attend, it would work as well. But then I’d also be paying that with extra compute…
Not to mention that there would be lesser gradient stability (as the recall mechanism is just a giant skip-connection).

I might give the idea a try though as soon as I get ALiBi to work!

In the LeGO paper, they did have more-than-enough context length for the task at hand. They also used pure self-attention, no approximations. But it still couldn’t generalize much. True, there was a gradient for generalization but it would take an insane amount of data and scale to reach that point - which is often infeasible in real world scenarios, hence my project :slight_smile:

LMK if you have any more questions! :rocket:

Very interesting. Thanks for the reply! Let me read the paper and get back to you with qs.

The Twitter thread is surprising where the system learns to solve chess puzzles but can’t play a game of chess.

There some debate about what extrapolation means in the context of ML. For example if the system learns something about the geometry of the topology of the training data then it might look like it is extrapolating, if the geometry matches the unsampled parts of the distribution.

I wonder if an approach like ANML be more effective in terms of making use of recurrence?

It is very much in the spirit of Algovera to experiment and share new ideas, so all power to you!

1 Like