- Sang Michael Xie, Aditi Raghunathan, Percy

- Sang Michael Xie, Aditi Raghunathan, Percy Liang and Tengyu Ma (2021) An Explanation of In-context Learning as Implicit Bayesian Inference

I liked this paper because it relates to one of my favourite concepts and ideas: exchangeability. And it took me back to thoughts I had back in 2015 (pre-historic by deep learning standards) about leveraging exchangeable sequence models to implement giant general-purpose learning machines. In that old post I made this observation about exchangeable models:

If we had an exchangeable RNN, we could train it on multiple unsupervised learning problems over the same input space. Such system actually. If you want to use it on a new dataset, you just feed it into the RNN, and it will give you Bayesian predictive probabilities without any additional computation. So it would be an ultimate general inference machine™.learns to learn

Fast forwarding a bit, *ultimate general inference machine* (fortunate I trademarked it) is not actually all that far from how OpenAI's GPT-3 is sometimes branded and used. It's been demonstrated that you can repurpose them as few-shot (or in some cases zero-shot) learners in a surprising variety of tasks (Brown et al, 2020). This ability of language models to solve different tasks by feeding them cleverly designed prompts is sometimes referred to as prompt-hacking or *in-context learning*.

Quite honestly, I never connected these dots until I read this paper: the motivations for leveraging one single big exchangeable sequence model as a general-purpose learner, and the more recent trend of in-context learning using GPT-3. In fact, I was deeply skeptical about the latter, thinking of it as another hack that must be somehow fundamentally flawed. But this paper by Xie et al (2021) connected those dots for me, which is why I found it so fascinating, and I will never think of 'prompt hacking' or in-context learning quite the same way.

Before talking about the paper, let me first refresh those old ideas about exchangeable sequences and implicit learning. An exchangeable sequence model is a probability distribution $p(x_1, x_2, \ldots)$ over sequences that is invariant to permutations of the tokens within the sequence, i.e. $p(x_1, x_2, \ldots, x_N) = p(x_{\pi_1}, x_{\pi_2}, \ldots, x_{\pi_N})$ for any permutation $\pi$.

The de Finetti theorem connects such sequence models to Bayesian inference, saying that any such distribution can be decomposed as a mixture over i.i.d. sequence models:

$$

p(x_1, x_2, \ldots, x_N) = \int \prod_{n=1}^N p(x_n\vert \theta) d\pi(\theta)

$$

As a consequence, the one-step-ahead predictive distribution (which predicts the next token in the sequence) also always has a decomposition as Bayesian integration:

$$

p(x_N \vert x_1, x_2, \ldots, x_{N-1}) = \int p(x_N\vert \theta) d\pi(\theta\vert x_1, \ldots, x_N),

$$

where $\pi(\theta\vert x_1, \ldots, x_N)$ is the Bayesian posterior obtained from the prior $\pi(\theta)$ via the Bayes rule:

$$

\pi(\theta \vert x_1, x_2, \ldots, x_{N-1}) \propto \pi(\theta) \prod_{n=1}^{N-1}p(x_n\vert \theta)

$$

So in this sense, if we have an exchangeable sequence model, we can think of these one-step-ahead predictive distributions as *implicitly performing Bayesian inference*. Crucially, this happens even if we don't know what $\theta$ is, or what our prior $\pi$ is, or what the likelihood $p(x_i\vert \theta)$ is. We don't have to explicitly specify what those components are, de Finetti theorem guarantees that they exist, so long as the predictions $p(x_N \vert x_1, x_2, \ldots, x_{N-1})$ are consistent with an exchangeable sequence model.

This thought motivated me to try and design RNNs (remember this was in pre-transformer times) that always produce exchangeable distributions by construction. This turned out to be very difficult, but the idea eventually evolved into BRUNO (named after Bruno de Finetti), a flexible meta-trained model for exchangeable data which exhibited few-shot concept learning abilities. This idea then got extended in a number of ways in Ira Korshunova's PhD thesis.

But GPT-3 is a language models, and clearly language tokens are not exchangeble. So whats the connection?

There are interesting extensions to the concept of exchangeability, which come with interesting generalisations of the de Finetti-type theorems. *Partial exchangeability,* as defined by Diaconis and Freedman (1980), is an invariance property of a distribution over sequences which guarantees that the sequence can be decomposed as a mixture of Markov-chains. Thus, one can say that a partially exchangeable process implicitly performs Bayesian inference over Markov chains, much the same way exchangeable processes can be said to be preforming inference over i.i.d. data generating processes.

In this new paper, Xie et al (2021), assume that the sequence model we work with is a mixture of hidden Markov models (HMMs). This is more general still than the partially exchangeable mixture of Markov chains of Diaconis and Freedman. I don't know if Mixtures of HMMs (MoHMMs) can be characterised by an exchangeability-like invariance property, but that's somewhat irrelevant now. In fact, Xie et al (2021) never mention exchangeability in the paper. The core argument about implicit Bayeisan inferencec holds every time we work with a sequence model which is a mixture of simpler distributions: you can interpret the one-step-ahead predictions as implicitly performing Bayesian inference over some parameter. While it is unlikely that the distribution of human language from the internet follows a MoHMM distribution, it is reasonable to assume that the distribution over sequences that comes out of a GPT-3 is perhaps a mixture of some sort. And if that is the case, predicting the next token implicitly performs Bayesian inference over some parameter $\theta$, which the authors refer to as a 'concept'.

The core idea of this paper is that perhaps in-context learning exploits this implicit Bayesian inference, inherent to statistical models of language, to solve tasks. Language models learn to implicitly make probabilistic inferences about concepts - whatever those are - because learning to carry out such inferences is needed to do well on next-token-prediction tasks. One that implicit learning capability is there, one can hijack it to perform other tasks that also require such inferences, including few-shot classification.

I think this is a very intriguing general idea. But then the key question the authors focus is somewhat disappointingly specific and artificial: Although a MoHMMs can be used to `complete sequences` drawn from a specific HMM (one of the mixture components) what happens if we ask the MoHMMs to complete sequences that they would never natively generate, for example an artificially constructed sequence that has a few-shot classification task embedded inside. This then becomes a question about distribution mismatch. The key findings is that, despite this distribution mismatch, the implicit inference machinery inside MoHMMs is able to identify the right concept and use it to make correct predictions in the few-shot task.

However - and please read the paper for specific details - , the analysis makes very strong assumptions about how the in-context-learning task embedded in the sequence is related to the MoHMM distribution. In a way, the in-context task the authors study is in fact more like a few-shot sequence completion task than, say, a classification task.

All in all, this was a fun paper to think about, and one that definitely changed my way of thinking about the whole in-context-learning and language-models-as-few-shot-learners direction.

]]>Even objectively outstanding students often receive dull, short, factual, almost negative-sounding reference letters. This is a result of (A) cultural

]]>Even objectively outstanding students often receive dull, short, factual, almost negative-sounding reference letters. This is a result of (A) cultural differences - we are very good at sarcasm, painfully good at giving direct negative feedback, not so good at praising others and (B) the fact that reference letters play no role in Eastern Europe and most professors have never written or seen a good one before.

Poor reference letters hurt students. They give us no insight into the applicant's true strengths, and no ammunition to support the best candidates in scholarship competitions or the admission process in general. I decided to write this guide for students so they can share it with their professors when asking for reference letters. Although reading letters from the region is what triggered me to write this, mist of this advice should be generally useful for many other people who don't know how to write good academic reference letters.

**Help the supervisor to make a case for admitting a student:**The reference letter is very important in the whole admissions process. In competitive places in Europe, there is often competition not just between applicants, but also between different research groups and supervisors about whose student gets funding. Reference letters are often used as ammunition to justify decisions internally, and to determine who gets prioritised for various scholarship and funding competitions.**Help put candidate's profile into context:**If you write a reference letter from a region like Eastern Europe, keep in mind how difficult it is to compare candidates from wildly different education systems and backgrounds. Is someone with a 4.9/5.0 GPA from Hungary more impressive than someone with a 9.5/10.0 GPA from Serbia? Your job, partly, is to explain to the admissions committee what the student's achievements mean in a global context. Do not use abbreviations that are not internationally obvious. Do not assume the reader has ever heard of your institution.*Explain everything.*

**Confidentiality:**Please do not ask the student to write their own recommendation letter. Sadly, many professors do it, but this is*not acceptable*, especially for the best students who apply to a top institution. You can also assume your reference letter is confidential. Don't share it with the student directly (Why? You probably want to write nicer things than you are comfortable sharing with them directly.)**Length:**Reference letters for the best candidates are often*2 full pages*long. Something that's half a page or just two paragraphs is interpreted as 'weak support' or worse.**Format:**Although plain text is often accepted on submissions forms, when possible, please submit a PDF on letter-headed format (where the institution's logo, name, etc appear on the header). The format should follow the layout of a formal letter. You may address it 'To Whom It May Concern,' or 'To the Admissions Committee,' or to 'Dear Colleagues,' or if you know the potential supervisor, by all means make it personal, address it to them. Obviously sign the letter with your name and title.**Basic contents:**Make sure that the letter mentions your full job title, affiliation, the candidate's full name, and the name of the programme/job/scholarship they are applying for.

Below is an example structure that is often used. (I'll use Marta as an example because I don't have a student called Marta so it won't get personal)

**Introduction:**A few sentences mentioning who you're recommending and for what program, for example "I'm writing to recommend Marta Somethingova for the Cambridge MPhil in Advanced Computer Science." The second sentence should clearly indicate how strongly you are recommending this candidate. Factual statements signal this is a lukewarm recommendation (they asked me and I had to write something). To convey your enthusiasm, you can write something like ''Marta is the strongest student I've worked with in the last couple years".**Context, How do I know Marta:**Since when, in what capacity and how closely you have known Marta. This is important - a reference from a thesis supervisor who has worked with the student for a year is more informative than a reference from someone who only met them in one exam. If you've done a project together, include details on how many times you've met, etc. What was the project about, how challenging was it, what was the student's contribution.**Marta's academic results/performance, in context:**How good is Marta, compared to other students/persons in a similar context? Be aware that whoever is reading your letter may not know your country's marking scheme, so something like a GPA of 4.8 out of 5 isn't all that informative. Try to put that in context as much as you can: how many other students would achieve similar results in your institution? Best if you can give a rank index (#8 out of a cohort of 300) relative to the whole cohort*. Context on your institution:*Similarly, assume the reader has no idea how selective your institution is, so include a few details like 'top/most selective computer science program in the country' or something.*Try to put this in context by making a prediction*about how well the student will do in the course you're recommending them to, or how well they would have done in a more challenging program. Do your research here, if you can.**Details of research/project, if applicable:**If you're recommending someone who has worked on a research project with you, include enough technical information (ideally with references or pointers) so that the reader can judge how serious that project was, and what Marta's contribution was. Don't worry, nobody is going to steal your research idea if you write it down in a recommendation letter - we are way too busy reading reference letters to do any research :D**Marta's specific strengths:**What quality of Marta do you think will be first noticed in an interview? Is Marta particularly good at understanding complex ideas fast? Is Marta very good at getting things done? Or writing clean code, mentoring others? Where appropriate, try to focus on talent and potential, before commenting on diligence or effort: If the first thing you write is "Marta is very hard working" it may be misinterpreted as a covert way of saying she tries very hard because she is not as good as the students who just get it without much effort. Be conscious of possible gender stereotypes that often come up here. E.g. she is quiet.*Make a prediction*about Marta's career prospects: She's on a good track to an academic career/well positioned for a career in industry. Please consider what the people reading your letter will want to see. If you recommend someone for an academic pure Maths program, you don't want to say the student is well positioned to end up in a boring finance job. If you feel like you MUST, you can include relative weaknesses here, but please phrase these as opportunities of growth, and what Marta needs to improve.**Other/extracurricular activities:**If you're aware of other things the student is doing - like organising meetups, volunteering, competitions, whatever - you can include them here if you feel they are relevant. Your job, again, is to put these in context.**Further background on Marta's education history:**This may be useful to support candidates who achieved impressive things in their country, but whose achievements may not make a lot of sense in an international context. For example, did they go to a very selective secondary school known for some specialization? Or, to the contrary, did they do exceptionally well despite not having access to the best education? Did they participate in country-specific olympiads or competitions? If so, what do those results mean? How many students do those things? Did they get a scholarship for their academic performance? If so, how many students get those? Did they participate in some kind of university activity? If so, what's the relevance of that? The most important assumption to remember is: Whoever reads the applicant's CV or your recommendation letter will know absolutely nothing about your country. You have to fill in the blanks, and explain everything from the ground up.*NO ACRONYMS!***Your mini-CV:**It's worth including one paragraph about yourself, the referee. What is your job title, how long you've been doing what you're doing, what's your specialty, etc. The purpose of this is to prove you are qualified and able to spot talent. Make this as internationally attractive and meaningful as you can.**Conclusion:**Here is your chance to reiterate the strength of your recommendation. If you think you're describing a not-to-miss candidate, say so explicitly. One sentence we often include here is along the lines of 'If Marta were to apply for a PhD/Masters under my supervision I would not hesitate to take her as a student'.

- Often, reference submission websites ask you to place the student in the top X% of students you've worked with. More depends on this than you might think. Be honest, but be aware that these judgments often go into a formula for scoring or pre-filtering applications. In a competitive program, if you say someone is top 20%, that is likely a death sentence for the student's chances of getting a scholarship. Again, don't lie, just make sure you don't put the student in a lower bucket than they really deserve to be in.

- Be aware of cultural differences in how we praise others and give direct feedback to/on colleagues. I often recommend the Culture Map book by Erin Meyer on this topic. Though individuals are individuals, by and large, those who socialise in the U.S. Academic system tend to write recommendation letters with a higher baseline level of enthusiasm. If you feel your letter is too positive, that may be appropriate compensation for these differences, so long as your letter is honest, of course.
- Writing style and tone are the most difficult to get right if you haven't seen examples before. I suggest you write a draft a couple weeks before submitting a letter, and then return to it before submitting. Re-reading after a week often allows you to better notice where the letter isn't conveying what you wanted.
- Ask for help! If you have a candidate you enthusiastically support, don't be afraid to ask for help writing the reference letter. Ideally, ask someone who is experienced, doesn't know the candidate, and who is not part of the decision making at the institution the student is applying.

In summary, please take time to write strong recommendation letters for your best students. There may not be many students at your institution who apply to top programs, but those who do are likely the ones who really need and deserve your attention.

]]>This post is written with my PhD student and now **guest author Patrik Reizinger** and is part 4 of a series of posts on causal inference:

- Part 1: Intro to causal inference and do-calculus
- Part 2: Illustrating Interventions with a Toy Example
- Part 3: Counterfactuals
- ➡️️ Part 4: Causal Diagrams, Markov Factorization, Structural Equation Models

One way to think about causal inference is that causal models require a more fine-grained models of the world compared to statistical models. Many causal models are equivalent to the same statistical model, yet support different causal inferences. This post elaborates on this point, and makes the relationship between causal and statistical models more precise.

Do you remember those combinatorics problems from school where the question was how many ways exist to get from a start position to a target field on a chessboard? And you can only move one step right or one step down. If you remember, then I need to admit that we will not consider problems like that. But its (one possible) takeaway actually can help us to understand Markov factorizations.

You know, it is totally indifferent how you traversed the chessboard, the result is the same. So we can say that - from the perspective of target position and the process of getting there - this is a many-to-one mapping. The same holds for random variables and causal generative models.

If you have a bunch of random variables - let's call them $X_1, X_2, \dots, X_n$ -, their joint distribution is $p \left(X_1, X_2, \dots, X_n \right) $. If you invoke the *chain rule of probability,* you will have several options to express this joint as a product of factors:

$$

p \left(X_1, X_2, \dots, X_n \right) = \prod p(X_{\pi_i}\vert X_{\pi_1}, \ldots, X_{\pi_{i-1}}),

$$

where $\pi_i$ is a permutation of indices. Since you can do this for any permutation $\pi$, the mapping between such factorizations and the joint distribution they express is many-to-one. As you can see this in the image below. The different factorizations induce a different graph, but have the same joint distribution.

Since you are reading this post, you may already be aware that in causal inference we often talk about a causal factorization, which looks like

$$

p \left(X_1, X_2, \dots, X_n \right) = \prod_{i=1}^{n} p\left(X_i | X_{\mathrm{pa}(i)}\right),

$$

where $\mathrm{pa}(X_i)$ denotes the causal parents of node $X_i$. This is one of many possible ways you can factorize the joint distribution, but we consider this one special. In the recent work, Schölkopf et al. call it a *disentangled* model. What are disentangled models? Disentangled factors describe independent aspects of the mechanism that generated the data. And they are not independent because you factored them in this way, but you were looking for this factorization because its factors are independent.

In other words, for every joint distribution there are many possible factorizations, but we assume that only one, the causal or disentangled factorization, describes the true underlying process that generated the data.

Let's consider an example for *disentangled* models. We want to model the joint distribution of altitude $A$ and temperature $T$. In this case, the causal direction is $A \rightarrow T$ - if the altitude changes, the distribution of the temperature will change too. But you cannot change the altitude by artificially heating a city - otherwise we all would enjoy views as in Miami; global warming is real but fortunately has no altitude-changing effect.

In the end, we get the factorization of $p(A)p(T|A)$. The important insights here are the answers to the question: What do we expect from these factors? The previously-mentioned Schölkopf et al. paper calls the main takeaway the **Independent Causal Mechanisms (ICM) Principle**, i.e.

By conditioning on the parents of any factor in the disentangled model, the factor will neither be able to give you further information about other factors nor is able to influence them.

In the above example, this means that if you consider different countries with their altitude distributions, you can still use the same $p(T|A),$ i.e., the factors **generalize** well. For *no influence*, the example holds straight above the ICM Principle. Furthermore, knowing any of the factors - e.g. $p(A)$ - won't tell anything about the other *(no information)*. If you know which country you are in, so will have no clue about the climate (if you consulted the website of the corresponding weather agency, that's what I call cheating). In the other direction, despite being the top-of-class student in climate matters, you won't be able to tell the country if somebody says to you that here the altitude is 350 meters and the temperature is 7°C!

We discussed Markov factorizations, as they help us understand the philosophical difference between statistical and causal inference. The beauty, and a source of confusion, is that one can use Markov factorizations in both paradigms.

However, while using Markov factorizations is optional for statistical inference, it is a must for causal inference.

So why would a statistical inference person use Markov factorizations? Because they make life easier in the sense that you do not need to worry about too high electricty costs. Namely, factorized models of data can be computationally much more efficient. Instead of modeling a joint distribution directly, which has a lot of parameters - in the case of $n$ binary variables, that is $2^n-1$ different values -, a factorized version can be pretty lightweight and parameter-efficient. If you are able to factorize the joint in a way that you have 8 factors with $n/8$ variables each, then you can describe your model with $8\times2^{n/8}-1$ parameters. If $n=16$, that is $65,535$ vs $31$. Similarly, representing your distibution in a factorized form gives rise to efficient, general-purpose message-passing algorithms, such as belief propagation or expectation propagation.

On the other hand, causal inference people really need this, otherwise, they are lost. Because without Markov factorizations, they cannot really formulate causal claims.

A causal practicioner uses Markov factorizations, because this way she is able to reason about interventions.

If you do not have the disentangled factorization, you cannot model the effect of interventions *on the real mechanisms* that make the system tick.

In plain machine learning lingo, what you want to do is *domain adaptation,* that is, you want to draw conclusions about a distribution you did not observe (these are the interventional ones). The Markov factorization prescribes ways in which you expect the distribution to change - one factor at a time - and thus the set of distributions you want to be able to robustly generalise to or draw inferences about.

Do-caclculus, the topic of the first post in the series, can be relatively simply described using Markov factorizations. As you remember, $\mathrm{do}(X=x)$ means that we set the variable $X$ to the value $x$, meaning that the distribution of that variable $p(X)$ collapses to a point mass. We can model this intervention mathematically by replacing the factor $p( x \vert \mathrm{pa}(X))$ by a Dirac-delta $\delta_x$, resulting in the deletion of all incoming edges of the intervened factors in the graphical model. We then marginalise over $x$ to calculate the joint distribution of the remaining variables. For example, if we have two variables $x$ and $y$ we can write:

$$

p(y\vert do(X=x_0)) = \int p(x,y) \frac{\delta(x - x_0)}{p(x\vert y)} dx

$$

If you've read the previous parts in this series, you'll know that Markov factorizations aren't the only tool we use in causal inference. For counterfactuals, we used structural equation models (SEMs). In this part we will illustrate the connection between these with a cheesy reference to the reparametrization trick used in VAEs among others.

But before that, let's recap SEMs. In this case, you define the relationship between the child node and its parents via a functional assignment. For node $X$ with parents $\mathrm{pa}(X)$ it has the form of

$$

X = f(\mathrm{pa}(X), \epsilon),

$$

with some noise $\epsilon.$ Here, you should read "=" in the sense of an assigment (like in Python), in mathematics, this should be ":=".

The above equation expresses the conditional probability $ p\left(X| \mathrm{pa}(X)\right)$ as a *deterministic* function of $X$ and some noise variable $\epsilon$. Wait a second..., isn't it the same thing what the reparametrization trick does? Yes it is.

So the SEM formulation (called the

implicit distribution) is related via the reparametrization trick to the conditional probability of $X$ given its parents.

Thus, we can say that a SEM is a conditional distribution, and vica versa. Okay, but how do the sets of these constructs relate to each other?

If you have a SEM, then you can read off the conditional, which is **unique**. On the other hand, you can find more SEMs for the same conditional. Just as you can express a conditional distribution in multiple different ways using different reparametrizations, it is possible to express the same Markov factorization by multiple SEMs. Consider for example that if your distribution is $\mathcal{N}(\mu,\sigma),$ then multiplying it by -1 gives you the same distribution. In this sense, SEMs are a richer class of models than Markov factorizations, thus they allow us to make inferences (counterfactual) which we weren't able to express in the more coarse grained language of Markov Factorizations.

As we discussed above, a single joint distribution has multiple valid Markov factorizations, and the same Markov factorization can be expressed as different SEMs. We can think of joint distributions, Markov factorizations, and SEMs as increasingly fine-grained model classes: joint distributions $\subset$ Markov facorizations $\subset$ SEMs. The more aspects of the data generating process you model, the more elaborate the set of inferences you can make become. Thus, Joint distributions allow you to make predictions under no mechanism shift, Markov factorizations allow you to model interventions, SEMs allow you to make counterfactual statements.

The price you pay for more expressive models is that they also get generally much harder to estimate from data. In fact, some aspects of causal models are impossible to infer from i.i.d. observational data. Moreover, some counterfactual inferences are experimentally not verifiable.

]]>I'm writing this post mostly to annoy him, by presenting this work using super hand-wavy intuitions and cartoon figures. If this isn't

]]>Few days ago we had a talk by Gergely Neu, who presented his recent work:

I'm writing this post mostly to annoy him, by presenting this work using super hand-wavy intuitions and cartoon figures. If this isn't enough, I will even find a way to mention GANs in this context.

But truthfully, I'm just excited because for once, there is a little bit of learning theory that I half-understand, at least at an intuitive level, thanks to its reliance on KL divergences and the mutual information.

Let's start this with a simple thought experiment to illustrate why and how mutual information may be useful in describing an algorithm's ability to generalize. Say we're given two datasets, $\mathcal{D}_{train}$ and $\mathcal{D}_{test}$, of the same size for simplicity. We play the following game: we both have access to $\mathcal{D}_{train}$ and $\mathcal{D}_{test}$, and we both know what learning algorithm, $\operatorname{Alg}$ we're going to use.

Now I toss a coin and I keep the result (recorded as random variable $Y$) a secret. If it's heads, I run $\operatorname{Alg}$ on the training set $\mathcal{D}_{train}$. If it's tails, I run $\operatorname{Alg}$ on the test data $\mathcal{D}_{test}$ instead. I don't tell you which of these I did, I only reveal to you the final parameter value $W$. Can you guess, just by looking at $W$, whether I trained on training or test data?

If you cannot guess $Y$, that means that the algorithm returns the same random $W$ irrespective of whether you train it on training or test data. So the training and test losses become interchangeable. This implies that the algorithm will generalize very well (on average) and not overfit to the data it's trained on.

The mutual information, in this case between $W$ and $Y$ quantifies your theoretical ability to guess $Y$ from $W$. The higher this value is, the easier it is to tell which dataset the algorithm was trained on. If it's easy to reverse engineer my coin toss from parameters, it means that the algorithm's output is very sensitive to the input dataset it's trained on. And that likely implies poor generalization.

Note by: an algorithm generalizing well on average doesn't mean it works well on average. It just means that there won't be a large gap between the expected training and expected test error. Take for example an algorithm returns a randomly initialized neural network, without even touching the data. That algorithm generalizes extremely well on average: it does just as poorly on test data as it does on training data.

Below is an illustration of my thought experiment for SGD.

In the top row, I doodled the distribution of the parameter $W_t$ at various timesteps $t=0,1,2,\ldots,T$ of SGD. We start the algorithm by initializing $W$ randomly from a Gaussian (left panel). Then, each stochastic gradient update changes the distribution of $W_t$ a bit compared to the distribution of $W_{t-1}$. How the shape of the distribution changes depends on the data we use in the SGD steps. In the top row, let's say I ran SGD on $\mathcal{D}_{train}$ and in the bottom, I run it on $\mathcal{D}_{test}$. The distibutions $p(W_t\vert \mathcal{D})$ I drew here describe where the SGD iterate is likely to be after $t$ steps of SGD started from random initialization. They are not to be confused with Bayesian posteriors, for example.

We know that running the algorithm on the test set would produce low test error. Therefore, sampling a weight vector $W$ from $p(W_T\vert \mathcal{D}_{test})$ would be great if we could do that. But in practice, we can't train on the test data, all we have the ability to sample from is $p(W_T\vert \mathcal{D}_{train})$. So what we'd like for good generalization, is if $p(W_T\vert \mathcal{D}_{test})$ and $p(W_T\vert \mathcal{D}_{train})$ were as close as possible. The mutual information between $W_T$ and my coinflip $Y$ measures this closeness in terms of the Jensen-Shannon divergence:

$$

\mathbb{I}[Y, W_T] = \operatorname{JSD}[p(W_T\vert \mathcal{D}_{test})\|p(W_T\vert \mathcal{D}_{train})]

$$

So, in summary, if we can guarantee that the final parameter an algorithm comes up with doesn't reveal too much information about what dataset it was trained on, we can hope that the algorithm has good generalization properties.

These vague intuitions can be formalized into real information-theoretic generalization bounds. These were first presented in a general context in (Russo and Zou, 2016) and in a more clearly machine learning context in (Xu and Raginsky, 2017). I'll give a quick - and possibly somewhat handwavy - overview of the main results.

Let $\mathcal{D}$ and $\mathcal{D}'$ be random datasets of size $n$, drawn i.i.d. from some underlying data distribution $P$. Let $W$ be a parameter vector, which we obtain by running a learning algorithm $\operatorname{Alg}$ on the training data $\mathcal{D}$: $W = \operatorname{Alg}(\mathcal{D})$. The algorithm may be non-deterministic, i.e. it may output a random $W$ given a dataset. Let $\mathcal{L}(W, \mathcal{D})$ denote the loss of model $W$ on dataset $\mathcal{D}$. The expected generalization error of $\operatorname{Alg}$ is defined as follows:

$$

\text{gen}( \operatorname{Alg}, P) = \mathbb{E}_{\mathcal{D}\sim P^n,\mathcal{D}'\sim P^n, W\vert \mathcal{D}\sim \operatorname{Alg}(\mathcal{D})}[\mathcal{L}(W, \mathcal{D}') - \mathcal{L}(W, \mathcal{D})]

$$

If we unpack this, we have two datasets $\mathcal{D}$ and $\mathcal{D}'$, the former taking the role of the training dataset, the latter of the test data. We look at the expected difference between the training and test losses ($\mathcal{L}(W, \mathcal{D})$ and $\mathcal{L}(W, \mathcal{D}')$), where $W$ is obtained by running $\operatorname{Alg}$ on the training data $\mathcal{D}$. The expectation is taken over all possible random training sets, test sets, and over all possible random outcomes of the learning algorithm.

The information theoretic bound states that for any learning algorithm, and any loss function that's bounded by $1$, the following inequality holds:

$$

gen(\operatorname{Alg}, P) \leq \sqrt{\frac{\mathbb{I}[W, \mathcal{D}]}{n}}

$$

The main term in the RHS of this bound is the mutual infomation between the training data \mathcal{D} and the pararmeter vector $W$ the algorithm finds. It essentially quantifies the number of bits of information the algorithm leaks about the training data into the parameters it learns. The lower this number, the better the algorithm generalizes.

The problem with applying these nice, intuitive bounds to SGD is that SGD, in fact, leaks too much information about the specific minibatches it is trained on. Let's go back to my illustrative example of having to guess if we ran the algorithm on training or test data. Consider the scenario where we start form some parameter value $w_t$ and we update either with a random minibatch of training data (blue) or a random minibatch of test data (orange).

Since the training and test datasets are assumed to be of finite size, there are only a finite number of possible minibatches. Each of these minibatches can take the parameter to a unique new location. The problem is, the set of locations you can reach with one dataset (blue dots) does not overlap with the set of locations you can reach if you update with the other dataset (orange dots). Suddenly, if I give you $w_{t+1}$, you can immediately tell if it's an orange or blue dot, so you can immediately reconstruct my coinflip $Y$.

In the more general case, the problem with SGD in the context of information-theoretic bounds is that the amount of information SGD leaks about the dataset it was trained on is high, and in some cases may even be infinite. This is actually related to the problem that several of us noticed in the context of GANs, where the true and fake distributions may have non-overlapping support, making the KL divergence infinite, and saturating out the Jensen-Shannon divergence. The first trick we came up with to solve this problem was to smooth things out by adding Gaussian noise. Indeed, adding noise is key what researches have been doing to apply these information-theoretic bounds to SGD.

The first thing people did (Pensia et al, 2018) is to study a noisy cousin of SGD: stochastic gradient Langevin dynamics (SGLD). SGDL is like SGD but in each iteration we add a bit of Gaussian noise to the parameters in addition to the gradient update. To understand why SGLD leaks less information, consider the previous example with the orange and blue point clouds. SGLD makes those point clouds overlap by convolving them with Gaussian noise.

However, SGLD is not exactly SGD, and it's not really used as much in practice. In order to say something about SGD specifically, Neu (2021) did something else, while still relying on the idea of adding noise. Instead of baking the noise in as part of the algorithm, Neu only adds noise as part of the analysis. The algorithm being analysed is still SGD, but when we measure the mutual information we will measure the mutual information between $\mathbb{I}[W + \xi; \mathcal{D}]$, where $\xi$ is Gaussian noise.

I leave it to you to check out the details of the paper. While the findings fall short of explaining whether SGD have any tendency to find solutions that generalise well, some of the results are nice and interpretable: they connect the generalization of SGD to the noisiness of gradients as well as the smoothness of the loss along the specific optimization path that was taken.

]]>- Samuel L Smith, Benoit Dherin, David Barrett, Soham De (2021) On the Origin of Implicit Regularization in Stochastic Gradient Descent

There's actually a related paper that came out simultaneously, studying full-batch gradient descent instead

]]>I wanted to highlight an intriguing paper I presented at a journal club recently:

- Samuel L Smith, Benoit Dherin, David Barrett, Soham De (2021) On the Origin of Implicit Regularization in Stochastic Gradient Descent

There's actually a related paper that came out simultaneously, studying full-batch gradient descent instead of SGD:

- David G.T. Barrett, Benoit Dherin (2021) Implicit Gradient Regularization

One of the most important insights in machine learning over the past few years relates to the importance of optimization algorithms in generalization performance.

In order to understand why deep learning works as well as it does, it is insufficient to reason about the loss function or the model class, which is what classical generalisation theory focussed on. Instead, the algorithms we use to find minima (namely, stochastic gradient descent) seem to play an important role. In many tasks, powerful neural networks are able to interpolate training data, i.e. achieve near-0 training loss. There are in fact several minima of the training loss which are virtually indistinguishably good on the training data. Some of these minima generalise well (i.e. result in low test error), others can be arbitrarily badly overfit.

What seems to be important then is not whether the optimization algorithm converges quickly to a local minimum, but which of the available "virtually global" minima it prefers to reach. It seems to be the case that the optimization algorithms we use to train deep neural networks *prefer* some minima over others, and that this preference results in better generalisation performance. The preference of optimization algorithms to converge to certain minima while avoiding others is described as *implicit regularization*.

I wrote this note as an overview on how we/I currently think about why deep networks generalize.

One of the interesting new theories that helped me imagine what happens in deep learning training is that of neural tangent kernels. In this framework we study neural network training in the limit of infinitely wide layers, full-batch training and infinitesimally small learning rate, i.e. when gradient becomes continuous gradient flow, described by an ordinary differential equation. Although the theory is useful and appealing, full-batch training with infinitesimally small learning rates is very much a cartoon version of what we actually do in practice. In practice, the smallest learning date doesn't always work best. Secondly, the stochasticity of gradient updates in minibatch-SGD seems to be of importance as well.

What Smith et al (2021) do differently in this paper is they try to study minibatch-SGD, for small, but not infinitesimally small, learning rates. This is much closer to practice. The toolkit that allows them to study this scenario is borrowed from the study of differential equations and is called backward error analysis. The cartoon illustration below shows what backward error analysis tries to achieve:

Let's say we have a differential equation $\dot{\omega} = f(\omega)$. The solution to this ODE with initial condition $\omega_0$ is a continuous trajectory $\omega_t$, shown in the image in black. We usually can't compute this solution in closed form, and instead simulate the ODE using the Euler's method, $\omega_{k+1} = \omega_k + \epsilon f(\omega_k)$. This results in a discrete trajectory shown in teal. Due to discretization error, for finite stepsize $\epsilon$, this discrete path may not lie exactly where the continuous black path lies. Errors accumulate over time, as shown in this illustration. The goal of backward error analysis is to find a different ODE, $\dot{\omega} = \tilde{f}(\omega)$ such that the approximate discrete path we got from Euler's method lieas near the the continuous path which solves this new ODE. Our goal is to reverse engineer a modified $\tilde{f}$ such that the discrete iteration can be well-modelled by an ODE.

Why is this useful? Because the form $\tilde{f}$ takes can reveal interesting aspects of the behaviour of the discrete algorithm, particularly if it has any implicit bias towards moving into different areas of the space. When the authors apply this technique to (full-batch) gradient descent, it already suggests the kind of implicit regularization bias gradient descent has.

In Gradient descent with a cost function $C$, the original ODE is $f(\omega) = -\nabla C (\omega)$. The modified ODE which corresponds to a finite stepsize $\epsilon$ takes the form $\dot{\omega} = -\nabla\tilde{C}_{GD}(\omega)$ where

$$

\tilde{C}_{GD}(\omega) = C(\omega) + \frac{\epsilon}{4} \|\nabla C(\omega)\|^2

$$

So, gradient descent with finite stepsize $\epsilon$ is like running gradient flow, but with an added penalty that penalises the gradients of the loss function. The second term is what Barret and Dherin (2021) call implicit gradient regularization.

Analysing SGD in this framework is a bit more difficult because the trajectory in stochastic gradient descent is, well, stochastic. Therefore, you don't have have a single discrete trajectory to optimize, but instead you have a distribution of different trajectories which you'd traverse if you randomly reshuffle your data. Here's a picture illustrating this situation:

Starting from the initial point $\omega_0$ we now have multiple trajectories. These correspond to different ways we can shuffle data (in the paper we assume we have a fixed allocation of datapoints to minibatches, and the randomness comes from the order in which the minibatches are considered). The two teal trajectories illustrate two potential paths. The paths end up at a random location, the teal dots show additional random endpoints where trajectories may end up at. The teal star shows the mean of the distribution of random trajectory endpoints.

The goal in (Smith et al, 2021) is to reverse-engineer an ODE so that the continuous (orange) path lies close to this mean location. The corresponding ODE is of the form $\dot{omega} = -\nabla C_{SGD}(\omega)$, where

$$

\tilde{C}_{SGD}(\omega) = C(\omega) + \frac{\epsilon}{4m} \sum_{k=1}^{m} \|\nabla \hat{C}_k(\omega)\|^2,

$$

where $\hat{C}_k$ is the loss function on the $k^{th}$ minibatch. There are $m$ minibatches in total. Note that this is similar to what we had for gradient descent, but instead of the norm of the full-batch gradient we now have the average norm of minibatch gradients as the implicit regularizer. Another interesting view on this is to look at the difference between the GD and SGD regularizers:

$$

\tilde{C}_{SGD} = \tilde{C}_{GD} + \frac{\epsilon}{4m} \sum_{k=1}^{m} \|\nabla \hat{C}_k(\omega) - C(\omega)\|^2

$$

This additional regularization term, $\frac{1}{m}\sum_{k=1}^{m} \|\nabla \hat{C}_k(\omega) - C(\omega)\|^2$, is something like the total variance of minibatch gradients (the trace of the empirical Fisher information matrix). Intuitively, this regularizer term will avoid parts of the parameter-space where the variance of gradients calculated over different minibatches is high.

Importantly, while $C_{GD}$ has the same minima as $C$, this is no longer true for $C_{SGD}$. Some minima of $C$ where the variance of gradients is high, is no longer a minimum of $C_{SGD}$. As an implication, not only does SGD follow different trajectories than full-batch GD, it may also converge to completely different solutions.

As a sidenote, there are many versions of SGD, based on how data is sampled for the gradient updates. Here, it is assumed that the datapoints are assigned to minibatches, but then the minibatches are randomly sampled. This is different from randomly sampling datapoints with replacement from the training data (Li et al (2015) consider that case), and indeed an analysis of that variant may well lead to different results.

Why would an implicit regularization effect avoiding high minibatch gradient variance be useful for generalisation? Well, let's consider a cartoon illustration of two local minima below:

Both minima are the same as much as the average loss $C$ is concerned: the value of the minimum is the same, and the width of the two minima are the same. Yet, in the left-hand situation, the wide minimum arises as the average of several minibatch losses, which all look the same, and which all are relatively wide themselves. In the right-hand minimum, the wide average loss minimum arises as the average of a lot of sharp minibatch losses, which all disagree on where exactly the location of the minimum is.

If we have these two options, it is reasonable to expect the left-hand minimum to generalise better, because the loss function seems to be less sensitive to whichever specific minibatch we are evaluating it on. As a consequence, the loss function also may be less sensitive to whether a datapoint is in the training set or in the test set.

In summary, this paper is a very interesting analysis of stochastic gradient descent. While it has its limitations (which the authors don't try to hide and discuss transparently in the paper), it nevertheless contributes a very interesting new technique for analysing optimization algorithms with finite stepsize. I found the paper to be well-written, with the explanation of somewhat tedious details of the analysis clearly laid out. But perhaps I liked this paper most because it confirmed my intuitions about why SGD works, and what type of minima it tends to prefer.

]]>

This is a half-guest-post written jointly with Dóra, a fellow participant in a reading group where we recently discussed the original paper on $\beta$-VAEs:

- Irina Higgins et al (ICLR 2017): $\beta$-VAE: Learning Basic Visual Concepts with a

This is a half-guest-post written jointly with Dóra, a fellow participant in a reading group where we recently discussed the original paper on $\beta$-VAEs:

- Irina Higgins et al (ICLR 2017): $\beta$-VAE: Learning Basic Visual Concepts with a Constrained Variational Framework

On the surface of it, $\beta$-VAEs are a straightforward extension of VAEs where we are allowed to directly control the tradeoff between the reconstruction and KL loss terms. In an attempt to better understand where the $\beta$-VAE objective comes from, and to further motivate why it makes sense, here we derive $\beta$-VAEs from different first principles than it is presented in the paper. Over to mostly Dóra for the rest of this post:

- $p_\mathcal{D}(\mathbf{x})$: data distribution
- $q_\psi(\mathbf{z}\vert \mathbf{x})$: representation distribution
- $q_\psi(\mathbf{z}) = \int p_\mathcal{D}(\mathbf{x})q_\psi(\mathbf{z}\vert \mathbf{x})$: aggregate posterior - marginal distribution of representation $Z$
- $q_\psi(\mathbf{x}\vert \mathbf{z}) = \frac{q_\psi(\mathbf{z}\vert \mathbf{x})p_\mathcal{D}(\mathbf{x})}{q_\psi(\mathbf{z})}$: "inverted posterior"

Learning disentangled representations that recover the independent data generative factors has been a long-term goal for unsupervised representation learning.

$\beta$-VAEs were introduced in 2017 with a proposed modification to the original VAE formulation that can achieve better disentanglement in the posterior $q_\psi(\mathbf{z}\vert \mathbf{x})$. An assumption of $\beta$-VAEs is that there are two sets of latent factors, $\mathbf{v}$ and $\mathbf{w}$, that contribute to generating observations $x$ in the real world. One set, $\mathbf{v}$, is coordinate-wise conditionally independent given the observed variable, i.e., $\log p(\mathbf{v}\vert x) = \sum_k \log p(v_k\vert \mathbf{x})$. At the same time, we don't assume anything about the remaining factors $\mathbf{w}$.

The factors $v$ are going to be the main object of interest for us. The conditional independence assumption allows us to formulate what it means to disentangle these factors of variation. Consider a representation $\mathbb{z}$ which *entangles * coordinates of $v$, in that each coordinate of $\mathbb{z}$ depends on multiple coordinates of $\mathbb{v}$, e.g. $z_1 = f_1(v_1, v_2)$ and $z_2 = f_2(v_1, v_2)$. Such a $\mathbb{z}$ won't necessarily satisfy co-ordinatewise conditional independence $\log p(\mathbf{z}\vert x) = \sum_k \log p(z_k\vert \mathbf{x})$. However, if each component of $\mathbb{z}$ depended only on one corresponding coordinate of $\mathbf{v}$, for example $z_1 = g_1(v_1)$ and $z_2 = g_2(v_2)$, the component-wise conditional independence would hold for $\mathbb{z}$ too.

Thus, under these assumptions we can encourage disentanglement to happen by encouraging the posterior $q_\psi(\mathbf{z}\vert \mathbf{x})$ to be coordinate-wise conditionally independent. This can be done by adding a new hyperparameter $\beta$ to the original VAE formulation

$$

\mathcal{L}(\theta, \phi; x, z, \beta) = -\mathbb{E}_{q_{\phi}(x\vert z)p_\mathcal{D}(x)}[\log p_{\theta}(x \vert z)] + \beta \operatorname{KL} (q_{\phi(z\vert x)}\| p(z)),

$$

where $\beta$ controls the trade-off between the capacity of the latent information channel and learning conditionally independent latent factors. When $\beta$ is higher than 1, we encourage the posterior $q_\psi(z\vert x)$ to be close to the isotropic unit Gaussian $p(z) = \mathcal{N}(0, I)$, which itself is coordnate-wise independent.

In this post, we revisit the conditional independence assumption of latent factors, and argue that a more appropriate objective would be to have marginal independence in the latent factors. To show you our intuition, let's revisit the "Explaining Away" phenomenon from Probabilistic Graphical Models.

Consider three random variables:

$A$: Ferenc is grading exams

$B$: Ferenc is in a good mood

$C$: Ferenc is tweeting a meme

with the following graphical model $A \rightarrow C \leftarrow B$.

Here we could assume that Ferenc grading exams is independent of him being in a good mood, i.e., $A \perp B$. However, the pressure of marking exams results in increased likelihood of procrastination, which increases the chances of tweeting memes, too.

However, as soon as we see a meme being tweeted by him, we know that he either in a good mood or he is grading exams. If we know he is grading exams, that explains why he is tweeting memes, so it's less likely he's tweeting memes because he's a good mood. Consequently, $A \not!\perp B\vert C$.

In all seriousness, if we have a graphical model $A \rightarrow C \leftarrow B$, in evidence of $C$, independence between $A$ and $B$ no longer holds.

We argue that the explaining away phenomenon makes the conditional independence of latent factors undesirable. A much more reasonable assumption about the generative process of the data is that the factors of variation $\mathbf{v}$ are drawn independently, and then the observations are generated conditoned on them. However, if we consider two coordinates of $\mathbb{v}$ and the observation $\mathbf{x}$, we now have a $V_1 \rightarrow \mathbf{X} \leftarrow V_2$ graphical model, thus, conditional independence cannot hold.

Instead, we argue that to recover the generative factors of the data, we should encourage latent factors to be marginally independent. In the next section, we set out to derive an algorithm that encourages marginal independence in the representation Z. We will also show how the resulting loss function from this new derivation is actually equivalent to the original $\beta$-VAEs formulation.

We'll start from desired properties of the representation distribution $q_\psi(z\vert x)$. We'd like this representation to satisfy two properties:

- Marginal independence: We would like the aggregate posterior $q_\psi(z)$ to be close to some fixed and factorized unit Gaussian prior $p(z) = \prod_i p(z_i)$. This encourages $q_\psi(z)$ to exhibit coordinate-wise independence.
- Maximum Information: We'd like the representation $Z$ to retain as much information as possible about the input data $X$.

Note that without (1), (2) is insufficient, because then any deterministic and invertible function of $X$ would contain maximum information about $X$ but that wouldn't make it a useful or disentangled representation. Similarly, without (2), (1) is insufficient because if we set $q_\psi(z\vert x) = p(z)$ it would give us a latent representation Z that is coordinate-wise independent, but it is also independent of the data which is not very useful.

We can achieve a combination of these desiderata by optimizing an objective with the weighted combination of two terms corresponding to the two goals we set out above:

$$

\mathcal{L}(\psi) = \operatorname{KL}[q_\psi(z)| p(z)] - \lambda \mathbb{I}_{q_\psi(z\vert x) p_\mathcal{D}(x)}[X, Z]

$$

Remember, we use $q_\psi(z)$ to denote the aggregate posterior. We will refer to this as the InfoMax objective. Now we're going to show how this objective can be related to the $\beta$-VAE objective. Let's first consider the KL term in the above objective:

\begin{align}

\operatorname{KL}[q_\psi(z)| p(z)] &= \mathbb{E}_{q_\psi(z)} \log \frac{q_\psi(z)}{p(z)}\\

&= \mathbb{E}_{q_\psi(z\vert x)p_\mathcal{D}(x)} \log \frac{q_\psi(z)}{p(z)}\\

&= \mathbb{E}_{q_\psi(z\vert x)p_\mathcal{D}(x)} \log \frac{q_\psi(z)}{q_\psi(z\vert x)} + \mathbb{E}_{q_\psi(z\vert x)p_\mathcal{D}(x)} \log \frac{q_\psi(z\vert x)}{p(z)}\\

&= \mathbb{E}_{q_\psi(z\vert x)p_\mathcal{D}(x)} \log \frac{q_\psi(z)p_\mathcal{D}(x)}{q_\psi(z\vert x)p_\mathcal{D}(x)} + \mathbb{E}_{q_\psi(z\vert x)p_\mathcal{D}(x)} \log \frac{q_\psi(z\vert x)}{p(z)}\\

&= -\mathbb{I}_{q_\psi(z\vert x)p_\mathcal{D}(x)}[X,Z] + \mathbb{E}_{p_\mathcal{D}}\operatorname{KL}[q_\psi(z\vert x)| p(z)]

\end{align}

This is interesting. If the mutual information between $X$ and $Z$ is non-zero (which is ideally the case), the above equation shows that latent factors cannot be both marginally and conditionally independent at the same time. It also gives us a way to relate the KL terms representing marginal and conditional independence.

Putting this back into the InfoMax objective, we have that

\begin{align}

\mathcal{L}(\psi) &= \operatorname{KL}[q_\psi(z)| p(z)] - \lambda \mathbb{I}_{q_\psi(z\vert x)p_\mathcal{D}(x)}[X, Z]\\

&= \mathbb{E}_{p_\mathcal{D}}\operatorname{KL}[q_\psi(z\vert x)| p(z)] - (\lambda + 1) \mathbb{I}_{q_\psi(z\vert x)p_\mathcal{D}(x)}[X, Z]\

\end{align}

Using the KL term in the InfoMax objective, we were able to recover the KL-divergence term that also appears in the $\beta$-VAE (and consequently, VAE) objective.

At this point, we still haven't defined the generative model $p_\theta(x\vert z)$, the above objective expresses everything in terms of the data distribution $p_\mathcal{D}$ and the posterior/representation distribution $q_\psi$.

We will now focus on the 2nd term in our desired objective, the weighted mutual information, which we still can't easily evaluate. We will now show that we can recover the reconstruction term in $\beta$-VAEs by doing a variational approximation to the mutual information.

Note the following equality:

\begin{equation}

\mathbb{I}[X,Z] = \mathbb{H}[X] - \mathbb{H}[X\vert Z]

\end{equation}

Since we sample X from the data distribution $p_\mathcal{D}$, we see that the first term $\mathbb{H}[X]$, the entropy of $X$, is constant with respect to the variational parameter $\psi$. We are left to focus on finding a good approximation to the second term $\mathbb{H}[X\vert Z]$. We can do so by minimizing the KL divergence between $q_\psi(x\vert z)$ and an auxilliary distribution $p_\theta(x\vert z)$ to make a variational appoximation to the mutual information:

$$\mathbb{H}[X\vert Z] = - \mathbb{E}_{q\psi(z\vert x)p_\mathcal{D}(x)} \log q_\psi(x\vert z) \leq \inf_\theta - \mathbb{E}_{q\psi(z\vert x)p_\mathcal{D}(x)} \log p_\theta(x\vert z)$$

Finding this lower bound to MI, we have now recovered the reconstruction term from the $\beta$-VAE objective:

$$

\mathcal{L}(\psi) + \text{const} \leq - (1 + \lambda) \mathbb{E}_{q_\psi(z\vert x)p_\mathcal{D}(x)} \log p_\theta(x\vert z) + \mathbb{E}_{p_\mathcal{D}}\operatorname{KL}[q_\psi(z\vert x)| p(z)]

$$

This is essentially the same as the $\beta$-VAE objective function, where $\beta$ is related to our previous $\lambda$. In particular, $\beta = \frac{1}{1 + \lambda}$. Thus, since we assumed $\lambda>0$ for the InfoMax objective to make sense, we can say that the $\beta$-VAE objective encourages disentanglement in the InfoMax sense for values of $0<\beta<1$.

Conceptually, this derivation is interesting because the main object of interest is now the recognition model, $q_\psi(z\vert x)$. That is, the posterior becomes a the focus of the objective function - something that is not the case when we are maximizing model likelihood alone (as explained here). In this respect, this derivation of the $\beta$-VAE makes more sense from a representation learning viewpoint than the derivation of VAE from maximum likelihood.

There is a nice symmetry to these two views. There are two joint distributions over latents and observable variables in a VAE. On one hand we have $q_\psi(z\vert x)p_\mathcal{D}(x)$ and on the other we have $p(x)p_\theta(x\vert z)$. The "latent variable model" $q_\psi(z\vert x)p_\mathcal{D}(x)$ is a family of LVMs which has a marginal distribution on observable $\mathbf{x}$ that is exactly the same as the data distribution $p_\mathcal{D}$. So one can say $q_\psi(z\vert x)p_\mathcal{D}(x)$ is a parametric family of latent variable models with whose likelihood is maximal - and we want to choose from this family a model where the representation $q_\psi(z\vert x)$ has nice properties.

On the flipside, $p(z)p_\theta(x\vert z)$ is a parametic set of models where the marginal distribution of latents is coordinatewise independent, but we would like to choose from this family a model that has good data likelihood.

The VAE objective tries to move these two latent variable models closer to one another. From the perspective of $q_\psi(z\vert x)p_\mathcal{D}(x)$ this amounts to reproducing the prior $p(z)$ with the aggregate posterior. from the perspective of $p(z)p_\theta(x\vert z)$, it amounts to maximising the data likelihood. When the $\beta$-VAE objective is used, we additionally wish to maximise the mutual information between the observed data and the representation.

This dual role of information maximization and maximum likelihood has been pointed out before, for example in this paper about the IM algorithm. The symmetry of variational learning has been exploited a few times, for example in yin-yang machines, and more recently also in methods like adversarially learned inference.

]]>Neural tangent kernels are a useful tool for understanding neural network training and implicit regularization in gradient descent. But it's not the easiest concept to wrap your head around. The paper that I found to have been most useful for me to develop an understanding is this one:

In this post I will illustrate the concept of neural tangent kernels through a simple 1D regression example. Please feel free to peruse the google colab notebook I used to make these plots.

Let's start from a very boring case begin with. Let's say we have a function defined over integers between -10 and 20. We parametrize our function as a look-up table, that is the value of the function $f(i)$ at each integer $i$ is described by a separate parameter $\theta_i = f(i)$. I'm initializing the parameters of this function as $\theta_i = 3i+2$. The function is shown by the black dots below:

Now, let's consider what happens if we observe a new datapoint, $(x, y) =(10, 50)$, shown by the blue cross. We're going to take a gradient descent step updating $\theta$. Let's say we use the squared error loss function $(f(10; \theta) - 50)^2$ and a learning rate $\eta=0.1$. Because the function's value at $x=10$ only depends on one of the parameter $\theta_10$, only this parameter will be updated. The rest of the parameters, and therefore the rest of the function values remain unchanged. The red arrows illustrate the way function values move in a single gradient descent step: Most values don't move at all, only one of them moves closer to the observed data. Hence only one visible red arrow.

However, in machine learning we rarely parametrize functions as lookup tables of individual function values. This parametrization is pretty useless as it doesn't allow you to interpolate let alone extrapolate to unseen data. Let's see what happens in a more familiar model: linear regression.

Let's now consider the linear function $f(x, \theta) = \theta_1 x + \theta_2$. I initialize the parameters to $\theta_1=3$ and $\theta_2=1$, so at initialisation, I have exactly the same function over integers as I had in the first example. Let's look at what happens to this function as I update $\theta$ by performing single gradient descent step incorporating the observation $(x, y) =(10, 50)$ as before. Again, red arrows are show how function values move:

Whoa! What's going on now? Since individual function values are no longer independently parametrized, we can't move them independently. The model binds them together through its global parameters $\theta_1$ and $\theta_2$. If we want to move the function closer to the desired output $y=50$ at location $x=10$ the function values elsewhere have to change, too.

In this case, updating the function with an observation at $x=10$ changes the function value far away from the observation. It even changes the function value in the opposite direction than what one would expect.. This might seem a bit weird, but that's really how linear models work.

Now we have a little bit of background to start talking about this neural tangent kernel thing.

Given a function $f_\theta(x)$ which is parametrized by $\theta$, its *neural tangent kernel* $k_\theta(x, x')$ quantifies how much the function's value at $x$ changes as we take an infinitesimally small gradient step in $\theta$ incorporating a new observation at $x'$. Another way of phrasing this is: $k(x, x')$ measures how sensitive the function value at $x$ is to prediction errors at $x'$.

In the plots before, the size of the red arrows at each location $x$ were given by the following equation:

$$

\eta \tilde{k}_\theta(x, x') = f\left(x, \theta + \eta \frac{f_\theta(x')}{d\theta}\right) - f(x, \theta)

$$

In neural network parlance, this is what's going on: The loss function tells me to increase the function value $f_\theta(x')$. I back-propagate this through the network to see what change in $\theta$ do I have to make to achieve this. However, moving $f_\theta(x')$ this way also simultaneously moves $f_\theta(x)$ at other locations $x \neq x'$. $\tilde{k}_\theta(x, x')$ expresses by how much.

The neural kernel is basically something like the limit of $\tilde{k}$ in as the stepsize becomes infinitesimally small. In particular:

$$

k(x, x') = \lim_{\eta \rightarrow 0} \frac{f\left(x, \theta + \eta \frac{df_\theta(x')}{d\theta}\right) - f(x, \theta)}{\eta}

$$

Using a 1st order Taylor expansion of $f_\theta(x)$, it is possible to show that

$$

k_\theta(x, x') = \left\langle \frac{df_\theta(x)}{d\theta} , \frac{f_\theta(x')}{d\theta} \right\rangle

$$

As homework for you: find $k(x, x')$ and/or $\tilde{k}(x, x')$ for a fixed $\eta$ in the linear model in the pervious example. Is it linear? Is it something else?

Note that this is a different derivation from what's in the paper (which starts from continuous differential equation version of gradient descent).

Now, I'll go back to the examples to illustrate two more important property of this kernel: sensitivity to parametrization, and changes during training.

It's well known that neural networks can be repararmetized in ways that don't change the actual output of the function, but which may lead to differences in how optimization works. Batchnorm is a well-known example of this. Can we see the effect of reparametrization in the neural tangent kernel? Yes we can. Let's look at what happens if I reparametrize the linear function I used in the second example as:

$$

f_\theta(x) = \theta_1 x + \color{blue}{10\cdot}\theta_2

$$

but now with parameters $\theta_1=3, \theta_2=\color{blue}{0.1}$. I highlighted in blue what changed. The function itself, at initialization is the same since $10 * 0.1 = 1$. The function class is the same, too, as I can still implement arbitrary linear functions. However, when we look at the effect of a single gradient step, we see that the function changes differently when gradient descent is performed in this parametrisation.

In this parametization, it became easier for gradient descent to push the whole function up by a constant, while in the previous parametrisation it decided to change the slope. What this demonstrates is that the neural tangent kernel $k_\theta(x, x')$ is sensitive to reparametrization.

While the linear models may be good illustration, let's look at what $k_\theta(x, x')$ looks like in a nonlinear model. Here, I'll consider a model with two squared exponential basis functions:

$$

f_\theta(x) = \theta_1 \exp\left(-\frac{(x - \theta_2)^2}{30}\right) + \theta_3 \exp\left(-\frac{(x - \theta_4)^2}{30}\right) + \theta_5,

$$

with initial parameter values $(\theta_1, \theta_2, \theta_3, \theta_4, \theta_5) = (4.0, -10.0, 25.0, 10.0, 50.0)$. These are chosen somewhat arbitrarily and to make the result visually appealing:

We can visualise the function $\tilde{k}_\theta(x, 10)$ directly, rather than plotting it on top the function. Here I also normalize it by dividing by $\tilde{k}_\theta(10, 10)$.

What we can see is that this starts to look a bit like a *kernel* function in that it has higher values near $10$ and decreases as you go farther away. However, a few things are worth noting: the maximum of this kernel function is not at $x=1o$, but at $x=7$. It means, that the function value $f(7)$ changes more in reaction to an observation at $x'=10$ than the value $f(10)$. Secondly, there are some negative values. In this case the previous figure provides a visual explanation why: we can increase the function value at $x=10$ by pushing the valley centred at $\theta_1=4$ away from it, to the left. This parameter change in turn decreases function values on the left-hand wall of the valley. Third, the kernel function converges to a positive constant at its tails - this is because of the offset $\theta_5$.

Now I'm going to illustrate another important property of the neural tangent kernel: in general, the kernel depends on the parameter value $\theta$, and therefore it changes as the model is trained. Here I show what happens to the kernel as I take 15 gradient ascent steps trying to increase $f(10)$. The purple curve is the one I had at initialization (above), and the orange ones show the kernel at the last gradient step.

The corresponding changes to the function $f_\theta_t$ changes are shown below:

So we can see that as the parameter changes, the kernel also changes. The kernel becomes flatter. An explanation of this is that eventually we reach a region of parameter space, where $\theta_4$ changes the fastest.

It turns out the neural tangent kernel becomes particularly useful when studying learning dynamics in infinitely wide feed-forward neural networks. Why? Because in this limit, two things happen:

- First: if we initialize $\theta_0$ randomly from appropiately chosen distributions, the initial NTK of the network $k_{\theta_0}$ approaches a deterministic kernel as the width increases. This means, that at initialization, $k_{\theta_0}$ doesn't really depend on $\theta_0$ but is a fixed kernel independent of the specific initialization.
- Second: in the infinite limit the kernel $k_{\theta_t}$ stays constant over time as we optimise $\theta_t$. This removes the parameter dependence during training.

These two facts put together imply that gradient descent in the infinitely wide and infinitesimally small learning rate limit can be understood as a pretty simple algorithm called *kernel gradient descent* with a fixed kernel function that depends only on the architecture (number of layers, activations, etc).

These results, taken together with an older known result by Neal, (1994), allows us to characterise the probability distribution of minima that gradient descent converges to in this infinite limit as a Gaussian process. For details, see the paper mentioned above.

There are two somewhat related sets of results both involving infinitely wide neural netwoks and kernel functions, so I just wanted to clarify the difference between them:

- the older, well-known result by Neal, (1994), later extended by others, is that the distribution of $f_\theta$ under random initialization of $\theta$ converges to a Gaussian process. This Gaussian process has a kernel or covariance function which is not, in general, the same as the neural tangent kernel. This old result doesn't say anything about gradient descent, and is typically used to motivate the use of Gaussian process-based Bayesian methods.
- the new, NTK, result is that the evolution of $f_{\theta_t}$ during gradient descent training can be described in terms of a kernel, the neural tangent kernel, and that in the infinite limit this kernel stays constant during training and is deterministic at initialization. Using this result, it is possible to show that in some cases the distribution of $f_{\theta_t}$ is a Gaussian process at every timestep $t$, not just at initialization. This result also allows us to identify the Gaussian process which describes the limit as $t \rightarrow \infty$. This limiting Gaussian process however is not the same as the posterior Gaussian process which Neal and others would calculate on the basis of the first result.

So I hope this post helps a bit by building some intuition about what the neural tangent kernel is. If you're interested, check out the simple colab notebook I used for these illustrations.

]]>- Rezende et al (2020) Rezende Causally Correct Partial Models for Reinforcement Learning

It's frankly taken me a long time to understand what was going on, and it took me weeks to write this half-decent explanation of it.

]]>I recently encountered this cool paper in a reading group presentation:

- Rezende et al (2020) Rezende Causally Correct Partial Models for Reinforcement Learning

It's frankly taken me a long time to understand what was going on, and it took me weeks to write this half-decent explanation of it. The first notes I wrote followed the logic of the paper more, this in this post I thought I'd just focus on the high level idea, after which hopefully the paper is more straightforward. I wanted to capture the key idea, without the distractions of RNN hidden states, etc, which I found confusing to think about.

To start with the basics, this paper deals with the partially observed Markov decision process (POMDP) setup. The diagram below illustrates what's going on:

The grey nodes $e_t$ show the unobserved state of the environment at each timestep $t=0,1,2\ldots$. At each timestep the agent observes $y_0$ which depends on the current state of the environment (red-ish nodes). The agent then updates their state $s_t$ based on its past state $s_{t-1}$, the new observation $y_t$, and the previous action taken $a_{t-1}$. This is shown by the blue squares (they're squares, signifying that this node depends deterministically on its parents). Then, based on the agent's state, it chooses an action $a_t$ from by sampling from policy $\pi(a_t\vert s_t)$. The action influences how the environment's state, $e_{t+1}$ changes.

We assume that the agent's ultimate goal is to maximise reward at the last state at time $T$, which we assume is a deterministic function of the observation $r(y_T)$. Think of this reward as the score in an atari game, which is written on the screen whose contents are made available in $y_t$.

Let's start by stating what we ultimately would like estimate from the data we have. The assumption is that we sampled the data using some policy $\pi$, but we would like to be able to say how well a different policy $\tilde{\pi}$ would do, in other words, what would be the expected score at time $T$ if instead of $\pi$ we used a different policy $\tilde{\pi}$.

What we are interested in, is a causal/counterfactual query:

$$

\mathbb{E}_{\tau\sim\tilde{p}}[r(y_T)],

$$

where $\tau = [(s_t, y_t, e_t, a_t) : t=0\ldots T]$ denotes a trajectory or rollout up to time $T$, and $\tilde{p}$ denotes the generative process when using policy $\tilde{\pi}$, that is:

$$

\tilde{p}(\tau) = p(e_0)p(y_0\vert e_0) \tilde{\pi}(a_0\vert s_0) p(s_0)\prod_{t=1}^T p(e_t\vert a_{t-1}, e_{t-1}) p(y_t\vert e_t) \tilde{\pi}(a_t\vert s_t) \delta (s_t - g(s_{t-1}, y_t))

$$

I called this a causal or counterfactual query, because we are interested in making predictions under a different distribution $\tilde{p}$ than $p$ which we have observations from. The difference between $\tilde{p}$ and $p$ can be called an intervention, where we replace specific factors in the data generating process with different ones.

There are - at least - two ways one could go about estimating such counterfactual distribution:

- model-free, via
*importance sampling*. This method tries to directly estimate the causal query by calculating a weighted average over the observed data. The weights are given by the ratio between $\tilde{\pi}(a_t\vert s_t)$, the probability by which $\tilde{\pi}$ would choose an action and $\pi(a_t\vert s_t)$, the probability it was chosen by the policy that we used to collect the data. A great paper explaining how this works is (Bottou et al, 2013). Importance sampling as the advantage that we don't have to build any model of the environment, we can directly evaluate the average reward from the samples we have, using only $\pi$ and $\tilde{\pi}$ to calculate the weights. The downside, however, is that importance sampling often incredibly high variance estimate, and is only reliable if $\tilde{\pi}$ and $\pi$ are very close. - model-based, via
*causal calculus*. If possible, we can use do-calculus to express the causal query in an alternative way, using various conditional distributions estimated from the observed data. This approach has the disadvantage that it requires us build a model from the data first. We then use the conditional distributions learned from the data to approximate the quantity of interest by plugging them into the formula we got from do-calculus. If our models are imperfect, these imperfections/approximation errors can compound when the causal estimand is calculated, potentially leading to large biases and inaccuracies. On the other hand, our models may be accurate enough to extrapolate to situations where importance weighting would be unreliable.

In this paper, we focus on solving the problem with causal calculus. This requires us to build a model of observed data, which we can then use to make causal predictions. The key question this paper asks is

How much of the data do we have to model to be able to make the kinds of causal inferences we would like to make?

One way we can answer the query above is to model the joint distribution of everything, or mostly everything, that we can observe. For example, we could build a full autoregressive model of observations $y_t$ conditioned on actions $a_t$. In essence this would amount to fitting a model to $p(y_{0:T}\vert a_{0:T})$.

If we had such model, we would theoretically be able to make causal predictions, for reasons I will explain later. However, this option is ruled out in the paper because we assume the observations $y_t$ are very high dimensional, such as images rendered in a computer game. Thus, modelling the joint distribution of the whole observation sequence $y_{1:T}$ accurately is hopeless and would require a lot of data. Therefore, we would like to get away without modelling the whole observation sequence $y_{1:T}$, which brings us to partial models.

Partial models try to avoid modelling the joint distribution of high-dimensional observations $y_{1:T}$ or agent-state sequences $s_{0:T}$, and focus on modelling directly the distribution of $r(y_T)$ - i.e. only the reward component of the last observation, given the action-sequence $a_{0:T}$. This is clearly a lot easier to do, because $r(y_T)$ is assumed to be a low-dimensional aspect of the full observation $y_T$, so all we have to learn is a model of a scalar conditioned on a sequence of actions $q_\theta(r(y_T)\vert a_{0:T})$. We know very well how to fit such models to realistic amounts of data.

However, if we don't include either $y_t$ or $s_t$ in our model, we won't be able to make the counterfactual inferences we wanted to make in the first place. Why? Let's look at he data generating process once more:

We are trying to model the causal impact of actions $a_0$ and $a_1$ on the outcome $y_2$. Let's focus on $a_1$. $y_2$ is clearly statistically dependent on $a_1$. However, this statistical dependence emerges due to completely different effects:

**causal association:**$a_1$ influences the state of the environment $e_2$, resulting in an observation $y_2$. Therefore, $a_1$ has an direct causal effect on $y_2$, mediated by $e_2$**spurious association due to confounding:**The unobserved hidden state $e_1$ is a confounder between the action $a_1$ and the observation $y_2$. The state $e_1$ has an indirect causal effect on $a_1$ mediated by the observation $y_1$ and the agent's state $s_1$. Similarly $e_1$ has an indirect effect on $y_2$ mediated by $e_2$.

I illustrated these two sources of statistical association by colour-coding the different paths in the causal graph. The blue path is the confounding path: correlation is induced because both $a_1$ and $y_2$ have $e_1$ as causal ancestor. The red path is the causal path: $a_1$ indirectly influences $y_2$ via the hidden state $e_2$.

If we would like to correctly evaluate the consequence of changing policies, we have to be able to disambiguate between these two sources of statistical association, get rid of the blue path, and only take the red path into account. Unfortunately, this is not possible in a partial model, where we only model the distribution of $y_2$ conditional on $a_0$ and $a_1$.

If we want to draw causal inferences, we **must model the distribution of at least one variable along blue path.** Clearly, $y_1$ and $s_1$ are theoretically observable, and are on the confounding path. Adding either of these to our model would allow us to use the backdoor adjustment formula (explained in the paper). However, this would take us back to Option 1, where we have to model the joint distribution of either sequences of observations $y_{0:T}$ or sequences of states $s_{0:T}$, both assumed to be high-dimensional and difficult to model.

So we finally got to the core of what is proposed in the paper: a kind of halfway-house between modeling everything and modeling too little. We are going to model *enough *variables to be able to evaluate causal queries, while keeping the dimensionality of the model we have to fit low. To do this, we change the data generating process slightly - by splitting the policy into two stages:

The agent first generates $z_t$ from the state $s_t$, and then uses the sampled $z_t$ value to make a decision $a_t$. One can understand $z_t$ as being a stochastic bottleneck between the agent's high-dimensional state $s_t$, and the low-dimensional action $a_t$. The assumption is that the sequence $z_{0:T}$ should be a lot easier to model than either $y_{0:T}$ or $s_{0:T}$. However, if we now build a model $p(r(y_T), z_{0:T} \vert a_{0:T})$ are now able to use this model evaluate the causal queries of interest, thanks for the backdoor adjustment formula. For how to precisely do this, please refer to the paper.

Intuitively, this approach helps by adding a low-dimensional stochastic node along the confounding path. This allows us to compensate for confounding, without having to build a full generative model of sequences of high-dimensional variables. It allows us to solve the problem we need to solve without having to solve a ridiculously difficult subproblem.

]]>Last night on the train I read this nice paper by David Duvenaud and colleagues. Around midnight I got a calendar notification "it's David Duvenaud's birthday". So I thought it's time for a David Duvenaud birthday special (don't get too excited David, I won't make it an annual tradition...)

- Jonathan Lorraine, Paul Vicol, David Duvenaud (2019) Optimizing Millions of Hyperparameters by Implicit Differentiation

I recently covered iMAML: the meta-learning algorithm that makes use of implicit gradients to sidestep backpropagating through the inner loop optimization in meta-learning/hyperparameter tuning. The method presented in (Lorraine et al, 2019) uses the same high-level idea, but introduces a different - on the surface less fiddly - approximation to the crucial inverse Hessian. I won't spend a lot of time introducing the whole meta-learning setup from scratch, you can use the previous post as a starting point.

Many - though not all - meta-learning or hyperparameter optimization problems can be stated as nested optimization problems. If we have some hyperparameters $\lambda$ and some parameters $\theta$ we are interested in

$$

\operatorname{argmin}_\lambda \mathcal{L}_V (\operatorname{argmin}_\theta \mathcal{L}_T(\theta, \lambda)),

$$

Where $\mathcal{L}_T$ is some training loss and $\mathcal{L}_V$ a validation loss. The optimal parameter to the training problem, $\theta^\ast$ implicitly depends on the hyperparameters $\lambda$:

$$

\theta^\ast(\lambda) = \operatorname{argmin} f(\theta, \lambda)

$$

If this implicit function mapping $\lambda$ to $\theta^\ast$ is differentiable, and subject to some other conditions, the implicit function theorem states that its derivative is

$$

\left.\frac{\partial\theta^{\ast}}{\partial\lambda}\right\vert_{\lambda_0} = \left.-\left[\frac{\partial^2 \mathcal{L}_T}{\partial \theta \partial \theta}\right]^{-1}\frac{\partial^2\mathcal{L}_T}{\partial \theta \partial \lambda}\right\vert_{\lambda_0, \theta^\ast(\lambda_0)}

$$

The formula we obtained for iMAML is a special case of this where the $\frac{\partial^2\mathcal{L}_T}{\partial \theta \partial \lambda}$ is the identity This is because there, the hyperparameter controls a quadratic regularizer $\frac{1}{2}\|\theta - \lambda\|^2$, and indeed if you differentiate this with respect to both $\lambda$ and $\theta$ you are left with a constant times identity.

The primary difficulty of course is approximating the inverse Hessian, or indeed matrix-vector products involving this inverse Hessian. This is where iMAML and the method proposed by Lorraine et al, (2019) differ. iMAML uses a conjugate gradient method to iteratively approximate the gradient. In this work, they use a Neumann series approximation, which, for a matrix $U$ looks as follows:

$$

U^{-1} = \sum_{i=0}^{\infty}(I - U)^i

$$

This is basically a generalization of the better known sum of a geometric series: if you have a scalar $\vert u \vert<1$ then

$$

\sum_{i=0}^\infty q^i = \frac{1}{1-q}.

$$

Using a finite truncation of the Neumann series one can approximate the inverse Hessian in the following way:

$$

\left[\frac{\partial^2 \mathcal{L}_T}{\partial \theta \partial \theta}\right]^{-1} \approx \sum_{i=1}^j \left(I - \frac{\partial^2 \mathcal{L}_T}{\partial \theta \partial \theta}\right)^i.

$$

This Neumann series approximation, at least on the surface, seems significantly less hassle to implement than running a conjugate gradient optimization step.

One of the fun bits of this paper is the interesting set of experiments the authors used to demonstrate the versatility of this approach. For example, in this framework, one can treat the training dataset as a hyperparameter. Optimizing pixel values in a small training dataset, one image per class, allowed the authors to "distill" a dataset into a set of prototypical examples. If you train your neural net on this distilled dataset, you get relatively good validation performance. The results are not quite as image-like as one would imagine, but for some classes, like bikes, you even get recognisable shapes:

In another experiment the authors trained a network to perform data augmentation, treating parameters of this network as a hyperparameter of a learning task. In both of these cases, the number of hyperparameters optimized were in the hundreds of thousands, way beyond the number we usually consider as hyperparameters.

This method inherits some of the limitations I already discussed with iMAML. Please also see the comments where various people gave pointers to work that overcomes some of these limitations.

Most crucially, methods based on implicit gradients assume that your learning algorithm (inner loop) finds a unique, optimal parameter that minimises some loss function. This is simply not a valid assumption for SGD where different random seeds might produce very different and differently behaving optima.

Secondly, this assumption only allows for hyperparameters that control the loss function, but not for ones that control other aspects of the optimization algorithm, such as learning rates, batch sizes or initialization. For those kind of situations, explicit differentiation may still be the most competitive solution. On that note, I also recommend reading this recent paper on generalized inner-loop meta-learning and the associated pytorch package higher.

Happy birthday David. Nice work!

]]>My parents didn't raise me in a religious tradition. It all started to change when a great scientist took me under his wing and taught me the teachings of Bayes. I travelled the world and spent 4 years in a Bayesian monastery in Cambridge, UK.

]]>My parents didn't raise me in a religious tradition. It all started to change when a great scientist took me under his wing and taught me the teachings of Bayes. I travelled the world and spent 4 years in a Bayesian monastery in Cambridge, UK. This particular place practiced the nonparametric Bayesian doctrine.

We were religious Bayesians. We looked at the world and all we saw the face of Bayes: if something worked, it did because it had a Bayesian interpretation. If an algorithm did not work, we shunned its creator for being unfaithful to Bayes. We scorned at point estimates, despised p-values. Bayes had the answer to everything. But above all, we believed in our models.

At a convention dominated by Bayesian thinkers I was approached by a frequentist, let's call him Lucifer (in fact his real name is Laci so not that far off). "Do you believe your data exists?" - he asked. "Yes" I answered. "Do you believe your model and its parameters exist?" "Well, not really, it's just a model I use to describe reality" I said. Then he told me the following, poisoning my pure Bayesian heart forever: "If you use Bayes' rule, you assume that a joint distribution between model parameters and data exist. This, however, only exists if your data and your parameters both exist, in the same $\sigma$-algebra. You can't have it both ways. You have to think your model really exists somewhere."

I never forgot this encounter, but equally I didn't think much about it since then. Over the years, I started to doubt more and more aspects of my Bayesian faith. I realised the likelihood was important, but not the only thing that exists. There were scoring rules, loss functions which couldn't be written as a log-likelihood. I noticed nonparametric Bayesian models weren't automatically more useful than large parametric ones. I worked on weird stuff like loss-calibrated Bayes. I started having thoughts about model misspecification, kind of a taboo topic in the Bayesian church.

Over the years I came to terms with my Bayesian heritage, and I now live my life as a secular Bayesian. Certain elements of the Bayesian way are no doubt useful: Engineering inductuve biases explicitly into a prior distribution, using probabilities, divergences, information, variational bounds as tools for developing new algorithms. Posterior distributions can capture model uncertainty which can be exploited for active learning or exploration in interactive learning. Bayesian methods often - though not always - lead to increased robustness, better calibration, and so much more. At the same time, I can carry on living my life, use gradient descent to find local minima, use bootstrap to capture uncertainty. And first and foremost, I do not have to believe that my models really exist or perfectly describe reality anymore. I am free to think about model misspecification.

Lately, I have started to familiarize myself with a new body of work, which I call secular Bayesianism, that combines Bayesian inference with more frequentists ideas about learning from observation. In this body of work, people study model misspecification (see e.g. M-open Bayesian inference). And, I found a resolution to the "you have to believe in your model, you can't have it both ways" problem that bothered me all these years.

After this rather long intro, let me present the paper this post is really about and which, as a secular Bayesian, I found very interesting:

- P.G. Bissiri, C.C. Holmes and S.G. Walker (2016) A General Framework for Updating Belief Distributions

This paper basically asks: can we take the belief out of belief distributions? Let's say we want to estimate some parameter of interest $\theta$ from data. Does it still make sense to specify a prior distribution over this parameter, and then update them in light of data using some kind of Bayes rule-like update mechanism to form posterior distributions, all without assuming that the parameter of interest $\theta$ and the observations $x_i$ are linked to one another via a probabilistic model? And if it is meaningful, what form would that update rule take.

First of all, for simplicity, let's assume that data $x_i$ is sampled i.i.d from some distribution $P$. That's right, not exchangeable, actually i.i.d. like in frequentist settings. Let's also assume that we have some parameter of interest $\theta$. Unlike in Bayesian analysis where $\theta$ usually parametrises some kind of generative model for data $x_i$, we don't assume anything like that. All we assume is that there is a loss function $\ell$ which connects the parameter to the observations: $\ell(\theta, x)$ measures how well the estimate $\theta$ agrees with observation $x$.

Let's say that a priori, without seeing any datapoints we have a prior distribution $\pi$ over $\theta$. Now we observe a datapoint $x_1$. How should we make use of our observation $x_1$, the loss function $\ell$ and the prior $\pi$ to come up with some kind of posterior over this parameter? Let's denote this update rule $\psi(\ell(\cdot, x_1), \pi)$. There are many ways we could do this, but is there one which is better than the rest?

The paper lists a number of desiderata - desired properties the update rule $\psi$ should satisfy. These are all meaningful assumptions to have. The main one is coherence, which is a property somewhat analogous to exchangeability: if we observe a sequence of observations, we would like the resulting posterior to be the same, irrespective of which order the observations are presented. The coherence property can be written as follows

$$

\psi\left(\ell(\cdot, x_2), \psi\left(\ell(\cdot, x_1), \pi\right)\right) = \psi\left(\ell(\cdot, x_1), \psi\left(\ell(\cdot, x_2), \pi \right)\right)

$$

As a desired property, this makes a lot of sense, and Bayes rule obviously satisfies it. However, this is not really how the authors actually define coherence. In Equation (3) they use a more restrictive definition of coherence, further restricting the set of acceptable update rules as follows:

$$

\psi\left(\ell(\cdot, x_2), \psi\left(\ell(\cdot, x_1), \pi\right)\right) = \psi\left(\ell(\cdot, x_1) + \ell(\cdot, x_2), \pi \right)

$$

By combining losses from the two observations in an additive way, one can indeed ensure permuation invariance. However, the sum is not the only way to do this. Any pooling operation over observations would also have satisfied this. For example, one could replace the $\ell(\cdot, x_1) + \ell(\cdot, x_2)$ bit by $\max(\ell(\cdot, x_1), \ell(\cdot, x_2))$ and still satisfy the general principle of coherence. The most general class of permutation invariant functions which would satisfy the general coherence desideratum are discussed in DeepSets. Overall, my hunch is that going with the sum is a design choice, rather than a general desideratum. This choice is the real reason why the resulting update rule will end up very Bayes-rule like, as we will see later.

The other desiderata the paper proposes are actually discussed separately in Section 1.2 of (Brissini et al, 2016), and called assumptions instead. These are much more basic requirements for the update function. Assumption 2 for example talks about how restricting the prior to a subset should result in a posterior which is also the restricted version of the original posterior. Assumption 3 requires that lower evidence (larger loss) for a parameter should yield smaller posterior probabilities - a monotonicity property.

One contribution of the paper is showing that all the desiderata mentioned above pinpoint a specific update rule $\psi$ which satisfies all the desired properties. This update takes the following form:

$$

\pi(\theta\vert x_{1:N}) = \psi(\ell(\cdot, x), \pi) \alpha \exp\{-\sum_{n=1}^N\ell(\theta, x_N)\}\pi(\theta)

$$

Just like Bayes rule we have a normalized product of the prior with something that takes the role of the likelihood term. If the loss is the logarithmic loss of a probabilistic model, we recover the Bayes rule, but this update rule makes sense for arbitrary loss functions.

Again, this solution is unique under the very strong and specific desideratum that we'd like the losses from i.i.d. observations combine in an additive way, and I presume that, had we chosen a different permutation invariant function, we would end up with a similar generalization of Bayes rule with that permutation invariant function appearing in the exponent.

Now that we have an update rule which satisfies our desiderata, can we say if it's actually a good or useful update rule? It seems it is, in the following sense.

Let's think about a way to measure the usefulness of a posterior $\nu$. Suppose we have data sampling distribution $P$, losses are still measured by $\ell$, and our prior is $\pi$. A good posterior does two things well: it allows us to make good decisions in some kind of downstream test scenario, and it is informed by our prior. It therefore makes sense to define a loss function over the posterior $\nu$ as a sum of two terms:

$$

L(\nu; \ell, \pi, P) = h_1(\nu; \ell, P) + h_2(\nu; \pi)

$$

The first term, $h_1$ measures the posterior's usefulness at test time, and $h_2$ measures how well it's influenced by the prior. The authors define $h_1$ to be as follows:

$h_1(\nu; \ell, P) = \mathbb{E}_{x\sim P} \mathbb{E}_\theta\sim\nu \ell(x, \theta)$

So basically, we will sample from the posterior, and then evaluate the random sample parameter $\theta$ on a randomly chosen test datapoint $x$ using our loss $\ell$. I would say this is a rather narrow view on what it means for a posterior to do well on a downstream task, more about it later in the criticism section. In any case it's one possible goal for a posterior to try to achieve.

Now we turn to choosing $h_2$, and the authors note something very interesting. If we want the resulting optimal posterior to possess the coherence property (as defined in their Eqn. (3)), it turns out the only choice for $h_2$ is the KL divergence between the prior and posterior. Any other choice would lead to incoherent updates. This, I believe is only true for the additive definition of coherence, not the more general definition I gave above.

Putting $h_1$ and $h_2$ together it turns out that the posterior that minimizes this loss function is precisely of the form $\pi(\theta\vert x_{1:N}) \alpha \exp\{-\sum_{n=1}^N \ell(\theta, x_n)\}$. So, not only is this update rule the only update rule that satisfies the desired properties, it is also optimal under this particular definition of optimality/rationality.

This work is interesting because it gives a new justification for Bayes rule-like updates to belief distributions, and as a result it also provides a different/new perspective on Bayesian inference. Crucially, never in this derivation did we have to reason about a joint distribution between $\theta$ and the observations $x$ (or conditionals of one given the other). Even though I wrote $\pi(\theta \vert x_{1:N})$ to denote a posterior, this is really just a shorthand notation, syntactic sugar. This is important. One of the main technical criticisms of the Bayesian terminology is that in order to reason about the joint distribution between two random variables ($x$ and $\theta$), these variables have to live in the same probability space, so if you believe that your data exists, you have to believe in your model, and model parameters exist as well. This framework sidesteps that.

It allows rational updates of belief distributions, without forcing you to believe in anything.

From a practical viewpoint, this work also extends Bayesian inference in a meaningful way. While Bayesian inference only made sense if you inferred the whole set of parameters jointly, here you are allowed to specify any loss function, and really focus on the parameter of importance. For example, if you're only interested in estimating the median of a distribution in a Bayesian way, without assuming it follows a certain distribution, you can now do this by specifying your loss to be $\vert x-\theta\vert$. This is explained a lot more clearly in the paper, so I encourage you to read it.

My main criticism of this work is that it made a number of assumptions that ultimately limited the range of acceptable solutions, and to my cynical eye it appears that these choices were specifically made so that Bayes rule-like update rules came out winning. So rather than really deriving Bayesian updates from first principles, we engineered principles under which Bayesian updates are optimal. In other words, the top-down analysis was rigged in favour of familiar Bayes-like updates. There are two specific assumptions which I would personally like to see relaxed:

The first one is the restrictive notion of coherence, which requires losses to combine additively from multiple observations. I think this very clearly gives rise to the convenient exponential, log-additive form in the end. It would be interesting to see whether other types of permutation invariant update rules also make sense in practice.

Secondly, the way the authors defined optimality, in terms of the loss $h_1$ above is very limiting. We rarely use posterior distributions in this way (take a random sample). Instead, we might be intersested integrating over the posterior, and evaluating the loss of that classifier. This is a loss that cannot be written in the bilinear form that is the formula for $h_1$ above. I wonder if. using more elaborate losses for the posterior, perhaps along the lines of general decision problems as in (Lacoste-Julien et al, 2011), could lead to more interesting update rules which don't look at all like Bayes rule but are still rational.

]]>- Zhiyuan Li and Sanjeev Arora (2019) An Exponential Learning Rate Schedule for Deep Learning

The paper provides both theoretical insights as well

]]>Yesterday I read this intriguing paper about the midboggling fact that it is possible to use exponentially growing learning rate schedule when training neural networks with batch normalization:

- Zhiyuan Li and Sanjeev Arora (2019) An Exponential Learning Rate Schedule for Deep Learning

The paper provides both theoretical insights as well as empirical demonstration of this remarcable property.

The reason why this works boils down to the observation that batch-normalization renders the loss function of neural networks scale invariant - scaling the weights by a constant does not change the output, or the loss, of the batch normalized network. It turns out that this property alone might result in somewhat unexpected and potentially helpful properties for optimization. I will use this post to illustrate some of the properties of scale invariant loss functions - and gradient descent trajectories on them - using a 2D toy example:

Here, I drew a loss function whih has the scale invariance property. The value of the loss only depends on the angle, but not the magnitude of the weight vector. The value of the loss along any radial line from the origin outwards is constant. Simple consequences of scale invariance are that (Lemma 1 of the paper)

- that the gradient of this function is always orthogonal to the current value of the parameter vector, and that
- the farther you are from the origin, the smaller the magnitude of the gradient. This is perhaps less intuitive but think about how the function behaves on a circle around the origin. The function is the same, but as you increase the radius you stretch the same function round a larger circle - it gets fatter, therefore its gradients decrease.

Here is a - somewhat messy - quiver plot showing the gradients of the function above:

The quiver plot is messy because the gradients around the origin explode. But you can perhaps see how the gradients get larger and larger - and remain perpendicular to the value itself.

So Imagine doing vanilla gradient descent (no momentum, weight decay, fixed learning rate) on such a loss surface. Because the gradient is always perpendicular to the current value of the parameter, by the Pythagorean theorem, the norm of the parameter vector increases with each iteration. So gradient descent takes you away from the origin. However, the weight vector won't completely blow up to infinity, because the gradients also get smaller and smaller as the weight vector grows, so it settles at some point. Here is a gradient descent path looks like starting from the coordinate $(-0.7, 0.7)$:

In fact, you can't really see it but the optimization kind of gets stuck in there, and doesn't move any longer. It's interesting to see what happens if we add weight decay, which is the same as adding L2 regularizer over the weights:

We can see that once the trajectory is about to get stuck in a local minimum, weight decay pulls it back towards the origin, which is where gradients become larger. This, in turn, perturbs the trajectory often pushing it out of the current local minimum. So in a way, we can start to build the intuition that weight decay on a scale-invariant loss function acts as a kind of learning rate adjustment.

In fact, what the paper works out is an equivalence between two things:

- weight decay with constant learning rate and
- no weight decay and an exponentially growing learning rate

On the plot below I show the trajectory with the exponentially growing learning rate which is equivalent to the one I showed before with weight decay. This one has no weight decay, and its learning rate keeps growing:

We can see that the trajectory blows up, and quickly gets out of bound on this animation. How can this be equivalent to the weight decay trajectory? Well, from the perspective of the loss function, the magnitude of the weight vector is irrelevant, and we only care about the angle when viewed from the origin. Turns out, if you look at those angles, the two trajectories are the same. To illustrate this, I use the normalization formula from Theorem 2.1 to project this trajectory back to the same magnitude the weight decay one would have. I obtain something that indeed looks very much like the trajectory above:

After a while, the trajectories start working differently, which I think is probably due to the accumulation of numerical errors in my implementation of the toy example. I could probably fix this, but I'm not sure it's worth the effort. The authors show much more convincing empirical evidence that this works in real, complicated neural network losses that people actually want to optimize.

You can think of this renormalization I did above as "constantly zooming out" on the loss landscape to keep up with the exponentially exploding parameter. I tried to illustrate this below:

On the left-hand plot, I show the original, weight-decayed gradient descent with a constant learning rate. On the right-hand plot I show the equivalent trajectory with exponentially growing learning rate and no weight decay, and I also added a constant zoom to counteract the explosion of the parameter's norm, in line with Theorem 2.1. We can see that, especially initially, the two paths behave the same way when viewed from the origin. They then work differently which I believe is down to the numerical precision issue that could probably be worked out.

The paper shows a similar equivalence in the presence of momentum as well, if interested, read the details in the paper.

I thought this observation was very cool, and may well lead to a better understanding of the mechanisms by which batchnorm and other weight normalization schemes work. It also explains why the combination of weight decay with weight normalization schemes results in a relatively robust gradient descent regime where constant learning rate works well.

]]>- Edwin Fong, Chris Holmes (2019) On the marginal likelihood and cross-validation

I found this paper to be lacking on the accessibility front, mostly owing to

]]>Here's a paper someone has pointed me to, along the lines of "everything that works, works because it's Bayesian":

- Edwin Fong, Chris Holmes (2019) On the marginal likelihood and cross-validation

I found this paper to be lacking on the accessibility front, mostly owing to the fact that it is a mixture of two somewhat related but separate things:

- (A) a simple-in-hindsight and cool observation about the relationship between marginal likelihoods and cross validation which I will present in this post, and
- (B) a somewhat tangential sidetrack about a generalized form of Bayesian inference and prequential analysis which I think is mostly there to advertise an otherwise interesting line of research Chris Holmes and colleagues have been working on for some time. The advertising worked for sure, as I found the underlying paper (Bissiri et al, 2016) quite interesting as well. I will leave discussing that to a different post, maybe next week.

To discuss the connection between marginal likelihoods to (Bayesian) cross validation, let's first define what is what.

First of all, we are in the world of exchangeable data, assuming we model a sequence of observations $x_1,\ldots,x_N$ by a probabilistic model which renders them conditionally independent given some global parameter $\theta$. Our model is thus specified by the observation model p(x\vert \theta) and prior $p(\theta)$. The marginal likelhiood is the probability mass this model assigns to a given sequence of observations:

$$

p(x_1,\ldots,x_N) = \int \prod p(x_i \vert \theta) p(\theta) d\theta

$$

Important for the discussion of its connection with cross-validation, the marginal likelhihood, like any multivariate distribution, can be decomposed by the chain rule:

$$

p(x_1,\ldots,x_N) = p(x_1)\prod_{n=1}^{N-1}p(x_{n+1}\vert x_1,\ldots, x_{n})

$$

And, of course, a similar decomposition exists for any arbitrary ordering of the observations $x_n$.

Another related quantity is a single-fold leave-$P$-out cross-validation. Here, we set the last $P \leq N$ observations aside, fit our model to the first $N-P$ observations, and then we calculate the average predictive log loss on the held-out pounts. This can be written as:

$$

- \sum_{p=1}^{P} \log p(x_{N-p+1}\vert x_1, \ldots, x_{N-P})

$$

Importantly, here, we assume that we perform Bayesian cross-validation of the model. I.e. in this formula, the parameter $\theta$ is integrated out. In fact what we're looking at is:

$$

- \sum_{p=1}^{P} \log \int p(x_{N-P+1}\vert \theta) p(\theta \vert x_1, \ldots, x_{N-P}) d\theta

$$

Now of course, we could leave any other subset of size $P$ of the observations out. If we repeat this process $K$ times with a uniform random subset of datapoints left out each time, and average the results over the $K$ trials, we have $K$-fold leave-$P$-out cross validation. If $K$ is large enough, we might be trying all possible subsets of $P$ with the same probability. I will cheesily call this $\infty$-fold cross-validation. Mathematically, $\infty$-fold leave-$P$-out Bayesian cross-validation is the following quantity:

$$

- \frac{1}{N \choose P} \sum_{\substack{S⊂\{1\ldots N\}\\|S|=P}} \sum_{i \in S} \log p(x_i\vert x_j : j \notin S),

$$

which is Eqn (10) in the paper with slightly different notation.

The connection I think is best illustrated in the following way. Let's consider three observations, and all the possible ways we can permute them. There are $3(3+1)/2 = 6$ different permutations. For each of these permutations we can decompose the marginal likelihood as a product of conditionals, or equivalently we can write the log marginal likelihood as a sum of logs of the same conditionals. Let's arrange these log conditionals into a table as follows:

Each column corresponds to a different ordering of variables, and summing up the terms in each column gives the log marginal likelihood. So, the sum of all the terms in this matrix gives the marginal likelihood times 6 (as there are 6 columns). In general it gives $N(N+1)/2$ times the marginal likelihood for $N$ observations. Now look at the sums of the terms in each row. The first row is full of terms you'd see in leave-$3$-out cross validation (which doesn't make too much sense with $3$ observations). In the second row, you see terms for leave-2-out CV. Third row corresponds to leave-1-out CV. So, if you do some careful combinatorics (homework) and count how many duplicate terms you'll find in each row, one can conclude that the sum of leave-K-out $\infty$-fold Bayesian cross-validation errors for all values of $K$ gives you the log marginal likelihood times a constant. Which is the main point of the paper.

This observation gives a really good motivation for using the marginal likelihood, and also gives a new perspective on how it works. For $N$ datapoints, there are 2^N-1 different ways of selecting a non-empty test set and corresponding training set. Calculating the marginal likelihood amounts to evaluating the average predictive score on all of these exponentially many 'folds'.

Before we jump to the conclusion that cross-validation, too, works only because it is essentially an approximation to Bayesian model selection, we must remind ourselves that this connection only holds for Bayesian cross-validation. What this means is that in each fold of cross-validation, we integrate $\theta$ in a Bayesian fashion.

In practice, when cross-validating neural networks, we usually optimize over the parameters rather than integrate in a Bayesian way. Or, at best, we use a variational approximation to the posterior and integrate over that approximately. As the relationship only holds in theory, when exact parameter marginalization is performed, it remains to be seen how useful and robust this connection will prove in potential applications.

]]>- Aravind Rajeswaran, Chelsea Finn, Sham Kakade, Sergey Levine (2019) Meta-Learning with Implicit Gradients

Another paper that came out at the

]]>This week I read this cool new paper on meta-learning: it a slightly different approach compared to its predecessors based on some observations about differentiating the optima of regularized optimization.

- Aravind Rajeswaran, Chelsea Finn, Sham Kakade, Sergey Levine (2019) Meta-Learning with Implicit Gradients

Another paper that came out at the same time has discovered similar techniques, so I thought I'd update the post and mention it, although I won't cover it in detail and the post was written primarily about Rajeswaran et al (2019)

- Yutian Chen, Abram L. Friesen, Feryal Behbahani, David Budden, Matthew W. Hoffman, Arnaud Doucet, Nando de Freitas (2019) Modular Meta-Learning with Shrinkage

- I will give a high-level overview of the meta-learning setup, where our goal is to learn a good initialization or regularization strategy for SGD so it converges to better minima across a range of tasks.
- I illustrate how iMAML works on a 1D toy-example, and discuss the behaviour and properties of the meta-objective.
- I will then discuss a limitation of iMAML: that it only considers the location of minima, and not the probability with which a stochastic algorithm ends up in a specific minimum.
- I will finally relate iMAML to a variational approach to meta-learning.

Meta-learning has several possible formulations, I will try to explain the setup of this paper following my own interpretation and notation that differs from the paper but will make my explanations clearer (hopefully).

In meta-learning we have a series of independent tasks, with associated training and validation loss functions $f_i$ and $g_i$, respectively. We have a set of model parameters $\theta$ which are shared across the tasks, and the loss functions $f_i(\theta)$ and $g_i(\theta)$ evaluate how well the model with parameters $\theta$ does on the training and test cases of task $i$. We have an algorithm that has access to the training loss $f_i$ and some meta-parameters $\theta_0$, and output some optimal or learned parameters $\theta_i^\ast = Alg(f_i, \theta_0)$. The goal of the meta-learning algorithm is to optimize the meta-objective

$$

\mathcal{M}(\theta_0) = \sum_i g_i(Alg(f_i, \theta_0))

$$

with respect to the meta-parameters $\theta_0$.

In early versions of this work, MAML, the algorithm was chosen to be stochastic gradient descent, $f_i$ and $g_i$ being the training and test loss of a neural network, for example. The meta-parameter $\theta_0$ was the point of initialization for the SGD algorithm, shared between all the tasks. Since SGD updates are differentiable, one can compute the gradient of the meta-objective with respect to the initial value $\theta_0$ by simply backpropagating through the SGD steps. This was essentially what MAML did.

However, the effect of initialization on the final value of $\theta$ is pretty weak, and difficult - if at all possible - to characterise analytically. If we allow the SGD to go on for many steps, we might converge to a better parameter, but the trajectory will be very long, and the gradients with respect to the initial value vanish. If we make the trajectories short enough, the gradients w.r.t. $\theta_0$ are informative but we may not reach a very good final value.

This is why Rajeswaran et al opted to make the dependence of the final point of the trajectory on meta-paramteter $\theta\_0$ way stronger: Instead of simply initializing SGD from $\theta\_0$ they also anchor the parameter to stay in the vicinity of $\theta\_0$ by adding a quadratic regularizer $\|\theta - \theta_0\|$ to their loss. Because of this, two things happen:

- now all steps of the SGD depend on $\theta$, not just the initial point
- now the location of the minimum SGD eventually converges to also depend on #\theta\_0#

It is this second property that iMAML exploits. Let me illustrate what that dependence looks like:

In the figure above, let's say that we would like to minimise an objective function $f(\theta)$. This would be the training loss of one of the tasks the meta-learning algorithm has to solve. Our current meta-parameter $\theta_0$ is marked on the x axis, and the orange curve shows the associated quadratic penalty. The teal curve shows the sum of the objective with the penalty. The red star shows the location of the minimum, which is what the learning algorithm finds.

Now let's animate this plot. I'm going to move the anchor point $\theta_0$ around, and reproduce the same plots. You can see that, as we move $\theta_0$ and the associate penalty, the local (and therefore global) minima of the regularized objective move change:

So it's clear that there is a non-trivial, non-linear relationship between the anchor-point $\theta_0$ and the location of a local minimum $\theta^\ast$. Let's plot this relationship as a function of the anchor point:

We can see that this function is not at all nice to work with, it has sharp jumps when the closest local minimum to $\theta_0$ changes, and it is relatively flat between these jumps. In fact, you can observe that the sharpest the local minimum nearest to $\theta_0$ is, the flatter the relationship between $\theta_0$ and $\theta$. This is because if $f$ has a sharp local minimum near $\theta_0$, then the location of the regularized minimum will be mostly determined by $f$, and the location of the anchor point $\theta_0$ doesn't matter much. If the local minimum around f is wide, there's a lot of wiggle room for the optimum and the effect of the regularization will be larger.

And now we come to the whole point of the iMAML procedure. The gradient of this function $\theta^\ast(\theta_0)$ in fact can be calculated in closed form. It is, indeed, related to the curvature, or second derivative, of $f$ around the minimum we find:

$$

\frac{d\theta^\ast}{d\theta_0} = \frac{1}{1 + f''(\theta^\ast)}

$$

In order to check that this formula works, I calculated the derivative numerically and compared it with what the theory predicts, they match perfectly:

When the parameter space is high-dimensional, we have a similar formula involving the inverse of the Hessian plus the identity. In high dimensions, inverting or even calculating and storing the Hessian is not very practical. One of the main contributions of the iMAML paper is a practical way to approximate gradients, using a conjugate gradient inner optimization loop. For details, please read the paper.

When optimizing the anchor point in a meta-learning setting, it is not the location $\theta^\ast$ we are interested in, only the value that the function $f$ takes at this location. (in reality, we would now use the validation loss, in place of the training loss used for gradient descent, but for simplicity, I assume the two losses overlap). The value of $f$ at its local optimum is plotted below:

Oh dear. This function is not very pretty. The meta-objective $f(\theta^\ast(\theta_0))$ becomes a piecewise continuous function, a connection of neighbouring basins, with non-smooth boundaries. The local gradients of this function contain very little information about the global structure of the loss function, it only tells you where to go to reach the nearest local minimum. So I wouldn't say this is the nicest function to optimize.

Thankfully, though, this function is not what we have to optimize. In meta-learning, we have a distribution over functions $f$ we optimize, so the actual meta-objective is something like $\sum_i f_i(\theta_i^\ast(\theta_0))$. And the sum of a bunch of ugly functions might well turn into something smooth and nice. In addition, the 1-D function I use for this blog post is not representative of the high-dimensional loss functions of neural networks which we want to apply iMAML to. Take for example the concept of mode connectivity (see e.g. Garipov et al, 2018): it seems that the modes found by SGD using different random seeds are not just isolated basins, but they are connected by smooth valleys along which the training and test error are low. This may in turn make the meta-objective behave more smoothly between minima.

An important aspect that MAML or iMAML do not not consider is the fact that we usually use stochastic optimization algorithms. Rather than deterministically finding a particular local minimum, SGD samples different minima: when run with different random seeds it will find different minima.

A more generous formulation of the meta-objective would allow for stochastic algorithms. If we denote by $\mathcal{Alg}(f_i, \theta_0)$ the distribution over solutions the algorithm finds, the meta-objective would be

$$

\mathcal{M}_{stochastic}(\theta) = \sum_i \mathbb{E}_{\theta \sim \mathcal{Alg}(f_i, \theta_0)} g_i(\theta)

$$

Allowing for stochastic behaviour might actually be a great feature for meta-learning. While the position of the global minimum of the regularized objective can change abruptly s a function $\theta_0$ (as illustrated in the third figure above), allowing for stochastic behaviour might smooth our the meta-learning objective.

Now suppose that SGD anchored to $\theta_0$ converges to one of a finite set of local minima. The meta-learning objective now depends on $\theta_0$ in two different ways:

- as we change the anchor $\theta_0$, the location of the minima change, as illustrated above. This change is differentiable, and we know its derivative.
- as we change the anchor $\theta_0$, the probability with which we find the different solutions changes. Some solutions will be found more often, some less often.

iMAML accounts for the first influence, but it ignores the influence through the second mechanism. This is not to say that iMAML is broken, but that it misses a possibly crucial contribution of stochastic behaviour that MAML or explicitly differentiating through the algorithm does not.

Of course this work reminded me of a Bayesian approach. Whenever someone describes quadratic penalties, all I see are Gaussian distributions.

In a Bayesian interpretation of iMAML, one can think of the anchor point $\theta_0$ as the mean of a prior distribution over the neural network's weights. The inner loop of the algorithm, or $Alg(f_i, \theta_0)$ then finds the maximum-a-posteriori (MAP) approximation to the posterior over $\theta$ given the dataset in question. This is assuming that the loss is a log likelihood of some kind. The question is, how should one update the meta-parameter $\theta_0$?

In the Bayesian world, we would seek to optimize $\theta_0$ by maximising the marginal likelihood. As this is usually intractable, so it is common to turn to a variational approximation, which in this case would look something like this:

$$

\mathcal{M}_{\text{variational}}(\theta_0, Q_i) = \sum_i \left( KL[Q_i\vert \mathcal{N}_{\theta_0}] + \mathbb{E}_{\theta \sim Q_i} f_i(\theta) \right),

$$

where $Q_i$ approximates the posterior over model parameters for task $i$. A specific choice of $Q_i$ is a dirac delta distribution centred at a specific point $Q_i(

theta) = \delta(\theta - \theta^{\ast}_i)$. If we generously ignore some constants that blow up to infinitely large, the KL divergence between the Gaussian prior and the degenerate point-posterior is a simple Euclidean distance, and our variational objective reduces to:

$$

\mathcal{M}_{\text{variational}}(\theta_0, \theta_i) = \sum_i \left( \|\theta_i - \theta_0\|^2 + f_i(\theta_i) \right)

$$

Now this objective function looks very much like the optimization problem that the inner loop of iMAML attempts to solve. If we were working in the pure variational framework, this may be where we leave things, and we could jointly optimize all the $\theta_i$s as well as $\theta_0$. Someone in the know, please comment below pointing me to the best references where this is being done for meta-learning.

This objective is significantly easier to optimize with and involves no inner-loop optimization or black magic. It simply ends up pulling $\theta_0$ closer to the centre of gravity of the various optima found for each task $i$. Not sure if this is such a good idea though for meta-learning, as the final values of $\theta_i$ which we reach by jointly optimizing over everything may not be reachable by doing SGD from $\theta_0$ from scratch. But who knows. A good idea may be, given the observations above, to jointly minimize the variational objective with respect to $\theta_0$ and $\theta_i$, but every once in a while reinitialize $\theta_i$ to be $\theta_0$. But at this point, I'm really just making stuff up...

Anyway, back to iMAML, which does something interesting with this variational objective, and I think it can be understood as a kind of amortized computation: Instead of treating $\theta_i$ as separate auxiliary parameters, it specifies that $\theta_i$ are in fact a deterministic function of $\theta_0$. As the variational objective is a valid upper bound for any value of $\theta_i$, it is also a valid upper bound if we make $\theta_i$ explicitly dependent on $\theta_0$. The variational objective thus becomes a function of $\theta_0$ only (and also of hyperparameters of the algorithm $Alg$ if it has any):

$$

\mathcal{M}_{\text{variational}}(\theta_0) = \sum_i \left( \|Alg(f_i, \theta_0) - \theta_0\|^2 + f_i(Alg(f_i, \theta_0)) \right)

$$

And there we have it. A variational objective for meta-learning $\theta_0$ which is very similar to the MAML/iMAML meta-objective, except it also has the $\|Alg(f_i, \theta_0) - \theta_0\|^2$ term which factors into updating $\theta_0$ which we didn't have before. Also notice that I did not use separate training and validation loss $f_i$ and $g_i$ but that would be a very justified choice as well.

What is cool about this is that this provides extra justification and interpretation for what iMAML is trying to do, and suggests directions in which iMAML could perhaps be improved. On the flipside, the implicit differentiation trick in iMAML might be useful in other situations where we want to amortize the variational posterior similarly.

I'm pretty sure I missed many references, please comment below if you think I should add anything, especially on the variational bit.

]]>I finally got around to reading this new paper by Arjovsky et al. It debuted on Twitter with a big splash, being decribed as 'beautiful' and 'long awaited' 'gem of a paper'. It almost felt like a new superhero movie or Disney remake just came out.

- Martin Arjovsky, Léon Bottou, Ishaan Gulrajani, David Lopez-Paz (2019) Invariant Risk Minimization

The paper is, indeed, very well written, and describes a very elegant idea, a practical algorithm, some theory and lots of discussion around how this is related to various bits. Here, I will describe the main idea and then provide an information theoretic view on the same topic.

We would like to learn robust predictors that are based on invariant causal associations between variables, rather than spurious surface correlations that might be present in our data. If we only observe i.i.d. data from a generative process, this is generally not possible.

In this paper, the authors assume that we have access to data sampled from different environments $e$. The data distribution in these different enviroments is different, but there is an underlying causal dependence of the variable of interest $Y$ on some of the observed features $X$ that remains constant, or invariant across all environments. The question is, can we exploit the variability across different environments to learn this underlying invariant association?

Usual empirical risk minimisation (ERM) approaches cannot distinguish between statistical associations that correspond to causal connections, and those that are just spurious correlations. Invariant Risk Minimization can, in certain situations. It does this by finding a representation $\phi$ of features, such that the optimal predictor is simultaneously Bayes optimal in all environments.

The authors then propose a practical loss function that tries to capture his property:

$$

\min_\phi \sum_e \mathcal{R}^e(\Phi) + \lambda \|\nabla_{w\vert w=1}\mathcal{R}^e(w \cdot \Phi)\|^2_2

$$

The first term is the usual ERM: we're trying to minimize average risk across all environments, using a single predictor $\phi$. The second term is where the interesting bit happens. I said before that what we want this term to encorage is that $\phi$ is simultaneously Bayes-optimal in all environments. What the term actually looks at is whether $\phi$ is locally optimal, wether it can be improved locally by scaling by a constant $w$. For details, I recommend reading the paper where a lot of intuitive explanation is provided. In this post, I'll focus on an information theoretic interpretation of what's going on.

Unlike the authors, who treat the environment index $e$ as something outside of the structural equation model, I prefer to think of $E$ as also being part of the generative process: an observable random variable. This may not be the most useful formulation in all circumstances, but it will help when trying to derive IRM through the lens of conditional dependence relationships.

The most general generative model of data in the IRM setup looks something like this:

There are three (sets of) observable variables: $E$, the environment index, $X$, the features describing the datapoint and $Y$, the label we wish to predict. I also assume the existence of a hidden confounder $W$. In the above graph, I separated $X$ into upstream dimensions $X_1$ and downstream dimensions $X_2$ based on where they are in the causal chain relative to $Y$. In reality, we don't know what this breakdown is, and breaking $X$ up to $X_1$ and $X_2$ may not be trivial due to entanglement, but it is still reasonable to assume that some components of the input $X$ encode causal parents of $Y$, and others encode causal descendants of $Y$.

The environment $E$ influences every factor in this generative model, except the factor $p(Y\vert X_1, W)$: notice there is no arrow from $E$ to $Y$. In other words, it is assumed that the relationship of variable $Y$ to its observable causal parents $X_1$ and hidden variables $W$ is stable across all environments. This is the primary underlying assumption of IRM. Now, let's read out some conditional independence relationships from this graph:

- $Y \cancel{\perp\mkern-13mu\perp} E$: this is simply saying that the marginal distribution of $Y$ can, generally, change across environments.
- $Y \perp\mkern-13mu\perp E\vert X_1, W$: The observable $X_1$ and latent $W$ shield the label $Y$ from the influence of the environment. I already said that this is the key assumption on which IRM is based: that there is an underlying causal mechanism determining the value of Y from its causal parents, which does not change across environments.
- $Y \cancel{\perp\mkern-13mu\perp} E\vert X_1$: If we leave $W$ out of the conditioning, the above environment-independence no longer holds. This is because the confounder $W$ inroduces spurious association between $X_1$ and $Y$. This spurious correlation is assumed to be environment-dependent.
- $Y \cancel{\perp\mkern-13mu\perp} E\vert X_1, X_2$: this is, perhaps, the most important point. This dependence statement says that the way $Y$ depends on the observable variables $X = (X_1, X_2)$ is environment-dependent. This can be verified by noticing that $X_2$ is a collider between $E$ and $Y$. Conditioning on a collider introduces spurious correlations (for example, explaining away).

In summary, the association between $X$ and $Y$ will be a result of three sources of correlation:

- real causal relationship between some components of $X$ and $Y$.
- spurious association introduced by the unobserved confounder $W$.
- spurious association introduced by conditioning on parts of $X$ which is are causally influenced by $Y$, rather than the other way around.

In a general, if our generative model describes the world accurately, the conditional independence statements we observed tell us that while the real causal association is stable across environments ($Y \perp\mkern-13mu\perp E\vert X_1, W$), the other two are environment-dependent ($Y \cancel{\perp\mkern-13mu\perp} E\vert X_1$ and $Y \cancel{\perp\mkern-13mu\perp} E\vert X_1, X_2$). Thus, we can eliminate the spurious associations by seeking associations that are stable across environments, i.e. independent of $E$.

One can interpret the objective of Invariant Risk Minimisation as seeking a representation of observable variables $\Phi(x)$, such that:

- $Y \perp\mkern-13mu\perp E\vert \phi(X)$, and
- $\phi$ is informative about $y$, i.e. we can predict $y$ accurately from $\phi(x)$

This is a bit similar to the information bottleneck criterion which would seek a stochasstic representation $Z = \phi(X, \nu)$, $\nu$ being some random noise, by solving the following optimization problem:

$$

\max_{\phi} \left\{ I[Y, Z] - \beta I[X, Z] \right\}

$$

We could similarly attempt to find an invariant repesentation $Z = \phi(X)$ by minimizing an objective like:

$$

\max_{\phi} \left\{ I[Y, Z] - \beta I[Y, E \vert Z] \right\}

$$

Notice, that one can write mutual information $\mathbb{I}[Y, E \vert \phi(X)] $ in the following variational interpretation:

$$

I[Y, E \vert \phi(x)] = \max_q \min_r \mathbb{E}_{x,y,e} [\log q(y\vert \phi(x), e) - \log r(y\vert \phi(x))]

$$

With a little bit of math juggling, we can write the above stability objective in the following form:

\begin{align}

\max_{\phi} \left\{ I[Y, \phi(X)] - \beta I[Y, E \vert \phi(X)] \right\} = \max_\phi \max_r \min_q \mathbb{E}_{x,y,e} \left[\log r(y\vert \phi(x)) - \frac{\beta}{1 + \beta}\log q(y\vert \phi(x), e) \right]

\end{align}

This optimization problem is intuitively already very similar to IRM: we would like a representation $\phi$ so we can predict $y$ from it accurately, but we shouldn't be able to build a much better predictor by overfitting to one of the environments. In other words: here, too, we seek a representation so that the Bayes optimal predictor of $y$ is close to Bayes optimal simultanously in all environments.

Sadly, this is a minimax type of problem, so $\phi$ can only be found with a GAN-like iterative algorithm - something we don't really like, but we're kind of getting used to. It would be interesting to see how such algorithm would work, and if you're aware of this being done already, please feel to point me to the relevant references in the comments section.

I wanted to add sidenote here that I think we can recover something very similar in spirit to (IRMv1). I leave developing the full connection as a homework to you. Let me just illustrate the basic idea here.

Say we have a parametric family of functions $f(y\vert \phi(x); \theta)$ for predicting $y$ from $\phi(x)$. The conditional information can be approximated as follows:

\begin{align}

I[Y, E \vert \phi(x)] &\approx \min_\theta {E}_{x,y} \ell (f(y\vert \phi(x); \theta) - \mathbb{E}_e \min_{\theta_e} \mathbb{E}_{x,y\vert e} \ell (f(y\vert \phi(x); \theta_e)\\

&= \min_\theta \mathbb{E}_e \mathcal{R}^e(f_\theta\circ\phi) - \mathbb{E}_e \min_{\theta_e} \mathcal{R}^e(f_{\theta_e}\circ\phi)

\end{align}

where $\ell$ is the log-loss, if we want to recover Shannon's information. If we assume that $f$ is a universal function approximator, an equality holds. If, instead of globally optimizating $\theta_e$, we only search locally within a trust region around $\theta$, we can create the following (approximate) lower-bound to the information.

\begin{align}

I[Y, E \vert \phi(x)] &\geq \min_\theta {E}_{x,y}\ell f(y\vert \phi(x); \theta) - \mathbb{E}_e \min_{\|d\|^2\leq \epsilon} \mathbb{E}_{x,y\vert e} \ell f(y\vert \phi(x); \theta + d) \\

&= \min_\theta \mathbb{E}_e \left\{ \mathcal{R}^e(f_\theta\circ\phi) - \min_{\|d_e\|\leq \epsilon}\mathcal{R}^e(f_{\theta + d_e}\circ\phi) \right\}

\end{align}

Now, we can approximate the risk $\mathcal{R}^e(f_{\theta + d_e}\circ\phi) $ locally by a first order Taylor approximation around $\theta$, and show that, as $\epsilon \rightarrow 0$, we obtain that:

$$

I[Y, E \vert \phi(x)] \geq \min_\theta \mathbb{E}_e \| \nabla_\theta \mathbb{E}_{x,y\vert e} [\ell f(y\vert \phi(x), \theta)] \|_2= \min_\theta \mathbb{E}_e \| \nabla_\theta \mathcal{R}^e(f_\theta\circ\phi) \|_2

$$

Compare this with the second term in Eqn IRMv1 of the paper. If we now add back in the requirement that we would like to be able to predict $y$ from $\phi(x)$, we get an optimization problem of the following form:

$$

\min_\phi \left\{ \min_\theta \mathbb{E}_e \mathcal{R}^e(f_\theta\circ\phi) + \lambda \min_\theta \mathbb{E}_e \| \nabla_\theta \mathcal{R}^e(f_\theta\circ\phi) \|_2 \right\},

$$

which is almost like the IRM objective. Technically, there are two minimizations over $\theta$ and there's no reason why the two shouldn't be done separately. Note however, that a global minimum of the second term $\mathbb{E}_e \| \nabla_\theta \mathcal{R}^e(f_\theta\circ\phi) \|_2$ is always a local minimum of the first term. This justifies connecting the two minimization problems together:

$$

\min_\phi \min_\theta \left\{ \mathbb{E}_e \mathcal{R}^e(f_\theta\circ\phi) + \lambda \mathbb{E}_e \| \nabla_\theta \mathcal{R}^e(f_\theta\circ\phi) \|_2 \right\},

$$

This is no longer an ugly minimax problem, however, it is still not amazing. It is a lower bound to the original objective which we originally wished to minimize. The lower bound was created when we replaced global minimization over local minimization. Thus, the bound is actually tight if all local minima with respect to $\theta$ of $\mathcal{R}^e(f_\theta\circ\phi)$ are also global minima, e.g. if the loss is convex. In non-convex problems, all bets are off. It may still work, but who knows.

This is indeed a nice paper, with lots of great insights. Unfortunately, I am not sure how realistic the assumptions are that we can sample from a multitude of different environments, which differ from each other sufficiently so that the invariant causal quantities can be identified.

I would like to mention that this is not the first time that invariance and causality have been connected and exploited for domain adaptation. I personally first encountered this idea in a talk by Jonas Peters at the Causality Workshop in 2018. Here is a related paper by him that I wanted to highlight here:

- Peters, Bühlmann, Meinshausen (2016) Causal inference by using invariant prediction: identification and confidence intervals

And here are two more papers which propose a causal treatment of the domain adaptation problem:

- Adarsh Subbaswamy, Peter Schulam, Suchi Saria (2018) Preventing Failures Due to Dataset Shift: Learning Predictive Models That Transport
- Christina Heinze-Deml, Nicolai Meinshausen (2019) Conditional Variance Penalties and Domain Shift Robustness

The second paper, which commenters also pointed out to me, is perhaps the most closely related, but it is based on slightly different assumptions about what is invariant across the domains.

Finally, commenters asked me about domain-adversarial learning, so I wanted to include a pointer here for completeness:

- Yaroslav Ganin, Victor Lempitsky (2014) Unsupervised Domain Adaptation by Backpropagation

On this paper, I agree with Arjovsky et al (2019)'s discussion in the IRM paper: it promotes the wrong invariance property by trying to learn a data representation that is marginally independent of the domain index. See the discussion on this in the comments section below.

Finally, I wanted to point out another slightly looser connection: non-stationarity, or the availability of data from multiple environments has been exploited by (Hyvarinen and Morioka, 2016) exploits this idea for unsupervised feature learning. It turns out, this non-stationarity and the availability of different environments makes otherwise non-identifiable nonlinear ICA models identifable.

]]>- Ruiz and Titsias (2019) A Contrastive Divergence for Combining Variational Inference

Welcome to my ICML 2019 jetlag special - because what else do you do when you wake up earlier than anyone than write a blog post. Here's a paper that was presented yesterday which I really liked.

- Ruiz and Titsias (2019) A Contrastive Divergence for Combining Variational Inference and MCMC

First, some background on why I found this paper particulartly interesting. When AlphaGo Zero came out, I wrote a post about the *principle of minimal improvement*: Suppose you have an operator which may be computationally expensive to run, but which that can take a policy and improve it. Using such improvement operator you can define an objective function for policies by measuring the extent to which the operator changes a policy. If your policy is already optimal, the operator can't improve it any further, so the change will be zero. In the case of AlphaGo Zero, the improvement operator is Monte Carlo Tree Search (MCTS). I noted in that post how the same principle may be applicable to approximate inference: expectation propagation and contrastive divergence both can be casted in a similar light.

The paper I'm talking about uses a very similar argument to come up with a contrastive divergence for variational inference, where the improvement operator is MCMC step.

The two dominant ways of performing inference in latent variable models are variational inference (including amortized inference, such as in VAE), and Markov Chain Monte Carlo (MCMC). VI approximates the posterior with a paramteric distribution. This can be computationally efficient and practically convenient as VI results in end-to-end differentiable, unbiased estimates of the evidence lower bound (ELBO). As a drawback, the paramteric posterior approximation often can't approximate the posterior perfectly, and an approximation error almost always remains. By contrast, an MCMC approximate posterior can always be improved by running the chains longer, and obtaining more independent samples, but it is more difficult to work with and computationally more demanding than VI.

A lot of great work has been done recently on combining VI with MCMC (see references in Ruiz and Titsias, 2019). Usually, one starts from a crude, parametric variational approximation to the posterior, and then improves it by running a couple steps of MCMC. Crucially, one can view the transition kernel, $\Pi$, of the MCMC as an improvement operator: given any distribution $q$, taking an MCMC steps should take you closer to the posterior. In other words, $\Pi q$ is an improvement over $q$. Taking multiple steps, i.e. $\Pi^t q$, should provide an even greater improvement. Improvement is $0$ only when $\Pi q = q$ which only holds for the true posterior. It is a reasonable criterion therefore to seek a posterior approximation $q_\theta$ such that the improvement in $\Pi^t q_\theta$ over over $q_\theta$ is minimized.

There are two ways of quantifying the amount of improvement, or change, that MCMC provides over a parametric posterior $q_\theta$. The first measures how much closer we got to the posterior $p$ by comparing the KL divergences:

$$

\mathcal{L}_1(\theta) = \operatorname{KL}\left[q_\theta\middle\|p\right] - \operatorname{KL}\left[\Pi^tq_\theta\middle\|p\right]

$$

The second one measures the amount of change between $\Pi^tq_\theta$ and $q_\theta$, measured as the KL divergence. This objective function merely tries to identify fixed points of the improvement operator:

$$

\mathcal{L}_2(\theta) = \operatorname{KL}\left[\Pi^tq_\theta\middle\|q_\theta\right]

$$

Either of these objectives would make sense and is justified on its own right, but sadly neither of them can be evaluated or optimized easily in this case. Both require taking expectations over $\Pi^tq_\theta$, as well as evaluating $\log \Pi^tq_\theta$. However, the brilliant insight in this paper is that when you sum them together, the most problematic terms cancel out, leaving you with a tractable objective to minimise.

$$

\mathcal{L}_1(\theta) + \mathcal{L}_2(\theta) = \mathbb{E}_{z\sim \Pi^t q_\theta} f_\theta(z) - \mathbb{E}_{z\sim q_\theta} f_\theta(z),

$$

where $f_theta(z) = \log p(z,x) - \log q_\theta(z)$, $x$ is the observed, $z$ the hidden variable.

There is one more technical hurdle to overcome, which is to calculate or estimate the derivative of this objective with respect to $\theta$. The authors propose a REINFORCE-like score function gradient estimator in Eqn. (12), which is somewhat worrying as it is known to have very high variance. The authors propose overcoming this using a control variate. For more details, please refer to the paper.

There is further discussion on the behaviour of this objective function in the limit of infinitely long MCMC paths, i.e. $t\rightarrow\infty$. It turns out, the criterion works like the symmetrized KL divergence $KL[q\|p] + KL[p\|q]$. The difference of this objective from the usual conservative mode and seeking VI objective is neatly illustrated in Figure 1 of the paper:

Variational Contrastive Divergence (VCD) favours posterior approximations which have a much higher coverage of the true posterior compared to VI, which tries to cover the modes and tries to avoid allocating mass to areas where the true posterior does not.

]]>