Implicit Bayesian Inference in Large Language Models
This intriguing paper kept me thinking long enough for me to I decide it's time to resurrect my blogging (I started writing this during ICLR review period, and realised it might be a good idea to wait until that's concluded)
- Sang Michael Xie, Aditi Raghunathan, Percy Liang and Tengyu Ma (2021) An Explanation of In-context Learning as Implicit Bayesian Inference
I liked this paper because it relates to one of my favourite concepts and ideas: exchangeability. And it took me back to thoughts I had back in 2015 (pre-historic by deep learning standards) about leveraging exchangeable sequence models to implement giant general-purpose learning machines. In that old post I made this observation about exchangeable models:
If we had an exchangeable RNN, we could train it on multiple unsupervised learning problems over the same input space. Such system actually learns to learn. If you want to use it on a new dataset, you just feed it into the RNN, and it will give you Bayesian predictive probabilities without any additional computation. So it would be an ultimate general inference machine™.
Fast forwarding a bit, ultimate general inference machine (fortunate I trademarked it) is not actually all that far from how OpenAI's GPT-3 is sometimes branded and used. It's been demonstrated that you can repurpose them as few-shot (or in some cases zero-shot) learners in a surprising variety of tasks (Brown et al, 2020). This ability of language models to solve different tasks by feeding them cleverly designed prompts is sometimes referred to as prompt-hacking or in-context learning.
Quite honestly, I never connected these dots until I read this paper: the motivations for leveraging one single big exchangeable sequence model as a general-purpose learner, and the more recent trend of in-context learning using GPT-3. In fact, I was deeply skeptical about the latter, thinking of it as another hack that must be somehow fundamentally flawed. But this paper by Xie et al (2021) connected those dots for me, which is why I found it so fascinating, and I will never think of 'prompt hacking' or in-context learning quite the same way.
Exchangeable sequences as Implicit Learning Machines
Before talking about the paper, let me first refresh those old ideas about exchangeable sequences and implicit learning. An exchangeable sequence model is a probability distribution $p(x_1, x_2, \ldots)$ over sequences that is invariant to permutations of the tokens within the sequence, i.e. $p(x_1, x_2, \ldots, x_N) = p(x_{\pi_1}, x_{\pi_2}, \ldots, x_{\pi_N})$ for any permutation $\pi$.
The de Finetti theorem connects such sequence models to Bayesian inference, saying that any such distribution can be decomposed as a mixture over i.i.d. sequence models:
$$
p(x_1, x_2, \ldots, x_N) = \int \prod_{n=1}^N p(x_n\vert \theta) d\pi(\theta)
$$
As a consequence, the one-step-ahead predictive distribution (which predicts the next token in the sequence) also always has a decomposition as Bayesian integration:
$$
p(x_N \vert x_1, x_2, \ldots, x_{N-1}) = \int p(x_N\vert \theta) d\pi(\theta\vert x_1, \ldots, x_N),
$$
where $\pi(\theta\vert x_1, \ldots, x_N)$ is the Bayesian posterior obtained from the prior $\pi(\theta)$ via the Bayes rule:
$$
\pi(\theta \vert x_1, x_2, \ldots, x_{N-1}) \propto \pi(\theta) \prod_{n=1}^{N-1}p(x_n\vert \theta)
$$
So in this sense, if we have an exchangeable sequence model, we can think of these one-step-ahead predictive distributions as implicitly performing Bayesian inference. Crucially, this happens even if we don't know what $\theta$ is, or what our prior $\pi$ is, or what the likelihood $p(x_i\vert \theta)$ is. We don't have to explicitly specify what those components are, de Finetti theorem guarantees that they exist, so long as the predictions $p(x_N \vert x_1, x_2, \ldots, x_{N-1})$ are consistent with an exchangeable sequence model.
This thought motivated me to try and design RNNs (remember this was in pre-transformer times) that always produce exchangeable distributions by construction. This turned out to be very difficult, but the idea eventually evolved into BRUNO (named after Bruno de Finetti), a flexible meta-trained model for exchangeable data which exhibited few-shot concept learning abilities. This idea then got extended in a number of ways in Ira Korshunova's PhD thesis.
From Exchangeable sequences to Mixtures of HMMs
But GPT-3 is a language models, and clearly language tokens are not exchangeble. So whats the connection?
There are interesting extensions to the concept of exchangeability, which come with interesting generalisations of the de Finetti-type theorems. Partial exchangeability, as defined by Diaconis and Freedman (1980), is an invariance property of a distribution over sequences which guarantees that the sequence can be decomposed as a mixture of Markov-chains. Thus, one can say that a partially exchangeable process implicitly performs Bayesian inference over Markov chains, much the same way exchangeable processes can be said to be preforming inference over i.i.d. data generating processes.
In this new paper, Xie et al (2021), assume that the sequence model we work with is a mixture of hidden Markov models (HMMs). This is more general still than the partially exchangeable mixture of Markov chains of Diaconis and Freedman. I don't know if Mixtures of HMMs (MoHMMs) can be characterised by an exchangeability-like invariance property, but that's somewhat irrelevant now. In fact, Xie et al (2021) never mention exchangeability in the paper. The core argument about implicit Bayeisan inferencec holds every time we work with a sequence model which is a mixture of simpler distributions: you can interpret the one-step-ahead predictions as implicitly performing Bayesian inference over some parameter. While it is unlikely that the distribution of human language from the internet follows a MoHMM distribution, it is reasonable to assume that the distribution over sequences that comes out of a GPT-3 is perhaps a mixture of some sort. And if that is the case, predicting the next token implicitly performs Bayesian inference over some parameter $\theta$, which the authors refer to as a 'concept'.
In context learning
The core idea of this paper is that perhaps in-context learning exploits this implicit Bayesian inference, inherent to statistical models of language, to solve tasks. Language models learn to implicitly make probabilistic inferences about concepts - whatever those are - because learning to carry out such inferences is needed to do well on next-token-prediction tasks. One that implicit learning capability is there, one can hijack it to perform other tasks that also require such inferences, including few-shot classification.
I think this is a very intriguing general idea. But then the key question the authors focus is somewhat disappointingly specific and artificial: Although a MoHMMs can be used to `complete sequences` drawn from a specific HMM (one of the mixture components) what happens if we ask the MoHMMs to complete sequences that they would never natively generate, for example an artificially constructed sequence that has a few-shot classification task embedded inside. This then becomes a question about distribution mismatch. The key findings is that, despite this distribution mismatch, the implicit inference machinery inside MoHMMs is able to identify the right concept and use it to make correct predictions in the few-shot task.
However - and please read the paper for specific details - , the analysis makes very strong assumptions about how the in-context-learning task embedded in the sequence is related to the MoHMM distribution. In a way, the in-context task the authors study is in fact more like a few-shot sequence completion task than, say, a classification task.
All in all, this was a fun paper to think about, and one that definitely changed my way of thinking about the whole in-context-learning and language-models-as-few-shot-learners direction.