This post is written with my PhD student and now **guest author Patrik Reizinger** and is part 4 of a series of posts on causal inference:

- Part 1: Intro to causal inference and do-calculus
- Part 2: Illustrating Interventions with a Toy Example
- Part 3: Counterfactuals
- ➡️️ Part 4: Causal Diagrams, Markov Factorization, Structural Equation Models

One way to think about causal inference is that causal models require a more fine-grained models of the world compared to statistical models. Many causal models are equivalent to the same statistical model, yet support different causal inferences. This post elaborates on this point, and makes the relationship between causal and statistical models more precise.

Do you remember those combinatorics problems from school where the question was how many ways exist to get from a start position to a target field on a chessboard? And you can only move one step right or one step down. If you remember, then I need to admit that we will not consider problems like that. But its (one possible) takeaway actually can help us to understand Markov factorizations.

You know, it is totally indifferent how you traversed the chessboard, the result is the same. So we can say that - from the perspective of target position and the process of getting there - this is a many-to-one mapping. The same holds for random variables and causal generative models.

If you have a bunch of random variables - let's call them $X_1, X_2, \dots, X_n$ -, their joint distribution is $p \left(X_1, X_2, \dots, X_n \right) $. If you invoke the *chain rule of probability,* you will have several options to express this joint as a product of factors:

$$

p \left(X_1, X_2, \dots, X_n \right) = \prod p(X_{\pi_i}\vert X_{\pi_1}, \ldots, X_{\pi_{i-1}}),

$$

where $\pi_i$ is a permutation of indices. Since you can do this for any permutation $\pi$, the mapping between such factorizations and the joint distribution they express is many-to-one. As you can see this in the image below. The different factorizations induce a different graph, but have the same joint distribution.

Since you are reading this post, you may already be aware that in causal inference we often talk about a causal factorization, which looks like

$$

p \left(X_1, X_2, \dots, X_n \right) = \prod_{i=1}^{n} p\left(X_i | X_{\mathrm{pa}(i)}\right),

$$

where $\mathrm{pa}(X_i)$ denotes the causal parents of node $X_i$. This is one of many possible ways you can factorize the joint distribution, but we consider this one special. In the recent work, Schölkopf et al. call it a *disentangled* model. What are disentangled models? Disentangled factors describe independent aspects of the mechanism that generated the data. And they are not independent because you factored them in this way, but you were looking for this factorization because its factors are independent.

In other words, for every joint distribution there are many possible factorizations, but we assume that only one, the causal or disentangled factorization, describes the true underlying process that generated the data.

Let's consider an example for *disentangled* models. We want to model the joint distribution of altitude $A$ and temperature $T$. In this case, the causal direction is $A \rightarrow T$ - if the altitude changes, the distribution of the temperature will change too. But you cannot change the altitude by artificially heating a city - otherwise we all would enjoy views as in Miami; global warming is real but fortunately has no altitude-changing effect.

In the end, we get the factorization of $p(A)p(T|A)$. The important insights here are the answers to the question: What do we expect from these factors? The previously-mentioned Schölkopf et al. paper calls the main takeaway the **Independent Causal Mechanisms (ICM) Principle**, i.e.

By conditioning on the parents of any factor in the disentangled model, the factor will neither be able to give you further information about other factors nor is able to influence them.

In the above example, this means that if you consider different countries with their altitude distributions, you can still use the same $p(T|A),$ i.e., the factors **generalize** well. For *no influence*, the example holds straight above the ICM Principle. Furthermore, knowing any of the factors - e.g. $p(A)$ - won't tell anything about the other *(no information)*. If you know which country you are in, so will have no clue about the climate (if you consulted the website of the corresponding weather agency, that's what I call cheating). In the other direction, despite being the top-of-class student in climate matters, you won't be able to tell the country if somebody says to you that here the altitude is 350 meters and the temperature is 7°C!

We discussed Markov factorizations, as they help us understand the philosophical difference between statistical and causal inference. The beauty, and a source of confusion, is that one can use Markov factorizations in both paradigms.

However, while using Markov factorizations is optional for statistical inference, it is a must for causal inference.

So why would a statistical inference person use Markov factorizations? Because they make life easier in the sense that you do not need to worry about too high electricty costs. Namely, factorized models of data can be computationally much more efficient. Instead of modeling a joint distribution directly, which has a lot of parameters - in the case of $n$ binary variables, that is $2^n-1$ different values -, a factorized version can be pretty lightweight and parameter-efficient. If you are able to factorize the joint in a way that you have 8 factors with $n/8$ variables each, then you can describe your model with $8\times2^{n/8}-1$ parameters. If $n=16$, that is $65,535$ vs $31$. Similarly, representing your distibution in a factorized form gives rise to efficient, general-purpose message-passing algorithms, such as belief propagation or expectation propagation.

On the other hand, causal inference people really need this, otherwise, they are lost. Because without Markov factorizations, they cannot really formulate causal claims.

A causal practicioner uses Markov factorizations, because this way she is able to reason about interventions.

If you do not have the disentangled factorization, you cannot model the effect of interventions *on the real mechanisms* that make the system tick.

In plain machine learning lingo, what you want to do is *domain adaptation,* that is, you want to draw conclusions about a distribution you did not observe (these are the interventional ones). The Markov factorization prescribes ways in which you expect the distribution to change - one factor at a time - and thus the set of distributions you want to be able to robustly generalise to or draw inferences about.

Do-caclculus, the topic of the first post in the series, can be relatively simply described using Markov factorizations. As you remember, $\mathrm{do}(X=x)$ means that we set the variable $X$ to the value $x$, meaning that the distribution of that variable $p(X)$ collapses to a point mass. We can model this intervention mathematically by replacing the factor $p( x \vert \mathrm{pa}(X))$ by a Dirac-delta $\delta_x$, resulting in the deletion of all incoming edges of the intervened factors in the graphical model. We then marginalise over $x$ to calculate the joint distribution of the remaining variables. For example, if we have two variables $x$ and $y$ we can write:

$$

p(y\vert do(X=x_0)) = \int p(x,y) \frac{\delta(x - x_0)}{p(x\vert y)} dx

$$

If you've read the previous parts in this series, you'll know that Markov factorizations aren't the only tool we use in causal inference. For counterfactuals, we used structural equation models (SEMs). In this part we will illustrate the connection between these with a cheesy reference to the reparametrization trick used in VAEs among others.

But before that, let's recap SEMs. In this case, you define the relationship between the child node and its parents via a functional assignment. For node $X$ with parents $\mathrm{pa}(X)$ it has the form of

$$

X = f(\mathrm{pa}(X), \epsilon),

$$

with some noise $\epsilon.$ Here, you should read "=" in the sense of an assigment (like in Python), in mathematics, this should be ":=".

The above equation expresses the conditional probability $ p\left(X| \mathrm{pa}(X)\right)$ as a *deterministic* function of $X$ and some noise variable $\epsilon$. Wait a second..., isn't it the same thing what the reparametrization trick does? Yes it is.

So the SEM formulation (called the

implicit distribution) is related via the reparametrization trick to the conditional probability of $X$ given its parents.

Thus, we can say that a SEM is a conditional distribution, and vica versa. Okay, but how do the sets of these constructs relate to each other?

If you have a SEM, then you can read off the conditional, which is **unique**. On the other hand, you can find more SEMs for the same conditional. Just as you can express a conditional distribution in multiple different ways using different reparametrizations, it is possible to express the same Markov factorization by multiple SEMs. Consider for example that if your distribution is $\mathcal{N}(\mu,\sigma),$ then multiplying it by -1 gives you the same distribution. In this sense, SEMs are a richer class of models than Markov factorizations, thus they allow us to make inferences (counterfactual) which we weren't able to express in the more coarse grained language of Markov Factorizations.

As we discussed above, a single joint distribution has multiple valid Markov factorizations, and the same Markov factorization can be expressed as different SEMs. We can think of joint distributions, Markov factorizations, and SEMs as increasingly fine-grained model classes: joint distributions $\subset$ Markov facorizations $\subset$ SEMs. The more aspects of the data generating process you model, the more elaborate the set of inferences you can make become. Thus, Joint distributions allow you to make predictions under no mechanism shift, Markov factorizations allow you to model interventions, SEMs allow you to make counterfactual statements.

The price you pay for more expressive models is that they also get generally much harder to estimate from data. In fact, some aspects of causal models are impossible to infer from i.i.d. observational data. Moreover, some counterfactual inferences are experimentally not verifiable.

]]>I'm writing this post mostly to annoy him, by presenting this work using super hand-wavy intuitions and cartoon figures. If this isn't enough, I will even

]]>Few days ago we had a talk by Gergely Neu, who presented his recent work:

I'm writing this post mostly to annoy him, by presenting this work using super hand-wavy intuitions and cartoon figures. If this isn't enough, I will even find a way to mention GANs in this context.

But truthfully, I'm just excited because for once, there is a little bit of learning theory that I half-understand, at least at an intuitive level, thanks to its reliance on KL divergences and the mutual information.

Let's start this with a simple thought experiment to illustrate why and how mutual information may be useful in describing an algorithm's ability to generalize. Say we're given two datasets, $\mathcal{D}_{train}$ and $\mathcal{D}_{test}$, of the same size for simplicity. We play the following game: we both have access to $\mathcal{D}_{train}$ and $\mathcal{D}_{test}$, and we both know what learning algorithm, $\operatorname{Alg}$ we're going to use.

Now I toss a coin and I keep the result (recorded as random variable $Y$) a secret. If it's heads, I run $\operatorname{Alg}$ on the training set $\mathcal{D}_{train}$. If it's tails, I run $\operatorname{Alg}$ on the test data $\mathcal{D}_{test}$ instead. I don't tell you which of these I did, I only reveal to you the final parameter value $W$. Can you guess, just by looking at $W$, whether I trained on training or test data?

If you cannot guess $Y$, that means that the algorithm returns the same random $W$ irrespective of whether you train it on training or test data. So the training and test losses become interchangeable. This implies that the algorithm will generalize very well (on average) and not overfit to the data it's trained on.

The mutual information, in this case between $W$ and $Y$ quantifies your theoretical ability to guess $Y$ from $W$. The higher this value is, the easier it is to tell which dataset the algorithm was trained on. If it's easy to reverse engineer my coin toss from parameters, it means that the algorithm's output is very sensitive to the input dataset it's trained on. And that likely implies poor generalization.

Note by: an algorithm generalizing well on average doesn't mean it works well on average. It just means that there won't be a large gap between the expected training and expected test error. Take for example an algorithm returns a randomly initialized neural network, without even touching the data. That algorithm generalizes extremely well on average: it does just as poorly on test data as it does on training data.

Below is an illustration of my thought experiment for SGD.

In the top row, I doodled the distribution of the parameter $W_t$ at various timesteps $t=0,1,2,\ldots,T$ of SGD. We start the algorithm by initializing $W$ randomly from a Gaussian (left panel). Then, each stochastic gradient update changes the distribution of $W_t$ a bit compared to the distribution of $W_{t-1}$. How the shape of the distribution changes depends on the data we use in the SGD steps. In the top row, let's say I ran SGD on $\mathcal{D}_{train}$ and in the bottom, I run it on $\mathcal{D}_{test}$. The distibutions $p(W_t\vert \mathcal{D})$ I drew here describe where the SGD iterate is likely to be after $t$ steps of SGD started from random initialization. They are not to be confused with Bayesian posteriors, for example.

We know that running the algorithm on the test set would produce low test error. Therefore, sampling a weight vector $W$ from $p(W_T\vert \mathcal{D}_{test})$ would be great if we could do that. But in practice, we can't train on the test data, all we have the ability to sample from is $p(W_T\vert \mathcal{D}_{train})$. So what we'd like for good generalization, is if $p(W_T\vert \mathcal{D}_{test})$ and $p(W_T\vert \mathcal{D}_{train})$ were as close as possible. The mutual information between $W_T$ and my coinflip $Y$ measures this closeness in terms of the Jensen-Shannon divergence:

$$

\mathbb{I}[Y, W_T] = \operatorname{JSD}[p(W_T\vert \mathcal{D}_{test})\|p(W_T\vert \mathcal{D}_{train})]

$$

So, in summary, if we can guarantee that the final parameter an algorithm comes up with doesn't reveal too much information about what dataset it was trained on, we can hope that the algorithm has good generalization properties.

These vague intuitions can be formalized into real information-theoretic generalization bounds. These were first presented in a general context in (Russo and Zou, 2016) and in a more clearly machine learning context in (Xu and Raginsky, 2017). I'll give a quick - and possibly somewhat handwavy - overview of the main results.

Let $\mathcal{D}$ and $\mathcal{D}'$ be random datasets of size $n$, drawn i.i.d. from some underlying data distribution $P$. Let $W$ be a parameter vector, which we obtain by running a learning algorithm $\operatorname{Alg}$ on the training data $\mathcal{D}$: $W = \operatorname{Alg}(\mathcal{D})$. The algorithm may be non-deterministic, i.e. it may output a random $W$ given a dataset. Let $\mathcal{L}(W, \mathcal{D})$ denote the loss of model $W$ on dataset $\mathcal{D}$. The expected generalization error of $\operatorname{Alg}$ is defined as follows:

$$

\text{gen}( \operatorname{Alg}, P) = \mathbb{E}_{\mathcal{D}\sim P^n,\mathcal{D}'\sim P^n, W\vert \mathcal{D}\sim \operatorname{Alg}(\mathcal{D})}[\mathcal{L}(W, \mathcal{D}') - \mathcal{L}(W, \mathcal{D})]

$$

If we unpack this, we have two datasets $\mathcal{D}$ and $\mathcal{D}'$, the former taking the role of the training dataset, the latter of the test data. We look at the expected difference between the training and test losses ($\mathcal{L}(W, \mathcal{D})$ and $\mathcal{L}(W, \mathcal{D}')$), where $W$ is obtained by running $\operatorname{Alg}$ on the training data $\mathcal{D}$. The expectation is taken over all possible random training sets, test sets, and over all possible random outcomes of the learning algorithm.

The information theoretic bound states that for any learning algorithm, and any loss function that's bounded by $1$, the following inequality holds:

$$

gen(\operatorname{Alg}, P) \leq \sqrt{\frac{\mathbb{I}[W, \mathcal{D}]}{n}}

$$

The main term in the RHS of this bound is the mutual infomation between the training data \mathcal{D} and the pararmeter vector $W$ the algorithm finds. It essentially quantifies the number of bits of information the algorithm leaks about the training data into the parameters it learns. The lower this number, the better the algorithm generalizes.

The problem with applying these nice, intuitive bounds to SGD is that SGD, in fact, leaks too much information about the specific minibatches it is trained on. Let's go back to my illustrative example of having to guess if we ran the algorithm on training or test data. Consider the scenario where we start form some parameter value $w_t$ and we update either with a random minibatch of training data (blue) or a random minibatch of test data (orange).

Since the training and test datasets are assumed to be of finite size, there are only a finite number of possible minibatches. Each of these minibatches can take the parameter to a unique new location. The problem is, the set of locations you can reach with one dataset (blue dots) does not overlap with the set of locations you can reach if you update with the other dataset (orange dots). Suddenly, if I give you $w_{t+1}$, you can immediately tell if it's an orange or blue dot, so you can immediately reconstruct my coinflip $Y$.

In the more general case, the problem with SGD in the context of information-theoretic bounds is that the amount of information SGD leaks about the dataset it was trained on is high, and in some cases may even be infinite. This is actually related to the problem that several of us noticed in the context of GANs, where the true and fake distributions may have non-overlapping support, making the KL divergence infinite, and saturating out the Jensen-Shannon divergence. The first trick we came up with to solve this problem was to smooth things out by adding Gaussian noise. Indeed, adding noise is key what researches have been doing to apply these information-theoretic bounds to SGD.

The first thing people did (Pensia et al, 2018) is to study a noisy cousin of SGD: stochastic gradient Langevin dynamics (SGLD). SGDL is like SGD but in each iteration we add a bit of Gaussian noise to the parameters in addition to the gradient update. To understand why SGLD leaks less information, consider the previous example with the orange and blue point clouds. SGLD makes those point clouds overlap by convolving them with Gaussian noise.

However, SGLD is not exactly SGD, and it's not really used as much in practice. In order to say something about SGD specifically, Neu (2021) did something else, while still relying on the idea of adding noise. Instead of baking the noise in as part of the algorithm, Neu only adds noise as part of the analysis. The algorithm being analysed is still SGD, but when we measure the mutual information we will measure the mutual information between $\mathbb{I}[W + \xi; \mathcal{D}]$, where $\xi$ is Gaussian noise.

I leave it to you to check out the details of the paper. While the findings fall short of explaining whether SGD have any tendency to find solutions that generalise well, some of the results are nice and interpretable: they connect the generalization of SGD to the noisiness of gradients as well as the smoothness of the loss along the specific optimization path that was taken.

]]>- Samuel L Smith, Benoit Dherin, David Barrett, Soham De (2021) On the Origin of Implicit Regularization in Stochastic Gradient Descent

There's actually a related paper that came out simultaneously, studying full-batch gradient descent instead of SGD:

]]>I wanted to highlight an intriguing paper I presented at a journal club recently:

- Samuel L Smith, Benoit Dherin, David Barrett, Soham De (2021) On the Origin of Implicit Regularization in Stochastic Gradient Descent

There's actually a related paper that came out simultaneously, studying full-batch gradient descent instead of SGD:

- David G.T. Barrett, Benoit Dherin (2021) Implicit Gradient Regularization

One of the most important insights in machine learning over the past few years relates to the importance of optimization algorithms in generalization performance.

In order to understand why deep learning works as well as it does, it is insufficient to reason about the loss function or the model class, which is what classical generalisation theory focussed on. Instead, the algorithms we use to find minima (namely, stochastic gradient descent) seem to play an important role. In many tasks, powerful neural networks are able to interpolate training data, i.e. achieve near-0 training loss. There are in fact several minima of the training loss which are virtually indistinguishably good on the training data. Some of these minima generalise well (i.e. result in low test error), others can be arbitrarily badly overfit.

What seems to be important then is not whether the optimization algorithm converges quickly to a local minimum, but which of the available "virtually global" minima it prefers to reach. It seems to be the case that the optimization algorithms we use to train deep neural networks *prefer* some minima over others, and that this preference results in better generalisation performance. The preference of optimization algorithms to converge to certain minima while avoiding others is described as *implicit regularization*.

I wrote this note as an overview on how we/I currently think about why deep networks generalize.

One of the interesting new theories that helped me imagine what happens in deep learning training is that of neural tangent kernels. In this framework we study neural network training in the limit of infinitely wide layers, full-batch training and infinitesimally small learning rate, i.e. when gradient becomes continuous gradient flow, described by an ordinary differential equation. Although the theory is useful and appealing, full-batch training with infinitesimally small learning rates is very much a cartoon version of what we actually do in practice. In practice, the smallest learning date doesn't always work best. Secondly, the stochasticity of gradient updates in minibatch-SGD seems to be of importance as well.

What Smith et al (2021) do differently in this paper is they try to study minibatch-SGD, for small, but not infinitesimally small, learning rates. This is much closer to practice. The toolkit that allows them to study this scenario is borrowed from the study of differential equations and is called backward error analysis. The cartoon illustration below shows what backward error analysis tries to achieve:

Let's say we have a differential equation $\dot{\omega} = f(\omega)$. The solution to this ODE with initial condition $\omega_0$ is a continuous trajectory $\omega_t$, shown in the image in black. We usually can't compute this solution in closed form, and instead simulate the ODE using the Euler's method, $\omega_{k+1} = \omega_k + \epsilon f(\omega_k)$. This results in a discrete trajectory shown in teal. Due to discretization error, for finite stepsize $\epsilon$, this discrete path may not lie exactly where the continuous black path lies. Errors accumulate over time, as shown in this illustration. The goal of backward error analysis is to find a different ODE, $\dot{\omega} = \tilde{f}(\omega)$ such that the approximate discrete path we got from Euler's method lieas near the the continuous path which solves this new ODE. Our goal is to reverse engineer a modified $\tilde{f}$ such that the discrete iteration can be well-modelled by an ODE.

Why is this useful? Because the form $\tilde{f}$ takes can reveal interesting aspects of the behaviour of the discrete algorithm, particularly if it has any implicit bias towards moving into different areas of the space. When the authors apply this technique to (full-batch) gradient descent, it already suggests the kind of implicit regularization bias gradient descent has.

In Gradient descent with a cost function $C$, the original ODE is $f(\omega) = -\nabla C (\omega)$. The modified ODE which corresponds to a finite stepsize $\epsilon$ takes the form $\dot{\omega} = -\nabla\tilde{C}_{GD}(\omega)$ where

$$

\tilde{C}_{GD}(\omega) = C(\omega) + \frac{\epsilon}{4} \|\nabla C(\omega)\|^2

$$

So, gradient descent with finite stepsize $\epsilon$ is like running gradient flow, but with an added penalty that penalises the gradients of the loss function. The second term is what Barret and Dherin (2021) call implicit gradient regularization.

Analysing SGD in this framework is a bit more difficult because the trajectory in stochastic gradient descent is, well, stochastic. Therefore, you don't have have a single discrete trajectory to optimize, but instead you have a distribution of different trajectories which you'd traverse if you randomly reshuffle your data. Here's a picture illustrating this situation:

Starting from the initial point $\omega_0$ we now have multiple trajectories. These correspond to different ways we can shuffle data (in the paper we assume we have a fixed allocation of datapoints to minibatches, and the randomness comes from the order in which the minibatches are considered). The two teal trajectories illustrate two potential paths. The paths end up at a random location, the teal dots show additional random endpoints where trajectories may end up at. The teal star shows the mean of the distribution of random trajectory endpoints.

The goal in (Smith et al, 2021) is to reverse-engineer an ODE so that the continuous (orange) path lies close to this mean location. The corresponding ODE is of the form $\dot{omega} = -\nabla C_{SGD}(\omega)$, where

$$

\tilde{C}_{SGD}(\omega) = C(\omega) + \frac{\epsilon}{4m} \sum_{k=1}^{m} \|\nabla \hat{C}_k(\omega)\|^2,

$$

where $\hat{C}_k$ is the loss function on the $k^{th}$ minibatch. There are $m$ minibatches in total. Note that this is similar to what we had for gradient descent, but instead of the norm of the full-batch gradient we now have the average norm of minibatch gradients as the implicit regularizer. Another interesting view on this is to look at the difference between the GD and SGD regularizers:

$$

\tilde{C}_{SGD} = \tilde{C}_{GD} + \frac{\epsilon}{4m} \sum_{k=1}^{m} \|\nabla \hat{C}_k(\omega) - C(\omega)\|^2

$$

This additional regularization term, $\frac{1}{m}\sum_{k=1}^{m} \|\nabla \hat{C}_k(\omega) - C(\omega)\|^2$, is something like the total variance of minibatch gradients (the trace of the empirical Fisher information matrix). Intuitively, this regularizer term will avoid parts of the parameter-space where the variance of gradients calculated over different minibatches is high.

Importantly, while $C_{GD}$ has the same minima as $C$, this is no longer true for $C_{SGD}$. Some minima of $C$ where the variance of gradients is high, is no longer a minimum of $C_{SGD}$. As an implication, not only does SGD follow different trajectories than full-batch GD, it may also converge to completely different solutions.

As a sidenote, there are many versions of SGD, based on how data is sampled for the gradient updates. Here, it is assumed that the datapoints are assigned to minibatches, but then the minibatches are randomly sampled. This is different from randomly sampling datapoints with replacement from the training data (Li et al (2015) consider that case), and indeed an analysis of that variant may well lead to different results.

Why would an implicit regularization effect avoiding high minibatch gradient variance be useful for generalisation? Well, let's consider a cartoon illustration of two local minima below:

Both minima are the same as much as the average loss $C$ is concerned: the value of the minimum is the same, and the width of the two minima are the same. Yet, in the left-hand situation, the wide minimum arises as the average of several minibatch losses, which all look the same, and which all are relatively wide themselves. In the right-hand minimum, the wide average loss minimum arises as the average of a lot of sharp minibatch losses, which all disagree on where exactly the location of the minimum is.

If we have these two options, it is reasonable to expect the left-hand minimum to generalise better, because the loss function seems to be less sensitive to whichever specific minibatch we are evaluating it on. As a consequence, the loss function also may be less sensitive to whether a datapoint is in the training set or in the test set.

In summary, this paper is a very interesting analysis of stochastic gradient descent. While it has its limitations (which the authors don't try to hide and discuss transparently in the paper), it nevertheless contributes a very interesting new technique for analysing optimization algorithms with finite stepsize. I found the paper to be well-written, with the explanation of somewhat tedious details of the analysis clearly laid out. But perhaps I liked this paper most because it confirmed my intuitions about why SGD works, and what type of minima it tends to prefer.

]]>

This is a half-guest-post written jointly with Dóra, a fellow participant in a reading group where we recently discussed the original paper on $\beta$-VAEs:

- Irina Higgins et al (ICLR 2017): $\beta$-VAE: Learning Basic Visual Concepts with a Constrained Variational Framework

On the surface

]]>

This is a half-guest-post written jointly with Dóra, a fellow participant in a reading group where we recently discussed the original paper on $\beta$-VAEs:

- Irina Higgins et al (ICLR 2017): $\beta$-VAE: Learning Basic Visual Concepts with a Constrained Variational Framework

On the surface of it, $\beta$-VAEs are a straightforward extension of VAEs where we are allowed to directly control the tradeoff between the reconstruction and KL loss terms. In an attempt to better understand where the $\beta$-VAE objective comes from, and to further motivate why it makes sense, here we derive $\beta$-VAEs from different first principles than it is presented in the paper. Over to mostly Dóra for the rest of this post:

- $p_\mathcal{D}(\mathbf{x})$: data distribution
- $q_\psi(\mathbf{z}\vert \mathbf{x})$: representation distribution
- $q_\psi(\mathbf{z}) = \int p_\mathcal{D}(\mathbf{x})q_\psi(\mathbf{z}\vert \mathbf{x})$: aggregate posterior - marginal distribution of representation $Z$
- $q_\psi(\mathbf{x}\vert \mathbf{z}) = \frac{q_\psi(\mathbf{z}\vert \mathbf{x})p_\mathcal{D}(\mathbf{x})}{q_\psi(\mathbf{z})}$: "inverted posterior"

Learning disentangled representations that recover the independent data generative factors has been a long-term goal for unsupervised representation learning.

$\beta$-VAEs were introduced in 2017 with a proposed modification to the original VAE formulation that can achieve better disentanglement in the posterior $q_\psi(\mathbf{z}\vert \mathbf{x})$. An assumption of $\beta$-VAEs is that there are two sets of latent factors, $\mathbf{v}$ and $\mathbf{w}$, that contribute to generating observations $x$ in the real world. One set, $\mathbf{v}$, is coordinate-wise conditionally independent given the observed variable, i.e., $\log p(\mathbf{v}\vert x) = \sum_k \log p(v_k\vert \mathbf{x})$. At the same time, we don't assume anything about the remaining factors $\mathbf{w}$.

The factors $v$ are going to be the main object of interest for us. The conditional independence assumption allows us to formulate what it means to disentangle these factors of variation. Consider a representation $\mathbb{z}$ which *entangles * coordinates of $v$, in that each coordinate of $\mathbb{z}$ depends on multiple coordinates of $\mathbb{v}$, e.g. $z_1 = f_1(v_1, v_2)$ and $z_2 = f_2(v_1, v_2)$. Such a $\mathbb{z}$ won't necessarily satisfy co-ordinatewise conditional independence $\log p(\mathbf{z}\vert x) = \sum_k \log p(z_k\vert \mathbf{x})$. However, if each component of $\mathbb{z}$ depended only on one corresponding coordinate of $\mathbf{v}$, for example $z_1 = g_1(v_1)$ and $z_2 = g_2(v_2)$, the component-wise conditional independence would hold for $\mathbb{z}$ too.

Thus, under these assumptions we can encourage disentanglement to happen by encouraging the posterior $q_\psi(\mathbf{z}\vert \mathbf{x})$ to be coordinate-wise conditionally independent. This can be done by adding a new hyperparameter $\beta$ to the original VAE formulation

$$

\mathcal{L}(\theta, \phi; x, z, \beta) = -\mathbb{E}_{q_{\phi}(x\vert z)p_\mathcal{D}(x)}[\log p_{\theta}(x \vert z)] + \beta \operatorname{KL} (q_{\phi(z\vert x)}\| p(z)),

$$

where $\beta$ controls the trade-off between the capacity of the latent information channel and learning conditionally independent latent factors. When $\beta$ is higher than 1, we encourage the posterior $q_\psi(z\vert x)$ to be close to the isotropic unit Gaussian $p(z) = \mathcal{N}(0, I)$, which itself is coordnate-wise independent.

In this post, we revisit the conditional independence assumption of latent factors, and argue that a more appropriate objective would be to have marginal independence in the latent factors. To show you our intuition, let's revisit the "Explaining Away" phenomenon from Probabilistic Graphical Models.

Consider three random variables:

$A$: Ferenc is grading exams

$B$: Ferenc is in a good mood

$C$: Ferenc is tweeting a meme

with the following graphical model $A \rightarrow C \leftarrow B$.

Here we could assume that Ferenc grading exams is independent of him being in a good mood, i.e., $A \perp B$. However, the pressure of marking exams results in increased likelihood of procrastination, which increases the chances of tweeting memes, too.

However, as soon as we see a meme being tweeted by him, we know that he either in a good mood or he is grading exams. If we know he is grading exams, that explains why he is tweeting memes, so it's less likely he's tweeting memes because he's a good mood. Consequently, $A \not!\perp B\vert C$.

In all seriousness, if we have a graphical model $A \rightarrow C \leftarrow B$, in evidence of $C$, independence between $A$ and $B$ no longer holds.

We argue that the explaining away phenomenon makes the conditional independence of latent factors undesirable. A much more reasonable assumption about the generative process of the data is that the factors of variation $\mathbf{v}$ are drawn independently, and then the observations are generated conditoned on them. However, if we consider two coordinates of $\mathbb{v}$ and the observation $\mathbf{x}$, we now have a $V_1 \rightarrow \mathbf{X} \leftarrow V_2$ graphical model, thus, conditional independence cannot hold.

Instead, we argue that to recover the generative factors of the data, we should encourage latent factors to be marginally independent. In the next section, we set out to derive an algorithm that encourages marginal independence in the representation Z. We will also show how the resulting loss function from this new derivation is actually equivalent to the original $\beta$-VAEs formulation.

We'll start from desired properties of the representation distribution $q_\psi(z\vert x)$. We'd like this representation to satisfy two properties:

- Marginal independence: We would like the aggregate posterior $q_\psi(z)$ to be close to some fixed and factorized unit Gaussian prior $p(z) = \prod_i p(z_i)$. This encourages $q_\psi(z)$ to exhibit coordinate-wise independence.
- Maximum Information: We'd like the representation $Z$ to retain as much information as possible about the input data $X$.

Note that without (1), (2) is insufficient, because then any deterministic and invertible function of $X$ would contain maximum information about $X$ but that wouldn't make it a useful or disentangled representation. Similarly, without (2), (1) is insufficient because if we set $q_\psi(z\vert x) = p(z)$ it would give us a latent representation Z that is coordinate-wise independent, but it is also independent of the data which is not very useful.

We can achieve a combination of these desiderata by optimizing an objective with the weighted combination of two terms corresponding to the two goals we set out above:

$$

\mathcal{L}(\psi) = \operatorname{KL}[q_\psi(z)| p(z)] - \lambda \mathbb{I}_{q_\psi(z\vert x) p_\mathcal{D}(x)}[X, Z]

$$

Remember, we use $q_\psi(z)$ to denote the aggregate posterior. We will refer to this as the InfoMax objective. Now we're going to show how this objective can be related to the $\beta$-VAE objective. Let's first consider the KL term in the above objective:

\begin{align}

\operatorname{KL}[q_\psi(z)| p(z)] &= \mathbb{E}_{q_\psi(z)} \log \frac{q_\psi(z)}{p(z)}\\

&= \mathbb{E}_{q_\psi(z\vert x)p_\mathcal{D}(x)} \log \frac{q_\psi(z)}{p(z)}\\

&= \mathbb{E}_{q_\psi(z\vert x)p_\mathcal{D}(x)} \log \frac{q_\psi(z)}{q_\psi(z\vert x)} + \mathbb{E}_{q_\psi(z\vert x)p_\mathcal{D}(x)} \log \frac{q_\psi(z\vert x)}{p(z)}\\

&= \mathbb{E}_{q_\psi(z\vert x)p_\mathcal{D}(x)} \log \frac{q_\psi(z)p_\mathcal{D}(x)}{q_\psi(z\vert x)p_\mathcal{D}(x)} + \mathbb{E}_{q_\psi(z\vert x)p_\mathcal{D}(x)} \log \frac{q_\psi(z\vert x)}{p(z)}\\

&= -\mathbb{I}_{q_\psi(z\vert x)p_\mathcal{D}(x)}[X,Z] + \mathbb{E}_{p_\mathcal{D}}\operatorname{KL}[q_\psi(z\vert x)| p(z)]

\end{align}

This is interesting. If the mutual information between $X$ and $Z$ is non-zero (which is ideally the case), the above equation shows that latent factors cannot be both marginally and conditionally independent at the same time. It also gives us a way to relate the KL terms representing marginal and conditional independence.

Putting this back into the InfoMax objective, we have that

\begin{align}

\mathcal{L}(\psi) &= \operatorname{KL}[q_\psi(z)| p(z)] - \lambda \mathbb{I}_{q_\psi(z\vert x)p_\mathcal{D}(x)}[X, Z]\\

&= \mathbb{E}_{p_\mathcal{D}}\operatorname{KL}[q_\psi(z\vert x)| p(z)] - (\lambda + 1) \mathbb{I}_{q_\psi(z\vert x)p_\mathcal{D}(x)}[X, Z]\

\end{align}

Using the KL term in the InfoMax objective, we were able to recover the KL-divergence term that also appears in the $\beta$-VAE (and consequently, VAE) objective.

At this point, we still haven't defined the generative model $p_\theta(x\vert z)$, the above objective expresses everything in terms of the data distribution $p_\mathcal{D}$ and the posterior/representation distribution $q_\psi$.

We will now focus on the 2nd term in our desired objective, the weighted mutual information, which we still can't easily evaluate. We will now show that we can recover the reconstruction term in $\beta$-VAEs by doing a variational approximation to the mutual information.

Note the following equality:

\begin{equation}

\mathbb{I}[X,Z] = \mathbb{H}[X] - \mathbb{H}[X\vert Z]

\end{equation}

Since we sample X from the data distribution $p_\mathcal{D}$, we see that the first term $\mathbb{H}[X]$, the entropy of $X$, is constant with respect to the variational parameter $\psi$. We are left to focus on finding a good approximation to the second term $\mathbb{H}[X\vert Z]$. We can do so by minimizing the KL divergence between $q_\psi(x\vert z)$ and an auxilliary distribution $p_\theta(x\vert z)$ to make a variational appoximation to the mutual information:

$$\mathbb{H}[X\vert Z] = - \mathbb{E}_{q\psi(z\vert x)p_\mathcal{D}(x)} \log q_\psi(x\vert z) \leq \inf_\theta - \mathbb{E}_{q\psi(z\vert x)p_\mathcal{D}(x)} \log p_\theta(x\vert z)$$

Finding this lower bound to MI, we have now recovered the reconstruction term from the $\beta$-VAE objective:

$$

\mathcal{L}(\psi) + \text{const} \leq - (1 + \lambda) \mathbb{E}_{q_\psi(z\vert x)p_\mathcal{D}(x)} \log p_\theta(x\vert z) + \mathbb{E}_{p_\mathcal{D}}\operatorname{KL}[q_\psi(z\vert x)| p(z)]

$$

This is essentially the same as the $\beta$-VAE objective function, where $\beta$ is related to our previous $\lambda$. In particular, $\beta = \frac{1}{1 + \lambda}$. Thus, since we assumed $\lambda>0$ for the InfoMax objective to make sense, we can say that the $\beta$-VAE objective encourages disentanglement in the InfoMax sense for values of $0<\beta<1$.

Conceptually, this derivation is interesting because the main object of interest is now the recognition model, $q_\psi(z\vert x)$. That is, the posterior becomes a the focus of the objective function - something that is not the case when we are maximizing model likelihood alone (as explained here). In this respect, this derivation of the $\beta$-VAE makes more sense from a representation learning viewpoint than the derivation of VAE from maximum likelihood.

There is a nice symmetry to these two views. There are two joint distributions over latents and observable variables in a VAE. On one hand we have $q_\psi(z\vert x)p_\mathcal{D}(x)$ and on the other we have $p(x)p_\theta(x\vert z)$. The "latent variable model" $q_\psi(z\vert x)p_\mathcal{D}(x)$ is a family of LVMs which has a marginal distribution on observable $\mathbf{x}$ that is exactly the same as the data distribution $p_\mathcal{D}$. So one can say $q_\psi(z\vert x)p_\mathcal{D}(x)$ is a parametric family of latent variable models with whose likelihood is maximal - and we want to choose from this family a model where the representation $q_\psi(z\vert x)$ has nice properties.

On the flipside, $p(z)p_\theta(x\vert z)$ is a parametic set of models where the marginal distribution of latents is coordinatewise independent, but we would like to choose from this family a model that has good data likelihood.

The VAE objective tries to move these two latent variable models closer to one another. From the perspective of $q_\psi(z\vert x)p_\mathcal{D}(x)$ this amounts to reproducing the prior $p(z)$ with the aggregate posterior. from the perspective of $p(z)p_\theta(x\vert z)$, it amounts to maximising the data likelihood. When the $\beta$-VAE objective is used, we additionally wish to maximise the mutual information between the observed data and the representation.

This dual role of information maximization and maximum likelihood has been pointed out before, for example in this paper about the IM algorithm. The symmetry of variational learning has been exploited a few times, for example in yin-yang machines, and more recently also in methods like adversarially learned inference.

]]>- Lee et

Neural tangent kernels are a useful tool for understanding neural network training and implicit regularization in gradient descent. But it's not the easiest concept to wrap your head around. The paper that I found to have been most useful for me to develop an understanding is this one:

In this post I will illustrate the concept of neural tangent kernels through a simple 1D regression example. Please feel free to peruse the google colab notebook I used to make these plots.

Let's start from a very boring case begin with. Let's say we have a function defined over integers between -10 and 20. We parametrize our function as a look-up table, that is the value of the function $f(i)$ at each integer $i$ is described by a separate parameter $\theta_i = f(i)$. I'm initializing the parameters of this function as $\theta_i = 3i+2$. The function is shown by the black dots below:

Now, let's consider what happens if we observe a new datapoint, $(x, y) =(10, 50)$, shown by the blue cross. We're going to take a gradient descent step updating $\theta$. Let's say we use the squared error loss function $(f(10; \theta) - 50)^2$ and a learning rate $\eta=0.1$. Because the function's value at $x=10$ only depends on one of the parameter $\theta_10$, only this parameter will be updated. The rest of the parameters, and therefore the rest of the function values remain unchanged. The red arrows illustrate the way function values move in a single gradient descent step: Most values don't move at all, only one of them moves closer to the observed data. Hence only one visible red arrow.

However, in machine learning we rarely parametrize functions as lookup tables of individual function values. This parametrization is pretty useless as it doesn't allow you to interpolate let alone extrapolate to unseen data. Let's see what happens in a more familiar model: linear regression.

Let's now consider the linear function $f(x, \theta) = \theta_1 x + \theta_2$. I initialize the parameters to $\theta_1=3$ and $\theta_2=1$, so at initialisation, I have exactly the same function over integers as I had in the first example. Let's look at what happens to this function as I update $\theta$ by performing single gradient descent step incorporating the observation $(x, y) =(10, 50)$ as before. Again, red arrows are show how function values move:

Whoa! What's going on now? Since individual function values are no longer independently parametrized, we can't move them independently. The model binds them together through its global parameters $\theta_1$ and $\theta_2$. If we want to move the function closer to the desired output $y=50$ at location $x=10$ the function values elsewhere have to change, too.

In this case, updating the function with an observation at $x=10$ changes the function value far away from the observation. It even changes the function value in the opposite direction than what one would expect.. This might seem a bit weird, but that's really how linear models work.

Now we have a little bit of background to start talking about this neural tangent kernel thing.

Given a function $f_\theta(x)$ which is parametrized by $\theta$, its *neural tangent kernel* $k_\theta(x, x')$ quantifies how much the function's value at $x$ changes as we take an infinitesimally small gradient step in $\theta$ incorporating a new observation at $x'$. Another way of phrasing this is: $k(x, x')$ measures how sensitive the function value at $x$ is to prediction errors at $x'$.

In the plots before, the size of the red arrows at each location $x$ were given by the following equation:

$$

\eta \tilde{k}_\theta(x, x') = f\left(x, \theta + \eta \frac{f_\theta(x')}{d\theta}\right) - f(x, \theta)

$$

In neural network parlance, this is what's going on: The loss function tells me to increase the function value $f_\theta(x')$. I back-propagate this through the network to see what change in $\theta$ do I have to make to achieve this. However, moving $f_\theta(x')$ this way also simultaneously moves $f_\theta(x)$ at other locations $x \neq x'$. $\tilde{k}_\theta(x, x')$ expresses by how much.

The neural kernel is basically something like the limit of $\tilde{k}$ in as the stepsize becomes infinitesimally small. In particular:

$$

k(x, x') = \lim_{\eta \rightarrow 0} \frac{f\left(x, \theta + \eta \frac{df_\theta(x')}{d\theta}\right) - f(x, \theta)}{\eta}

$$

Using a 1st order Taylor expansion of $f_\theta(x)$, it is possible to show that

$$

k_\theta(x, x') = \left\langle \frac{df_\theta(x)}{d\theta} , \frac{f_\theta(x')}{d\theta} \right\rangle

$$

As homework for you: find $k(x, x')$ and/or $\tilde{k}(x, x')$ for a fixed $\eta$ in the linear model in the pervious example. Is it linear? Is it something else?

Note that this is a different derivation from what's in the paper (which starts from continuous differential equation version of gradient descent).

Now, I'll go back to the examples to illustrate two more important property of this kernel: sensitivity to parametrization, and changes during training.

It's well known that neural networks can be repararmetized in ways that don't change the actual output of the function, but which may lead to differences in how optimization works. Batchnorm is a well-known example of this. Can we see the effect of reparametrization in the neural tangent kernel? Yes we can. Let's look at what happens if I reparametrize the linear function I used in the second example as:

$$

f_\theta(x) = \theta_1 x + \color{blue}{10\cdot}\theta_2

$$

but now with parameters $\theta_1=3, \theta_2=\color{blue}{0.1}$. I highlighted in blue what changed. The function itself, at initialization is the same since $10 * 0.1 = 1$. The function class is the same, too, as I can still implement arbitrary linear functions. However, when we look at the effect of a single gradient step, we see that the function changes differently when gradient descent is performed in this parametrisation.

In this parametization, it became easier for gradient descent to push the whole function up by a constant, while in the previous parametrisation it decided to change the slope. What this demonstrates is that the neural tangent kernel $k_\theta(x, x')$ is sensitive to reparametrization.

While the linear models may be good illustration, let's look at what $k_\theta(x, x')$ looks like in a nonlinear model. Here, I'll consider a model with two squared exponential basis functions:

$$

f_\theta(x) = \theta_1 \exp\left(-\frac{(x - \theta_2)^2}{30}\right) + \theta_3 \exp\left(-\frac{(x - \theta_4)^2}{30}\right) + \theta_5,

$$

with initial parameter values $(\theta_1, \theta_2, \theta_3, \theta_4, \theta_5) = (4.0, -10.0, 25.0, 10.0, 50.0)$. These are chosen somewhat arbitrarily and to make the result visually appealing:

We can visualise the function $\tilde{k}_\theta(x, 10)$ directly, rather than plotting it on top the function. Here I also normalize it by dividing by $\tilde{k}_\theta(10, 10)$.

What we can see is that this starts to look a bit like a *kernel* function in that it has higher values near $10$ and decreases as you go farther away. However, a few things are worth noting: the maximum of this kernel function is not at $x=1o$, but at $x=7$. It means, that the function value $f(7)$ changes more in reaction to an observation at $x'=10$ than the value $f(10)$. Secondly, there are some negative values. In this case the previous figure provides a visual explanation why: we can increase the function value at $x=10$ by pushing the valley centred at $\theta_1=4$ away from it, to the left. This parameter change in turn decreases function values on the left-hand wall of the valley. Third, the kernel function converges to a positive constant at its tails - this is because of the offset $\theta_5$.

Now I'm going to illustrate another important property of the neural tangent kernel: in general, the kernel depends on the parameter value $\theta$, and therefore it changes as the model is trained. Here I show what happens to the kernel as I take 15 gradient ascent steps trying to increase $f(10)$. The purple curve is the one I had at initialization (above), and the orange ones show the kernel at the last gradient step.

The corresponding changes to the function $f_\theta_t$ changes are shown below:

So we can see that as the parameter changes, the kernel also changes. The kernel becomes flatter. An explanation of this is that eventually we reach a region of parameter space, where $\theta_4$ changes the fastest.

It turns out the neural tangent kernel becomes particularly useful when studying learning dynamics in infinitely wide feed-forward neural networks. Why? Because in this limit, two things happen:

- First: if we initialize $\theta_0$ randomly from appropiately chosen distributions, the initial NTK of the network $k_{\theta_0}$ approaches a deterministic kernel as the width increases. This means, that at initialization, $k_{\theta_0}$ doesn't really depend on $\theta_0$ but is a fixed kernel independent of the specific initialization.
- Second: in the infinite limit the kernel $k_{\theta_t}$ stays constant over time as we optimise $\theta_t$. This removes the parameter dependence during training.

These two facts put together imply that gradient descent in the infinitely wide and infinitesimally small learning rate limit can be understood as a pretty simple algorithm called *kernel gradient descent* with a fixed kernel function that depends only on the architecture (number of layers, activations, etc).

These results, taken together with an older known result by Neal, (1994), allows us to characterise the probability distribution of minima that gradient descent converges to in this infinite limit as a Gaussian process. For details, see the paper mentioned above.

There are two somewhat related sets of results both involving infinitely wide neural netwoks and kernel functions, so I just wanted to clarify the difference between them:

- the older, well-known result by Neal, (1994), later extended by others, is that the distribution of $f_\theta$ under random initialization of $\theta$ converges to a Gaussian process. This Gaussian process has a kernel or covariance function which is not, in general, the same as the neural tangent kernel. This old result doesn't say anything about gradient descent, and is typically used to motivate the use of Gaussian process-based Bayesian methods.
- the new, NTK, result is that the evolution of $f_{\theta_t}$ during gradient descent training can be described in terms of a kernel, the neural tangent kernel, and that in the infinite limit this kernel stays constant during training and is deterministic at initialization. Using this result, it is possible to show that in some cases the distribution of $f_{\theta_t}$ is a Gaussian process at every timestep $t$, not just at initialization. This result also allows us to identify the Gaussian process which describes the limit as $t \rightarrow \infty$. This limiting Gaussian process however is not the same as the posterior Gaussian process which Neal and others would calculate on the basis of the first result.

So I hope this post helps a bit by building some intuition about what the neural tangent kernel is. If you're interested, check out the simple colab notebook I used for these illustrations.

]]>- Rezende et al (2020) Rezende Causally Correct Partial Models for Reinforcement Learning

It's frankly taken me a long time to understand what was going on, and it took me weeks to write this half-decent explanation of it. The first

]]>I recently encountered this cool paper in a reading group presentation:

- Rezende et al (2020) Rezende Causally Correct Partial Models for Reinforcement Learning

It's frankly taken me a long time to understand what was going on, and it took me weeks to write this half-decent explanation of it. The first notes I wrote followed the logic of the paper more, this in this post I thought I'd just focus on the high level idea, after which hopefully the paper is more straightforward. I wanted to capture the key idea, without the distractions of RNN hidden states, etc, which I found confusing to think about.

To start with the basics, this paper deals with the partially observed Markov decision process (POMDP) setup. The diagram below illustrates what's going on:

The grey nodes $e_t$ show the unobserved state of the environment at each timestep $t=0,1,2\ldots$. At each timestep the agent observes $y_0$ which depends on the current state of the environment (red-ish nodes). The agent then updates their state $s_t$ based on its past state $s_{t-1}$, the new observation $y_t$, and the previous action taken $a_{t-1}$. This is shown by the blue squares (they're squares, signifying that this node depends deterministically on its parents). Then, based on the agent's state, it chooses an action $a_t$ from by sampling from policy $\pi(a_t\vert s_t)$. The action influences how the environment's state, $e_{t+1}$ changes.

We assume that the agent's ultimate goal is to maximise reward at the last state at time $T$, which we assume is a deterministic function of the observation $r(y_T)$. Think of this reward as the score in an atari game, which is written on the screen whose contents are made available in $y_t$.

Let's start by stating what we ultimately would like estimate from the data we have. The assumption is that we sampled the data using some policy $\pi$, but we would like to be able to say how well a different policy $\tilde{\pi}$ would do, in other words, what would be the expected score at time $T$ if instead of $\pi$ we used a different policy $\tilde{\pi}$.

What we are interested in, is a causal/counterfactual query:

$$

\mathbb{E}_{\tau\sim\tilde{p}}[r(y_T)],

$$

where $\tau = [(s_t, y_t, e_t, a_t) : t=0\ldots T]$ denotes a trajectory or rollout up to time $T$, and $\tilde{p}$ denotes the generative process when using policy $\tilde{\pi}$, that is:

$$

\tilde{p}(\tau) = p(e_0)p(y_0\vert e_0) \tilde{\pi}(a_0\vert s_0) p(s_0)\prod_{t=1}^T p(e_t\vert a_{t-1}, e_{t-1}) p(y_t\vert e_t) \tilde{\pi}(a_t\vert s_t) \delta (s_t - g(s_{t-1}, y_t))

$$

I called this a causal or counterfactual query, because we are interested in making predictions under a different distribution $\tilde{p}$ than $p$ which we have observations from. The difference between $\tilde{p}$ and $p$ can be called an intervention, where we replace specific factors in the data generating process with different ones.

There are - at least - two ways one could go about estimating such counterfactual distribution:

- model-free, via
*importance sampling*. This method tries to directly estimate the causal query by calculating a weighted average over the observed data. The weights are given by the ratio between $\tilde{\pi}(a_t\vert s_t)$, the probability by which $\tilde{\pi}$ would choose an action and $\pi(a_t\vert s_t)$, the probability it was chosen by the policy that we used to collect the data. A great paper explaining how this works is (Bottou et al, 2013). Importance sampling as the advantage that we don't have to build any model of the environment, we can directly evaluate the average reward from the samples we have, using only $\pi$ and $\tilde{\pi}$ to calculate the weights. The downside, however, is that importance sampling often incredibly high variance estimate, and is only reliable if $\tilde{\pi}$ and $\pi$ are very close. - model-based, via
*causal calculus*. If possible, we can use do-calculus to express the causal query in an alternative way, using various conditional distributions estimated from the observed data. This approach has the disadvantage that it requires us build a model from the data first. We then use the conditional distributions learned from the data to approximate the quantity of interest by plugging them into the formula we got from do-calculus. If our models are imperfect, these imperfections/approximation errors can compound when the causal estimand is calculated, potentially leading to large biases and inaccuracies. On the other hand, our models may be accurate enough to extrapolate to situations where importance weighting would be unreliable.

In this paper, we focus on solving the problem with causal calculus. This requires us to build a model of observed data, which we can then use to make causal predictions. The key question this paper asks is

How much of the data do we have to model to be able to make the kinds of causal inferences we would like to make?

One way we can answer the query above is to model the joint distribution of everything, or mostly everything, that we can observe. For example, we could build a full autoregressive model of observations $y_t$ conditioned on actions $a_t$. In essence this would amount to fitting a model to $p(y_{0:T}\vert a_{0:T})$.

If we had such model, we would theoretically be able to make causal predictions, for reasons I will explain later. However, this option is ruled out in the paper because we assume the observations $y_t$ are very high dimensional, such as images rendered in a computer game. Thus, modelling the joint distribution of the whole observation sequence $y_{1:T}$ accurately is hopeless and would require a lot of data. Therefore, we would like to get away without modelling the whole observation sequence $y_{1:T}$, which brings us to partial models.

Partial models try to avoid modelling the joint distribution of high-dimensional observations $y_{1:T}$ or agent-state sequences $s_{0:T}$, and focus on modelling directly the distribution of $r(y_T)$ - i.e. only the reward component of the last observation, given the action-sequence $a_{0:T}$. This is clearly a lot easier to do, because $r(y_T)$ is assumed to be a low-dimensional aspect of the full observation $y_T$, so all we have to learn is a model of a scalar conditioned on a sequence of actions $q_\theta(r(y_T)\vert a_{0:T})$. We know very well how to fit such models to realistic amounts of data.

However, if we don't include either $y_t$ or $s_t$ in our model, we won't be able to make the counterfactual inferences we wanted to make in the first place. Why? Let's look at he data generating process once more:

We are trying to model the causal impact of actions $a_0$ and $a_1$ on the outcome $y_2$. Let's focus on $a_1$. $y_2$ is clearly statistically dependent on $a_1$. However, this statistical dependence emerges due to completely different effects:

**causal association:**$a_1$ influences the state of the environment $e_2$, resulting in an observation $y_2$. Therefore, $a_1$ has an direct causal effect on $y_2$, mediated by $e_2$**spurious association due to confounding:**The unobserved hidden state $e_1$ is a confounder between the action $a_1$ and the observation $y_2$. The state $e_1$ has an indirect causal effect on $a_1$ mediated by the observation $y_1$ and the agent's state $s_1$. Similarly $e_1$ has an indirect effect on $y_2$ mediated by $e_2$.

I illustrated these two sources of statistical association by colour-coding the different paths in the causal graph. The blue path is the confounding path: correlation is induced because both $a_1$ and $y_2$ have $e_1$ as causal ancestor. The red path is the causal path: $a_1$ indirectly influences $y_2$ via the hidden state $e_2$.

If we would like to correctly evaluate the consequence of changing policies, we have to be able to disambiguate between these two sources of statistical association, get rid of the blue path, and only take the red path into account. Unfortunately, this is not possible in a partial model, where we only model the distribution of $y_2$ conditional on $a_0$ and $a_1$.

If we want to draw causal inferences, we **must model the distribution of at least one variable along blue path.** Clearly, $y_1$ and $s_1$ are theoretically observable, and are on the confounding path. Adding either of these to our model would allow us to use the backdoor adjustment formula (explained in the paper). However, this would take us back to Option 1, where we have to model the joint distribution of either sequences of observations $y_{0:T}$ or sequences of states $s_{0:T}$, both assumed to be high-dimensional and difficult to model.

So we finally got to the core of what is proposed in the paper: a kind of halfway-house between modeling everything and modeling too little. We are going to model *enough *variables to be able to evaluate causal queries, while keeping the dimensionality of the model we have to fit low. To do this, we change the data generating process slightly - by splitting the policy into two stages:

The agent first generates $z_t$ from the state $s_t$, and then uses the sampled $z_t$ value to make a decision $a_t$. One can understand $z_t$ as being a stochastic bottleneck between the agent's high-dimensional state $s_t$, and the low-dimensional action $a_t$. The assumption is that the sequence $z_{0:T}$ should be a lot easier to model than either $y_{0:T}$ or $s_{0:T}$. However, if we now build a model $p(r(y_T), z_{0:T} \vert a_{0:T})$ are now able to use this model evaluate the causal queries of interest, thanks for the backdoor adjustment formula. For how to precisely do this, please refer to the paper.

Intuitively, this approach helps by adding a low-dimensional stochastic node along the confounding path. This allows us to compensate for confounding, without having to build a full generative model of sequences of high-dimensional variables. It allows us to solve the problem we need to solve without having to solve a ridiculously difficult subproblem.

]]>- Jonathan

Last night on the train I read this nice paper by David Duvenaud and colleagues. Around midnight I got a calendar notification "it's David Duvenaud's birthday". So I thought it's time for a David Duvenaud birthday special (don't get too excited David, I won't make it an annual tradition...)

- Jonathan Lorraine, Paul Vicol, David Duvenaud (2019) Optimizing Millions of Hyperparameters by Implicit Differentiation

I recently covered iMAML: the meta-learning algorithm that makes use of implicit gradients to sidestep backpropagating through the inner loop optimization in meta-learning/hyperparameter tuning. The method presented in (Lorraine et al, 2019) uses the same high-level idea, but introduces a different - on the surface less fiddly - approximation to the crucial inverse Hessian. I won't spend a lot of time introducing the whole meta-learning setup from scratch, you can use the previous post as a starting point.

Many - though not all - meta-learning or hyperparameter optimization problems can be stated as nested optimization problems. If we have some hyperparameters $\lambda$ and some parameters $\theta$ we are interested in

$$

\operatorname{argmin}_\lambda \mathcal{L}_V (\operatorname{argmin}_\theta \mathcal{L}_T(\theta, \lambda)),

$$

Where $\mathcal{L}_T$ is some training loss and $\mathcal{L}_V$ a validation loss. The optimal parameter to the training problem, $\theta^\ast$ implicitly depends on the hyperparameters $\lambda$:

$$

\theta^\ast(\lambda) = \operatorname{argmin} f(\theta, \lambda)

$$

If this implicit function mapping $\lambda$ to $\theta^\ast$ is differentiable, and subject to some other conditions, the implicit function theorem states that its derivative is

$$

\left.\frac{\partial\theta^{\ast}}{\partial\lambda}\right\vert_{\lambda_0} = \left.-\left[\frac{\partial^2 \mathcal{L}_T}{\partial \theta \partial \theta}\right]^{-1}\frac{\partial^2\mathcal{L}_T}{\partial \theta \partial \lambda}\right\vert_{\lambda_0, \theta^\ast(\lambda_0)}

$$

The formula we obtained for iMAML is a special case of this where the $\frac{\partial^2\mathcal{L}_T}{\partial \theta \partial \lambda}$ is the identity This is because there, the hyperparameter controls a quadratic regularizer $\frac{1}{2}\|\theta - \lambda\|^2$, and indeed if you differentiate this with respect to both $\lambda$ and $\theta$ you are left with a constant times identity.

The primary difficulty of course is approximating the inverse Hessian, or indeed matrix-vector products involving this inverse Hessian. This is where iMAML and the method proposed by Lorraine et al, (2019) differ. iMAML uses a conjugate gradient method to iteratively approximate the gradient. In this work, they use a Neumann series approximation, which, for a matrix $U$ looks as follows:

$$

U^{-1} = \sum_{i=0}^{\infty}(I - U)^i

$$

This is basically a generalization of the better known sum of a geometric series: if you have a scalar $\vert u \vert<1$ then

$$

\sum_{i=0}^\infty q^i = \frac{1}{1-q}.

$$

Using a finite truncation of the Neumann series one can approximate the inverse Hessian in the following way:

$$

\left[\frac{\partial^2 \mathcal{L}_T}{\partial \theta \partial \theta}\right]^{-1} \approx \sum_{i=1}^j \left(I - \frac{\partial^2 \mathcal{L}_T}{\partial \theta \partial \theta}\right)^i.

$$

This Neumann series approximation, at least on the surface, seems significantly less hassle to implement than running a conjugate gradient optimization step.

One of the fun bits of this paper is the interesting set of experiments the authors used to demonstrate the versatility of this approach. For example, in this framework, one can treat the training dataset as a hyperparameter. Optimizing pixel values in a small training dataset, one image per class, allowed the authors to "distill" a dataset into a set of prototypical examples. If you train your neural net on this distilled dataset, you get relatively good validation performance. The results are not quite as image-like as one would imagine, but for some classes, like bikes, you even get recognisable shapes:

In another experiment the authors trained a network to perform data augmentation, treating parameters of this network as a hyperparameter of a learning task. In both of these cases, the number of hyperparameters optimized were in the hundreds of thousands, way beyond the number we usually consider as hyperparameters.

This method inherits some of the limitations I already discussed with iMAML. Please also see the comments where various people gave pointers to work that overcomes some of these limitations.

Most crucially, methods based on implicit gradients assume that your learning algorithm (inner loop) finds a unique, optimal parameter that minimises some loss function. This is simply not a valid assumption for SGD where different random seeds might produce very different and differently behaving optima.

Secondly, this assumption only allows for hyperparameters that control the loss function, but not for ones that control other aspects of the optimization algorithm, such as learning rates, batch sizes or initialization. For those kind of situations, explicit differentiation may still be the most competitive solution. On that note, I also recommend reading this recent paper on generalized inner-loop meta-learning and the associated pytorch package higher.

Happy birthday David. Nice work!

]]>My parents didn't raise me in a religious tradition. It all started to change when a great scientist took me under his wing and taught me the teachings of Bayes. I travelled the world and spent 4 years in a Bayesian monastery in Cambridge, UK. This particular

]]>My parents didn't raise me in a religious tradition. It all started to change when a great scientist took me under his wing and taught me the teachings of Bayes. I travelled the world and spent 4 years in a Bayesian monastery in Cambridge, UK. This particular place practiced the nonparametric Bayesian doctrine.

We were religious Bayesians. We looked at the world and all we saw the face of Bayes: if something worked, it did because it had a Bayesian interpretation. If an algorithm did not work, we shunned its creator for being unfaithful to Bayes. We scorned at point estimates, despised p-values. Bayes had the answer to everything. But above all, we believed in our models.

At a convention dominated by Bayesian thinkers I was approached by a frequentist, let's call him Lucifer (in fact his real name is Laci so not that far off). "Do you believe your data exists?" - he asked. "Yes" I answered. "Do you believe your model and its parameters exist?" "Well, not really, it's just a model I use to describe reality" I said. Then he told me the following, poisoning my pure Bayesian heart forever: "If you use Bayes' rule, you assume that a joint distribution between model parameters and data exist. This, however, only exists if your data and your parameters both exist, in the same $\sigma$-algebra. You can't have it both ways. You have to think your model really exists somewhere."

I never forgot this encounter, but equally I didn't think much about it since then. Over the years, I started to doubt more and more aspects of my Bayesian faith. I realised the likelihood was important, but not the only thing that exists. There were scoring rules, loss functions which couldn't be written as a log-likelihood. I noticed nonparametric Bayesian models weren't automatically more useful than large parametric ones. I worked on weird stuff like loss-calibrated Bayes. I started having thoughts about model misspecification, kind of a taboo topic in the Bayesian church.

Over the years I came to terms with my Bayesian heritage, and I now live my life as a secular Bayesian. Certain elements of the Bayesian way are no doubt useful: Engineering inductuve biases explicitly into a prior distribution, using probabilities, divergences, information, variational bounds as tools for developing new algorithms. Posterior distributions can capture model uncertainty which can be exploited for active learning or exploration in interactive learning. Bayesian methods often - though not always - lead to increased robustness, better calibration, and so much more. At the same time, I can carry on living my life, use gradient descent to find local minima, use bootstrap to capture uncertainty. And first and foremost, I do not have to believe that my models really exist or perfectly describe reality anymore. I am free to think about model misspecification.

Lately, I have started to familiarize myself with a new body of work, which I call secular Bayesianism, that combines Bayesian inference with more frequentists ideas about learning from observation. In this body of work, people study model misspecification (see e.g. M-open Bayesian inference). And, I found a resolution to the "you have to believe in your model, you can't have it both ways" problem that bothered me all these years.

After this rather long intro, let me present the paper this post is really about and which, as a secular Bayesian, I found very interesting:

- P.G. Bissiri, C.C. Holmes and S.G. Walker (2016) A General Framework for Updating Belief Distributions

This paper basically asks: can we take the belief out of belief distributions? Let's say we want to estimate some parameter of interest $\theta$ from data. Does it still make sense to specify a prior distribution over this parameter, and then update them in light of data using some kind of Bayes rule-like update mechanism to form posterior distributions, all without assuming that the parameter of interest $\theta$ and the observations $x_i$ are linked to one another via a probabilistic model? And if it is meaningful, what form would that update rule take.

First of all, for simplicity, let's assume that data $x_i$ is sampled i.i.d from some distribution $P$. That's right, not exchangeable, actually i.i.d. like in frequentist settings. Let's also assume that we have some parameter of interest $\theta$. Unlike in Bayesian analysis where $\theta$ usually parametrises some kind of generative model for data $x_i$, we don't assume anything like that. All we assume is that there is a loss function $\ell$ which connects the parameter to the observations: $\ell(\theta, x)$ measures how well the estimate $\theta$ agrees with observation $x$.

Let's say that a priori, without seeing any datapoints we have a prior distribution $\pi$ over $\theta$. Now we observe a datapoint $x_1$. How should we make use of our observation $x_1$, the loss function $\ell$ and the prior $\pi$ to come up with some kind of posterior over this parameter? Let's denote this update rule $\psi(\ell(\cdot, x_1), \pi)$. There are many ways we could do this, but is there one which is better than the rest?

The paper lists a number of desiderata - desired properties the update rule $\psi$ should satisfy. These are all meaningful assumptions to have. The main one is coherence, which is a property somewhat analogous to exchangeability: if we observe a sequence of observations, we would like the resulting posterior to be the same, irrespective of which order the observations are presented. The coherence property can be written as follows

$$

\psi\left(\ell(\cdot, x_2), \psi\left(\ell(\cdot, x_1), \pi\right)\right) = \psi\left(\ell(\cdot, x_1), \psi\left(\ell(\cdot, x_2), \pi \right)\right)

$$

As a desired property, this makes a lot of sense, and Bayes rule obviously satisfies it. However, this is not really how the authors actually define coherence. In Equation (3) they use a more restrictive definition of coherence, further restricting the set of acceptable update rules as follows:

$$

\psi\left(\ell(\cdot, x_2), \psi\left(\ell(\cdot, x_1), \pi\right)\right) = \psi\left(\ell(\cdot, x_1) + \ell(\cdot, x_2), \pi \right)

$$

By combining losses from the two observations in an additive way, one can indeed ensure permuation invariance. However, the sum is not the only way to do this. Any pooling operation over observations would also have satisfied this. For example, one could replace the $\ell(\cdot, x_1) + \ell(\cdot, x_2)$ bit by $\max(\ell(\cdot, x_1), \ell(\cdot, x_2))$ and still satisfy the general principle of coherence. The most general class of permutation invariant functions which would satisfy the general coherence desideratum are discussed in DeepSets. Overall, my hunch is that going with the sum is a design choice, rather than a general desideratum. This choice is the real reason why the resulting update rule will end up very Bayes-rule like, as we will see later.

The other desiderata the paper proposes are actually discussed separately in Section 1.2 of (Brissini et al, 2016), and called assumptions instead. These are much more basic requirements for the update function. Assumption 2 for example talks about how restricting the prior to a subset should result in a posterior which is also the restricted version of the original posterior. Assumption 3 requires that lower evidence (larger loss) for a parameter should yield smaller posterior probabilities - a monotonicity property.

One contribution of the paper is showing that all the desiderata mentioned above pinpoint a specific update rule $\psi$ which satisfies all the desired properties. This update takes the following form:

$$

\pi(\theta\vert x_{1:N}) = \psi(\ell(\cdot, x), \pi) \alpha \exp\{-\sum_{n=1}^N\ell(\theta, x_N)\}\pi(\theta)

$$

Just like Bayes rule we have a normalized product of the prior with something that takes the role of the likelihood term. If the loss is the logarithmic loss of a probabilistic model, we recover the Bayes rule, but this update rule makes sense for arbitrary loss functions.

Again, this solution is unique under the very strong and specific desideratum that we'd like the losses from i.i.d. observations combine in an additive way, and I presume that, had we chosen a different permutation invariant function, we would end up with a similar generalization of Bayes rule with that permutation invariant function appearing in the exponent.

Now that we have an update rule which satisfies our desiderata, can we say if it's actually a good or useful update rule? It seems it is, in the following sense.

Let's think about a way to measure the usefulness of a posterior $\nu$. Suppose we have data sampling distribution $P$, losses are still measured by $\ell$, and our prior is $\pi$. A good posterior does two things well: it allows us to make good decisions in some kind of downstream test scenario, and it is informed by our prior. It therefore makes sense to define a loss function over the posterior $\nu$ as a sum of two terms:

$$

L(\nu; \ell, \pi, P) = h_1(\nu; \ell, P) + h_2(\nu; \pi)

$$

The first term, $h_1$ measures the posterior's usefulness at test time, and $h_2$ measures how well it's influenced by the prior. The authors define $h_1$ to be as follows:

$h_1(\nu; \ell, P) = \mathbb{E}_{x\sim P} \mathbb{E}_\theta\sim\nu \ell(x, \theta)$

So basically, we will sample from the posterior, and then evaluate the random sample parameter $\theta$ on a randomly chosen test datapoint $x$ using our loss $\ell$. I would say this is a rather narrow view on what it means for a posterior to do well on a downstream task, more about it later in the criticism section. In any case it's one possible goal for a posterior to try to achieve.

Now we turn to choosing $h_2$, and the authors note something very interesting. If we want the resulting optimal posterior to possess the coherence property (as defined in their Eqn. (3)), it turns out the only choice for $h_2$ is the KL divergence between the prior and posterior. Any other choice would lead to incoherent updates. This, I believe is only true for the additive definition of coherence, not the more general definition I gave above.

Putting $h_1$ and $h_2$ together it turns out that the posterior that minimizes this loss function is precisely of the form $\pi(\theta\vert x_{1:N}) \alpha \exp\{-\sum_{n=1}^N \ell(\theta, x_n)\}$. So, not only is this update rule the only update rule that satisfies the desired properties, it is also optimal under this particular definition of optimality/rationality.

This work is interesting because it gives a new justification for Bayes rule-like updates to belief distributions, and as a result it also provides a different/new perspective on Bayesian inference. Crucially, never in this derivation did we have to reason about a joint distribution between $\theta$ and the observations $x$ (or conditionals of one given the other). Even though I wrote $\pi(\theta \vert x_{1:N})$ to denote a posterior, this is really just a shorthand notation, syntactic sugar. This is important. One of the main technical criticisms of the Bayesian terminology is that in order to reason about the joint distribution between two random variables ($x$ and $\theta$), these variables have to live in the same probability space, so if you believe that your data exists, you have to believe in your model, and model parameters exist as well. This framework sidesteps that.

It allows rational updates of belief distributions, without forcing you to believe in anything.

From a practical viewpoint, this work also extends Bayesian inference in a meaningful way. While Bayesian inference only made sense if you inferred the whole set of parameters jointly, here you are allowed to specify any loss function, and really focus on the parameter of importance. For example, if you're only interested in estimating the median of a distribution in a Bayesian way, without assuming it follows a certain distribution, you can now do this by specifying your loss to be $\vert x-\theta\vert$. This is explained a lot more clearly in the paper, so I encourage you to read it.

My main criticism of this work is that it made a number of assumptions that ultimately limited the range of acceptable solutions, and to my cynical eye it appears that these choices were specifically made so that Bayes rule-like update rules came out winning. So rather than really deriving Bayesian updates from first principles, we engineered principles under which Bayesian updates are optimal. In other words, the top-down analysis was rigged in favour of familiar Bayes-like updates. There are two specific assumptions which I would personally like to see relaxed:

The first one is the restrictive notion of coherence, which requires losses to combine additively from multiple observations. I think this very clearly gives rise to the convenient exponential, log-additive form in the end. It would be interesting to see whether other types of permutation invariant update rules also make sense in practice.

Secondly, the way the authors defined optimality, in terms of the loss $h_1$ above is very limiting. We rarely use posterior distributions in this way (take a random sample). Instead, we might be intersested integrating over the posterior, and evaluating the loss of that classifier. This is a loss that cannot be written in the bilinear form that is the formula for $h_1$ above. I wonder if. using more elaborate losses for the posterior, perhaps along the lines of general decision problems as in (Lacoste-Julien et al, 2011), could lead to more interesting update rules which don't look at all like Bayes rule but are still rational.

]]>- Zhiyuan Li and Sanjeev Arora (2019) An Exponential Learning Rate Schedule for Deep Learning

The paper provides both theoretical insights as well

]]>Yesterday I read this intriguing paper about the midboggling fact that it is possible to use exponentially growing learning rate schedule when training neural networks with batch normalization:

- Zhiyuan Li and Sanjeev Arora (2019) An Exponential Learning Rate Schedule for Deep Learning

The paper provides both theoretical insights as well as empirical demonstration of this remarcable property.

The reason why this works boils down to the observation that batch-normalization renders the loss function of neural networks scale invariant - scaling the weights by a constant does not change the output, or the loss, of the batch normalized network. It turns out that this property alone might result in somewhat unexpected and potentially helpful properties for optimization. I will use this post to illustrate some of the properties of scale invariant loss functions - and gradient descent trajectories on them - using a 2D toy example:

Here, I drew a loss function whih has the scale invariance property. The value of the loss only depends on the angle, but not the magnitude of the weight vector. The value of the loss along any radial line from the origin outwards is constant. Simple consequences of scale invariance are that (Lemma 1 of the paper)

- that the gradient of this function is always orthogonal to the current value of the parameter vector, and that
- the farther you are from the origin, the smaller the magnitude of the gradient. This is perhaps less intuitive but think about how the function behaves on a circle around the origin. The function is the same, but as you increase the radius you stretch the same function round a larger circle - it gets fatter, therefore its gradients decrease.

Here is a - somewhat messy - quiver plot showing the gradients of the function above:

The quiver plot is messy because the gradients around the origin explode. But you can perhaps see how the gradients get larger and larger - and remain perpendicular to the value itself.

So Imagine doing vanilla gradient descent (no momentum, weight decay, fixed learning rate) on such a loss surface. Because the gradient is always perpendicular to the current value of the parameter, by the Pythagorean theorem, the norm of the parameter vector increases with each iteration. So gradient descent takes you away from the origin. However, the weight vector won't completely blow up to infinity, because the gradients also get smaller and smaller as the weight vector grows, so it settles at some point. Here is a gradient descent path looks like starting from the coordinate $(-0.7, 0.7)$:

In fact, you can't really see it but the optimization kind of gets stuck in there, and doesn't move any longer. It's interesting to see what happens if we add weight decay, which is the same as adding L2 regularizer over the weights:

We can see that once the trajectory is about to get stuck in a local minimum, weight decay pulls it back towards the origin, which is where gradients become larger. This, in turn, perturbs the trajectory often pushing it out of the current local minimum. So in a way, we can start to build the intuition that weight decay on a scale-invariant loss function acts as a kind of learning rate adjustment.

In fact, what the paper works out is an equivalence between two things:

- weight decay with constant learning rate and
- no weight decay and an exponentially growing learning rate

On the plot below I show the trajectory with the exponentially growing learning rate which is equivalent to the one I showed before with weight decay. This one has no weight decay, and its learning rate keeps growing:

We can see that the trajectory blows up, and quickly gets out of bound on this animation. How can this be equivalent to the weight decay trajectory? Well, from the perspective of the loss function, the magnitude of the weight vector is irrelevant, and we only care about the angle when viewed from the origin. Turns out, if you look at those angles, the two trajectories are the same. To illustrate this, I use the normalization formula from Theorem 2.1 to project this trajectory back to the same magnitude the weight decay one would have. I obtain something that indeed looks very much like the trajectory above:

After a while, the trajectories start working differently, which I think is probably due to the accumulation of numerical errors in my implementation of the toy example. I could probably fix this, but I'm not sure it's worth the effort. The authors show much more convincing empirical evidence that this works in real, complicated neural network losses that people actually want to optimize.

You can think of this renormalization I did above as "constantly zooming out" on the loss landscape to keep up with the exponentially exploding parameter. I tried to illustrate this below:

On the left-hand plot, I show the original, weight-decayed gradient descent with a constant learning rate. On the right-hand plot I show the equivalent trajectory with exponentially growing learning rate and no weight decay, and I also added a constant zoom to counteract the explosion of the parameter's norm, in line with Theorem 2.1. We can see that, especially initially, the two paths behave the same way when viewed from the origin. They then work differently which I believe is down to the numerical precision issue that could probably be worked out.

The paper shows a similar equivalence in the presence of momentum as well, if interested, read the details in the paper.

I thought this observation was very cool, and may well lead to a better understanding of the mechanisms by which batchnorm and other weight normalization schemes work. It also explains why the combination of weight decay with weight normalization schemes results in a relatively robust gradient descent regime where constant learning rate works well.

]]>- Edwin Fong, Chris Holmes (2019) On the marginal likelihood and cross-validation

I found this paper to be lacking on the accessibility front, mostly owing to the fact that it is a

]]>Here's a paper someone has pointed me to, along the lines of "everything that works, works because it's Bayesian":

- Edwin Fong, Chris Holmes (2019) On the marginal likelihood and cross-validation

I found this paper to be lacking on the accessibility front, mostly owing to the fact that it is a mixture of two somewhat related but separate things:

- (A) a simple-in-hindsight and cool observation about the relationship between marginal likelihoods and cross validation which I will present in this post, and
- (B) a somewhat tangential sidetrack about a generalized form of Bayesian inference and prequential analysis which I think is mostly there to advertise an otherwise interesting line of research Chris Holmes and colleagues have been working on for some time. The advertising worked for sure, as I found the underlying paper (Bissiri et al, 2016) quite interesting as well. I will leave discussing that to a different post, maybe next week.

To discuss the connection between marginal likelihoods to (Bayesian) cross validation, let's first define what is what.

First of all, we are in the world of exchangeable data, assuming we model a sequence of observations $x_1,\ldots,x_N$ by a probabilistic model which renders them conditionally independent given some global parameter $\theta$. Our model is thus specified by the observation model p(x\vert \theta) and prior $p(\theta)$. The marginal likelhiood is the probability mass this model assigns to a given sequence of observations:

$$

p(x_1,\ldots,x_N) = \int \prod p(x_i \vert \theta) p(\theta) d\theta

$$

Important for the discussion of its connection with cross-validation, the marginal likelhihood, like any multivariate distribution, can be decomposed by the chain rule:

$$

p(x_1,\ldots,x_N) = p(x_1)\prod_{n=1}^{N-1}p(x_{n+1}\vert x_1,\ldots, x_{n})

$$

And, of course, a similar decomposition exists for any arbitrary ordering of the observations $x_n$.

Another related quantity is a single-fold leave-$P$-out cross-validation. Here, we set the last $P \leq N$ observations aside, fit our model to the first $N-P$ observations, and then we calculate the average predictive log loss on the held-out pounts. This can be written as:

$$

- \sum_{p=1}^{P} \log p(x_{N-p+1}\vert x_1, \ldots, x_{N-P})

$$

Importantly, here, we assume that we perform Bayesian cross-validation of the model. I.e. in this formula, the parameter $\theta$ is integrated out. In fact what we're looking at is:

$$

- \sum_{p=1}^{P} \log \int p(x_{N-P+1}\vert \theta) p(\theta \vert x_1, \ldots, x_{N-P}) d\theta

$$

Now of course, we could leave any other subset of size $P$ of the observations out. If we repeat this process $K$ times with a uniform random subset of datapoints left out each time, and average the results over the $K$ trials, we have $K$-fold leave-$P$-out cross validation. If $K$ is large enough, we might be trying all possible subsets of $P$ with the same probability. I will cheesily call this $\infty$-fold cross-validation. Mathematically, $\infty$-fold leave-$P$-out Bayesian cross-validation is the following quantity:

$$

- \frac{1}{N \choose P} \sum_{\substack{S⊂\{1\ldots N\}\\|S|=P}} \sum_{i \in S} \log p(x_i\vert x_j : j \notin S),

$$

which is Eqn (10) in the paper with slightly different notation.

The connection I think is best illustrated in the following way. Let's consider three observations, and all the possible ways we can permute them. There are $3(3+1)/2 = 6$ different permutations. For each of these permutations we can decompose the marginal likelihood as a product of conditionals, or equivalently we can write the log marginal likelihood as a sum of logs of the same conditionals. Let's arrange these log conditionals into a table as follows:

Each column corresponds to a different ordering of variables, and summing up the terms in each column gives the log marginal likelihood. So, the sum of all the terms in this matrix gives the marginal likelihood times 6 (as there are 6 columns). In general it gives $N(N+1)/2$ times the marginal likelihood for $N$ observations. Now look at the sums of the terms in each row. The first row is full of terms you'd see in leave-$3$-out cross validation (which doesn't make too much sense with $3$ observations). In the second row, you see terms for leave-2-out CV. Third row corresponds to leave-1-out CV. So, if you do some careful combinatorics (homework) and count how many duplicate terms you'll find in each row, one can conclude that the sum of leave-K-out $\infty$-fold Bayesian cross-validation errors for all values of $K$ gives you the log marginal likelihood times a constant. Which is the main point of the paper.

This observation gives a really good motivation for using the marginal likelihood, and also gives a new perspective on how it works. For $N$ datapoints, there are 2^N-1 different ways of selecting a non-empty test set and corresponding training set. Calculating the marginal likelihood amounts to evaluating the average predictive score on all of these exponentially many 'folds'.

Before we jump to the conclusion that cross-validation, too, works only because it is essentially an approximation to Bayesian model selection, we must remind ourselves that this connection only holds for Bayesian cross-validation. What this means is that in each fold of cross-validation, we integrate $\theta$ in a Bayesian fashion.

In practice, when cross-validating neural networks, we usually optimize over the parameters rather than integrate in a Bayesian way. Or, at best, we use a variational approximation to the posterior and integrate over that approximately. As the relationship only holds in theory, when exact parameter marginalization is performed, it remains to be seen how useful and robust this connection will prove in potential applications.

]]>- Aravind Rajeswaran, Chelsea Finn, Sham Kakade, Sergey Levine (2019) Meta-Learning with Implicit Gradients

Another paper that came out at the

]]>This week I read this cool new paper on meta-learning: it a slightly different approach compared to its predecessors based on some observations about differentiating the optima of regularized optimization.

- Aravind Rajeswaran, Chelsea Finn, Sham Kakade, Sergey Levine (2019) Meta-Learning with Implicit Gradients

Another paper that came out at the same time has discovered similar techniques, so I thought I'd update the post and mention it, although I won't cover it in detail and the post was written primarily about Rajeswaran et al (2019)

- Yutian Chen, Abram L. Friesen, Feryal Behbahani, David Budden, Matthew W. Hoffman, Arnaud Doucet, Nando de Freitas (2019) Modular Meta-Learning with Shrinkage

- I will give a high-level overview of the meta-learning setup, where our goal is to learn a good initialization or regularization strategy for SGD so it converges to better minima across a range of tasks.
- I illustrate how iMAML works on a 1D toy-example, and discuss the behaviour and properties of the meta-objective.
- I will then discuss a limitation of iMAML: that it only considers the location of minima, and not the probability with which a stochastic algorithm ends up in a specific minimum.
- I will finally relate iMAML to a variational approach to meta-learning.

Meta-learning has several possible formulations, I will try to explain the setup of this paper following my own interpretation and notation that differs from the paper but will make my explanations clearer (hopefully).

In meta-learning we have a series of independent tasks, with associated training and validation loss functions $f_i$ and $g_i$, respectively. We have a set of model parameters $\theta$ which are shared across the tasks, and the loss functions $f_i(\theta)$ and $g_i(\theta)$ evaluate how well the model with parameters $\theta$ does on the training and test cases of task $i$. We have an algorithm that has access to the training loss $f_i$ and some meta-parameters $\theta_0$, and output some optimal or learned parameters $\theta_i^\ast = Alg(f_i, \theta_0)$. The goal of the meta-learning algorithm is to optimize the meta-objective

$$

\mathcal{M}(\theta_0) = \sum_i g_i(Alg(f_i, \theta_0))

$$

with respect to the meta-parameters $\theta_0$.

In early versions of this work, MAML, the algorithm was chosen to be stochastic gradient descent, $f_i$ and $g_i$ being the training and test loss of a neural network, for example. The meta-parameter $\theta_0$ was the point of initialization for the SGD algorithm, shared between all the tasks. Since SGD updates are differentiable, one can compute the gradient of the meta-objective with respect to the initial value $\theta_0$ by simply backpropagating through the SGD steps. This was essentially what MAML did.

However, the effect of initialization on the final value of $\theta$ is pretty weak, and difficult - if at all possible - to characterise analytically. If we allow the SGD to go on for many steps, we might converge to a better parameter, but the trajectory will be very long, and the gradients with respect to the initial value vanish. If we make the trajectories short enough, the gradients w.r.t. $\theta_0$ are informative but we may not reach a very good final value.

This is why Rajeswaran et al opted to make the dependence of the final point of the trajectory on meta-paramteter $\theta\_0$ way stronger: Instead of simply initializing SGD from $\theta\_0$ they also anchor the parameter to stay in the vicinity of $\theta\_0$ by adding a quadratic regularizer $\|\theta - \theta_0\|$ to their loss. Because of this, two things happen:

- now all steps of the SGD depend on $\theta$, not just the initial point
- now the location of the minimum SGD eventually converges to also depend on #\theta\_0#

It is this second property that iMAML exploits. Let me illustrate what that dependence looks like:

In the figure above, let's say that we would like to minimise an objective function $f(\theta)$. This would be the training loss of one of the tasks the meta-learning algorithm has to solve. Our current meta-parameter $\theta_0$ is marked on the x axis, and the orange curve shows the associated quadratic penalty. The teal curve shows the sum of the objective with the penalty. The red star shows the location of the minimum, which is what the learning algorithm finds.

Now let's animate this plot. I'm going to move the anchor point $\theta_0$ around, and reproduce the same plots. You can see that, as we move $\theta_0$ and the associate penalty, the local (and therefore global) minima of the regularized objective move change:

So it's clear that there is a non-trivial, non-linear relationship between the anchor-point $\theta_0$ and the location of a local minimum $\theta^\ast$. Let's plot this relationship as a function of the anchor point:

We can see that this function is not at all nice to work with, it has sharp jumps when the closest local minimum to $\theta_0$ changes, and it is relatively flat between these jumps. In fact, you can observe that the sharpest the local minimum nearest to $\theta_0$ is, the flatter the relationship between $\theta_0$ and $\theta$. This is because if $f$ has a sharp local minimum near $\theta_0$, then the location of the regularized minimum will be mostly determined by $f$, and the location of the anchor point $\theta_0$ doesn't matter much. If the local minimum around f is wide, there's a lot of wiggle room for the optimum and the effect of the regularization will be larger.

And now we come to the whole point of the iMAML procedure. The gradient of this function $\theta^\ast(\theta_0)$ in fact can be calculated in closed form. It is, indeed, related to the curvature, or second derivative, of $f$ around the minimum we find:

$$

\frac{d\theta^\ast}{d\theta_0} = \frac{1}{1 + f''(\theta^\ast)}

$$

In order to check that this formula works, I calculated the derivative numerically and compared it with what the theory predicts, they match perfectly:

When the parameter space is high-dimensional, we have a similar formula involving the inverse of the Hessian plus the identity. In high dimensions, inverting or even calculating and storing the Hessian is not very practical. One of the main contributions of the iMAML paper is a practical way to approximate gradients, using a conjugate gradient inner optimization loop. For details, please read the paper.

When optimizing the anchor point in a meta-learning setting, it is not the location $\theta^\ast$ we are interested in, only the value that the function $f$ takes at this location. (in reality, we would now use the validation loss, in place of the training loss used for gradient descent, but for simplicity, I assume the two losses overlap). The value of $f$ at its local optimum is plotted below:

Oh dear. This function is not very pretty. The meta-objective $f(\theta^\ast(\theta_0))$ becomes a piecewise continuous function, a connection of neighbouring basins, with non-smooth boundaries. The local gradients of this function contain very little information about the global structure of the loss function, it only tells you where to go to reach the nearest local minimum. So I wouldn't say this is the nicest function to optimize.

Thankfully, though, this function is not what we have to optimize. In meta-learning, we have a distribution over functions $f$ we optimize, so the actual meta-objective is something like $\sum_i f_i(\theta_i^\ast(\theta_0))$. And the sum of a bunch of ugly functions might well turn into something smooth and nice. In addition, the 1-D function I use for this blog post is not representative of the high-dimensional loss functions of neural networks which we want to apply iMAML to. Take for example the concept of mode connectivity (see e.g. Garipov et al, 2018): it seems that the modes found by SGD using different random seeds are not just isolated basins, but they are connected by smooth valleys along which the training and test error are low. This may in turn make the meta-objective behave more smoothly between minima.

An important aspect that MAML or iMAML do not not consider is the fact that we usually use stochastic optimization algorithms. Rather than deterministically finding a particular local minimum, SGD samples different minima: when run with different random seeds it will find different minima.

A more generous formulation of the meta-objective would allow for stochastic algorithms. If we denote by $\mathcal{Alg}(f_i, \theta_0)$ the distribution over solutions the algorithm finds, the meta-objective would be

$$

\mathcal{M}_{stochastic}(\theta) = \sum_i \mathbb{E}_{\theta \sim \mathcal{Alg}(f_i, \theta_0)} g_i(\theta)

$$

Allowing for stochastic behaviour might actually be a great feature for meta-learning. While the position of the global minimum of the regularized objective can change abruptly s a function $\theta_0$ (as illustrated in the third figure above), allowing for stochastic behaviour might smooth our the meta-learning objective.

Now suppose that SGD anchored to $\theta_0$ converges to one of a finite set of local minima. The meta-learning objective now depends on $\theta_0$ in two different ways:

- as we change the anchor $\theta_0$, the location of the minima change, as illustrated above. This change is differentiable, and we know its derivative.
- as we change the anchor $\theta_0$, the probability with which we find the different solutions changes. Some solutions will be found more often, some less often.

iMAML accounts for the first influence, but it ignores the influence through the second mechanism. This is not to say that iMAML is broken, but that it misses a possibly crucial contribution of stochastic behaviour that MAML or explicitly differentiating through the algorithm does not.

Of course this work reminded me of a Bayesian approach. Whenever someone describes quadratic penalties, all I see are Gaussian distributions.

In a Bayesian interpretation of iMAML, one can think of the anchor point $\theta_0$ as the mean of a prior distribution over the neural network's weights. The inner loop of the algorithm, or $Alg(f_i, \theta_0)$ then finds the maximum-a-posteriori (MAP) approximation to the posterior over $\theta$ given the dataset in question. This is assuming that the loss is a log likelihood of some kind. The question is, how should one update the meta-parameter $\theta_0$?

In the Bayesian world, we would seek to optimize $\theta_0$ by maximising the marginal likelihood. As this is usually intractable, so it is common to turn to a variational approximation, which in this case would look something like this:

$$

\mathcal{M}_{\text{variational}}(\theta_0, Q_i) = \sum_i \left( KL[Q_i\vert \mathcal{N}_{\theta_0}] + \mathbb{E}_{\theta \sim Q_i} f_i(\theta) \right),

$$

where $Q_i$ approximates the posterior over model parameters for task $i$. A specific choice of $Q_i$ is a dirac delta distribution centred at a specific point $Q_i(

theta) = \delta(\theta - \theta^{\ast}_i)$. If we generously ignore some constants that blow up to infinitely large, the KL divergence between the Gaussian prior and the degenerate point-posterior is a simple Euclidean distance, and our variational objective reduces to:

$$

\mathcal{M}_{\text{variational}}(\theta_0, \theta_i) = \sum_i \left( \|\theta_i - \theta_0\|^2 + f_i(\theta_i) \right)

$$

Now this objective function looks very much like the optimization problem that the inner loop of iMAML attempts to solve. If we were working in the pure variational framework, this may be where we leave things, and we could jointly optimize all the $\theta_i$s as well as $\theta_0$. Someone in the know, please comment below pointing me to the best references where this is being done for meta-learning.

This objective is significantly easier to optimize with and involves no inner-loop optimization or black magic. It simply ends up pulling $\theta_0$ closer to the centre of gravity of the various optima found for each task $i$. Not sure if this is such a good idea though for meta-learning, as the final values of $\theta_i$ which we reach by jointly optimizing over everything may not be reachable by doing SGD from $\theta_0$ from scratch. But who knows. A good idea may be, given the observations above, to jointly minimize the variational objective with respect to $\theta_0$ and $\theta_i$, but every once in a while reinitialize $\theta_i$ to be $\theta_0$. But at this point, I'm really just making stuff up...

Anyway, back to iMAML, which does something interesting with this variational objective, and I think it can be understood as a kind of amortized computation: Instead of treating $\theta_i$ as separate auxiliary parameters, it specifies that $\theta_i$ are in fact a deterministic function of $\theta_0$. As the variational objective is a valid upper bound for any value of $\theta_i$, it is also a valid upper bound if we make $\theta_i$ explicitly dependent on $\theta_0$. The variational objective thus becomes a function of $\theta_0$ only (and also of hyperparameters of the algorithm $Alg$ if it has any):

$$

\mathcal{M}_{\text{variational}}(\theta_0) = \sum_i \left( \|Alg(f_i, \theta_0) - \theta_0\|^2 + f_i(Alg(f_i, \theta_0)) \right)

$$

And there we have it. A variational objective for meta-learning $\theta_0$ which is very similar to the MAML/iMAML meta-objective, except it also has the $\|Alg(f_i, \theta_0) - \theta_0\|^2$ term which factors into updating $\theta_0$ which we didn't have before. Also notice that I did not use separate training and validation loss $f_i$ and $g_i$ but that would be a very justified choice as well.

What is cool about this is that this provides extra justification and interpretation for what iMAML is trying to do, and suggests directions in which iMAML could perhaps be improved. On the flipside, the implicit differentiation trick in iMAML might be useful in other situations where we want to amortize the variational posterior similarly.

I'm pretty sure I missed many references, please comment below if you think I should add anything, especially on the variational bit.

]]>- Martin Arjovsky, Léon Bottou,

I finally got around to reading this new paper by Arjovsky et al. It debuted on Twitter with a big splash, being decribed as 'beautiful' and 'long awaited' 'gem of a paper'. It almost felt like a new superhero movie or Disney remake just came out.

- Martin Arjovsky, Léon Bottou, Ishaan Gulrajani, David Lopez-Paz (2019) Invariant Risk Minimization

The paper is, indeed, very well written, and describes a very elegant idea, a practical algorithm, some theory and lots of discussion around how this is related to various bits. Here, I will describe the main idea and then provide an information theoretic view on the same topic.

We would like to learn robust predictors that are based on invariant causal associations between variables, rather than spurious surface correlations that might be present in our data. If we only observe i.i.d. data from a generative process, this is generally not possible.

In this paper, the authors assume that we have access to data sampled from different environments $e$. The data distribution in these different enviroments is different, but there is an underlying causal dependence of the variable of interest $Y$ on some of the observed features $X$ that remains constant, or invariant across all environments. The question is, can we exploit the variability across different environments to learn this underlying invariant association?

Usual empirical risk minimisation (ERM) approaches cannot distinguish between statistical associations that correspond to causal connections, and those that are just spurious correlations. Invariant Risk Minimization can, in certain situations. It does this by finding a representation $\phi$ of features, such that the optimal predictor is simultaneously Bayes optimal in all environments.

The authors then propose a practical loss function that tries to capture his property:

$$

\min_\phi \sum_e \mathcal{R}^e(\Phi) + \lambda \|\nabla_{w\vert w=1}\mathcal{R}^e(w \cdot \Phi)\|^2_2

$$

The first term is the usual ERM: we're trying to minimize average risk across all environments, using a single predictor $\phi$. The second term is where the interesting bit happens. I said before that what we want this term to encorage is that $\phi$ is simultaneously Bayes-optimal in all environments. What the term actually looks at is whether $\phi$ is locally optimal, wether it can be improved locally by scaling by a constant $w$. For details, I recommend reading the paper where a lot of intuitive explanation is provided. In this post, I'll focus on an information theoretic interpretation of what's going on.

Unlike the authors, who treat the environment index $e$ as something outside of the structural equation model, I prefer to think of $E$ as also being part of the generative process: an observable random variable. This may not be the most useful formulation in all circumstances, but it will help when trying to derive IRM through the lens of conditional dependence relationships.

The most general generative model of data in the IRM setup looks something like this:

There are three (sets of) observable variables: $E$, the environment index, $X$, the features describing the datapoint and $Y$, the label we wish to predict. I also assume the existence of a hidden confounder $W$. In the above graph, I separated $X$ into upstream dimensions $X_1$ and downstream dimensions $X_2$ based on where they are in the causal chain relative to $Y$. In reality, we don't know what this breakdown is, and breaking $X$ up to $X_1$ and $X_2$ may not be trivial due to entanglement, but it is still reasonable to assume that some components of the input $X$ encode causal parents of $Y$, and others encode causal descendants of $Y$.

The environment $E$ influences every factor in this generative model, except the factor $p(Y\vert X_1, W)$: notice there is no arrow from $E$ to $Y$. In other words, it is assumed that the relationship of variable $Y$ to its observable causal parents $X_1$ and hidden variables $W$ is stable across all environments. This is the primary underlying assumption of IRM. Now, let's read out some conditional independence relationships from this graph:

- $Y \cancel{\perp\mkern-13mu\perp} E$: this is simply saying that the marginal distribution of $Y$ can, generally, change across environments.
- $Y \perp\mkern-13mu\perp E\vert X_1, W$: The observable $X_1$ and latent $W$ shield the label $Y$ from the influence of the environment. I already said that this is the key assumption on which IRM is based: that there is an underlying causal mechanism determining the value of Y from its causal parents, which does not change across environments.
- $Y \cancel{\perp\mkern-13mu\perp} E\vert X_1$: If we leave $W$ out of the conditioning, the above environment-independence no longer holds. This is because the confounder $W$ inroduces spurious association between $X_1$ and $Y$. This spurious correlation is assumed to be environment-dependent.
- $Y \cancel{\perp\mkern-13mu\perp} E\vert X_1, X_2$: this is, perhaps, the most important point. This dependence statement says that the way $Y$ depends on the observable variables $X = (X_1, X_2)$ is environment-dependent. This can be verified by noticing that $X_2$ is a collider between $E$ and $Y$. Conditioning on a collider introduces spurious correlations (for example, explaining away).

In summary, the association between $X$ and $Y$ will be a result of three sources of correlation:

- real causal relationship between some components of $X$ and $Y$.
- spurious association introduced by the unobserved confounder $W$.
- spurious association introduced by conditioning on parts of $X$ which is are causally influenced by $Y$, rather than the other way around.

In a general, if our generative model describes the world accurately, the conditional independence statements we observed tell us that while the real causal association is stable across environments ($Y \perp\mkern-13mu\perp E\vert X_1, W$), the other two are environment-dependent ($Y \cancel{\perp\mkern-13mu\perp} E\vert X_1$ and $Y \cancel{\perp\mkern-13mu\perp} E\vert X_1, X_2$). Thus, we can eliminate the spurious associations by seeking associations that are stable across environments, i.e. independent of $E$.

One can interpret the objective of Invariant Risk Minimisation as seeking a representation of observable variables $\Phi(x)$, such that:

- $Y \perp\mkern-13mu\perp E\vert \phi(X)$, and
- $\phi$ is informative about $y$, i.e. we can predict $y$ accurately from $\phi(x)$

This is a bit similar to the information bottleneck criterion which would seek a stochasstic representation $Z = \phi(X, \nu)$, $\nu$ being some random noise, by solving the following optimization problem:

$$

\max_{\phi} \left\{ I[Y, Z] - \beta I[X, Z] \right\}

$$

We could similarly attempt to find an invariant repesentation $Z = \phi(X)$ by minimizing an objective like:

$$

\max_{\phi} \left\{ I[Y, Z] - \beta I[Y, E \vert Z] \right\}

$$

Notice, that one can write mutual information $\mathbb{I}[Y, E \vert \phi(X)] $ in the following variational interpretation:

$$

I[Y, E \vert \phi(x)] = \max_q \min_r \mathbb{E}_{x,y,e} [\log q(y\vert \phi(x), e) - \log r(y\vert \phi(x))]

$$

With a little bit of math juggling, we can write the above stability objective in the following form:

\begin{align}

\max_{\phi} \left\{ I[Y, \phi(X)] - \beta I[Y, E \vert \phi(X)] \right\} = \max_\phi \max_r \min_q \mathbb{E}_{x,y,e} \left[\log r(y\vert \phi(x)) - \frac{\beta}{1 + \beta}\log q(y\vert \phi(x), e) \right]

\end{align}

This optimization problem is intuitively already very similar to IRM: we would like a representation $\phi$ so we can predict $y$ from it accurately, but we shouldn't be able to build a much better predictor by overfitting to one of the environments. In other words: here, too, we seek a representation so that the Bayes optimal predictor of $y$ is close to Bayes optimal simultanously in all environments.

Sadly, this is a minimax type of problem, so $\phi$ can only be found with a GAN-like iterative algorithm - something we don't really like, but we're kind of getting used to. It would be interesting to see how such algorithm would work, and if you're aware of this being done already, please feel to point me to the relevant references in the comments section.

I wanted to add sidenote here that I think we can recover something very similar in spirit to (IRMv1). I leave developing the full connection as a homework to you. Let me just illustrate the basic idea here.

Say we have a parametric family of functions $f(y\vert \phi(x); \theta)$ for predicting $y$ from $\phi(x)$. The conditional information can be approximated as follows:

\begin{align}

I[Y, E \vert \phi(x)] &\approx \min_\theta {E}_{x,y} \ell (f(y\vert \phi(x); \theta) - \mathbb{E}_e \min_{\theta_e} \mathbb{E}_{x,y\vert e} \ell (f(y\vert \phi(x); \theta_e)\\

&= \min_\theta \mathbb{E}_e \mathcal{R}^e(f_\theta\circ\phi) - \mathbb{E}_e \min_{\theta_e} \mathcal{R}^e(f_{\theta_e}\circ\phi)

\end{align}

where $\ell$ is the log-loss, if we want to recover Shannon's information. If we assume that $f$ is a universal function approximator, an equality holds. If, instead of globally optimizating $\theta_e$, we only search locally within a trust region around $\theta$, we can create the following (approximate) lower-bound to the information.

\begin{align}

I[Y, E \vert \phi(x)] &\geq \min_\theta {E}_{x,y}\ell f(y\vert \phi(x); \theta) - \mathbb{E}_e \min_{\|d\|^2\leq \epsilon} \mathbb{E}_{x,y\vert e} \ell f(y\vert \phi(x); \theta + d) \\

&= \min_\theta \mathbb{E}_e \left\{ \mathcal{R}^e(f_\theta\circ\phi) - \min_{\|d_e\|\leq \epsilon}\mathcal{R}^e(f_{\theta + d_e}\circ\phi) \right\}

\end{align}

Now, we can approximate the risk $\mathcal{R}^e(f_{\theta + d_e}\circ\phi) $ locally by a first order Taylor approximation around $\theta$, and show that, as $\epsilon \rightarrow 0$, we obtain that:

$$

I[Y, E \vert \phi(x)] \geq \min_\theta \mathbb{E}_e \| \nabla_\theta \mathbb{E}_{x,y\vert e} [\ell f(y\vert \phi(x), \theta)] \|_2= \min_\theta \mathbb{E}_e \| \nabla_\theta \mathcal{R}^e(f_\theta\circ\phi) \|_2

$$

Compare this with the second term in Eqn IRMv1 of the paper. If we now add back in the requirement that we would like to be able to predict $y$ from $\phi(x)$, we get an optimization problem of the following form:

$$

\min_\phi \left\{ \min_\theta \mathbb{E}_e \mathcal{R}^e(f_\theta\circ\phi) + \lambda \min_\theta \mathbb{E}_e \| \nabla_\theta \mathcal{R}^e(f_\theta\circ\phi) \|_2 \right\},

$$

which is almost like the IRM objective. Technically, there are two minimizations over $\theta$ and there's no reason why the two shouldn't be done separately. Note however, that a global minimum of the second term $\mathbb{E}_e \| \nabla_\theta \mathcal{R}^e(f_\theta\circ\phi) \|_2$ is always a local minimum of the first term. This justifies connecting the two minimization problems together:

$$

\min_\phi \min_\theta \left\{ \mathbb{E}_e \mathcal{R}^e(f_\theta\circ\phi) + \lambda \mathbb{E}_e \| \nabla_\theta \mathcal{R}^e(f_\theta\circ\phi) \|_2 \right\},

$$

This is no longer an ugly minimax problem, however, it is still not amazing. It is a lower bound to the original objective which we originally wished to minimize. The lower bound was created when we replaced global minimization over local minimization. Thus, the bound is actually tight if all local minima with respect to $\theta$ of $\mathcal{R}^e(f_\theta\circ\phi)$ are also global minima, e.g. if the loss is convex. In non-convex problems, all bets are off. It may still work, but who knows.

This is indeed a nice paper, with lots of great insights. Unfortunately, I am not sure how realistic the assumptions are that we can sample from a multitude of different environments, which differ from each other sufficiently so that the invariant causal quantities can be identified.

I would like to mention that this is not the first time that invariance and causality have been connected and exploited for domain adaptation. I personally first encountered this idea in a talk by Jonas Peters at the Causality Workshop in 2018. Here is a related paper by him that I wanted to highlight here:

- Peters, Bühlmann, Meinshausen (2016) Causal inference by using invariant prediction: identification and confidence intervals

And here are two more papers which propose a causal treatment of the domain adaptation problem:

- Adarsh Subbaswamy, Peter Schulam, Suchi Saria (2018) Preventing Failures Due to Dataset Shift: Learning Predictive Models That Transport
- Christina Heinze-Deml, Nicolai Meinshausen (2019) Conditional Variance Penalties and Domain Shift Robustness

The second paper, which commenters also pointed out to me, is perhaps the most closely related, but it is based on slightly different assumptions about what is invariant across the domains.

Finally, commenters asked me about domain-adversarial learning, so I wanted to include a pointer here for completeness:

- Yaroslav Ganin, Victor Lempitsky (2014) Unsupervised Domain Adaptation by Backpropagation

On this paper, I agree with Arjovsky et al (2019)'s discussion in the IRM paper: it promotes the wrong invariance property by trying to learn a data representation that is marginally independent of the domain index. See the discussion on this in the comments section below.

Finally, I wanted to point out another slightly looser connection: non-stationarity, or the availability of data from multiple environments has been exploited by (Hyvarinen and Morioka, 2016) exploits this idea for unsupervised feature learning. It turns out, this non-stationarity and the availability of different environments makes otherwise non-identifiable nonlinear ICA models identifable.

]]>- Ruiz and Titsias (2019) A Contrastive Divergence for Combining Variational Inference and MCMC

Welcome to my ICML 2019 jetlag special - because what else do you do when you wake up earlier than anyone than write a blog post. Here's a paper that was presented yesterday which I really liked.

- Ruiz and Titsias (2019) A Contrastive Divergence for Combining Variational Inference and MCMC

First, some background on why I found this paper particulartly interesting. When AlphaGo Zero came out, I wrote a post about the *principle of minimal improvement*: Suppose you have an operator which may be computationally expensive to run, but which that can take a policy and improve it. Using such improvement operator you can define an objective function for policies by measuring the extent to which the operator changes a policy. If your policy is already optimal, the operator can't improve it any further, so the change will be zero. In the case of AlphaGo Zero, the improvement operator is Monte Carlo Tree Search (MCTS). I noted in that post how the same principle may be applicable to approximate inference: expectation propagation and contrastive divergence both can be casted in a similar light.

The paper I'm talking about uses a very similar argument to come up with a contrastive divergence for variational inference, where the improvement operator is MCMC step.

The two dominant ways of performing inference in latent variable models are variational inference (including amortized inference, such as in VAE), and Markov Chain Monte Carlo (MCMC). VI approximates the posterior with a paramteric distribution. This can be computationally efficient and practically convenient as VI results in end-to-end differentiable, unbiased estimates of the evidence lower bound (ELBO). As a drawback, the paramteric posterior approximation often can't approximate the posterior perfectly, and an approximation error almost always remains. By contrast, an MCMC approximate posterior can always be improved by running the chains longer, and obtaining more independent samples, but it is more difficult to work with and computationally more demanding than VI.

A lot of great work has been done recently on combining VI with MCMC (see references in Ruiz and Titsias, 2019). Usually, one starts from a crude, parametric variational approximation to the posterior, and then improves it by running a couple steps of MCMC. Crucially, one can view the transition kernel, $\Pi$, of the MCMC as an improvement operator: given any distribution $q$, taking an MCMC steps should take you closer to the posterior. In other words, $\Pi q$ is an improvement over $q$. Taking multiple steps, i.e. $\Pi^t q$, should provide an even greater improvement. Improvement is $0$ only when $\Pi q = q$ which only holds for the true posterior. It is a reasonable criterion therefore to seek a posterior approximation $q_\theta$ such that the improvement in $\Pi^t q_\theta$ over over $q_\theta$ is minimized.

There are two ways of quantifying the amount of improvement, or change, that MCMC provides over a parametric posterior $q_\theta$. The first measures how much closer we got to the posterior $p$ by comparing the KL divergences:

$$

\mathcal{L}_1(\theta) = \operatorname{KL}\left[q_\theta\middle\|p\right] - \operatorname{KL}\left[\Pi^tq_\theta\middle\|p\right]

$$

The second one measures the amount of change between $\Pi^tq_\theta$ and $q_\theta$, measured as the KL divergence. This objective function merely tries to identify fixed points of the improvement operator:

$$

\mathcal{L}_2(\theta) = \operatorname{KL}\left[\Pi^tq_\theta\middle\|q_\theta\right]

$$

Either of these objectives would make sense and is justified on its own right, but sadly neither of them can be evaluated or optimized easily in this case. Both require taking expectations over $\Pi^tq_\theta$, as well as evaluating $\log \Pi^tq_\theta$. However, the brilliant insight in this paper is that when you sum them together, the most problematic terms cancel out, leaving you with a tractable objective to minimise.

$$

\mathcal{L}_1(\theta) + \mathcal{L}_2(\theta) = \mathbb{E}_{z\sim \Pi^t q_\theta} f_\theta(z) - \mathbb{E}_{z\sim q_\theta} f_\theta(z),

$$

where $f_theta(z) = \log p(z,x) - \log q_\theta(z)$, $x$ is the observed, $z$ the hidden variable.

There is one more technical hurdle to overcome, which is to calculate or estimate the derivative of this objective with respect to $\theta$. The authors propose a REINFORCE-like score function gradient estimator in Eqn. (12), which is somewhat worrying as it is known to have very high variance. The authors propose overcoming this using a control variate. For more details, please refer to the paper.

There is further discussion on the behaviour of this objective function in the limit of infinitely long MCMC paths, i.e. $t\rightarrow\infty$. It turns out, the criterion works like the symmetrized KL divergence $KL[q\|p] + KL[p\|q]$. The difference of this objective from the usual conservative mode and seeking VI objective is neatly illustrated in Figure 1 of the paper:

Variational Contrastive Divergence (VCD) favours posterior approximations which have a much higher coverage of the true posterior compared to VI, which tries to cover the modes and tries to avoid allocating mass to areas where the true posterior does not.

]]>- Kunstner, Balles and Hennig (2019) Limitations of the Empirical Fisher Approximation

I was debating with myself whether I should write a post about this because it's a superbly written paper that you should probably

]]>This post is a short not on an excellent recent paper on empirical Fisher information matrices:

- Kunstner, Balles and Hennig (2019) Limitations of the Empirical Fisher Approximation

I was debating with myself whether I should write a post about this because it's a superbly written paper that you should probably read in full. There isn't a whole lot of novelty in the paper, but it is a great discussion paper that provides a concise overview of the Fisher information, the empirical Fisher matrix and their connectinos to generalized Gauss-Newton methods. For more detail, I also recomend this older and longer overview by James Martens.

For a supervised learning model $p_\theta (y\vert x)$ the Fisher infromation matrix for a specific input $x$ is defined as follows:

$$

F_x = \mathbb{E}_{p_\theta(y\vert x)} \left[ \nabla_\theta \log p_\theta(y\vert x) \nabla_\theta \log p_\theta(y\vert x)^\top \right]

$$

The Fisher information roughly measures the sensitivity of the model's output distribution to small changes in the parameters $\theta$. For me, the most intuitive interpretation is that it is the local curvature $KL[p_\theta\vert p_{\theta'}]$ as $\theta' \rightarrow \theta$. So, by using the Fisher infromation as a quadratic penalty, or as a preconditioner in second order gradient descent, you're penalizing changes to the model's output distribution as measured by the KL divergence. This makes sense especially for neural networks, where small changes in network parameters can lead to arbitrary changes in the function the model actually implements.

If you have an entire distribution or a dataset, you normally add up the pointwise $F_x$ matrices to form the Fisher information for the whole training data:

$$

F = \sum_n F_{x_n} = \sum_n \mathbb{E}_{p_\theta(y\vert x_n)} \left[ \nabla_\theta \log p_\theta(y\vert x_n) \nabla_\theta \log p_\theta(y\vert x_n)^\top \right]

$$

In pracice, however, people often work with another matrix instead, which is called the empirical Fisher information matrix, defined as follows:

$$

\tilde{F} = \sum_n\left[ \nabla_\theta \log p_\theta(y_n\vert x_n) \nabla_\theta \log p_\theta(y_n\vert x_n)^\top \right]

$$

The difference, of course, is that rather than sampling $y$ from the model itself, here, we use the observed labels $y_n$ from the dataset. This matrix, while it's often used, does not have a principled motivation. The paper shows how it lacks most properties that make the Fisher information so useful quantity in many cases, and it also lacks any principled motivation. The paper's simple, yet insightful, examples show how using the empirical Fisher instead of the real Fisher in second order optimization can lead to bonkers results:

What the above example shows is the vector field corresponding to differently preconditioned gradient descent algorithms in a two-parameter simple least squares linear regresesion example. The first panel shows the gradients field. The second shows the natural graident field, i.e. gradients corrected by the inverse Fisher information. The third shows the gradients corrected by the empirical Fisher instead. You can see that the empirical Fisher does not make the optimization problem any easier or better conditioned. The bottom line is: you probably shouldn't use it.

One particular *myth* that Kunstner et al (2019) try to debunk is that the Fisher and empirical Fisher coincide near minima of the loss function, and therefore either can be safely usef once you're done with training your model. I have made this argument several times in the past myself. If you look at he definitions, it's easy to conclude that they should indeed coincide under two conditions:

- that the model $p_\theta(y\vert x)$ closely approximates the
*true*sampling distribution $p_\mathcal{D}(y\vert x)$, and - that you have sufficiently large $N$ (number of samples) so that both the Fisher and empirical Fisher converge to their respective population average.

The paper argues that these two conditions are often at odds in practice. One either uses a simple model, in which case model misspecification is likely, meaning $p_\theta(y\vert x)$ can't approximate the true underlying $p_\mathcal{D}(y\vert x)$ well. Or, one uses an expressive, overparametrised model, such as a big neural network, in which case some form of overfitting likely happens, and $N$ is too small for the two matrices to coincide. The paper offers a lot more nuance to this, and again, I highly recommend reading it whole.

The paper also looks at comparing the direction of the natural gradient with the empirical Fisher-corrected gradients, and finds that the EF gradient often tries to go in a very different direction than the natural gradient.

It's also interesting to look at the projection of the EF gradient onto the original gradient, to see how big an EF-gradient update is in the direction of normal gradient descent.

The gradient for a single datapoint $(x_n, y_n)$ is g_n = -\nabla_\theta \log p_\theta(y_n\vert x_n). If we use the EF matrix as preconditioner, the gradient is modifierd to $\tilde{g}_n = \tilde{F}^{-1}g_n = -\tilde{F}^{-1} \nabla_\theta \log p_\theta(y_n\vert x_n)$. Let's look at the inner product $g^\top_n\tilde{g}_n$ on average over the dataset $p_\mathcal{D}$:

\begin{align}

\mathbb{E}_{p_\mathcal{D}} g^\top_n\tilde{g}_n &= \mathbb{E}_{p_\mathcal{D}}g^\top_n\tilde{F}^{-1}g_n \\

&= \operatorname{tr} \mathbb{E}_{p_\mathcal{D}} g^\top_n\tilde{F}^{-1}g_n\\

&= \operatorname{tr}\mathbb{E}_{p_\mathcal{D}} g_ng^\top_n \tilde{F}^{-1}\\

&= \operatorname{tr} \tilde{F}\tilde{F}^{-1}\\

&= D,

\end{align}

where $D$ is the dimensionality of $\theta$. This means that the projection of the EF gradient to the original gradient is constant. Since the length of the gradient $g_n$ can change as a function of $\theta$, this implies that, weirdly, there is an inverse relationship between the norm of $g_n$ and the norm of $\tilde{g}_n$ when projected in the same direction. On the whole, this is consistent with the weird-looking vector field in Figure 1 of the paper.

In summary, I highly recommend reading this paper. It has a lot more than what I covered here, and it is all explained rather nicely. The main message is clear: think twice (or more) about using empirical Fisher information for pretty much any application.

]]>Video is an interesting domain for unsupervised, or self-supervised, representation learning. But we still don't know what type of inductive biases will enable us to best exploit the information encoded in the temporal sequence of video frames. Slow Feature Analysis (SFA) and its more recent cousin Learning to Linearize (e.g. Goroshin et al., 2016) attempt to learn neural representations of images such that a trajectory spanned spanned by the sequence of transformed video frames satisfies a desired property, such as *slowness* or linearity/*straightness*. Whether or not the brain employs similar principles to learn representations is a fascinating topic.

Computational neuroscientists showed examples where representations learnt by the brain are consistent with the slowness principle (see e.g. Franzius et al, 2007). Brand new work by Eero Simoncelli's group suggests that the brain's representations are also consistent with the straightness principle:

- Hénaff, Goris and Simoncelli (2019): Perceptual straightening of natural videos, Nature Neuroscience pdf

This work is in fact a culmination of some great work by the authors over several years on developing the techniques that eventually enabled this study.

As you play a video, the frames of the video span a trajectory in pixel space. This trajectory is usually very complicated even for videos representing relatively simple actions. Imagine a video scene with he camera slowly panning over a static texture, such as a field of grass, at a constant speed. Despite the video appeaers conceptually simple and predictable, the trajectory of individual pixel values is highly non-linear and crazy: most pixels abruptly change their color from one frame to the next.

Instead of representing videos as sequences of frames in pixel space, we would like to map frames into a representation space, such that the trajectory behaves a bit more normally and predictably. For example, it makes sense to hope that in representation space the trajectory is straight or linear, meaning that it has a small curvature. If a trajectory is straight, it also means that linear extrapolation from the last couple points to the future works well.

Imagine we have a sequence (of frames) $x_t$. We map this into a representation space via a nonlinear mapping $f$. We can measure the local *straightness* of such representation by calculating the angle between the difference vectors $f(x_{t+1}) - f(x_{t})$ and $f(x_{t}) - f(x_{t-1})$. This angle is given by:

\begin{align}

\cos\alpha_t = \frac{(f(x_{t+1}) - f(x_{t}))^T(f(x_{t}) - f(x_{t-1}))}{\|(f(x_{t+1}) - f(x_{t}))\|\|(f(x_{t}) - f(x_{t-1}))\|}

\end{align}

The closer this value is to $1$, the straighter the sequence $f(x_{t-1}), f(x_{t}), f(x_{t+1})$ is locally. One can measure the straightness of the whole trajectory by averaging this local measure of curvature across time. It's worth noting that in high dimensional spaces, straightness is difficult to achieve. A $\cos\alpha$ close to $1$ is a very rare. Under Gaussian distributions, two randomly chosen vectors will be perpendicular to one another, yielding a cosine of $0$. So, for example, straight trajectories have an almost $0$ probability under a high-dimensional Brownian motion or Ornstein–Uhlenbeck (OU) process. As slow-feature analysis can be seen as learning representations that behave roughly like OU processes (Turner and Sahani, 2017), SFA would not, by default, yield straight trajectories. If one would like to recover straighter trajectories, we need stronger priors and constraints, such as $k$-times integrated OU processes. See also (Goroshin et al., 2016) for another approach to recover straight representations.

When we train our own a neural network, we can of course directly evaluate the straightness of learnt representations. But ho we strudy whether the representations learnt by the brain exhibit straight trajectories or not? We can't directly observe the representations learned by the human brain. Instead, Hénaff et al (2019) devised an an indirect way to measure the curvature, i.e. infer it from behavioural/psychophysics data. Their approach relies on ideal observer models for a certain perceptual discrimination task: they infer perceptual curvature from data about how accurately subjects can differentiate between pairs of stimuli.

Here is how to intuitively understand why the approach makes sense: We assume that whether or not a human is able to discriminate between two stimuli, when flashed for a fraction of a second, is related to Euclidean distance in the person's representation space between the reprsentations of the two stimuli. Now, if you take a triplet of stimuli (video frames) $A, B, C$ and measure the pairwise distances $d_{AB}, d_{AC}, d_{BC}$ between them, you can establish the straightness of them as a sequence: if the sequence is straight, you expect that $d_{AC}\approx d_{AB} + d_{AC}$. On the other hand, if the sequence is not straight, you'd expect $d^{2_{AC}\approx d}2_{AB} + d^2_{AC}$ to hold instead. So, if we have a model - called a normative or ideal observer model - that can somehow directly relate Euclidean distances in representation space to perceptual discriminability of stimuli, we have a way to experimentally infer the straightness of representations.

For more detailed description of the psychophysics experiments, data preparation, and the ideal observer model, please see the paper.

The main results of the paper - as expected - is that natural video sequences indeed appear to be mapped to straight trajectories in representation space. Furthermore, the figure below (Fig. 3b in the article) shows that the straightness of trajectories which correspond to natural stimuli is higher than the straightness of unnatural, artificially generated stimuli:

The blue curves correspond to a natural video stimulus in pixel space and in (inferred) representation space. You can see that the sequence is highly curved in pixel space, but in representation space it is straight. On the other hand, the green curves show an artificial video created by linearly interpolating between the first and the last frame of a video. This of course results in a straight trajectory in pixel space. However, the inferred trajectory in representation space is highly curved, which was the author's prediction as this is an unnatural stimulus for the human.

I really liked this paper and it is an excellent example of ideal observer model-based analysis of psychophysics data. It formulates an interesting hypothesis and finds a very nice way to answer it from well-designed experimental data.

That said, I was left wondering just how critically the results depend on assumptions of the ideal observer model, and how robust the findings are to changing some of those assumptions. For one, the paper assumes a Gaussian observation noise in representation space, and I wonder how robust the analysis would be to assuming heavy-tailed noise. Similarly, our very definition of straightness and angles relies on the assumption that the representation space is somehow Euclidean. If the representation space is not in fact Euclidean, our analysis might pick up on that, rather than the straightness of specific trajectories. That said, the most convincing aspect of the paper is contrasting natural vs unnatural stimuli. No matter how precise our inference of straightness is, and how much the actual values depend on changing these assumptions, this is a very reassuring finding.

]]>