- Lee et

Neural tangent kernels are a useful tool for understanding neural network training and implicit regularization in gradient descent. But it's not the easiest concept to wrap your head around. The paper that I found to have been most useful for me to develop an understanding is this one:

In this post I will illustrate the concept of neural tangent kernels through a simple 1D regression example. Please feel free to peruse the google colab notebook I used to make these plots.

Let's start from a very boring case begin with. Let's say we have a function defined over integers between -10 and 20. We parametrize our function as a look-up table, that is the value of the function $f(i)$ at each integer $i$ is described by a separate parameter $\theta_i = f(i)$. I'm initializing the parameters of this function as $\theta_i = 3i+2$. The function is shown by the black dots below:

Now, let's consider what happens if we observe a new datapoint, $(x, y) =(10, 50)$, shown by the blue cross. We're going to take a gradient descent step updating $\theta$. Let's say we use the squared error loss function $(f(10; \theta) - 50)^2$ and a learning rate $\eta=0.1$. Because the function's value at $x=10$ only depends on one of the parameter $\theta_10$, only this parameter will be updated. The rest of the parameters, and therefore the rest of the function values remain unchanged. The red arrows illustrate the way function values move in a single gradient descent step: Most values don't move at all, only one of them moves closer to the observed data. Hence only one visible red arrow.

However, in machine learning we rarely parametrize functions as lookup tables of individual function values. This parametrization is pretty useless as it doesn't allow you to interpolate let alone extrapolate to unseen data. Let's see what happens in a more familiar model: linear regression.

Let's now consider the linear function $f(x, \theta) = \theta_1 x + \theta_2$. I initialize the parameters to $\theta_1=3$ and $\theta_2=1$, so at initialisation, I have exactly the same function over integers as I had in the first example. Let's look at what happens to this function as I update $\theta$ by performing single gradient descent step incorporating the observation $(x, y) =(10, 50)$ as before. Again, red arrows are show how function values move:

Whoa! What's going on now? Since individual function values are no longer independently parametrized, we can't move them independently. The model binds them together through its global parameters $\theta_1$ and $\theta_2$. If we want to move the function closer to the desired output $y=50$ at location $x=10$ the function values elsewhere have to change, too.

In this case, updating the function with an observation at $x=10$ changes the function value far away from the observation. It even changes the function value in the opposite direction than what one would expect.. This might seem a bit weird, but that's really how linear models work.

Now we have a little bit of background to start talking about this neural tangent kernel thing.

Given a function $f_\theta(x)$ which is parametrized by $\theta$, its *neural tangent kernel* $k_\theta(x, x')$ quantifies how much the function's value at $x$ changes as we take an infinitesimally small gradient step in $\theta$ incorporating a new observation at $x'$. Another way of phrasing this is: $k(x, x')$ measures how sensitive the function value at $x$ is to prediction errors at $x'$.

In the plots before, the size of the red arrows at each location $x$ were given by the following equation:

$$

\eta \tilde{k}_\theta(x, x') = f\left(x, \theta + \eta \frac{f_\theta(x')}{d\theta}\right) - f(x, \theta)

$$

In neural network parlance, this is what's going on: The loss function tells me to increase the function value $f_\theta(x')$. I back-propagate this through the network to see what change in $\theta$ do I have to make to achieve this. However, moving $f_\theta(x')$ this way also simultaneously moves $f_\theta(x)$ at other locations $x \neq x'$. $\tilde{k}_\theta(x, x')$ expresses by how much.

The neural kernel is basically something like the limit of $\tilde{k}$ in as the stepsize becomes infinitesimally small. In particular:

$$

k(x, x') = \lim_{\eta \rightarrow 0} \frac{f\left(x, \theta + \eta \frac{df_\theta(x')}{d\theta}\right) - f(x, \theta)}{\eta}

$$

Using a 1st order Taylor expansion of $f_\theta(x)$, it is possible to show that

$$

k_\theta(x, x') = \left\langle \frac{df_\theta(x)}{d\theta} , \frac{f_\theta(x')}{d\theta} \right\rangle

$$

As homework for you: find $k(x, x')$ and/or $\tilde{k}(x, x')$ for a fixed $\eta$ in the linear model in the pervious example. Is it linear? Is it something else?

Note that this is a different derivation from what's in the paper (which starts from continuous differential equation version of gradient descent).

Now, I'll go back to the examples to illustrate two more important property of this kernel: sensitivity to parametrization, and changes during training.

It's well known that neural networks can be repararmetized in ways that don't change the actual output of the function, but which may lead to differences in how optimization works. Batchnorm is a well-known example of this. Can we see the effect of reparametrization in the neural tangent kernel? Yes we can. Let's look at what happens if I reparametrize the linear function I used in the second example as:

$$

f_\theta(x) = \theta_1 x + \color{blue}{10\cdot}\theta_2

$$

but now with parameters $\theta_1=3, \theta_2=\color{blue}{0.1}$. I highlighted in blue what changed. The function itself, at initialization is the same since $10 * 0.1 = 1$. The function class is the same, too, as I can still implement arbitrary linear functions. However, when we look at the effect of a single gradient step, we see that the function changes differently when gradient descent is performed in this parametrisation.

In this parametization, it became easier for gradient descent to push the whole function up by a constant, while in the previous parametrisation it decided to change the slope. What this demonstrates is that the neural tangent kernel $k_\theta(x, x')$ is sensitive to reparametrization.

While the linear models may be good illustration, let's look at what $k_\theta(x, x')$ looks like in a nonlinear model. Here, I'll consider a model with two squared exponential basis functions:

$$

f_\theta(x) = \theta_1 \exp\left(-\frac{(x - \theta_2)^2}{30}\right) + \theta_3 \exp\left(-\frac{(x - \theta_4)^2}{30}\right) + \theta_5,

$$

with initial parameter values $(\theta_1, \theta_2, \theta_3, \theta_4, \theta_5) = (4.0, -10.0, 25.0, 10.0, 50.0)$. These are chosen somewhat arbitrarily and to make the result visually appealing:

We can visualise the function $\tilde{k}_\theta(x, 10)$ directly, rather than plotting it on top the function. Here I also normalize it by dividing by $\tilde{k}_\theta(10, 10)$.

What we can see is that this starts to look a bit like a *kernel* function in that it has higher values near $10$ and decreases as you go farther away. However, a few things are worth noting: the maximum of this kernel function is not at $x=1o$, but at $x=7$. It means, that the function value $f(7)$ changes more in reaction to an observation at $x'=10$ than the value $f(10)$. Secondly, there are some negative values. In this case the previous figure provides a visual explanation why: we can increase the function value at $x=10$ by pushing the valley centred at $\theta_1=4$ away from it, to the left. This parameter change in turn decreases function values on the left-hand wall of the valley. Third, the kernel function converges to a positive constant at its tails - this is because of the offset $\theta_5$.

Now I'm going to illustrate another important property of the neural tangent kernel: in general, the kernel depends on the parameter value $\theta$, and therefore it changes as the model is trained. Here I show what happens to the kernel as I take 15 gradient ascent steps trying to increase $f(10)$. The purple curve is the one I had at initialization (above), and the orange ones show the kernel at the last gradient step.

The corresponding changes to the function $f_\theta_t$ changes are shown below:

So we can see that as the parameter changes, the kernel also changes. The kernel becomes flatter. An explanation of this is that eventually we reach a region of parameter space, where $\theta_4$ changes the fastest.

It turns out the neural tangent kernel becomes particularly useful when studying learning dynamics in infinitely wide feed-forward neural networks. Why? Because in this limit, two things happen:

- First: if we initialize $\theta_0$ randomly from appropiately chosen distributions, the initial NTK of the network $k_{\theta_0}$ approaches a deterministic kernel as the width increases. This means, that at initialization, $k_{\theta_0}$ doesn't really depend on $\theta_0$ but is a fixed kernel independent of the specific initialization.
- Second: in the infinite limit the kernel $k_{\theta_t}$ stays constant over time as we optimise $\theta_t$. This removes the parameter dependence during training.

These two facts put together imply that gradient descent in the infinitely wide and infinitesimally small learning rate limit can be understood as a pretty simple algorithm called *kernel gradient descent* with a fixed kernel function that depends only on the architecture (number of layers, activations, etc).

These results, taken together with an older known result by Neal, (1994), allows us to characterise the probability distribution of minima that gradient descent converges to in this infinite limit as a Gaussian process. For details, see the paper mentioned above.

There are two somewhat related sets of results both involving infinitely wide neural netwoks and kernel functions, so I just wanted to clarify the difference between them:

- the older, well-known result by Neal, (1994), later extended by others, is that the distribution of $f_\theta$ under random initialization of $\theta$ converges to a Gaussian process. This Gaussian process has a kernel or covariance function which is not, in general, the same as the neural tangent kernel. This old result doesn't say anything about gradient descent, and is typically used to motivate the use of Gaussian process-based Bayesian methods.
- the new, NTK, result is that the evolution of $f_{\theta_t}$ during gradient descent training can be described in terms of a kernel, the neural tangent kernel, and that in the infinite limit this kernel stays constant during training and is deterministic at initialization. Using this result, it is possible to show that in some cases the distribution of $f_{\theta_t}$ is a Gaussian process at every timestep $t$, not just at initialization. This result also allows us to identify the Gaussian process which describes the limit as $t \rightarrow \infty$. This limiting Gaussian process however is not the same as the posterior Gaussian process which Neal and others would calculate on the basis of the first result.

So I hope this post helps a bit by building some intuition about what the neural tangent kernel is. If you're interested, check out the simple colab notebook I used for these illustrations.

]]>- Rezende et al (2020) Rezende Causally Correct Partial Models for Reinforcement Learning

It's frankly taken me a long time to understand what was going on, and it took me weeks to write this half-decent explanation of it. The first

]]>I recently encountered this cool paper in a reading group presentation:

- Rezende et al (2020) Rezende Causally Correct Partial Models for Reinforcement Learning

It's frankly taken me a long time to understand what was going on, and it took me weeks to write this half-decent explanation of it. The first notes I wrote followed the logic of the paper more, this in this post I thought I'd just focus on the high level idea, after which hopefully the paper is more straightforward. I wanted to capture the key idea, without the distractions of RNN hidden states, etc, which I found confusing to think about.

To start with the basics, this paper deals with the partially observed Markov decision process (POMDP) setup. The diagram below illustrates what's going on:

The grey nodes $e_t$ show the unobserved state of the environment at each timestep $t=0,1,2\ldots$. At each timestep the agent observes $y_0$ which depends on the current state of the environment (red-ish nodes). The agent then updates their state $s_t$ based on its past state $s_{t-1}$, the new observation $y_t$, and the previous action taken $a_{t-1}$. This is shown by the blue squares (they're squares, signifying that this node depends deterministically on its parents). Then, based on the agent's state, it chooses an action $a_t$ from by sampling from policy $\pi(a_t\vert s_t)$. The action influences how the environment's state, $e_{t+1}$ changes.

We assume that the agent's ultimate goal is to maximise reward at the last state at time $T$, which we assume is a deterministic function of the observation $r(y_T)$. Think of this reward as the score in an atari game, which is written on the screen whose contents are made available in $y_t$.

Let's start by stating what we ultimately would like estimate from the data we have. The assumption is that we sampled the data using some policy $\pi$, but we would like to be able to say how well a different policy $\tilde{\pi}$ would do, in other words, what would be the expected score at time $T$ if instead of $\pi$ we used a different policy $\tilde{\pi}$.

What we are interested in, is a causal/counterfactual query:

$$

\mathbb{E}_{\tau\sim\tilde{p}}[r(y_T)],

$$

where $\tau = [(s_t, y_t, e_t, a_t) : t=0\ldots T]$ denotes a trajectory or rollout up to time $T$, and $\tilde{p}$ denotes the generative process when using policy $\tilde{\pi}$, that is:

$$

\tilde{p}(\tau) = p(e_0)p(y_0\vert e_0) \tilde{\pi}(a_0\vert s_0) p(s_0)\prod_{t=1}^T p(e_t\vert a_{t-1}, e_{t-1}) p(y_t\vert e_t) \tilde{\pi}(a_t\vert s_t) \delta (s_t - g(s_{t-1}, y_t))

$$

I called this a causal or counterfactual query, because we are interested in making predictions under a different distribution $\tilde{p}$ than $p$ which we have observations from. The difference between $\tilde{p}$ and $p$ can be called an intervention, where we replace specific factors in the data generating process with different ones.

There are - at least - two ways one could go about estimating such counterfactual distribution:

- model-free, via
*importance sampling*. This method tries to directly estimate the causal query by calculating a weighted average over the observed data. The weights are given by the ratio between $\tilde{\pi}(a_t\vert s_t)$, the probability by which $\tilde{\pi}$ would choose an action and $\pi(a_t\vert s_t)$, the probability it was chosen by the policy that we used to collect the data. A great paper explaining how this works is (Bottou et al, 2013). Importance sampling as the advantage that we don't have to build any model of the environment, we can directly evaluate the average reward from the samples we have, using only $\pi$ and $\tilde{\pi}$ to calculate the weights. The downside, however, is that importance sampling often incredibly high variance estimate, and is only reliable if $\tilde{\pi}$ and $\pi$ are very close. - model-based, via
*causal calculus*. If possible, we can use do-calculus to express the causal query in an alternative way, using various conditional distributions estimated from the observed data. This approach has the disadvantage that it requires us build a model from the data first. We then use the conditional distributions learned from the data to approximate the quantity of interest by plugging them into the formula we got from do-calculus. If our models are imperfect, these imperfections/approximation errors can compound when the causal estimand is calculated, potentially leading to large biases and inaccuracies. On the other hand, our models may be accurate enough to extrapolate to situations where importance weighting would be unreliable.

In this paper, we focus on solving the problem with causal calculus. This requires us to build a model of observed data, which we can then use to make causal predictions. The key question this paper asks is

How much of the data do we have to model to be able to make the kinds of causal inferences we would like to make?

One way we can answer the query above is to model the joint distribution of everything, or mostly everything, that we can observe. For example, we could build a full autoregressive model of observations $y_t$ conditioned on actions $a_t$. In essence this would amount to fitting a model to $p(y_{0:T}\vert a_{0:T})$.

If we had such model, we would theoretically be able to make causal predictions, for reasons I will explain later. However, this option is ruled out in the paper because we assume the observations $y_t$ are very high dimensional, such as images rendered in a computer game. Thus, modelling the joint distribution of the whole observation sequence $y_{1:T}$ accurately is hopeless and would require a lot of data. Therefore, we would like to get away without modelling the whole observation sequence $y_{1:T}$, which brings us to partial models.

Partial models try to avoid modelling the joint distribution of high-dimensional observations $y_{1:T}$ or agent-state sequences $s_{0:T}$, and focus on modelling directly the distribution of $r(y_T)$ - i.e. only the reward component of the last observation, given the action-sequence $a_{0:T}$. This is clearly a lot easier to do, because $r(y_T)$ is assumed to be a low-dimensional aspect of the full observation $y_T$, so all we have to learn is a model of a scalar conditioned on a sequence of actions $q_\theta(r(y_T)\vert a_{0:T})$. We know very well how to fit such models to realistic amounts of data.

However, if we don't include either $y_t$ or $s_t$ in our model, we won't be able to make the counterfactual inferences we wanted to make in the first place. Why? Let's look at he data generating process once more:

We are trying to model the causal impact of actions $a_0$ and $a_1$ on the outcome $y_2$. Let's focus on $a_1$. $y_2$ is clearly statistically dependent on $a_1$. However, this statistical dependence emerges due to completely different effects:

**causal association:**$a_1$ influences the state of the environment $e_2$, resulting in an observation $y_2$. Therefore, $a_1$ has an direct causal effect on $y_2$, mediated by $e_2$**spurious association due to confounding:**The unobserved hidden state $e_1$ is a confounder between the action $a_1$ and the observation $y_2$. The state $e_1$ has an indirect causal effect on $a_1$ mediated by the observation $y_1$ and the agent's state $s_1$. Similarly $e_1$ has an indirect effect on $y_2$ mediated by $e_2$.

I illustrated these two sources of statistical association by colour-coding the different paths in the causal graph. The blue path is the confounding path: correlation is induced because both $a_1$ and $y_2$ have $e_1$ as causal ancestor. The red path is the causal path: $a_1$ indirectly influences $y_2$ via the hidden state $e_2$.

If we would like to correctly evaluate the consequence of changing policies, we have to be able to disambiguate between these two sources of statistical association, get rid of the blue path, and only take the red path into account. Unfortunately, this is not possible in a partial model, where we only model the distribution of $y_2$ conditional on $a_0$ and $a_1$.

If we want to draw causal inferences, we **must model the distribution of at least one variable along blue path.** Clearly, $y_1$ and $s_1$ are theoretically observable, and are on the confounding path. Adding either of these to our model would allow us to use the backdoor adjustment formula (explained in the paper). However, this would take us back to Option 1, where we have to model the joint distribution of either sequences of observations $y_{0:T}$ or sequences of states $s_{0:T}$, both assumed to be high-dimensional and difficult to model.

So we finally got to the core of what is proposed in the paper: a kind of halfway-house between modeling everything and modeling too little. We are going to model *enough *variables to be able to evaluate causal queries, while keeping the dimensionality of the model we have to fit low. To do this, we change the data generating process slightly - by splitting the policy into two stages:

The agent first generates $z_t$ from the state $s_t$, and then uses the sampled $z_t$ value to make a decision $a_t$. One can understand $z_t$ as being a stochastic bottleneck between the agent's high-dimensional state $s_t$, and the low-dimensional action $a_t$. The assumption is that the sequence $z_{0:T}$ should be a lot easier to model than either $y_{0:T}$ or $s_{0:T}$. However, if we now build a model $p(r(y_T), z_{0:T} \vert a_{0:T})$ are now able to use this model evaluate the causal queries of interest, thanks for the backdoor adjustment formula. For how to precisely do this, please refer to the paper.

Intuitively, this approach helps by adding a low-dimensional stochastic node along the confounding path. This allows us to compensate for confounding, without having to build a full generative model of sequences of high-dimensional variables. It allows us to solve the problem we need to solve without having to solve a ridiculously difficult subproblem.

]]>- Jonathan

Last night on the train I read this nice paper by David Duvenaud and colleagues. Around midnight I got a calendar notification "it's David Duvenaud's birthday". So I thought it's time for a David Duvenaud birthday special (don't get too excited David, I won't make it an annual tradition...)

- Jonathan Lorraine, Paul Vicol, David Duvenaud (2019) Optimizing Millions of Hyperparameters by Implicit Differentiation

I recently covered iMAML: the meta-learning algorithm that makes use of implicit gradients to sidestep backpropagating through the inner loop optimization in meta-learning/hyperparameter tuning. The method presented in (Lorraine et al, 2019) uses the same high-level idea, but introduces a different - on the surface less fiddly - approximation to the crucial inverse Hessian. I won't spend a lot of time introducing the whole meta-learning setup from scratch, you can use the previous post as a starting point.

Many - though not all - meta-learning or hyperparameter optimization problems can be stated as nested optimization problems. If we have some hyperparameters $\lambda$ and some parameters $\theta$ we are interested in

$$

\operatorname{argmin}_\lambda \mathcal{L}_V (\operatorname{argmin}_\theta \mathcal{L}_T(\theta, \lambda)),

$$

Where $\mathcal{L}_T$ is some training loss and $\mathcal{L}_V$ a validation loss. The optimal parameter to the training problem, $\theta^\ast$ implicitly depends on the hyperparameters $\lambda$:

$$

\theta^\ast(\lambda) = \operatorname{argmin} f(\theta, \lambda)

$$

If this implicit function mapping $\lambda$ to $\theta^\ast$ is differentiable, and subject to some other conditions, the implicit function theorem states that its derivative is

$$

\left.\frac{\partial\theta^{\ast}}{\partial\lambda}\right\vert_{\lambda_0} = \left.-\left[\frac{\partial^2 \mathcal{L}_T}{\partial \theta \partial \theta}\right]^{-1}\frac{\partial^2\mathcal{L}_T}{\partial \theta \partial \lambda}\right\vert_{\lambda_0, \theta^\ast(\lambda_0)}

$$

The formula we obtained for iMAML is a special case of this where the $\frac{\partial^2\mathcal{L}_T}{\partial \theta \partial \lambda}$ is the identity This is because there, the hyperparameter controls a quadratic regularizer $\frac{1}{2}\|\theta - \lambda\|^2$, and indeed if you differentiate this with respect to both $\lambda$ and $\theta$ you are left with a constant times identity.

The primary difficulty of course is approximating the inverse Hessian, or indeed matrix-vector products involving this inverse Hessian. This is where iMAML and the method proposed by Lorraine et al, (2019) differ. iMAML uses a conjugate gradient method to iteratively approximate the gradient. In this work, they use a Neumann series approximation, which, for a matrix $U$ looks as follows:

$$

U^{-1} = \sum_{i=0}^{\infty}(I - U)^i

$$

This is basically a generalization of the better known sum of a geometric series: if you have a scalar $\vert u \vert<1$ then

$$

\sum_{i=0}^\infty q^i = \frac{1}{1-q}.

$$

Using a finite truncation of the Neumann series one can approximate the inverse Hessian in the following way:

$$

\left[\frac{\partial^2 \mathcal{L}_T}{\partial \theta \partial \theta}\right]^{-1} \approx \sum_{i=1}^j \left(I - \frac{\partial^2 \mathcal{L}_T}{\partial \theta \partial \theta}\right)^i.

$$

This Neumann series approximation, at least on the surface, seems significantly less hassle to implement than running a conjugate gradient optimization step.

One of the fun bits of this paper is the interesting set of experiments the authors used to demonstrate the versatility of this approach. For example, in this framework, one can treat the training dataset as a hyperparameter. Optimizing pixel values in a small training dataset, one image per class, allowed the authors to "distill" a dataset into a set of prototypical examples. If you train your neural net on this distilled dataset, you get relatively good validation performance. The results are not quite as image-like as one would imagine, but for some classes, like bikes, you even get recognisable shapes:

In another experiment the authors trained a network to perform data augmentation, treating parameters of this network as a hyperparameter of a learning task. In both of these cases, the number of hyperparameters optimized were in the hundreds of thousands, way beyond the number we usually consider as hyperparameters.

This method inherits some of the limitations I already discussed with iMAML. Please also see the comments where various people gave pointers to work that overcomes some of these limitations.

Most crucially, methods based on implicit gradients assume that your learning algorithm (inner loop) finds a unique, optimal parameter that minimises some loss function. This is simply not a valid assumption for SGD where different random seeds might produce very different and differently behaving optima.

Secondly, this assumption only allows for hyperparameters that control the loss function, but not for ones that control other aspects of the optimization algorithm, such as learning rates, batch sizes or initialization. For those kind of situations, explicit differentiation may still be the most competitive solution. On that note, I also recommend reading this recent paper on generalized inner-loop meta-learning and the associated pytorch package higher.

Happy birthday David. Nice work!

]]>My parents didn't raise me in a religious tradition. It all started to change when a great scientist took me under his wing and taught me the teachings of Bayes. I travelled the world and spent 4 years in a Bayesian monastery in Cambridge, UK. This particular

]]>My parents didn't raise me in a religious tradition. It all started to change when a great scientist took me under his wing and taught me the teachings of Bayes. I travelled the world and spent 4 years in a Bayesian monastery in Cambridge, UK. This particular place practiced the nonparametric Bayesian doctrine.

We were religious Bayesians. We looked at the world and all we saw the face of Bayes: if something worked, it did because it had a Bayesian interpretation. If an algorithm did not work, we shunned its creator for being unfaithful to Bayes. We scorned at point estimates, despised p-values. Bayes had the answer to everything. But above all, we believed in our models.

At a convention dominated by Bayesian thinkers I was approached by a frequentist, let's call him Lucifer (in fact his real name is Laci so not that far off). "Do you believe your data exists?" - he asked. "Yes" I answered. "Do you believe your model and its parameters exist?" "Well, not really, it's just a model I use to describe reality" I said. Then he told me the following, poisoning my pure Bayesian heart forever: "If you use Bayes' rule, you assume that a joint distribution between model parameters and data exist. This, however, only exists if your data and your parameters both exist, in the same $\sigma$-algebra. You can't have it both ways. You have to think your model really exists somewhere."

I never forgot this encounter, but equally I didn't think much about it since then. Over the years, I started to doubt more and more aspects of my Bayesian faith. I realised the likelihood was important, but not the only thing that exists. There were scoring rules, loss functions which couldn't be written as a log-likelihood. I noticed nonparametric Bayesian models weren't automatically more useful than large parametric ones. I worked on weird stuff like loss-calibrated Bayes. I started having thoughts about model misspecification, kind of a taboo topic in the Bayesian church.

Over the years I came to terms with my Bayesian heritage, and I now live my life as a secular Bayesian. Certain elements of the Bayesian way are no doubt useful: Engineering inductuve biases explicitly into a prior distribution, using probabilities, divergences, information, variational bounds as tools for developing new algorithms. Posterior distributions can capture model uncertainty which can be exploited for active learning or exploration in interactive learning. Bayesian methods often - though not always - lead to increased robustness, better calibration, and so much more. At the same time, I can carry on living my life, use gradient descent to find local minima, use bootstrap to capture uncertainty. And first and foremost, I do not have to believe that my models really exist or perfectly describe reality anymore. I am free to think about model misspecification.

Lately, I have started to familiarize myself with a new body of work, which I call secular Bayesianism, that combines Bayesian inference with more frequentists ideas about learning from observation. In this body of work, people study model misspecification (see e.g. M-open Bayesian inference). And, I found a resolution to the "you have to believe in your model, you can't have it both ways" problem that bothered me all these years.

After this rather long intro, let me present the paper this post is really about and which, as a secular Bayesian, I found very interesting:

- P.G. Bissiri, C.C. Holmes and S.G. Walker (2016) A General Framework for Updating Belief Distributions

This paper basically asks: can we take the belief out of belief distributions? Let's say we want to estimate some parameter of interest $\theta$ from data. Does it still make sense to specify a prior distribution over this parameter, and then update them in light of data using some kind of Bayes rule-like update mechanism to form posterior distributions, all without assuming that the parameter of interest $\theta$ and the observations $x_i$ are linked to one another via a probabilistic model? And if it is meaningful, what form would that update rule take.

First of all, for simplicity, let's assume that data $x_i$ is sampled i.i.d from some distribution $P$. That's right, not exchangeable, actually i.i.d. like in frequentist settings. Let's also assume that we have some parameter of interest $\theta$. Unlike in Bayesian analysis where $\theta$ usually parametrises some kind of generative model for data $x_i$, we don't assume anything like that. All we assume is that there is a loss function $\ell$ which connects the parameter to the observations: $\ell(\theta, x)$ measures how well the estimate $\theta$ agrees with observation $x$.

Let's say that a priori, without seeing any datapoints we have a prior distribution $\pi$ over $\theta$. Now we observe a datapoint $x_1$. How should we make use of our observation $x_1$, the loss function $\ell$ and the prior $\pi$ to come up with some kind of posterior over this parameter? Let's denote this update rule $\psi(\ell(\cdot, x_1), \pi)$. There are many ways we could do this, but is there one which is better than the rest?

The paper lists a number of desiderata - desired properties the update rule $\psi$ should satisfy. These are all meaningful assumptions to have. The main one is coherence, which is a property somewhat analogous to exchangeability: if we observe a sequence of observations, we would like the resulting posterior to be the same, irrespective of which order the observations are presented. The coherence property can be written as follows

$$

\psi\left(\ell(\cdot, x_2), \psi\left(\ell(\cdot, x_1), \pi\right)\right) = \psi\left(\ell(\cdot, x_1), \psi\left(\ell(\cdot, x_2), \pi \right)\right)

$$

As a desired property, this makes a lot of sense, and Bayes rule obviously satisfies it. However, this is not really how the authors actually define coherence. In Equation (3) they use a more restrictive definition of coherence, further restricting the set of acceptable update rules as follows:

$$

\psi\left(\ell(\cdot, x_2), \psi\left(\ell(\cdot, x_1), \pi\right)\right) = \psi\left(\ell(\cdot, x_1) + \ell(\cdot, x_2), \pi \right)

$$

By combining losses from the two observations in an additive way, one can indeed ensure permuation invariance. However, the sum is not the only way to do this. Any pooling operation over observations would also have satisfied this. For example, one could replace the $\ell(\cdot, x_1) + \ell(\cdot, x_2)$ bit by $\max(\ell(\cdot, x_1), \ell(\cdot, x_2))$ and still satisfy the general principle of coherence. The most general class of permutation invariant functions which would satisfy the general coherence desideratum are discussed in DeepSets. Overall, my hunch is that going with the sum is a design choice, rather than a general desideratum. This choice is the real reason why the resulting update rule will end up very Bayes-rule like, as we will see later.

The other desiderata the paper proposes are actually discussed separately in Section 1.2 of (Brissini et al, 2016), and called assumptions instead. These are much more basic requirements for the update function. Assumption 2 for example talks about how restricting the prior to a subset should result in a posterior which is also the restricted version of the original posterior. Assumption 3 requires that lower evidence (larger loss) for a parameter should yield smaller posterior probabilities - a monotonicity property.

One contribution of the paper is showing that all the desiderata mentioned above pinpoint a specific update rule $\psi$ which satisfies all the desired properties. This update takes the following form:

$$

\pi(\theta\vert x_{1:N}) = \psi(\ell(\cdot, x), \pi) \alpha \exp\{-\sum_{n=1}^N\ell(\theta, x_N)\}\pi(\theta)

$$

Just like Bayes rule we have a normalized product of the prior with something that takes the role of the likelihood term. If the loss is the logarithmic loss of a probabilistic model, we recover the Bayes rule, but this update rule makes sense for arbitrary loss functions.

Again, this solution is unique under the very strong and specific desideratum that we'd like the losses from i.i.d. observations combine in an additive way, and I presume that, had we chosen a different permutation invariant function, we would end up with a similar generalization of Bayes rule with that permutation invariant function appearing in the exponent.

Now that we have an update rule which satisfies our desiderata, can we say if it's actually a good or useful update rule? It seems it is, in the following sense.

Let's think about a way to measure the usefulness of a posterior $\nu$. Suppose we have data sampling distribution $P$, losses are still measured by $\ell$, and our prior is $\pi$. A good posterior does two things well: it allows us to make good decisions in some kind of downstream test scenario, and it is informed by our prior. It therefore makes sense to define a loss function over the posterior $\nu$ as a sum of two terms:

$$

L(\nu; \ell, \pi, P) = h_1(\nu; \ell, P) + h_2(\nu; \pi)

$$

The first term, $h_1$ measures the posterior's usefulness at test time, and $h_2$ measures how well it's influenced by the prior. The authors define $h_1$ to be as follows:

$h_1(\nu; \ell, P) = \mathbb{E}_{x\sim P} \mathbb{E}_\theta\sim\nu \ell(x, \theta)$

So basically, we will sample from the posterior, and then evaluate the random sample parameter $\theta$ on a randomly chosen test datapoint $x$ using our loss $\ell$. I would say this is a rather narrow view on what it means for a posterior to do well on a downstream task, more about it later in the criticism section. In any case it's one possible goal for a posterior to try to achieve.

Now we turn to choosing $h_2$, and the authors note something very interesting. If we want the resulting optimal posterior to possess the coherence property (as defined in their Eqn. (3)), it turns out the only choice for $h_2$ is the KL divergence between the prior and posterior. Any other choice would lead to incoherent updates. This, I believe is only true for the additive definition of coherence, not the more general definition I gave above.

Putting $h_1$ and $h_2$ together it turns out that the posterior that minimizes this loss function is precisely of the form $\pi(\theta\vert x_{1:N}) \alpha \exp\{-\sum_{n=1}^N \ell(\theta, x_n)\}$. So, not only is this update rule the only update rule that satisfies the desired properties, it is also optimal under this particular definition of optimality/rationality.

This work is interesting because it gives a new justification for Bayes rule-like updates to belief distributions, and as a result it also provides a different/new perspective on Bayesian inference. Crucially, never in this derivation did we have to reason about a joint distribution between $\theta$ and the observations $x$ (or conditionals of one given the other). Even though I wrote $\pi(\theta \vert x_{1:N})$ to denote a posterior, this is really just a shorthand notation, syntactic sugar. This is important. One of the main technical criticisms of the Bayesian terminology is that in order to reason about the joint distribution between two random variables ($x$ and $\theta$), these variables have to live in the same probability space, so if you believe that your data exists, you have to believe in your model, and model parameters exist as well. This framework sidesteps that.

It allows rational updates of belief distributions, without forcing you to believe in anything.

From a practical viewpoint, this work also extends Bayesian inference in a meaningful way. While Bayesian inference only made sense if you inferred the whole set of parameters jointly, here you are allowed to specify any loss function, and really focus on the parameter of importance. For example, if you're only interested in estimating the median of a distribution in a Bayesian way, without assuming it follows a certain distribution, you can now do this by specifying your loss to be $\vert x-\theta\vert$. This is explained a lot more clearly in the paper, so I encourage you to read it.

My main criticism of this work is that it made a number of assumptions that ultimately limited the range of acceptable solutions, and to my cynical eye it appears that these choices were specifically made so that Bayes rule-like update rules came out winning. So rather than really deriving Bayesian updates from first principles, we engineered principles under which Bayesian updates are optimal. In other words, the top-down analysis was rigged in favour of familiar Bayes-like updates. There are two specific assumptions which I would personally like to see relaxed:

The first one is the restrictive notion of coherence, which requires losses to combine additively from multiple observations. I think this very clearly gives rise to the convenient exponential, log-additive form in the end. It would be interesting to see whether other types of permutation invariant update rules also make sense in practice.

Secondly, the way the authors defined optimality, in terms of the loss $h_1$ above is very limiting. We rarely use posterior distributions in this way (take a random sample). Instead, we might be intersested integrating over the posterior, and evaluating the loss of that classifier. This is a loss that cannot be written in the bilinear form that is the formula for $h_1$ above. I wonder if. using more elaborate losses for the posterior, perhaps along the lines of general decision problems as in (Lacoste-Julien et al, 2011), could lead to more interesting update rules which don't look at all like Bayes rule but are still rational.

]]>- Zhiyuan Li and Sanjeev Arora (2019) An Exponential Learning Rate Schedule for Deep Learning

The paper provides both theoretical insights as well

]]>Yesterday I read this intriguing paper about the midboggling fact that it is possible to use exponentially growing learning rate schedule when training neural networks with batch normalization:

- Zhiyuan Li and Sanjeev Arora (2019) An Exponential Learning Rate Schedule for Deep Learning

The paper provides both theoretical insights as well as empirical demonstration of this remarcable property.

The reason why this works boils down to the observation that batch-normalization renders the loss function of neural networks scale invariant - scaling the weights by a constant does not change the output, or the loss, of the batch normalized network. It turns out that this property alone might result in somewhat unexpected and potentially helpful properties for optimization. I will use this post to illustrate some of the properties of scale invariant loss functions - and gradient descent trajectories on them - using a 2D toy example:

Here, I drew a loss function whih has the scale invariance property. The value of the loss only depends on the angle, but not the magnitude of the weight vector. The value of the loss along any radial line from the origin outwards is constant. Simple consequences of scale invariance are that (Lemma 1 of the paper)

- that the gradient of this function is always orthogonal to the current value of the parameter vector, and that
- the farther you are from the origin, the smaller the magnitude of the gradient. This is perhaps less intuitive but think about how the function behaves on a circle around the origin. The function is the same, but as you increase the radius you stretch the same function round a larger circle - it gets fatter, therefore its gradients decrease.

Here is a - somewhat messy - quiver plot showing the gradients of the function above:

The quiver plot is messy because the gradients around the origin explode. But you can perhaps see how the gradients get larger and larger - and remain perpendicular to the value itself.

So Imagine doing vanilla gradient descent (no momentum, weight decay, fixed learning rate) on such a loss surface. Because the gradient is always perpendicular to the current value of the parameter, by the Pythagorean theorem, the norm of the parameter vector increases with each iteration. So gradient descent takes you away from the origin. However, the weight vector won't completely blow up to infinity, because the gradients also get smaller and smaller as the weight vector grows, so it settles at some point. Here is a gradient descent path looks like starting from the coordinate $(-0.7, 0.7)$:

In fact, you can't really see it but the optimization kind of gets stuck in there, and doesn't move any longer. It's interesting to see what happens if we add weight decay, which is the same as adding L2 regularizer over the weights:

We can see that once the trajectory is about to get stuck in a local minimum, weight decay pulls it back towards the origin, which is where gradients become larger. This, in turn, perturbs the trajectory often pushing it out of the current local minimum. So in a way, we can start to build the intuition that weight decay on a scale-invariant loss function acts as a kind of learning rate adjustment.

In fact, what the paper works out is an equivalence between two things:

- weight decay with constant learning rate and
- no weight decay and an exponentially growing learning rate

On the plot below I show the trajectory with the exponentially growing learning rate which is equivalent to the one I showed before with weight decay. This one has no weight decay, and its learning rate keeps growing:

We can see that the trajectory blows up, and quickly gets out of bound on this animation. How can this be equivalent to the weight decay trajectory? Well, from the perspective of the loss function, the magnitude of the weight vector is irrelevant, and we only care about the angle when viewed from the origin. Turns out, if you look at those angles, the two trajectories are the same. To illustrate this, I use the normalization formula from Theorem 2.1 to project this trajectory back to the same magnitude the weight decay one would have. I obtain something that indeed looks very much like the trajectory above:

After a while, the trajectories start working differently, which I think is probably due to the accumulation of numerical errors in my implementation of the toy example. I could probably fix this, but I'm not sure it's worth the effort. The authors show much more convincing empirical evidence that this works in real, complicated neural network losses that people actually want to optimize.

You can think of this renormalization I did above as "constantly zooming out" on the loss landscape to keep up with the exponentially exploding parameter. I tried to illustrate this below:

On the left-hand plot, I show the original, weight-decayed gradient descent with a constant learning rate. On the right-hand plot I show the equivalent trajectory with exponentially growing learning rate and no weight decay, and I also added a constant zoom to counteract the explosion of the parameter's norm, in line with Theorem 2.1. We can see that, especially initially, the two paths behave the same way when viewed from the origin. They then work differently which I believe is down to the numerical precision issue that could probably be worked out.

The paper shows a similar equivalence in the presence of momentum as well, if interested, read the details in the paper.

I thought this observation was very cool, and may well lead to a better understanding of the mechanisms by which batchnorm and other weight normalization schemes work. It also explains why the combination of weight decay with weight normalization schemes results in a relatively robust gradient descent regime where constant learning rate works well.

]]>- Edwin Fong, Chris Holmes (2019) On the marginal likelihood and cross-validation

I found this paper to be lacking on the accessibility front, mostly owing to the fact that it is a

]]>Here's a paper someone has pointed me to, along the lines of "everything that works, works because it's Bayesian":

- Edwin Fong, Chris Holmes (2019) On the marginal likelihood and cross-validation

I found this paper to be lacking on the accessibility front, mostly owing to the fact that it is a mixture of two somewhat related but separate things:

- (A) a simple-in-hindsight and cool observation about the relationship between marginal likelihoods and cross validation which I will present in this post, and
- (B) a somewhat tangential sidetrack about a generalized form of Bayesian inference and prequential analysis which I think is mostly there to advertise an otherwise interesting line of research Chris Holmes and colleagues have been working on for some time. The advertising worked for sure, as I found the underlying paper (Bissiri et al, 2016) quite interesting as well. I will leave discussing that to a different post, maybe next week.

To discuss the connection between marginal likelihoods to (Bayesian) cross validation, let's first define what is what.

First of all, we are in the world of exchangeable data, assuming we model a sequence of observations $x_1,\ldots,x_N$ by a probabilistic model which renders them conditionally independent given some global parameter $\theta$. Our model is thus specified by the observation model p(x\vert \theta) and prior $p(\theta)$. The marginal likelhiood is the probability mass this model assigns to a given sequence of observations:

$$

p(x_1,\ldots,x_N) = \int \prod p(x_i \vert \theta) p(\theta) d\theta

$$

Important for the discussion of its connection with cross-validation, the marginal likelhihood, like any multivariate distribution, can be decomposed by the chain rule:

$$

p(x_1,\ldots,x_N) = p(x_1)\prod_{n=1}^{N-1}p(x_{n+1}\vert x_1,\ldots, x_{n})

$$

And, of course, a similar decomposition exists for any arbitrary ordering of the observations $x_n$.

Another related quantity is a single-fold leave-$P$-out cross-validation. Here, we set the last $P \leq N$ observations aside, fit our model to the first $N-P$ observations, and then we calculate the average predictive log loss on the held-out pounts. This can be written as:

$$

- \sum_{p=1}^{P} \log p(x_{N-p+1}\vert x_1, \ldots, x_{N-P})

$$

Importantly, here, we assume that we perform Bayesian cross-validation of the model. I.e. in this formula, the parameter $\theta$ is integrated out. In fact what we're looking at is:

$$

- \sum_{p=1}^{P} \log \int p(x_{N-P+1}\vert \theta) p(\theta \vert x_1, \ldots, x_{N-P}) d\theta

$$

Now of course, we could leave any other subset of size $P$ of the observations out. If we repeat this process $K$ times with a uniform random subset of datapoints left out each time, and average the results over the $K$ trials, we have $K$-fold leave-$P$-out cross validation. If $K$ is large enough, we might be trying all possible subsets of $P$ with the same probability. I will cheesily call this $\infty$-fold cross-validation. Mathematically, $\infty$-fold leave-$P$-out Bayesian cross-validation is the following quantity:

$$

- \frac{1}{N \choose P} \sum_{\substack{S⊂\{1\ldots N\}\\|S|=P}} \sum_{i \in S} \log p(x_i\vert x_j : j \notin S),

$$

which is Eqn (10) in the paper with slightly different notation.

The connection I think is best illustrated in the following way. Let's consider three observations, and all the possible ways we can permute them. There are $3(3+1)/2 = 6$ different permutations. For each of these permutations we can decompose the marginal likelihood as a product of conditionals, or equivalently we can write the log marginal likelihood as a sum of logs of the same conditionals. Let's arrange these log conditionals into a table as follows:

Each column corresponds to a different ordering of variables, and summing up the terms in each column gives the log marginal likelihood. So, the sum of all the terms in this matrix gives the marginal likelihood times 6 (as there are 6 columns). In general it gives $N(N+1)/2$ times the marginal likelihood for $N$ observations. Now look at the sums of the terms in each row. The first row is full of terms you'd see in leave-$3$-out cross validation (which doesn't make too much sense with $3$ observations). In the second row, you see terms for leave-2-out CV. Third row corresponds to leave-1-out CV. So, if you do some careful combinatorics (homework) and count how many duplicate terms you'll find in each row, one can conclude that the sum of leave-K-out $\infty$-fold Bayesian cross-validation errors for all values of $K$ gives you the log marginal likelihood times a constant. Which is the main point of the paper.

This observation gives a really good motivation for using the marginal likelihood, and also gives a new perspective on how it works. For $N$ datapoints, there are 2^N-1 different ways of selecting a non-empty test set and corresponding training set. Calculating the marginal likelihood amounts to evaluating the average predictive score on all of these exponentially many 'folds'.

Before we jump to the conclusion that cross-validation, too, works only because it is essentially an approximation to Bayesian model selection, we must remind ourselves that this connection only holds for Bayesian cross-validation. What this means is that in each fold of cross-validation, we integrate $\theta$ in a Bayesian fashion.

In practice, when cross-validating neural networks, we usually optimize over the parameters rather than integrate in a Bayesian way. Or, at best, we use a variational approximation to the posterior and integrate over that approximately. As the relationship only holds in theory, when exact parameter marginalization is performed, it remains to be seen how useful and robust this connection will prove in potential applications.

]]>- Aravind Rajeswaran, Chelsea Finn, Sham Kakade, Sergey Levine (2019) Meta-Learning with Implicit Gradients

Another paper that came out at the

]]>This week I read this cool new paper on meta-learning: it a slightly different approach compared to its predecessors based on some observations about differentiating the optima of regularized optimization.

- Aravind Rajeswaran, Chelsea Finn, Sham Kakade, Sergey Levine (2019) Meta-Learning with Implicit Gradients

Another paper that came out at the same time has discovered similar techniques, so I thought I'd update the post and mention it, although I won't cover it in detail and the post was written primarily about Rajeswaran et al (2019)

- Yutian Chen, Abram L. Friesen, Feryal Behbahani, David Budden, Matthew W. Hoffman, Arnaud Doucet, Nando de Freitas (2019) Modular Meta-Learning with Shrinkage

- I will give a high-level overview of the meta-learning setup, where our goal is to learn a good initialization or regularization strategy for SGD so it converges to better minima across a range of tasks.
- I illustrate how iMAML works on a 1D toy-example, and discuss the behaviour and properties of the meta-objective.
- I will then discuss a limitation of iMAML: that it only considers the location of minima, and not the probability with which a stochastic algorithm ends up in a specific minimum.
- I will finally relate iMAML to a variational approach to meta-learning.

Meta-learning has several possible formulations, I will try to explain the setup of this paper following my own interpretation and notation that differs from the paper but will make my explanations clearer (hopefully).

In meta-learning we have a series of independent tasks, with associated training and validation loss functions $f_i$ and $g_i$, respectively. We have a set of model parameters $\theta$ which are shared across the tasks, and the loss functions $f_i(\theta)$ and $g_i(\theta)$ evaluate how well the model with parameters $\theta$ does on the training and test cases of task $i$. We have an algorithm that has access to the training loss $f_i$ and some meta-parameters $\theta_0$, and output some optimal or learned parameters $\theta_i^\ast = Alg(f_i, \theta_0)$. The goal of the meta-learning algorithm is to optimize the meta-objective

$$

\mathcal{M}(\theta_0) = \sum_i g_i(Alg(f_i, \theta_0))

$$

with respect to the meta-parameters $\theta_0$.

In early versions of this work, MAML, the algorithm was chosen to be stochastic gradient descent, $f_i$ and $g_i$ being the training and test loss of a neural network, for example. The meta-parameter $\theta_0$ was the point of initialization for the SGD algorithm, shared between all the tasks. Since SGD updates are differentiable, one can compute the gradient of the meta-objective with respect to the initial value $\theta_0$ by simply backpropagating through the SGD steps. This was essentially what MAML did.

However, the effect of initialization on the final value of $\theta$ is pretty weak, and difficult - if at all possible - to characterise analytically. If we allow the SGD to go on for many steps, we might converge to a better parameter, but the trajectory will be very long, and the gradients with respect to the initial value vanish. If we make the trajectories short enough, the gradients w.r.t. $\theta_0$ are informative but we may not reach a very good final value.

This is why Rajeswaran et al opted to make the dependence of the final point of the trajectory on meta-paramteter $\theta\_0$ way stronger: Instead of simply initializing SGD from $\theta\_0$ they also anchor the parameter to stay in the vicinity of $\theta\_0$ by adding a quadratic regularizer $\|\theta - \theta_0\|$ to their loss. Because of this, two things happen:

- now all steps of the SGD depend on $\theta$, not just the initial point
- now the location of the minimum SGD eventually converges to also depend on #\theta\_0#

It is this second property that iMAML exploits. Let me illustrate what that dependence looks like:

In the figure above, let's say that we would like to minimise an objective function $f(\theta)$. This would be the training loss of one of the tasks the meta-learning algorithm has to solve. Our current meta-parameter $\theta_0$ is marked on the x axis, and the orange curve shows the associated quadratic penalty. The teal curve shows the sum of the objective with the penalty. The red star shows the location of the minimum, which is what the learning algorithm finds.

Now let's animate this plot. I'm going to move the anchor point $\theta_0$ around, and reproduce the same plots. You can see that, as we move $\theta_0$ and the associate penalty, the local (and therefore global) minima of the regularized objective move change:

So it's clear that there is a non-trivial, non-linear relationship between the anchor-point $\theta_0$ and the location of a local minimum $\theta^\ast$. Let's plot this relationship as a function of the anchor point:

We can see that this function is not at all nice to work with, it has sharp jumps when the closest local minimum to $\theta_0$ changes, and it is relatively flat between these jumps. In fact, you can observe that the sharpest the local minimum nearest to $\theta_0$ is, the flatter the relationship between $\theta_0$ and $\theta$. This is because if $f$ has a sharp local minimum near $\theta_0$, then the location of the regularized minimum will be mostly determined by $f$, and the location of the anchor point $\theta_0$ doesn't matter much. If the local minimum around f is wide, there's a lot of wiggle room for the optimum and the effect of the regularization will be larger.

And now we come to the whole point of the iMAML procedure. The gradient of this function $\theta^\ast(\theta_0)$ in fact can be calculated in closed form. It is, indeed, related to the curvature, or second derivative, of $f$ around the minimum we find:

$$

\frac{d\theta^\ast}{d\theta_0} = \frac{1}{1 + f''(\theta^\ast)}

$$

In order to check that this formula works, I calculated the derivative numerically and compared it with what the theory predicts, they match perfectly:

When the parameter space is high-dimensional, we have a similar formula involving the inverse of the Hessian plus the identity. In high dimensions, inverting or even calculating and storing the Hessian is not very practical. One of the main contributions of the iMAML paper is a practical way to approximate gradients, using a conjugate gradient inner optimization loop. For details, please read the paper.

When optimizing the anchor point in a meta-learning setting, it is not the location $\theta^\ast$ we are interested in, only the value that the function $f$ takes at this location. (in reality, we would now use the validation loss, in place of the training loss used for gradient descent, but for simplicity, I assume the two losses overlap). The value of $f$ at its local optimum is plotted below:

Oh dear. This function is not very pretty. The meta-objective $f(\theta^\ast(\theta_0))$ becomes a piecewise continuous function, a connection of neighbouring basins, with non-smooth boundaries. The local gradients of this function contain very little information about the global structure of the loss function, it only tells you where to go to reach the nearest local minimum. So I wouldn't say this is the nicest function to optimize.

Thankfully, though, this function is not what we have to optimize. In meta-learning, we have a distribution over functions $f$ we optimize, so the actual meta-objective is something like $\sum_i f_i(\theta_i^\ast(\theta_0))$. And the sum of a bunch of ugly functions might well turn into something smooth and nice. In addition, the 1-D function I use for this blog post is not representative of the high-dimensional loss functions of neural networks which we want to apply iMAML to. Take for example the concept of mode connectivity (see e.g. Garipov et al, 2018): it seems that the modes found by SGD using different random seeds are not just isolated basins, but they are connected by smooth valleys along which the training and test error are low. This may in turn make the meta-objective behave more smoothly between minima.

An important aspect that MAML or iMAML do not not consider is the fact that we usually use stochastic optimization algorithms. Rather than deterministically finding a particular local minimum, SGD samples different minima: when run with different random seeds it will find different minima.

A more generous formulation of the meta-objective would allow for stochastic algorithms. If we denote by $\mathcal{Alg}(f_i, \theta_0)$ the distribution over solutions the algorithm finds, the meta-objective would be

$$

\mathcal{M}_{stochastic}(\theta) = \sum_i \mathbb{E}_{\theta \sim \mathcal{Alg}(f_i, \theta_0)} g_i(\theta)

$$

Allowing for stochastic behaviour might actually be a great feature for meta-learning. While the position of the global minimum of the regularized objective can change abruptly s a function $\theta_0$ (as illustrated in the third figure above), allowing for stochastic behaviour might smooth our the meta-learning objective.

Now suppose that SGD anchored to $\theta_0$ converges to one of a finite set of local minima. The meta-learning objective now depends on $\theta_0$ in two different ways:

- as we change the anchor $\theta_0$, the location of the minima change, as illustrated above. This change is differentiable, and we know its derivative.
- as we change the anchor $\theta_0$, the probability with which we find the different solutions changes. Some solutions will be found more often, some less often.

iMAML accounts for the first influence, but it ignores the influence through the second mechanism. This is not to say that iMAML is broken, but that it misses a possibly crucial contribution of stochastic behaviour that MAML or explicitly differentiating through the algorithm does not.

Of course this work reminded me of a Bayesian approach. Whenever someone describes quadratic penalties, all I see are Gaussian distributions.

In a Bayesian interpretation of iMAML, one can think of the anchor point $\theta_0$ as the mean of a prior distribution over the neural network's weights. The inner loop of the algorithm, or $Alg(f_i, \theta_0)$ then finds the maximum-a-posteriori (MAP) approximation to the posterior over $\theta$ given the dataset in question. This is assuming that the loss is a log likelihood of some kind. The question is, how should one update the meta-parameter $\theta_0$?

In the Bayesian world, we would seek to optimize $\theta_0$ by maximising the marginal likelihood. As this is usually intractable, so it is common to turn to a variational approximation, which in this case would look something like this:

$$

\mathcal{M}_{\text{variational}}(\theta_0, Q_i) = \sum_i \left( KL[Q_i\vert \mathcal{N}_{\theta_0}] + \mathbb{E}_{\theta \sim Q_i} f_i(\theta) \right),

$$

where $Q_i$ approximates the posterior over model parameters for task $i$. A specific choice of $Q_i$ is a dirac delta distribution centred at a specific point $Q_i(

theta) = \delta(\theta - \theta^{\ast}_i)$. If we generously ignore some constants that blow up to infinitely large, the KL divergence between the Gaussian prior and the degenerate point-posterior is a simple Euclidean distance, and our variational objective reduces to:

$$

\mathcal{M}_{\text{variational}}(\theta_0, \theta_i) = \sum_i \left( \|\theta_i - \theta_0\|^2 + f_i(\theta_i) \right)

$$

Now this objective function looks very much like the optimization problem that the inner loop of iMAML attempts to solve. If we were working in the pure variational framework, this may be where we leave things, and we could jointly optimize all the $\theta_i$s as well as $\theta_0$. Someone in the know, please comment below pointing me to the best references where this is being done for meta-learning.

This objective is significantly easier to optimize with and involves no inner-loop optimization or black magic. It simply ends up pulling $\theta_0$ closer to the centre of gravity of the various optima found for each task $i$. Not sure if this is such a good idea though for meta-learning, as the final values of $\theta_i$ which we reach by jointly optimizing over everything may not be reachable by doing SGD from $\theta_0$ from scratch. But who knows. A good idea may be, given the observations above, to jointly minimize the variational objective with respect to $\theta_0$ and $\theta_i$, but every once in a while reinitialize $\theta_i$ to be $\theta_0$. But at this point, I'm really just making stuff up...

Anyway, back to iMAML, which does something interesting with this variational objective, and I think it can be understood as a kind of amortized computation: Instead of treating $\theta_i$ as separate auxiliary parameters, it specifies that $\theta_i$ are in fact a deterministic function of $\theta_0$. As the variational objective is a valid upper bound for any value of $\theta_i$, it is also a valid upper bound if we make $\theta_i$ explicitly dependent on $\theta_0$. The variational objective thus becomes a function of $\theta_0$ only (and also of hyperparameters of the algorithm $Alg$ if it has any):

$$

\mathcal{M}_{\text{variational}}(\theta_0) = \sum_i \left( \|Alg(f_i, \theta_0) - \theta_0\|^2 + f_i(Alg(f_i, \theta_0)) \right)

$$

And there we have it. A variational objective for meta-learning $\theta_0$ which is very similar to the MAML/iMAML meta-objective, except it also has the $\|Alg(f_i, \theta_0) - \theta_0\|^2$ term which factors into updating $\theta_0$ which we didn't have before. Also notice that I did not use separate training and validation loss $f_i$ and $g_i$ but that would be a very justified choice as well.

What is cool about this is that this provides extra justification and interpretation for what iMAML is trying to do, and suggests directions in which iMAML could perhaps be improved. On the flipside, the implicit differentiation trick in iMAML might be useful in other situations where we want to amortize the variational posterior similarly.

I'm pretty sure I missed many references, please comment below if you think I should add anything, especially on the variational bit.

]]>- Martin Arjovsky, Léon Bottou,

I finally got around to reading this new paper by Arjovsky et al. It debuted on Twitter with a big splash, being decribed as 'beautiful' and 'long awaited' 'gem of a paper'. It almost felt like a new superhero movie or Disney remake just came out.

- Martin Arjovsky, Léon Bottou, Ishaan Gulrajani, David Lopez-Paz (2019) Invariant Risk Minimization

The paper is, indeed, very well written, and describes a very elegant idea, a practical algorithm, some theory and lots of discussion around how this is related to various bits. Here, I will describe the main idea and then provide an information theoretic view on the same topic.

We would like to learn robust predictors that are based on invariant causal associations between variables, rather than spurious surface correlations that might be present in our data. If we only observe i.i.d. data from a generative process, this is generally not possible.

In this paper, the authors assume that we have access to data sampled from different environments $e$. The data distribution in these different enviroments is different, but there is an underlying causal dependence of the variable of interest $Y$ on some of the observed features $X$ that remains constant, or invariant across all environments. The question is, can we exploit the variability across different environments to learn this underlying invariant association?

Usual empirical risk minimisation (ERM) approaches cannot distinguish between statistical associations that correspond to causal connections, and those that are just spurious correlations. Invariant Risk Minimization can, in certain situations. It does this by finding a representation $\phi$ of features, such that the optimal predictor is simultaneously Bayes optimal in all environments.

The authors then propose a practical loss function that tries to capture his property:

$$

\min_\phi \sum_e \mathcal{R}^e(\Phi) + \lambda \|\nabla_{w\vert w=1}\mathcal{R}^e(w \cdot \Phi)\|^2_2

$$

The first term is the usual ERM: we're trying to minimize average risk across all environments, using a single predictor $\phi$. The second term is where the interesting bit happens. I said before that what we want this term to encorage is that $\phi$ is simultaneously Bayes-optimal in all environments. What the term actually looks at is whether $\phi$ is locally optimal, wether it can be improved locally by scaling by a constant $w$. For details, I recommend reading the paper where a lot of intuitive explanation is provided. In this post, I'll focus on an information theoretic interpretation of what's going on.

Unlike the authors, who treat the environment index $e$ as something outside of the structural equation model, I prefer to think of $E$ as also being part of the generative process: an observable random variable. This may not be the most useful formulation in all circumstances, but it will help when trying to derive IRM through the lens of conditional dependence relationships.

The most general generative model of data in the IRM setup looks something like this:

There are three (sets of) observable variables: $E$, the environment index, $X$, the features describing the datapoint and $Y$, the label we wish to predict. I also assume the existence of a hidden confounder $W$. In the above graph, I separated $X$ into upstream dimensions $X_1$ and downstream dimensions $X_2$ based on where they are in the causal chain relative to $Y$. In reality, we don't know what this breakdown is, and breaking $X$ up to $X_1$ and $X_2$ may not be trivial due to entanglement, but it is still reasonable to assume that some components of the input $X$ encode causal parents of $Y$, and others encode causal descendants of $Y$.

The environment $E$ influences every factor in this generative model, except the factor $p(Y\vert X_1, W)$: notice there is no arrow from $E$ to $Y$. In other words, it is assumed that the relationship of variable $Y$ to its observable causal parents $X_1$ and hidden variables $W$ is stable across all environments. This is the primary underlying assumption of IRM. Now, let's read out some conditional independence relationships from this graph:

- $Y \cancel{\perp\mkern-13mu\perp} E$: this is simply saying that the marginal distribution of $Y$ can, generally, change across environments.
- $Y \perp\mkern-13mu\perp E\vert X_1, W$: The observable $X_1$ and latent $W$ shield the label $Y$ from the influence of the environment. I already said that this is the key assumption on which IRM is based: that there is an underlying causal mechanism determining the value of Y from its causal parents, which does not change across environments.
- $Y \cancel{\perp\mkern-13mu\perp} E\vert X_1$: If we leave $W$ out of the conditioning, the above environment-independence no longer holds. This is because the confounder $W$ inroduces spurious association between $X_1$ and $Y$. This spurious correlation is assumed to be environment-dependent.
- $Y \cancel{\perp\mkern-13mu\perp} E\vert X_1, X_2$: this is, perhaps, the most important point. This dependence statement says that the way $Y$ depends on the observable variables $X = (X_1, X_2)$ is environment-dependent. This can be verified by noticing that $X_2$ is a collider between $E$ and $Y$. Conditioning on a collider introduces spurious correlations (for example, explaining away).

In summary, the association between $X$ and $Y$ will be a result of three sources of correlation:

- real causal relationship between some components of $X$ and $Y$.
- spurious association introduced by the unobserved confounder $W$.
- spurious association introduced by conditioning on parts of $X$ which is are causally influenced by $Y$, rather than the other way around.

In a general, if our generative model describes the world accurately, the conditional independence statements we observed tell us that while the real causal association is stable across environments ($Y \perp\mkern-13mu\perp E\vert X_1, W$), the other two are environment-dependent ($Y \cancel{\perp\mkern-13mu\perp} E\vert X_1$ and $Y \cancel{\perp\mkern-13mu\perp} E\vert X_1, X_2$). Thus, we can eliminate the spurious associations by seeking associations that are stable across environments, i.e. independent of $E$.

One can interpret the objective of Invariant Risk Minimisation as seeking a representation of observable variables $\Phi(x)$, such that:

- $Y \perp\mkern-13mu\perp E\vert \phi(X)$, and
- $\phi$ is informative about $y$, i.e. we can predict $y$ accurately from $\phi(x)$

This is a bit similar to the information bottleneck criterion which would seek a stochasstic representation $Z = \phi(X, \nu)$, $\nu$ being some random noise, by solving the following optimization problem:

$$

\max_{\phi} \left\{ I[Y, Z] - \beta I[X, Z] \right\}

$$

We could similarly attempt to find an invariant repesentation $Z = \phi(X)$ by minimizing an objective like:

$$

\max_{\phi} \left\{ I[Y, Z] - \beta I[Y, E \vert Z] \right\}

$$

Notice, that one can write mutual information $\mathbb{I}[Y, E \vert \phi(X)] $ in the following variational interpretation:

$$

I[Y, E \vert \phi(x)] = \max_q \min_r \mathbb{E}_{x,y,e} [\log q(y\vert \phi(x), e) - \log r(y\vert \phi(x))]

$$

With a little bit of math juggling, we can write the above stability objective in the following form:

\begin{align}

\max_{\phi} \left\{ I[Y, \phi(X)] - \beta I[Y, E \vert \phi(X)] \right\} = \max_\phi \max_r \min_q \mathbb{E}_{x,y,e} \left[\log r(y\vert \phi(x)) - \frac{\beta}{1 + \beta}\log q(y\vert \phi(x), e) \right]

\end{align}

This optimization problem is intuitively already very similar to IRM: we would like a representation $\phi$ so we can predict $y$ from it accurately, but we shouldn't be able to build a much better predictor by overfitting to one of the environments. In other words: here, too, we seek a representation so that the Bayes optimal predictor of $y$ is close to Bayes optimal simultanously in all environments.

Sadly, this is a minimax type of problem, so $\phi$ can only be found with a GAN-like iterative algorithm - something we don't really like, but we're kind of getting used to. It would be interesting to see how such algorithm would work, and if you're aware of this being done already, please feel to point me to the relevant references in the comments section.

I wanted to add sidenote here that I think we can recover something very similar in spirit to (IRMv1). I leave developing the full connection as a homework to you. Let me just illustrate the basic idea here.

Say we have a parametric family of functions $f(y\vert \phi(x); \theta)$ for predicting $y$ from $\phi(x)$. The conditional information can be approximated as follows:

\begin{align}

I[Y, E \vert \phi(x)] &\approx \min_\theta {E}_{x,y} \ell (f(y\vert \phi(x); \theta) - \mathbb{E}_e \min_{\theta_e} \mathbb{E}_{x,y\vert e} \ell (f(y\vert \phi(x); \theta_e)\\

&= \min_\theta \mathbb{E}_e \mathcal{R}^e(f_\theta\circ\phi) - \mathbb{E}_e \min_{\theta_e} \mathcal{R}^e(f_{\theta_e}\circ\phi)

\end{align}

where $\ell$ is the log-loss, if we want to recover Shannon's information. If we assume that $f$ is a universal function approximator, an equality holds. If, instead of globally optimizating $\theta_e$, we only search locally within a trust region around $\theta$, we can create the following (approximate) lower-bound to the information.

\begin{align}

I[Y, E \vert \phi(x)] &\geq \min_\theta {E}_{x,y}\ell f(y\vert \phi(x); \theta) - \mathbb{E}_e \min_{\|d\|^2\leq \epsilon} \mathbb{E}_{x,y\vert e} \ell f(y\vert \phi(x); \theta + d) \\

&= \min_\theta \mathbb{E}_e \left\{ \mathcal{R}^e(f_\theta\circ\phi) - \min_{\|d_e\|\leq \epsilon}\mathcal{R}^e(f_{\theta + d_e}\circ\phi) \right\}

\end{align}

Now, we can approximate the risk $\mathcal{R}^e(f_{\theta + d_e}\circ\phi) $ locally by a first order Taylor approximation around $\theta$, and show that, as $\epsilon \rightarrow 0$, we obtain that:

$$

I[Y, E \vert \phi(x)] \geq \min_\theta \mathbb{E}_e \| \nabla_\theta \mathbb{E}_{x,y\vert e} [\ell f(y\vert \phi(x), \theta)] \|_2= \min_\theta \mathbb{E}_e \| \nabla_\theta \mathcal{R}^e(f_\theta\circ\phi) \|_2

$$

Compare this with the second term in Eqn IRMv1 of the paper. If we now add back in the requirement that we would like to be able to predict $y$ from $\phi(x)$, we get an optimization problem of the following form:

$$

\min_\phi \left\{ \min_\theta \mathbb{E}_e \mathcal{R}^e(f_\theta\circ\phi) + \lambda \min_\theta \mathbb{E}_e \| \nabla_\theta \mathcal{R}^e(f_\theta\circ\phi) \|_2 \right\},

$$

which is almost like the IRM objective. Technically, there are two minimizations over $\theta$ and there's no reason why the two shouldn't be done separately. Note however, that a global minimum of the second term $\mathbb{E}_e \| \nabla_\theta \mathcal{R}^e(f_\theta\circ\phi) \|_2$ is always a local minimum of the first term. This justifies connecting the two minimization problems together:

$$

\min_\phi \min_\theta \left\{ \mathbb{E}_e \mathcal{R}^e(f_\theta\circ\phi) + \lambda \mathbb{E}_e \| \nabla_\theta \mathcal{R}^e(f_\theta\circ\phi) \|_2 \right\},

$$

This is no longer an ugly minimax problem, however, it is still not amazing. It is a lower bound to the original objective which we originally wished to minimize. The lower bound was created when we replaced global minimization over local minimization. Thus, the bound is actually tight if all local minima with respect to $\theta$ of $\mathcal{R}^e(f_\theta\circ\phi)$ are also global minima, e.g. if the loss is convex. In non-convex problems, all bets are off. It may still work, but who knows.

This is indeed a nice paper, with lots of great insights. Unfortunately, I am not sure how realistic the assumptions are that we can sample from a multitude of different environments, which differ from each other sufficiently so that the invariant causal quantities can be identified.

I would like to mention that this is not the first time that invariance and causality have been connected and exploited for domain adaptation. I personally first encountered this idea in a talk by Jonas Peters at the Causality Workshop in 2018. Here is a related paper by him that I wanted to highlight here:

- Peters, Bühlmann, Meinshausen (2016) Causal inference by using invariant prediction: identification and confidence intervals

And here are two more papers which propose a causal treatment of the domain adaptation problem:

- Adarsh Subbaswamy, Peter Schulam, Suchi Saria (2018) Preventing Failures Due to Dataset Shift: Learning Predictive Models That Transport
- Christina Heinze-Deml, Nicolai Meinshausen (2019) Conditional Variance Penalties and Domain Shift Robustness

The second paper, which commenters also pointed out to me, is perhaps the most closely related, but it is based on slightly different assumptions about what is invariant across the domains.

Finally, commenters asked me about domain-adversarial learning, so I wanted to include a pointer here for completeness:

- Yaroslav Ganin, Victor Lempitsky (2014) Unsupervised Domain Adaptation by Backpropagation

On this paper, I agree with Arjovsky et al (2019)'s discussion in the IRM paper: it promotes the wrong invariance property by trying to learn a data representation that is marginally independent of the domain index. See the discussion on this in the comments section below.

Finally, I wanted to point out another slightly looser connection: non-stationarity, or the availability of data from multiple environments has been exploited by (Hyvarinen and Morioka, 2016) exploits this idea for unsupervised feature learning. It turns out, this non-stationarity and the availability of different environments makes otherwise non-identifiable nonlinear ICA models identifable.

]]>- Ruiz and Titsias (2019) A Contrastive Divergence for Combining Variational Inference and MCMC

Welcome to my ICML 2019 jetlag special - because what else do you do when you wake up earlier than anyone than write a blog post. Here's a paper that was presented yesterday which I really liked.

- Ruiz and Titsias (2019) A Contrastive Divergence for Combining Variational Inference and MCMC

First, some background on why I found this paper particulartly interesting. When AlphaGo Zero came out, I wrote a post about the *principle of minimal improvement*: Suppose you have an operator which may be computationally expensive to run, but which that can take a policy and improve it. Using such improvement operator you can define an objective function for policies by measuring the extent to which the operator changes a policy. If your policy is already optimal, the operator can't improve it any further, so the change will be zero. In the case of AlphaGo Zero, the improvement operator is Monte Carlo Tree Search (MCTS). I noted in that post how the same principle may be applicable to approximate inference: expectation propagation and contrastive divergence both can be casted in a similar light.

The paper I'm talking about uses a very similar argument to come up with a contrastive divergence for variational inference, where the improvement operator is MCMC step.

The two dominant ways of performing inference in latent variable models are variational inference (including amortized inference, such as in VAE), and Markov Chain Monte Carlo (MCMC). VI approximates the posterior with a paramteric distribution. This can be computationally efficient and practically convenient as VI results in end-to-end differentiable, unbiased estimates of the evidence lower bound (ELBO). As a drawback, the paramteric posterior approximation often can't approximate the posterior perfectly, and an approximation error almost always remains. By contrast, an MCMC approximate posterior can always be improved by running the chains longer, and obtaining more independent samples, but it is more difficult to work with and computationally more demanding than VI.

A lot of great work has been done recently on combining VI with MCMC (see references in Ruiz and Titsias, 2019). Usually, one starts from a crude, parametric variational approximation to the posterior, and then improves it by running a couple steps of MCMC. Crucially, one can view the transition kernel, $\Pi$, of the MCMC as an improvement operator: given any distribution $q$, taking an MCMC steps should take you closer to the posterior. In other words, $\Pi q$ is an improvement over $q$. Taking multiple steps, i.e. $\Pi^t q$, should provide an even greater improvement. Improvement is $0$ only when $\Pi q = q$ which only holds for the true posterior. It is a reasonable criterion therefore to seek a posterior approximation $q_\theta$ such that the improvement in $\Pi^t q_\theta$ over over $q_\theta$ is minimized.

There are two ways of quantifying the amount of improvement, or change, that MCMC provides over a parametric posterior $q_\theta$. The first measures how much closer we got to the posterior $p$ by comparing the KL divergences:

$$

\mathcal{L}_1(\theta) = \operatorname{KL}\left[q_\theta\middle\|p\right] - \operatorname{KL}\left[\Pi^tq_\theta\middle\|p\right]

$$

The second one measures the amount of change between $\Pi^tq_\theta$ and $q_\theta$, measured as the KL divergence. This objective function merely tries to identify fixed points of the improvement operator:

$$

\mathcal{L}_2(\theta) = \operatorname{KL}\left[\Pi^tq_\theta\middle\|q_\theta\right]

$$

Either of these objectives would make sense and is justified on its own right, but sadly neither of them can be evaluated or optimized easily in this case. Both require taking expectations over $\Pi^tq_\theta$, as well as evaluating $\log \Pi^tq_\theta$. However, the brilliant insight in this paper is that when you sum them together, the most problematic terms cancel out, leaving you with a tractable objective to minimise.

$$

\mathcal{L}_1(\theta) + \mathcal{L}_2(\theta) = \mathbb{E}_{z\sim \Pi^t q_\theta} f_\theta(z) - \mathbb{E}_{z\sim q_\theta} f_\theta(z),

$$

where $f_theta(z) = \log p(z,x) - \log q_\theta(z)$, $x$ is the observed, $z$ the hidden variable.

There is one more technical hurdle to overcome, which is to calculate or estimate the derivative of this objective with respect to $\theta$. The authors propose a REINFORCE-like score function gradient estimator in Eqn. (12), which is somewhat worrying as it is known to have very high variance. The authors propose overcoming this using a control variate. For more details, please refer to the paper.

There is further discussion on the behaviour of this objective function in the limit of infinitely long MCMC paths, i.e. $t\rightarrow\infty$. It turns out, the criterion works like the symmetrized KL divergence $KL[q\|p] + KL[p\|q]$. The difference of this objective from the usual conservative mode and seeking VI objective is neatly illustrated in Figure 1 of the paper:

Variational Contrastive Divergence (VCD) favours posterior approximations which have a much higher coverage of the true posterior compared to VI, which tries to cover the modes and tries to avoid allocating mass to areas where the true posterior does not.

]]>- Kunstner, Balles and Hennig (2019) Limitations of the Empirical Fisher Approximation

I was debating with myself whether I should write a post about this because it's a superbly written paper that you should probably

]]>This post is a short not on an excellent recent paper on empirical Fisher information matrices:

- Kunstner, Balles and Hennig (2019) Limitations of the Empirical Fisher Approximation

I was debating with myself whether I should write a post about this because it's a superbly written paper that you should probably read in full. There isn't a whole lot of novelty in the paper, but it is a great discussion paper that provides a concise overview of the Fisher information, the empirical Fisher matrix and their connectinos to generalized Gauss-Newton methods. For more detail, I also recomend this older and longer overview by James Martens.

For a supervised learning model $p_\theta (y\vert x)$ the Fisher infromation matrix for a specific input $x$ is defined as follows:

$$

F_x = \mathbb{E}_{p_\theta(y\vert x)} \left[ \nabla_\theta \log p_\theta(y\vert x) \nabla_\theta \log p_\theta(y\vert x)^\top \right]

$$

The Fisher information roughly measures the sensitivity of the model's output distribution to small changes in the parameters $\theta$. For me, the most intuitive interpretation is that it is the local curvature $KL[p_\theta\vert p_{\theta'}]$ as $\theta' \rightarrow \theta$. So, by using the Fisher infromation as a quadratic penalty, or as a preconditioner in second order gradient descent, you're penalizing changes to the model's output distribution as measured by the KL divergence. This makes sense especially for neural networks, where small changes in network parameters can lead to arbitrary changes in the function the model actually implements.

If you have an entire distribution or a dataset, you normally add up the pointwise $F_x$ matrices to form the Fisher information for the whole training data:

$$

F = \sum_n F_{x_n} = \sum_n \mathbb{E}_{p_\theta(y\vert x_n)} \left[ \nabla_\theta \log p_\theta(y\vert x_n) \nabla_\theta \log p_\theta(y\vert x_n)^\top \right]

$$

In pracice, however, people often work with another matrix instead, which is called the empirical Fisher information matrix, defined as follows:

$$

\tilde{F} = \sum_n\left[ \nabla_\theta \log p_\theta(y_n\vert x_n) \nabla_\theta \log p_\theta(y_n\vert x_n)^\top \right]

$$

The difference, of course, is that rather than sampling $y$ from the model itself, here, we use the observed labels $y_n$ from the dataset. This matrix, while it's often used, does not have a principled motivation. The paper shows how it lacks most properties that make the Fisher information so useful quantity in many cases, and it also lacks any principled motivation. The paper's simple, yet insightful, examples show how using the empirical Fisher instead of the real Fisher in second order optimization can lead to bonkers results:

What the above example shows is the vector field corresponding to differently preconditioned gradient descent algorithms in a two-parameter simple least squares linear regresesion example. The first panel shows the gradients field. The second shows the natural graident field, i.e. gradients corrected by the inverse Fisher information. The third shows the gradients corrected by the empirical Fisher instead. You can see that the empirical Fisher does not make the optimization problem any easier or better conditioned. The bottom line is: you probably shouldn't use it.

One particular *myth* that Kunstner et al (2019) try to debunk is that the Fisher and empirical Fisher coincide near minima of the loss function, and therefore either can be safely usef once you're done with training your model. I have made this argument several times in the past myself. If you look at he definitions, it's easy to conclude that they should indeed coincide under two conditions:

- that the model $p_\theta(y\vert x)$ closely approximates the
*true*sampling distribution $p_\mathcal{D}(y\vert x)$, and - that you have sufficiently large $N$ (number of samples) so that both the Fisher and empirical Fisher converge to their respective population average.

The paper argues that these two conditions are often at odds in practice. One either uses a simple model, in which case model misspecification is likely, meaning $p_\theta(y\vert x)$ can't approximate the true underlying $p_\mathcal{D}(y\vert x)$ well. Or, one uses an expressive, overparametrised model, such as a big neural network, in which case some form of overfitting likely happens, and $N$ is too small for the two matrices to coincide. The paper offers a lot more nuance to this, and again, I highly recommend reading it whole.

The paper also looks at comparing the direction of the natural gradient with the empirical Fisher-corrected gradients, and finds that the EF gradient often tries to go in a very different direction than the natural gradient.

It's also interesting to look at the projection of the EF gradient onto the original gradient, to see how big an EF-gradient update is in the direction of normal gradient descent.

The gradient for a single datapoint $(x_n, y_n)$ is g_n = -\nabla_\theta \log p_\theta(y_n\vert x_n). If we use the EF matrix as preconditioner, the gradient is modifierd to $\tilde{g}_n = \tilde{F}^{-1}g_n = -\tilde{F}^{-1} \nabla_\theta \log p_\theta(y_n\vert x_n)$. Let's look at the inner product $g^\top_n\tilde{g}_n$ on average over the dataset $p_\mathcal{D}$:

\begin{align}

\mathbb{E}_{p_\mathcal{D}} g^\top_n\tilde{g}_n &= \mathbb{E}_{p_\mathcal{D}}g^\top_n\tilde{F}^{-1}g_n \\

&= \operatorname{tr} \mathbb{E}_{p_\mathcal{D}} g^\top_n\tilde{F}^{-1}g_n\\

&= \operatorname{tr}\mathbb{E}_{p_\mathcal{D}} g_ng^\top_n \tilde{F}^{-1}\\

&= \operatorname{tr} \tilde{F}\tilde{F}^{-1}\\

&= D,

\end{align}

where $D$ is the dimensionality of $\theta$. This means that the projection of the EF gradient to the original gradient is constant. Since the length of the gradient $g_n$ can change as a function of $\theta$, this implies that, weirdly, there is an inverse relationship between the norm of $g_n$ and the norm of $\tilde{g}_n$ when projected in the same direction. On the whole, this is consistent with the weird-looking vector field in Figure 1 of the paper.

In summary, I highly recommend reading this paper. It has a lot more than what I covered here, and it is all explained rather nicely. The main message is clear: think twice (or more) about using empirical Fisher information for pretty much any application.

]]>Video is an interesting domain for unsupervised, or self-supervised, representation learning. But we still don't know what type of inductive biases will enable us to best exploit the information encoded in the temporal sequence of video frames. Slow Feature Analysis (SFA) and its more recent cousin Learning to Linearize (e.g. Goroshin et al., 2016) attempt to learn neural representations of images such that a trajectory spanned spanned by the sequence of transformed video frames satisfies a desired property, such as *slowness* or linearity/*straightness*. Whether or not the brain employs similar principles to learn representations is a fascinating topic.

Computational neuroscientists showed examples where representations learnt by the brain are consistent with the slowness principle (see e.g. Franzius et al, 2007). Brand new work by Eero Simoncelli's group suggests that the brain's representations are also consistent with the straightness principle:

- Hénaff, Goris and Simoncelli (2019): Perceptual straightening of natural videos, Nature Neuroscience pdf

This work is in fact a culmination of some great work by the authors over several years on developing the techniques that eventually enabled this study.

As you play a video, the frames of the video span a trajectory in pixel space. This trajectory is usually very complicated even for videos representing relatively simple actions. Imagine a video scene with he camera slowly panning over a static texture, such as a field of grass, at a constant speed. Despite the video appeaers conceptually simple and predictable, the trajectory of individual pixel values is highly non-linear and crazy: most pixels abruptly change their color from one frame to the next.

Instead of representing videos as sequences of frames in pixel space, we would like to map frames into a representation space, such that the trajectory behaves a bit more normally and predictably. For example, it makes sense to hope that in representation space the trajectory is straight or linear, meaning that it has a small curvature. If a trajectory is straight, it also means that linear extrapolation from the last couple points to the future works well.

Imagine we have a sequence (of frames) $x_t$. We map this into a representation space via a nonlinear mapping $f$. We can measure the local *straightness* of such representation by calculating the angle between the difference vectors $f(x_{t+1}) - f(x_{t})$ and $f(x_{t}) - f(x_{t-1})$. This angle is given by:

\begin{align}

\cos\alpha_t = \frac{(f(x_{t+1}) - f(x_{t}))^T(f(x_{t}) - f(x_{t-1}))}{\|(f(x_{t+1}) - f(x_{t}))\|\|(f(x_{t}) - f(x_{t-1}))\|}

\end{align}

The closer this value is to $1$, the straighter the sequence $f(x_{t-1}), f(x_{t}), f(x_{t+1})$ is locally. One can measure the straightness of the whole trajectory by averaging this local measure of curvature across time. It's worth noting that in high dimensional spaces, straightness is difficult to achieve. A $\cos\alpha$ close to $1$ is a very rare. Under Gaussian distributions, two randomly chosen vectors will be perpendicular to one another, yielding a cosine of $0$. So, for example, straight trajectories have an almost $0$ probability under a high-dimensional Brownian motion or Ornstein–Uhlenbeck (OU) process. As slow-feature analysis can be seen as learning representations that behave roughly like OU processes (Turner and Sahani, 2017), SFA would not, by default, yield straight trajectories. If one would like to recover straighter trajectories, we need stronger priors and constraints, such as $k$-times integrated OU processes. See also (Goroshin et al., 2016) for another approach to recover straight representations.

When we train our own a neural network, we can of course directly evaluate the straightness of learnt representations. But ho we strudy whether the representations learnt by the brain exhibit straight trajectories or not? We can't directly observe the representations learned by the human brain. Instead, Hénaff et al (2019) devised an an indirect way to measure the curvature, i.e. infer it from behavioural/psychophysics data. Their approach relies on ideal observer models for a certain perceptual discrimination task: they infer perceptual curvature from data about how accurately subjects can differentiate between pairs of stimuli.

Here is how to intuitively understand why the approach makes sense: We assume that whether or not a human is able to discriminate between two stimuli, when flashed for a fraction of a second, is related to Euclidean distance in the person's representation space between the reprsentations of the two stimuli. Now, if you take a triplet of stimuli (video frames) $A, B, C$ and measure the pairwise distances $d_{AB}, d_{AC}, d_{BC}$ between them, you can establish the straightness of them as a sequence: if the sequence is straight, you expect that $d_{AC}\approx d_{AB} + d_{AC}$. On the other hand, if the sequence is not straight, you'd expect $d^{2_{AC}\approx d}2_{AB} + d^2_{AC}$ to hold instead. So, if we have a model - called a normative or ideal observer model - that can somehow directly relate Euclidean distances in representation space to perceptual discriminability of stimuli, we have a way to experimentally infer the straightness of representations.

For more detailed description of the psychophysics experiments, data preparation, and the ideal observer model, please see the paper.

The main results of the paper - as expected - is that natural video sequences indeed appear to be mapped to straight trajectories in representation space. Furthermore, the figure below (Fig. 3b in the article) shows that the straightness of trajectories which correspond to natural stimuli is higher than the straightness of unnatural, artificially generated stimuli:

The blue curves correspond to a natural video stimulus in pixel space and in (inferred) representation space. You can see that the sequence is highly curved in pixel space, but in representation space it is straight. On the other hand, the green curves show an artificial video created by linearly interpolating between the first and the last frame of a video. This of course results in a straight trajectory in pixel space. However, the inferred trajectory in representation space is highly curved, which was the author's prediction as this is an unnatural stimulus for the human.

I really liked this paper and it is an excellent example of ideal observer model-based analysis of psychophysics data. It formulates an interesting hypothesis and finds a very nice way to answer it from well-designed experimental data.

That said, I was left wondering just how critically the results depend on assumptions of the ideal observer model, and how robust the findings are to changing some of those assumptions. For one, the paper assumes a Gaussian observation noise in representation space, and I wonder how robust the analysis would be to assuming heavy-tailed noise. Similarly, our very definition of straightness and angles relies on the assumption that the representation space is somehow Euclidean. If the representation space is not in fact Euclidean, our analysis might pick up on that, rather than the straightness of specific trajectories. That said, the most convincing aspect of the paper is contrasting natural vs unnatural stimuli. No matter how precise our inference of straightness is, and how much the actual values depend on changing these assumptions, this is a very reassuring finding.

]]>One of my favourite recent innovations in neural network architectures is Deep Sets. This relatively simple architecture can implement arbitrary set functions: functions over collections of items where

]]>One of my favourite recent innovations in neural network architectures is Deep Sets. This relatively simple architecture can implement arbitrary set functions: functions over collections of items where the order of the items does not matter.

This is a guest post by Fabian Fuchs, Ed Wagstaff and Martin Engelcke, authors of a recent paper on the representational power of such architectures and why the deep sets architecture can represent arbitrary set functions in theory. It's a great paper. Imagine what these guys could achieve if their lab was in Cambridge rather than Oxford!

Here are the links to the original Deep Sets paper, and the more recent paper by the authors of this post:

- Zaheer, Kottur, Ravanbakhsh, Poczos, Salakhutdinov and Smola (NeurIPS 2017) Deep Sets
- Wagstaff, Fuchs, Engelcke, Posner and Osborne (2019) On the Limitations of Representing Functions on Sets

Over to Fabian, Ed and Martin for the rest of the post. Enjoy.

Most successful deep learning approaches make use of the structure in their inputs: CNNs work well for images, RNNs and temporal convolutions for sequences, etc. The success of convolutional networks boils down to exploiting a key invariance property: translation invariance. This allows CNNs to

- drastically reduce the number of parameters needed to model high-dimensional data
- decouple the number of parameters from the number of input dimensions, and
- ultimately, to become more data efficient and generalize better.

But images are far from the only data we want to build neural networks for. Often our inputs are sets: sequences of items, where the ordering of items caries no information for the task in hand. In such a situation, the invariance property we can exploit is permutation invariance.

To give a short, intuitive explanation for permutation invariance, this is what a permutation invariant function with three inputs would look like: $f(a, b, c) = f(a, c, b) = f(b, a, c) = \dots$.

Some practical examples where we want to treat data or different pieces of higher order information as sets (i.e. where we want permutation invariance) are:

- working with sets of objects in a scene (think AIR or SQAIR)
- multi-agent reinforcement learning
- perhaps surprisingly, point clouds

We will talk more about applications later in this post.

*Note from Ferenc:* I would like to jump in here - because it's my blog so I get to do that - to say that I think the killer application for this is actually meta-learning and few-shot learning. By meta-learning, don't think of anything fancy, I consider amortized variational inference, like a VAE, as a form of meta-learning. Consider a conditionally i.i.d model where you have a global parameter $\theta$, and a bunch of observations $x_i$ drawn conditionally i.i.d from a distribution $p_{X\vert \theta}$. Given a set of observations $x_1, \ldots, x_N$ we'd like to approximate the posterior $p(\theta\vert x_1, \ldots, x_N)$ by some parametric $q(\theta\vert x_1, \ldots, x_N; \psi)$, and we want this to work for any number of observations $N$. Clearly, the real posterior $p$ has a permutation invariance with respect to $x_n$, so it would make sense to make the recognition model, $q$, a permutation-invariant architecture. To me, this is the killer application of deep sets, especially in an on-line learning setting, where one wants to update our posterior estimate over some parameters with each new data point we observe.

Having established that there is a need for permutation-invariant neural networks, let's see how to enforce permutation invariance in practice. One approach is to make use of some operation $P$ which is already known to be permutation-invariant. We map each of our inputs separately to some latent representation and apply our $P$ to the set of latents to obtain a latent representation of the set as a whole. $P$ destroys the ordering information, leaving the overall model permutation invariant.

In particular, Deep Sets does this by setting $P$ to be summation in the latent space. Other operations are used as well, e.g. elementwise max. We call the case where the sum is used *sum-decomposition via the latent space*. The high-level description of the full architecture is now reasonably straightforward - transform your inputs into some latent space, destroy the ordering information in the latent space by applying the sum, and then transform from the latent space to the final output. This is illustrated in the following figure:

![](/content/images/2019/02/Architecture.png)

If we want to actually implement this architecture, we'll need to choose our latent space (in the guts of the model this will mean something like choosing the size of the output layer of a neural network). As it turns out, the choice of latent space will place a limit on how expressive the model is. In general, neural networks are universal function approximators (in the limit), and we'd like to preserve this property. Zaheer et al. provide a theoretical analysis of the ability of this architecture to represent arbitrary functions - that is, can the architecture, in theory, achieve exact equality with any target function, allowing us to use e.g. neural networks to approximate the necessary mappings? In our paper, we build on and extend this analysis, and discuss what implications it has for the choice of latent space.

Zaheer et al. show that, if we're only interested in sets drawn from a countable domain (e.g. $\mathbb{Z}$ or $\mathbb{Q}$), a 1-dimensional latent space is enough to represent any function. Their proof works by defining an injective mapping from sets to real numbers. Once you have an injective mapping, you can recover all the information about the original set, and can, therefore, represent any function. This sounds like good news -- we can do anything we like with a 1-D latent space! Unfortunately, there's a catch -- the mapping that we rely on is not continuous. The implication of this is that to recover the original set, even approximately, we need to know the exact real number that we mapped to -- knowing to within some tolerance doesn't help us. This is impossible on real hardware.

Above we considered a countable domain, but it's important to consider instead the uncountable domain $\mathbb{R}$, the real numbers. This is because continuity is a much stronger property on $\mathbb{R}$ than on $\mathbb{Q}$, and we need this stronger notion of continuity. The figure below illustrates this, showing a function which is continuous on $\mathbb{Q}$ but not continuous on $\mathbb{R}$ (and certainly not continuous in an intuitive sense). The figure is explained in detail in our paper. Using $\mathbb{R}$ is particularly important if we want to work with neural networks. Neural networks are universal approximators *for continuous functions on compact subsets of $\mathbb{R}^M$*. Continuity on $\mathbb{Q}$ won't do.

![](/content/images/2019/02/continuous.png)

Zaheer et al. go on to provide a proof using continuous functions on $\mathbb{R}$, but it places a limit on the set size for a fixed finite-dimensional latent space. In particular, it shows that with a latent space of $M+1$ dimensions, we can represent any function which takes as input sets of size $M$. If you want to feed the model larger sets, there's no guarantee that it can represent your target function.

As for the countable case, the proof of this statement uses an injective mapping. But the functions we're interested in modeling aren't going to be injective -- we're distilling a large set down into a smaller representation. So maybe we don't need injectivity -- maybe there's some clever lower-dimensional mapping to be found, and we can still get away with a smaller latent space?

**No.**

As it turns out, you often do need injectivity into the latent space. This is true even for simple functions, e.g. max, which is clearly far from injective. This means that if we want to use continuous mappings, the dimension of the latent space must be at least the maximum set size. We were also able to show that this dimension suffices for universal function representation. That is, we've improved on the result from Zaheer (latent dimension $N \geq M+1$ is sufficient) to obtain both a weaker sufficient condition, and a necessary condition (latent dimension $N \geq M$ is sufficient *and necessary*). Finally, we've shown that it's possible to be flexible about the input set size. While Zaheer's proof applies to sets of size exactly $M$, we showed that $N=M$ also works if the set size is allowed to vary $\leq M$.

Why do we care about all of this? Sum-decomposition is in fact used in many different contexts - some more obvious than others - and the above findings directly apply in some of these.

Self-attention via {keys, queries, and values} as in the *Attention Is All You Need* paper by Vaswani et al. 2017 is closely linked to Deep Sets. Self-attention is itself permutation-invariant unless you use positional encoding as often done in language applications. In a way, self-attention "generalises" the summation operation as it performs a weighted summation of different attention vectors. You can show that when setting all keys and queries to 1.0, you effectively end up with the Deep Sets architecture.

Therefore, self-attention inherits all the sufficiency statements ('with $N=M$ you can represent everything'), but not the necessity part: it is not clear that $N=M$ is needed in the self-attention architecture, just because it was proved that it is needed in the Deep Sets architecture.

Point clouds are unordered, variable length lists (aka sets) of $xyz$ coordinates. We can also view them as (sparse) 3D occupancy tensors, but there is no 'natural' 1D ordering because we have three equal spatial dimensions. We could e.g. build a kd-tree but again this imposes a somewhat 'unnatural' ordering.

As a specific example, PointNet by Qi et al. 2017 is an expressive set-based model with some more bells and whistles. It handles interactions between points by (1) computing a permutation-invariant global descriptor, (2) concatenating it with point-wise features, (3) repeating the first two steps several times. They also use transformer modules for translation and rotation invariance --- So. Much. Invariance!

A stochastic process corresponds to a set of random variables. Here we want to model the joint distributions of the values those random variables take. These distributions need to satisfy the condition of exchangeability, i.e. they need to be invariant to the order of the random variables.

Neural Processes and Conditional Neural Processes (both by Garnelo et al. 2018) achieve this by computing a global latent variable via summation. One well-known instance of this is *Generative Query Networks* by Eslami et al. 2018 which aggregate information from different views via summation to obtain a latent scene representation.

*👋 Hi, this is Ferenc again. Thanks to Fabian, Ed and Martin for the great post.*

*Update:* As commenters pointed out, these papers are, of course, not the only ones dealing with permutation invariance and set functions. Here are a couple more things you might want to look at (and there are quite likely many more that I don't mention here - feel free to add more in the comments section below)

- (Vinyals et al, 2015) on handling permutation invariance in the seq2seq framework.
- (Lee et al, 2018) proposed an attention-based architecture to implement set functions.
- (Bloem-Reddy and Teh, 2019) for a more theoretical take on invariances in neural networks.

As I said before, I think that the coolest application of this type of architecture is in meta-learning situations. When someone mentions meta-learning many people associate to complicated "learning to learn to learn via gradient descent via gradient descent via gradient descent" kind of things. But in reality, simpler variants of meta-learning are a lot closer to being practically useful.

Here is an example of a recommender system developed by (Vartak et al, 2017) for Twitter, using this idea. Here, a user's preferences are summarized by the set of tweets they recently engaged with on the platform. This set is processed by a DeepSets architecture (the sequence in which they engaged with tweets is assumed to carry little information in this application). The output of this set function is then fed into another neural network that scores new tweets the user might find interesting.

Such architectures can prove useful in online learning or streaming data settings, where new datapoints arrive over time, in a sequence. For every new datapoint, one can apply the $\phi$ mapping, and then simply maintain a moving averages of these $\phi$ values. For binary classification, one can have a moving average of $\phi(x)$ for all negative examples, and another moving average for all positive examples. These moving averages then provide a useful, permutation-invariant summary statistics of all the data received so far.

In summary, I'm a big fan of this architecture. I think that the work of Wagstaff et al (2019) provides further valuable intuition on their ability to represent arbitrary set functions.

]]>Counterfactuals are weird. I wasn't going to talk about them in my MLSS lectures on Causal Inference, mainly because wasn't sure I fully understood what they were all about, let alone knowing how to explain it to others. But during the Causality Panel, David Blei made comments about about how

]]>Counterfactuals are weird. I wasn't going to talk about them in my MLSS lectures on Causal Inference, mainly because wasn't sure I fully understood what they were all about, let alone knowing how to explain it to others. But during the Causality Panel, David Blei made comments about about how weird counterfactuals are: how difficult they are to explain and wrap one's head around. So after that panel discussion, and with a grand total of 5 slides finished for my lecture on Thursday, I said "challenge accepted". The figures and story I'm sharing below are what I came up with after that panel discussion as I finally understood how to think about counterfactuals. I'm hoping others will find them illuminating, too.

This is the third in a series of tutorial posts on causal inference. If you're new to causal inferenece, I recommend you start from the earlier posts:

- Part 1: Intro to causal inference and do-calculus
- Part 2: Illustrating Interventions with a Toy Example
- ➡️️ Part 3: Counterfactuals
- Part 4 (in prep): Causal Diagrams, Markov Factorization, Structural Equation Models

Let me first point out that *counterfactual* is one of those overloaded words. You can use it, like Judea Pearl, to talk about a very specific definition of counterfactuals: a probablilistic answer to a "what would have happened if" question (I will give concrete examples below). Others use the terms like *counterfactual machine learning* or *counterfactual reasoning* more liberally to refer to broad sets of techniques that have anything to do with causal analysis. In this post, I am going to focus on the narrow Pearlian definition of counterfactuals. As promised, I will start with a few examples:

This is an example David brought up during the Causality Panel and I referred back to this in my talk. I'm including it here for the benefit of those who attended my MLSS talk:

*Given that Hilary Clinton did not win the 2016 presidential election, and given that she did not visit Michigan 3 days before the election, and given everything else we know about the circumstances of the election, what can we say about the probability of Hilary Clinton winning the election, had she visited Michigan 3 days before the election?*

Let's try to unpack this. We are are interested in the probability that:

- she
*hypothetically*wins the election

conditionied on four sets of things:

- she lost the election
- she did not visit Michigan
- any other relevant an observable facts
- she
*hypothetically*visits Michigan

It's a weird beast: you're simultaneously conditioning on her visiting Michigan and not visiting Michigan. And you're interested in the probability of her winning the election given that she did not. WHAT?

Why would quantifying this probability be useful? Mainly for credit assignment. We want to know why she lost the election, and to what degree the loss can be attributed to her failure to visit Michigan three days before the election. Quantifying this is useful, it can help political advisors make better decisions next time.

Here's a real-world application of counterfactuals: evalueting the efairness of individual decisions. Consider this counterfactual question:

*Given that Alice did not get promoted in her job, and given that she is a woman, and given everything else we can observe about her circumstances and performance, what is the probability of her getting a promotion if she was a man instead?*

Again, the main reason for asking this question is to establish to what degree being a woman is directly responsible for the observed outcome. Note that this is an individual notion of fairness, unlike the aggregate assessment of whether the promotion process is fair or unfair statistically speaking. It may be that the promotion system is pretty fair overall, but in the particular case of Alice unfair discrimination took place.

A counterfactual question is about a specific datapoint, in this case Alice.

Another weird thing to note about this counterfactual is that the intervention (Alice's gender magically changing to male) is not something we could ever implement or experiment with in practice.

Here's the example I used in my talk, and will use throughout this post: I want to understand to what degree having a beard contributed to getting a PhD:

*Given that I have a beard, and that I have a PhD degree, and everything else we know about me, with what probability would I have obtained a PhD degree, had I never grown a beard.*

Before I start describing how to express this as a probability, let's first think about what we intuitively expect the answer to be? In the grand scheme of things, my beard probably was not a major contributing factor to getting a PhD. I would have pursued PhD studies, and probably completed my degree, even if something would have prevented me to keep my beard. So

We expect the answer to this counterfactual to be a high probability, something close to 1.

Let's start with the simplest thing one can do to attempt to answer my counterfactual question: collect some data about individuals, whether they have beards, whether they have PhDs, whether they are married, whether they are fit, etc. Here's a cartoon illustration of such dataset:

I am in this dataset, or someone just like me is in the dataset, in row number four: I have a beard, I am married, I am obviously very fit (this was the point where I hoped the audience would get the joke and laugh, and thankfully they did), and I have a PhD degree. I have it all.

If we have this data and do normal statistical machine learning, without causal reasoning, we'd probably attempt to estimate $p(🎓\vert 🧔=0)$, the conditional probability of possessing a PhD degree given the absence of a beard. As I show at the bottom, this is like predicting values in one column from values in another column.

You hopefully know enough about causal inference by now to know that $p(🎓\vert 🧔=0)$ is certainly not the quantity we seek. Without additional knowledge of causal structure, it can't generalize to hypothetical scenarios and interventions we care about her. There might be hidden confounders. Perhaps scoring high on the autism spectrum makes it more likely that you grow a beard, and it may also makes it more likely to obtain a PhD. This conditional isn't aware if the two quantities are causally related or not.

We've learned in the previous two posts that if we want to reason about interventions, we have to express a different conditional distribution, $p(🎓\vert do(🧔=0))$. We also know that in order to reason about this distribution, we need more than just a dataset, we also need a causal diagram. Let's add a few things to our figure:

Here, I drew a cartoon causal diagram on top of the data just for illustration, it was simply copy-pasted from previous figures, and does not represent a grand unifying theory of beards and PhD degrees. But let's assume our causal diagram describes reality.

The causal diagram lets us reason about the distribution of data in an alternative world, a parallel universe if you like, in which everyone is somehow magically prevented to grow a beard. You can imagine sampling a dataset from this distribution, shown in the green table. We can measure the association between PhD degrees and beards in this green distribution, which is precisely what $p(🎓\vert do(🧔=0))$ means. As shown by the arrow below the tables, $p(🎓\vert do(🧔=0))$ is about predicting columns of the green dataset from other columns of the green dataset.

Can $p(🎓\vert do(🧔=0))$ express the counterfactual probability we seek? Well, remember that we expected that I would have obtained a PhD degree with a high probability even without a beard. However, $p(🎓\vert do(🧔=0))$ talks about a the PhD of a random individual after a no-beard intervention. If you take a random person off the street, shave their beard if they have one, it is not very likely that your intervention will cause them to get a PhD with a high probability. Not to mention that your intervention has no effect on most women and men without beards. We intuitively expect $p(🎓\vert do(🧔=0))$ to be close to the base-rate of PhD degrees $p(🎓)$, which is apparently somewhere around 1-3%.

$p(🎓\vert do(🧔=0))$ talks about a randomly sampled individual, while a counterfactual talks about a specific individual

Counterfactuals are "personalized" in the sense that you'd expect the answer to change if you substitute a different person in there. My father has a mustache, (let's classify that as a type of beard for pedagogical purposes), but he does not have a PhD degree. I expect that preventing him to grow a mustache would not have made him any more likely to obtain a PhD. So his counterfactual probability would be a probability close to 0.

The counterfactual probabilities vary from person to person. If you calculate them for random individuals, and average the probabilities, you should expect to to get something like $p(🎓\vert do(🧔=0))$ in expectation. (More on this connection later.) But we not interested in the population mean now, but are interested in calculating the probabilities for each individual.

To finally explain counterfactuals, I have to step beyond causal graphs and introduce another concept: structural equation models.

A causal graph encodes which variables have a direct causal effect on any given node - we call these causal parents of the node. A structural equation model goes one step further to specify this dependence more explicitly: for each variable it has a function which describes the precise relationship between the value of each node the value of its causal parents.

It's easiest to illustrate this through an example: here's a causal graph of an online advertising system, taken from the excellent paper of Bottou et al, (2013):

It doesn't really matter what these variables mean, if interested, read the paper. The dependencies shown by the diagram are equivalently encoded by the following set of equations:

For each node in the graph above we now have a corresponding function $f_i$. The arguments of each function are the causal parents of the variable it instantiates, e.g. $f_1$ computes $x$ from its causal parent $u$, and $f_2$ computes $a$ from its causal parents $x$ and $v$. In order to allow for nondeterministic relationship between the variables, we additionally allow each function $f_i$ to take another input, $\epsilon_i$ which you can think of as a random number. Through the random input $\epsilon_1$, the output of $f_1$ can be random given a fixed value of $u$, hence giving rise to a conditional distribution $p(x \vert u)$.

The structural equation model (SEM) entails the causal graph, in that you can reconstruct the causal graph by looking at the inputs of each function. It also entails the joint distribution, in that you can "sample" from an SEM by evaluating the functions in order, plugging in the random $\epsilon$s where needed.

In a SEM an intervention on a variable, say $q$, can be modelled by deleting the corresponding function, $f_4$, and replacing it with another function. For example $do(Q = q_0)$ would correspond to a simple assignment to a constant $\tilde{f}_4(x, a) = q_0$.

Now that we know what SEMs are we can return to our example of beards and degrees. Let's add a few more things to the figure:

First change is, that instead of just a causal graph, I now assume that we model the world by a fully specified structural equation model. I signify this lazily by labelling the causal graph with the functions $f_1, f_2, f_3$ over the graph. Notice that the SEM of the green situation is the same as the SEM in the blue case, except that I deleted $f_1$ and replaced it with a constant assignment. But $f_2$ and $f_3$ are the same between the blue and the green models.

Secondly, I make the existence of the $\epsilon_i$ noise variables explicit, and show their values (it's all made up of course) in the gray table. If you feed the first row of epsilons to the blue structural equation model, you get the first blue datapoint $0110$. If you feed the same epsilons to the green SEM, you get the first green datapoint $(0110)$. If you feed the second row of epsilons to the models, you get the second rows in the blue and green tables, and so on...

I like to think about the first green datapoint as the *parallel twin* of the first blue datapoint. To talk about interventions I talked about this making predictions about a *parallel universe* where nobody has a beard. Now imagine that for every person who lives in our observable universe, there is a corresponding person, their parallel twin, in this parallel universe. Your twin is same in every respect as you, except for the absence of any beard you might have and any downstream consequences of having a beard. If you don't have a beard in this universe, your twin is an exact copy of you in every respect. Indeed, notice that the first blue datapoint is the same as the first green datapoint in my example.

You may know I'm a Star Trek fan and I like to use Star Trek analogies to explain things: In Star Trek, there is this concept called the mirror universe. It's a parallel universe populated by the same people who live in the normal universe, except that everyone who is good in the real universe is evil in the mirror universe. Hilariously, the mirror version of Spock, one of the main protagonists, has a goatie in this mirror universe. This is how you could tell if you're looking at evil mirror-Spock or normal Spock when watching the episode. This explains, sadly, why I'm using beards to explain counterfactuals. Here are Spock and mirror-Spock:

Now that we established the *twin datapoint* metaphor, we can say that counterfactuals are

making a prediction about features of the unobserved twin datapoint based on features of the observed datapoint.

Crucially, this was possible because we used the same $\epsilon$s in both the blue and the green SEM. This induces a joint distribution between variables in the observable regime, and variables in the unobserved, counterfactual regime. Columns of the green table are no longer independent of columns of the blue table. You can start predicting values in the green table using values in the blue table, as illustrated by the arrows below them.

Mathematically, a counterfactual is the following conditional probability:

$$

p(🎓^\ast \vert 🧔^\ast = 0, 🧔=1, 💍=1, 💪=1, 🎓=1),

$$

where variables with an $^\ast$ are unobserved (and unobservable) variables that live in the counterfactual world, while variables without $^\ast$ are observable.

Looking at the data, it turns out that mirror-Ferenc, who does not have a beard, is married and has PhD, but is not quite as strong as observable Ferenc.

Here is another drawing that some of you might find more appealing, especially those who are are into GANs, VAEs and similar generative models:

A SEM is essentially a generative model of data, which uses some noise variables $\epsilon_1, \epsilon_2, \ldots$ and turns them into observations $(U,Z,X,Y)$ in this example. This is shown in the left-hand branch of the graph above. Now if you want to make counterfactual statements under the intervention $X=\hat{x}$, you can construct a *mutilated* SEM, which is the same SEM except with $f_3$ deleted and replaced with the constant assignment $x = \hat{x}$. This modified SEM is shown in the right-hand branch. If you feed the $\epsilon$s into the mutilated SEM, you get another set of variables $(U^\ast,Z^\ast,X^\ast,Y^\ast)$, shown in green. These are the features of the twin as it were. This joint generative model over $(U,Z,X,Y)$ and $(U^\ast,Z^\ast,X^\ast,Y^\ast)$ defines a joint distribution over the combined set of variables $(U,Z,X,Y,U^\ast,Z^\ast,X^\ast,Y^\ast)$. Therefore, now you can calculate all sorts of conditionals and marginals of this joint.

Of particular interest are these conditionals:

$$

p(y^\ast \vert X^\ast = \hat{x}, X = x, Y = y, U = u, Z = z),

$$

which is a counterfactual prediction. In reality, since $X^\ast = \hat{x}$ holds with a probability of $1$, we can drop that conditioning.

My notation here is a bit sloppy, there are a lot of things going on implicitly under the hood, which I'm not making explicit in the notation. I'm sorry if it causes any irritation to people, I want to avoid overcomplicating things at this point. Now is a good time to point out that Pearl's notation, including do-notation is often criticized, but people use it because now it's widely adopted.

We can also express the intervention conditional $p(y\vert do(x))$ using this (somewhat sloppy) notation as:

$$

p(y\vert do(X=\hat{x})) = p(y^\ast \vert X^\ast = \hat{x})

$$

We can see that the intervention conditional only contains variables with an $^\ast$ so it does not require the joint distribution of $(U,Z,X,Y,U^\ast,Z^\ast,X^\ast,Y^\ast)$ only the marginal of the $^\ast$ variables $(X^\ast, Z^\ast,X^\ast,Y^\ast)$. As a consequence in order to talk about $p(y\vert do(X=\hat{x}))$ we did not need to introduce SEMs or talk about the epsilons.

Furthermore, notice the following equality:

\begin{align}

p(y\vert do(X=\hat{x})) &= p(y^\ast \vert X^\ast = \hat{x}) \\

&= \int_{x,y,u,z} p(y^\ast \vert X^\ast = \hat{x}, X = x, Y = y, U = u, Z = z) p(x,y,u,z) dx dy du dz \\

&= \mathbb{E}_{p_{X,Y,U,Z}} p(y^\ast \vert X^\ast = \hat{x}, X = x, Y = y, U = u, Z = z),

\end{align}

in other words, the intervention conditional $p(y\vert do(X=\hat{x}))$ is the average of counterfactuals over the obserevable population. This was something that I did did not realize before my MLSS tutorial, and it was pointed out to me by a student in the form of a question. In hindsight, of course this is true!

God, this was a loooong post. If you're still reading, thanks, I hope it was useful. I wanted to close with a few slightly philosophical remarks on counterfactuals.

Counterfactuals are often said to be unscientific, primarily because they are **not empirically testable**. In normal ML we are used to benchmark datasets, and that the quality of our predictions can always be tested on some test dataset. In causal ML, not everything can be directly tested or empirically benchmarked. For interventions, the best test is to run a randomized controlled trial to directly measure $p(y\vert do(X=x))$ if you can, and then use this experiemental data to evaluate your causal inferences. But some interventions are impossible to carry out in practice. Think about all the work on fairness reasoning about interventions on gender or race. So what to do then?

In the world of counterfactuals this is an even bigger problem, as it is outright impossible to observe the variables you make predictions about. You can't go back in time and rerun history with exactly the same circumstances except for a tiny change. You can't travel to parallel universes (at least not before the 24th century, according to Star Trek). Counterfactual judgments remain hypothetical, subjective, untestable, unfalsifiable. There can be no MNIST or Imagenet for counterfactuals that satisfies everyone, though some good datasets exist, they are for specific scenarios where explicit testing is possible (e.g. offline A/B testing), or make use of simulators instead of "real" data.

Despite it being untestable, and difficult to interpret, humans make use of counterfactual statements all the time, and intuitively it feels like they are pretty useful for intelligent behaviour. Being able to pinpoint the causes that lead to a particular situation or outcome is certainly useful for learning, reasoning and intelligence. So my strategy for now is to ignore the philosophical debate about counterfactuals, and just get on with it, knowing that the tools are there if such predictions have to be made.

]]>Last week I had the honor to lecture at the Machine Learning Summer School in Stellenbosch, South Africa. I chose to talk about Causal Inference, despite being a newcomer to this whole area. In fact, I chose it exactly because I'm a newcomer: causal inference has been a blindspot for

]]>Last week I had the honor to lecture at the Machine Learning Summer School in Stellenbosch, South Africa. I chose to talk about Causal Inference, despite being a newcomer to this whole area. In fact, I chose it exactly because I'm a newcomer: causal inference has been a blindspot for me for a long time. I wanted to communicate some of the intuitions and ideas I learned and developed over the past few months, which I wish someone had explained to me earlier.

Now, I'm turning my presentation into a series of posts, starting with this one, building on the previous one I wrote in May. In this one, I will present the toy example I used in my talk to explain interventions. I call this the *three scripts* toy example. I was not sure if people are going to get it, but I got good feedback on it from the audience, so I'm hoping you will find it useful, too. Here are links to the other posts:

- Part 1: Intro to Causal Inference and do-Calculus
- ➡️️ Part 2: Illustrating Interventions with a Toy Example
- Part 3: Counterfactuals

Imagine you teach a programming course and you ask students to write a python script that samples from a 2D Gaussian distribution with a certain mean and covariance. Some of the solutions will be correct, but as there are multiple ways to sample from a Gaussian, you might see very different solutions. For example, here are three scripts that would implement the same, correct sampling behaviour:

Below each of the code snippets I plotted samples drawn by repeatedly executing the scripts. As you can see, all three scripts produce the same joint distribution between $x$ and $y$. You can feed these distributions into a two-sample test, and you will find that they are indeed indistinguishable from each other.

Based on the joint distribution the three scripts are indistinguishable.

But despite the three scripts being equivalent in that they generate the same distribution, they are not exactly the same. For example, they behave differently if we interefere or intervene to the execution.

Consider this thought experiment: I am a hacker, and I can inject code to the python interpreter. For every line of code from the snippet, I can insert a line of code of my choice. Let's say that I really want to set the value of $x$ to $3$, so I use my code injection ability and insert the line `x=3`

after each line of code of yours. So what actually gets executed is this:

We can now run the scripts in this hacked interpreter and see how the intervention changes the distribution of $x$ and $y$:

Of course, we see that the value of $x$ is no longer random, it's deterministically set to $3$, this results in all samples lining up along the $x=3$ vertical line. But, interestingly, the distribution of $y$ is different for the different scripts. In the blue script, $y$ has a mean around $5$ while the green and red scripts produce a distribution of $y$ centered around a mean of $1$. Here is a better look at the marginal distribution of $y$ under the intervention:

I labelled this plot $p(y\vert do(X=3))$ which semantically means the distribution of $y$ under the intervention where I set the value of $X$ to $3$. This is generally different from the conditional distribution $p(y\vert x=3)$, which of course is the same for all three scripts. Below I show these conditionals - excuse me for the massive estimation errors here, I was lazy creating these plots, but believe me they technically are all the same:

The important point here is that

the scripts behave differently under intervention.

We have a situation where the scripts are indistinguishable when you only look at the joint distribution of the samples they produce, yet they behave differently under intervention.

Consequently,

the joint distribution of data alone is insufficient to predict behaviour under interventions.

If the joint distribution is insufficient, what level of description would allow us to make predictions about how the scripts behave under intervention. If I have the full source code, I can of course execute the modified scripts, i.e. *run an experiment* and directly observe how the interaction effects the distribution.

However, it turns out, you don't need the full source code. It is sufficient to know the *causal diagram* corresponding to the source code. The causal diagram encodes causal relationships between variables, with an arrow pointing from causes to effects. Here is what the causal diagrams would look like for these scripts:

We can see that, even though they produce the same joint distribution, the scripts have different causal diagrams. And this additional knowledge of the causal structure allows us to make inferences about intervention without actually running experiments with that intervention. To do this in general setting, we can use do-calculus, explained in a little bit more detail in my earlier post.

Graphically, to simulate the effect of an intervention, you *mutilate* the graph by removing all edges that point into the variable on which the intervention is applied, in this case $x$.

At the top row you see the three diagrams that describe the three scripts. In the second row are the *mutilated* graphs where all incoming edges to $x$ have been removed. In the first script, the graph looks the same after mutilation. From this, we can conclude that $p(y\vert do(x)) = p(y\vert x)$, i.e. that the distribution of $y$ under intervention $X=3$ is the same as the conditional distribution of $y$ conditioned on $X=3$. In the second script, after mutilation, $x$ and $y$ become disconnected, therefore independent. From this, we can conclude that $p(y\vert do(X=3)) = p(y)$. Changing the value of $x$ does nothing to change the value of $y$, so whatever you set $X$ to be, $y$ is just going to sample from its marginal distribution. The same argument holds for the third causal diagram.

The significance of this is the following: By only looking at the causal diagram, we are now able to predict how the scripts are going to behave under the intervention $X=3$. We can compute and plot $p(y\vert do(X=3))$ for the three scripts by only using data observed during the normal (non-intervened) condition, without ever having to run the experiment or simulate the intervention.

The causal diagram allows us to predict how the models will behave under intervention, without carrying out the intervention

Here is proof of this, I could estimate the distribution of $y$ observed during the intervention experiment using only samples from the script under the normal (non-intervention) situation. This is called *causal infernce from observational data*.

The morale of this story is summed up in the following picture:

Let's consider all the questions you would want to be able to answer given some data (i.i.d. samples from a joint distribution). Having access to data, or more generally, the joint distribution it was sampled from, allows you to answer a great many questions, and solve many tasks. For example, you can do supervised machine learning by approximating $p(y \vert x)$ and then use it for many things, such as image labelling. These questions together make up the the blue set.

However, we have already seen that some questions cannot be answered using data/joint distribution alone. Notably, if you want to make predictions about how the system you study would behave under certain interventions or perturbations, you typically won't be able to make such infernces based on the data you have. These types of questions lie outside the blue set.

However, if you complement your data with causal assumptions encoded in a causal diagram - a directed acyclic graph where nodes are your variables - you can exploit these extra assumptions to start answering these questions, shown by the green area.

I am also showing an even larger set of questions still, but I won't tell you what this refers to just yet. I'm leaving that to a future post.

The questions I got most frequently after the lecture were these: what if I don't know what the graph looks like? And what if I get the graph wrong? There are many ways to answer these questions.

To me, a rehabilitated Bayesian, the most appealing answer is that you have to accept that your analysis is conditional on the graph you choose, and your conclusions are valid under the assumptions encoded there. In a way, causal inference from observational data is *subjective*. When you publish a result, you should caveat it with "under these assumptions, this is true". Readers can then dispute and question your assumptions if they disagree.

As to how to obtain a graph, it varies by application. If you work with online systems such as a recommender system, the causal diagram is actually pretty simple to draw, as it corresponds to how various subsystems are hooked up and feed data to one another. In other applications, notably in healthcare, a little bit more guesswork and thought may involved.

Finally, you can use various causal discovery techniques to try to identify the causal diagram from the data itself. Theoretically, recovering the full causal graph from the data is impossible in general cases. However, if you add certain additional smoothness or independence assumptions to the mix, you may be able to recover the graph from the data with a certain reliability.

We have seen that modeling the joint distribution can only get you so far, and if you want to predict the effect of interventions, i.e. calculate $p(y\vert do(x))$-like quantities, you have to add a causal graph to your analysis.

An optimistic take on this is that, in the i.i.d. setting, drawing a causal diagram and using *do-calculus* (or something equivalent) can significantly broaden the set of problems you can tackle with machine learning.

The pessimist's take on this is that if you are not aware of all this, you might be trying to address questions in the green bit, without realizing that the data won't be able to give you an answer to such questions.

Whichever way you look at it, *causal inference from observational data* is an important topic to be aware of.

Bayesian deep learning methods often look like a theoretical curiosity, rather than a practically useful tool, and I'm personally a bit skeptical about the practical usefulness of some of the work. However, there are some situations a decent method of handling and representing residual uncertainties about model parameters might prove

]]>Bayesian deep learning methods often look like a theoretical curiosity, rather than a practically useful tool, and I'm personally a bit skeptical about the practical usefulness of some of the work. However, there are some situations a decent method of handling and representing residual uncertainties about model parameters might prove crucial. These applications include active learning, reinforcement learning and online/continual learning.

So as I recently read a paper by Tencent, I was surprised to learn that the online Bayesian deep learning algorithm is apparently deployed in production to power click-through-rate prediction in their ad system. I thought, therefore, that this is worth a mention. Here is the paper:

- Xun Liu, Wei Xue, Lei Xiao, Bo Zhang (2017) PBODL : Parallel Bayesian Online Deep Learning for Click-Through Rate Prediction in Tencent Advertising System

Though the paper has a few typos and sometimes inconsistent notation (like going back and forth between using $\omega$ and $w$) which can make it tricky to read, I like the paper's layout: starting from desiderata - what the system has to be able to do - arriving at an elegant solution.

The method relies on the approximate Bayesian online-learning technique often referred to as assumed density filtering.

ADF has been independently discovered by the statistics, machine learning and control communities (for citations see (Minka, 2001). It is perhaps most elegantly described by Opper, (1998), and by Minka, (2001) who also extended the idea to develop expectation propagation, a highly influential method for approximate Bayesian inference not just for online settings.

ADF can be explained as recursive algorithm repeating the following steps:

- The starting point for the algorithm at time $t$ is a "temporal prior" distribution $q_{t-1}(w)$ over parameters $w$. This $q_{t-1}(w)$ incorporates all evidence from datapoints observed in previous timesteps, and it is assumed to approximate the posterior $p(w\vert x_{1:t-1})$ conditioned on all data oberved so far. $q_{t-1}(w)$ is assumed to take some simple form, like a Gaussian distribution factorized over the elements of $w$.
- Then, some new observations $x_t$ are observed. The Bayesian way to learn from these new observations would be to update the posterior: $p_t(w|x_t) \propto p(x_t\vert w)q_{t-1}(w)$
- However, this posterior might be complicated, and may not be available in a simple form. Therefore, at this point we approximate $p_t(w|x_t)$ by a new simple distribution $q_{t}$, which is in the same family as $q_{t-1}$. This step can involve a KL divergence-based approximation, e.g. (Opper, 1998, Minka, 2001), or Laplace approximation e.g. (Kirkpatrick et al, 2017, Huszár, 2017), or a more involved inner loop such as EP in MatchBox (Stern et al, 2009) or probabilistic backpropagation (Hernandez-Lobato and Adams, 2015) as in this paper.

The Tencent paper is based on probabilistic backpropagation (Hernandez-Lobato and Adams, 2015), which uses the ADF idea multiple times to perform the two non-trivial tasks required in a supervised Bayesian deep network: inference and learning.

**forward propagation:**In Bayesian deep learning, we maintain a distribution $q(w)$ over neural network weights, and each value $w$ defines a conditional probability $p(y\vert x, w)$. To predict the label $y$ from the input $x$, we have to average over $q(w)$, that is calculate $p(y\vert x) = \int q(w)p(y\vert x,w)dw$, which is difficult due to the nonlinearities.

Probabilistic backprop uses an ADF-like algorithm to approximate this predictive distribution. Starting from the bottom of the network, it approximates the distribution of the first hidden layer's activations given the input $x$ with a Gaussian. The first hidden layer will have a distribution of activations because we have a distribution over the weights) It then propagates that distribution to the second layer, and approximates the result with a Gaussian. This process is repeated until the distribution of $y$ given $x$ is calculated in the last step, which can be done easily if one uses a probit activation at the last layer.**backward propagation:**Forward propagation allows us to make a prediction given an input. The second task we have to be able to perform is to incorporate evidence from a new datapoint $(x_t, y_t)$ by updating the distribution over weights $q_t(w)$ to $p_t(w|x_t, y_t) \propto p(y_t\vert x_t, w) q_{t-1}(w)$. We approximate this $p_t(w|x_t, y_t)$ in an inner loop, by first running probabilistic forward propagation, then a similar ADF-like sweep backwards in the network.

I find the similarity of this algorithm to gradient calculations in neural network training beautiful: forward propagation to perform inference, forward and backward propagation for learning and parameter updates.

Crucially though, there is no gradient descent, or indeed no gradients or iterative optimization taking place here. After a single forward-backward cycle, information about the new datapoint (or minibatch of data) is - approximately - incorporated into the updated distribution $q_{t}(w)$ of network weights. Once this is done, the mini-batch can be discarded and never revisited. This makes this method a great candidate for online learning in large-scale streaming data situations.

Another advantage of this method is that it can be parallelized in a data-parallel fashion: Multiple workers can update the posterior simultaneously on different minibatches of data. Expectation-propagation provides a meaningful way of combining the parameter updates computed by parallel workers in a meaningful fashion. Indeed, this is what the PBODL algorithm of Tencent does.

Bayesian online learning comes with great advantages. In recommender systems and ads marketplaces, the Bayesian online learning approach handles the user and item cold-start problem gracefully. If there is a new ad in the system that has to be recommended, initially our network might say "I don't really know", and then gradually hone in on a confident prediction as more data is observed.

One can use the uncertainty estimates in a Bayesian network to perform online learning: actively proposing which items to label in order to increase predictive performance in the future. In the context of ads, this may be a useful feature if implemented with care.

Finally, there is the useful principle that *Bayes never forgets:* if we perform exact Bayesian updates on global parameters of an exchangeable model, the posterior will store information indefinitely long about all datapoints, including very early ones. This advantage is clearly demonstrated in DeepMind's work on catastrophic forgetting (Kirkpatrick et al, 2017). This capacity to keep a long memory of course diminishes the more approximations we have to make, which leads me to drawbacks.

The ADF method is only approximate, and over time, the approximations in each step may accumulate resulting in the network to essentially forget what it learned from earlier datapoints. It is also worth pointing out that in stationary situations, the posterior over parameters is expected to shrink, especially in the last hidden layers of the network where there are often fewer parameters. This posterior compression is often countered by introducing explicit forgetting: without evidence to the contrary, the variance of each parameter is marginally increased in each step. Forgetting and the lack of infinite memory may turn out to be an advantage, not a bug, in non-stationary situations where more recent datapoints are more representative of future test cases.

A practical drawback of a method like this in production environments is that probabilistic forward- and backpropagation require non-trivial custom implementations. Probabilistic backprop does not use reverse mode automatic differentiation, i.e. vanilla backprop as a subroutine. As a consequence, one cannot rely on extensively tested and widely used autodiff packages in tensorflow or pytorch, and performing probabilistic backprop in new layer-types might require significant effort and custom implementations. It is possible that high-quality message passing packages such as the recently open-sourced infer.net will see a renaissance after all.

]]>