I finally got around to reading this new paper by Arjovsky et al. It debuted on Twitter with a big splash, being decribed as 'beautiful' and 'long awaited' 'gem of a paper'. It almost felt like a new superhero movie or Disney remake just came out.
The paper is, indeed, very well written, and describes a very elegant idea, a practical algorithm, some theory and lots of discussion around how this is related to various bits. Here, I will describe the main idea and then provide an information theoretic view on the same topic.
We would like to learn robust predictors that are based on invariant causal associations between variables, rather than spurious surface correlations that might be present in our data. If we only observe i.i.d. data from a generative process, this is generally not possible.
In this paper, the authors assume that we have access to data sampled from different environments $e$. The data distribution in these different enviroments is different, but there is an underlying causal dependence of the variable of interest $Y$ on some of the observed features $X$ that remains constant, or invariant across all environments. The question is, can we exploit the variability across different environments to learn this underlying invariant association?
Usual empirical risk minimisation (ERM) approaches cannot distinguish between statistical associations that correspond to causal connections, and those that are just spurious correlations. Invariant Risk Minimization can, in certain situations. It does this by finding a representation $\phi$ of features, such that the optimal predictor is simultaneously Bayes optimal in all environments.
The authors then propose a practical loss function that tries to capture his property:
$$
\min_\phi \sum_e \mathcal{R}^e(\Phi) + \lambda \|\nabla_{w\vert w=1}\mathcal{R}^e(w \cdot \Phi)\|^2_2
$$
The first term is the usual ERM: we're trying to minimize average risk across all environments, using a single predictor $\phi$. The second term is where the interesting bit happens. I said before that what we want this term to encorage is that $\phi$ is simultaneously Bayes-optimal in all environments. What the term actually looks at is whether $\phi$ is locally optimal, wether it can be improved locally by scaling by a constant $w$. For details, I recommend reading the paper where a lot of intuitive explanation is provided. In this post, I'll focus on an information theoretic interpretation of what's going on.
Unlike the authors, who treat the environment index $e$ as something outside of the structural equation model, I prefer to think of $E$ as also being part of the generative process: an observable random variable. This may not be the most useful formulation in all circumstances, but it will help when trying to derive IRM through the lens of conditional dependence relationships.
The most general generative model of data in the IRM setup looks something like this:
There are three (sets of) observable variables: $E$, the environment index, $X$, the features describing the datapoint and $Y$, the label we wish to predict. I also assume the existence of a hidden confounder $W$. In the above graph, I separated $X$ into upstream dimensions $X_1$ and downstream dimensions $X_2$ based on where they are in the causal chain relative to $Y$. In reality, we don't know what this breakdown is, and breaking $X$ up to $X_1$ and $X_2$ may not be trivial due to entanglement, but it is still reasonable to assume that some components of the input $X$ encode causal parents of $Y$, and others encode causal descendants of $Y$.
The environment $E$ influences every factor in this generative model, except the factor $p(Y\vert X_1, W)$: notice there is no arrow from $E$ to $Y$. In other words, it is assumed that the relationship of variable $Y$ to its observable causal parents $X_1$ and hidden variables $W$ is stable across all environments. This is the primary underlying assumption of IRM. Now, let's read out some conditional independence relationships from this graph:
In summary, the association between $X$ and $Y$ will be a result of three sources of correlation:
In a general, if our generative model describes the world accurately, the conditional independence statements we observed tell us that while the real causal association is stable across environments ($Y \perp\mkern-13mu\perp E\vert X_1, W$), the other two are environment-dependent ($Y \cancel{\perp\mkern-13mu\perp} E\vert X_1$ and $Y \cancel{\perp\mkern-13mu\perp} E\vert X_1, X_2$). Thus, we can eliminate the spurious associations by seeking associations that are stable across environments, i.e. independent of $E$.
One can interpret the objective of Invariant Risk Minimisation as seeking a representation of observable variables $\Phi(x)$, such that:
This is a bit similar to the information bottleneck criterion which would seek a stochasstic representation $Z = \phi(X, \nu)$, $\nu$ being some random noise, by solving the following optimization problem:
$$
\max_{\phi} \left\{ I[Y, Z] - \beta I[X, Z] \right\}
$$
We could similarly attempt to find an invariant repesentation $Z = \phi(X)$ by minimizing an objective like:
$$
\max_{\phi} \left\{ I[Y, Z] - \beta I[Y, E \vert Z] \right\}
$$
Notice, that one can write mutual information $\mathbb{I}[Y, E \vert \phi(X)] $ in the following variational interpretation:
$$
I[Y, E \vert \phi(x)] = \max_q \min_r \mathbb{E}_{x,y,e} [\log q(y\vert \phi(x), e) - \log r(y\vert \phi(x))]
$$
With a little bit of math juggling, we can write the above stability objective in the following form:
\begin{align}
\max_{\phi} \left\{ I[Y, \phi(X)] - \beta I[Y, E \vert \phi(X)] \right\} = \max_\phi \max_r \min_q \mathbb{E}_{x,y,e} \left[\log r(y\vert \phi(x)) - \frac{\beta}{1 + \beta}\log q(y\vert \phi(x), e) \right]
\end{align}
This optimization problem is intuitively already very similar to IRM: we would like a representation $\phi$ so we can predict $y$ from it accurately, but we shouldn't be able to build a much better predictor by overfitting to one of the environments. In other words: here, too, we seek a representation so that the Bayes optimal predictor of $y$ is close to Bayes optimal simultanously in all environments.
Sadly, this is a minimax type of problem, so $\phi$ can only be found with a GAN-like iterative algorithm - something we don't really like, but we're kind of getting used to. It would be interesting to see how such algorithm would work, and if you're aware of this being done already, please feel to point me to the relevant references in the comments section.
I wanted to add sidenote here that I think we can recover something very similar in spirit to (IRMv1). I leave developing the full connection as a homework to you. Let me just illustrate the basic idea here.
Say we have a parametric family of functions $f(y\vert \phi(x); \theta)$ for predicting $y$ from $\phi(x)$. The conditional information can be approximated as follows:
\begin{align}
I[Y, E \vert \phi(x)] &\approx \min_\theta {E}_{x,y} \ell (f(y\vert \phi(x); \theta) - \mathbb{E}_e \min_{\theta_e} \mathbb{E}_{x,y\vert e} \ell (f(y\vert \phi(x); \theta_e)\\
&= \min_\theta \mathbb{E}_e \mathcal{R}^e(f_\theta\circ\phi) - \mathbb{E}_e \min_{\theta_e} \mathcal{R}^e(f_{\theta_e}\circ\phi)
\end{align}
where $\ell$ is the log-loss, if we want to recover Shannon's information. If we assume that $f$ is a universal function approximator, an equality holds. If, instead of globally optimizating $\theta_e$, we only search locally within a trust region around $\theta$, we can create the following (approximate) lower-bound to the information.
\begin{align}
I[Y, E \vert \phi(x)] &\geq \min_\theta {E}_{x,y}\ell f(y\vert \phi(x); \theta) - \mathbb{E}_e \min_{\|d\|^2\leq \epsilon} \mathbb{E}_{x,y\vert e} \ell f(y\vert \phi(x); \theta + d) \\
&= \min_\theta \mathbb{E}_e \left\{ \mathcal{R}^e(f_\theta\circ\phi) - \min_{\|d_e\|\leq \epsilon}\mathcal{R}^e(f_{\theta + d_e}\circ\phi) \right\}
\end{align}
Now, we can approximate the risk $\mathcal{R}^e(f_{\theta + d_e}\circ\phi) $ locally by a first order Taylor approximation around $\theta$, and show that, as $\epsilon \rightarrow 0$, we obtain that:
$$
I[Y, E \vert \phi(x)] \geq \min_\theta \mathbb{E}_e \| \nabla_\theta \mathbb{E}_{x,y\vert e} [\ell f(y\vert \phi(x), \theta)] \|_2= \min_\theta \mathbb{E}_e \| \nabla_\theta \mathcal{R}^e(f_\theta\circ\phi) \|_2
$$
Compare this with the second term in Eqn IRMv1 of the paper. If we now add back in the requirement that we would like to be able to predict $y$ from $\phi(x)$, we get an optimization problem of the following form:
$$
\min_\phi \left\{ \min_\theta \mathbb{E}_e \mathcal{R}^e(f_\theta\circ\phi) + \lambda \min_\theta \mathbb{E}_e \| \nabla_\theta \mathcal{R}^e(f_\theta\circ\phi) \|_2 \right\},
$$
which is almost like the IRM objective. Technically, there are two minimizations over $\theta$ and there's no reason why the two shouldn't be done separately. Note however, that a global minimum of the second term $\mathbb{E}_e \| \nabla_\theta \mathcal{R}^e(f_\theta\circ\phi) \|_2$ is always a local minimum of the first term. This justifies connecting the two minimization problems together:
$$
\min_\phi \min_\theta \left\{ \mathbb{E}_e \mathcal{R}^e(f_\theta\circ\phi) + \lambda \mathbb{E}_e \| \nabla_\theta \mathcal{R}^e(f_\theta\circ\phi) \|_2 \right\},
$$
This is no longer an ugly minimax problem, however, it is still not amazing. It is a lower bound to the original objective which we originally wished to minimize. The lower bound was created when we replaced global minimization over local minimization. Thus, the bound is actually tight if all local minima with respect to $\theta$ of $\mathcal{R}^e(f_\theta\circ\phi)$ are also global minima, e.g. if the loss is convex. In non-convex problems, all bets are off. It may still work, but who knows.
This is indeed a nice paper, with lots of great insights. Unfortunately, I am not sure how realistic the assumptions are that we can sample from a multitude of different environments, which differ from each other sufficiently so that the invariant causal quantities can be identified.
I would like to mention that this is not the first time that invariance and causality have been connected and exploited for domain adaptation. I personally first encountered this idea in a talk by Jonas Peters at the Causality Workshop in 2018. Here is a related paper by him that I wanted to highlight here:
And here are two more papers which propose a causal treatment of the domain adaptation problem:
The second paper, which commenters also pointed out to me, is perhaps the most closely related, but it is based on slightly different assumptions about what is invariant across the domains.
Finally, commenters asked me about domain-adversarial learning, so I wanted to include a pointer here for completeness:
On this paper, I agree with Arjovsky et al (2019)'s discussion in the IRM paper: it promotes the wrong invariance property by trying to learn a data representation that is marginally independent of the domain index. See the discussion on this in the comments section below.
Finally, I wanted to point out another slightly looser connection: non-stationarity, or the availability of data from multiple environments has been exploited by (Hyvarinen and Morioka, 2016) exploits this idea for unsupervised feature learning. It turns out, this non-stationarity and the availability of different environments makes otherwise non-identifiable nonlinear ICA models identifable.
]]>Welcome to my ICML 2019 jetlag special - because what else do you do when you wake up earlier than anyone than write a blog post. Here's a paper that was presented yesterday which I really liked.
First, some background on why I found this paper particulartly interesting. When AlphaGo Zero came out, I wrote a post about the principle of minimal improvement: Suppose you have an operator which may be computationally expensive to run, but which that can take a policy and improve it. Using such improvement operator you can define an objective function for policies by measuring the extent to which the operator changes a policy. If your policy is already optimal, the operator can't improve it any further, so the change will be zero. In the case of AlphaGo Zero, the improvement operator is Monte Carlo Tree Search (MCTS). I noted in that post how the same principle may be applicable to approximate inference: expectation propagation and contrastive divergence both can be casted in a similar light.
The paper I'm talking about uses a very similar argument to come up with a contrastive divergence for variational inference, where the improvement operator is MCMC step.
The two dominant ways of performing inference in latent variable models are variational inference (including amortized inference, such as in VAE), and Markov Chain Monte Carlo (MCMC). VI approximates the posterior with a paramteric distribution. This can be computationally efficient and practically convenient as VI results in end-to-end differentiable, unbiased estimates of the evidence lower bound (ELBO). As a drawback, the paramteric posterior approximation often can't approximate the posterior perfectly, and an approximation error almost always remains. By contrast, an MCMC approximate posterior can always be improved by running the chains longer, and obtaining more independent samples, but it is more difficult to work with and computationally more demanding than VI.
A lot of great work has been done recently on combining VI with MCMC (see references in Ruiz and Titsias, 2019). Usually, one starts from a crude, parametric variational approximation to the posterior, and then improves it by running a couple steps of MCMC. Crucially, one can view the transition kernel, $\Pi$, of the MCMC as an improvement operator: given any distribution $q$, taking an MCMC steps should take you closer to the posterior. In other words, $\Pi q$ is an improvement over $q$. Taking multiple steps, i.e. $\Pi^t q$, should provide an even greater improvement. Improvement is $0$ only when $\Pi q = q$ which only holds for the true posterior. It is a reasonable criterion therefore to seek a posterior approximation $q_\theta$ such that the improvement in $\Pi^t q_\theta$ over over $q_\theta$ is minimized.
There are two ways of quantifying the amount of improvement, or change, that MCMC provides over a parametric posterior $q_\theta$. The first measures how much closer we got to the posterior $p$ by comparing the KL divergences:
$$
\mathcal{L}_1(\theta) = \operatorname{KL}\left[q_\theta\middle\|p\right] - \operatorname{KL}\left[\Pi^tq_\theta\middle\|p\right]
$$
The second one measures the amount of change between $\Pi^tq_\theta$ and $q_\theta$, measured as the KL divergence. This objective function merely tries to identify fixed points of the improvement operator:
$$
\mathcal{L}_2(\theta) = \operatorname{KL}\left[\Pi^tq_\theta\middle\|q_\theta\right]
$$
Either of these objectives would make sense and is justified on its own right, but sadly neither of them can be evaluated or optimized easily in this case. Both require taking expectations over $\Pi^tq_\theta$, as well as evaluating $\log \Pi^tq_\theta$. However, the brilliant insight in this paper is that when you sum them together, the most problematic terms cancel out, leaving you with a tractable objective to minimise.
$$
\mathcal{L}_1(\theta) + \mathcal{L}_2(\theta) = \mathbb{E}_{z\sim \Pi^t q_\theta} f_\theta(z) - \mathbb{E}_{z\sim q_\theta} f_\theta(z),
$$
where $f_theta(z) = \log p(z,x) - \log q_\theta(z)$, $x$ is the observed, $z$ the hidden variable.
There is one more technical hurdle to overcome, which is to calculate or estimate the derivative of this objective with respect to $\theta$. The authors propose a REINFORCE-like score function gradient estimator in Eqn. (12), which is somewhat worrying as it is known to have very high variance. The authors propose overcoming this using a control variate. For more details, please refer to the paper.
There is further discussion on the behaviour of this objective function in the limit of infinitely long MCMC paths, i.e. $t\rightarrow\infty$. It turns out, the criterion works like the symmetrized KL divergence $KL[q\|p] + KL[p\|q]$. The difference of this objective from the usual conservative mode and seeking VI objective is neatly illustrated in Figure 1 of the paper:
Variational Contrastive Divergence (VCD) favours posterior approximations which have a much higher coverage of the true posterior compared to VI, which tries to cover the modes and tries to avoid allocating mass to areas where the true posterior does not.
]]>I was debating with myself whether I should write a post about this because it's a superbly written paper that you should probably
]]>This post is a short not on an excellent recent paper on empirical Fisher information matrices:
I was debating with myself whether I should write a post about this because it's a superbly written paper that you should probably read in full. There isn't a whole lot of novelty in the paper, but it is a great discussion paper that provides a concise overview of the Fisher information, the empirical Fisher matrix and their connectinos to generalized Gauss-Newton methods. For more detail, I also recomend this older and longer overview by James Martens.
For a supervised learning model $p_\theta (y\vert x)$ the Fisher infromation matrix for a specific input $x$ is defined as follows:
$$
F_x = \mathbb{E}_{p_\theta(y\vert x)} \left[ \nabla_\theta \log p_\theta(y\vert x) \nabla_\theta \log p_\theta(y\vert x)^\top \right]
$$
The Fisher information roughly measures the sensitivity of the model's output distribution to small changes in the parameters $\theta$. For me, the most intuitive interpretation is that it is the local curvature $KL[p_\theta\vert p_{\theta'}]$ as $\theta' \rightarrow \theta$. So, by using the Fisher infromation as a quadratic penalty, or as a preconditioner in second order gradient descent, you're penalizing changes to the model's output distribution as measured by the KL divergence. This makes sense especially for neural networks, where small changes in network parameters can lead to arbitrary changes in the function the model actually implements.
If you have an entire distribution or a dataset, you normally add up the pointwise $F_x$ matrices to form the Fisher information for the whole training data:
$$
F = \sum_n F_{x_n} = \sum_n \mathbb{E}_{p_\theta(y\vert x_n)} \left[ \nabla_\theta \log p_\theta(y\vert x_n) \nabla_\theta \log p_\theta(y\vert x_n)^\top \right]
$$
In pracice, however, people often work with another matrix instead, which is called the empirical Fisher information matrix, defined as follows:
$$
\tilde{F} = \sum_n\left[ \nabla_\theta \log p_\theta(y_n\vert x_n) \nabla_\theta \log p_\theta(y_n\vert x_n)^\top \right]
$$
The difference, of course, is that rather than sampling $y$ from the model itself, here, we use the observed labels $y_n$ from the dataset. This matrix, while it's often used, does not have a principled motivation. The paper shows how it lacks most properties that make the Fisher information so useful quantity in many cases, and it also lacks any principled motivation. The paper's simple, yet insightful, examples show how using the empirical Fisher instead of the real Fisher in second order optimization can lead to bonkers results:
What the above example shows is the vector field corresponding to differently preconditioned gradient descent algorithms in a two-parameter simple least squares linear regresesion example. The first panel shows the gradients field. The second shows the natural graident field, i.e. gradients corrected by the inverse Fisher information. The third shows the gradients corrected by the empirical Fisher instead. You can see that the empirical Fisher does not make the optimization problem any easier or better conditioned. The bottom line is: you probably shouldn't use it.
One particular myth that Kunstner et al (2019) try to debunk is that the Fisher and empirical Fisher coincide near minima of the loss function, and therefore either can be safely usef once you're done with training your model. I have made this argument several times in the past myself. If you look at he definitions, it's easy to conclude that they should indeed coincide under two conditions:
The paper argues that these two conditions are often at odds in practice. One either uses a simple model, in which case model misspecification is likely, meaning $p_\theta(y\vert x)$ can't approximate the true underlying $p_\mathcal{D}(y\vert x)$ well. Or, one uses an expressive, overparametrised model, such as a big neural network, in which case some form of overfitting likely happens, and $N$ is too small for the two matrices to coincide. The paper offers a lot more nuance to this, and again, I highly recommend reading it whole.
The paper also looks at comparing the direction of the natural gradient with the empirical Fisher-corrected gradients, and finds that the EF gradient often tries to go in a very different direction than the natural gradient.
It's also interesting to look at the projection of the EF gradient onto the original gradient, to see how big an EF-gradient update is in the direction of normal gradient descent.
The gradient for a single datapoint $(x_n, y_n)$ is g_n = -\nabla_\theta \log p_\theta(y_n\vert x_n). If we use the EF matrix as preconditioner, the gradient is modifierd to $\tilde{g}_n = \tilde{F}^{-1}g_n = -\tilde{F}^{-1} \nabla_\theta \log p_\theta(y_n\vert x_n)$. Let's look at the inner product $g^\top_n\tilde{g}_n$ on average over the dataset $p_\mathcal{D}$:
\begin{align}
\mathbb{E}_{p_\mathcal{D}} g^\top_n\tilde{g}_n &= \mathbb{E}_{p_\mathcal{D}}g^\top_n\tilde{F}^{-1}g_n \\
&= \operatorname{tr} \mathbb{E}_{p_\mathcal{D}} g^\top_n\tilde{F}^{-1}g_n\\
&= \operatorname{tr}\mathbb{E}_{p_\mathcal{D}} g_ng^\top_n \tilde{F}^{-1}\\
&= \operatorname{tr} \tilde{F}\tilde{F}^{-1}\\
&= D,
\end{align}
where $D$ is the dimensionality of $\theta$. This means that the projection of the EF gradient to the original gradient is constant. Since the length of the gradient $g_n$ can change as a function of $\theta$, this implies that, weirdly, there is an inverse relationship between the norm of $g_n$ and the norm of $\tilde{g}_n$ when projected in the same direction. On the whole, this is consistent with the weird-looking vector field in Figure 1 of the paper.
In summary, I highly recommend reading this paper. It has a lot more than what I covered here, and it is all explained rather nicely. The main message is clear: think twice (or more) about using empirical Fisher information for pretty much any application.
]]>Video is an interesting domain for unsupervised, or self-supervised, representation learning. But we still don't know what type of inductive biases will enable us to best exploit the information encoded in the temporal sequence of video frames. Slow Feature Analysis (SFA) and its more recent cousin Learning to Linearize (e.g. Goroshin et al., 2016) attempt to learn neural representations of images such that a trajectory spanned spanned by the sequence of transformed video frames satisfies a desired property, such as slowness or linearity/straightness. Whether or not the brain employs similar principles to learn representations is a fascinating topic.
Computational neuroscientists showed examples where representations learnt by the brain are consistent with the slowness principle (see e.g. Franzius et al, 2007). Brand new work by Eero Simoncelli's group suggests that the brain's representations are also consistent with the straightness principle:
This work is in fact a culmination of some great work by the authors over several years on developing the techniques that eventually enabled this study.
As you play a video, the frames of the video span a trajectory in pixel space. This trajectory is usually very complicated even for videos representing relatively simple actions. Imagine a video scene with he camera slowly panning over a static texture, such as a field of grass, at a constant speed. Despite the video appeaers conceptually simple and predictable, the trajectory of individual pixel values is highly non-linear and crazy: most pixels abruptly change their color from one frame to the next.
Instead of representing videos as sequences of frames in pixel space, we would like to map frames into a representation space, such that the trajectory behaves a bit more normally and predictably. For example, it makes sense to hope that in representation space the trajectory is straight or linear, meaning that it has a small curvature. If a trajectory is straight, it also means that linear extrapolation from the last couple points to the future works well.
Imagine we have a sequence (of frames) $x_t$. We map this into a representation space via a nonlinear mapping $f$. We can measure the local straightness of such representation by calculating the angle between the difference vectors $f(x_{t+1}) - f(x_{t})$ and $f(x_{t}) - f(x_{t-1})$. This angle is given by:
\begin{align}
\cos\alpha_t = \frac{(f(x_{t+1}) - f(x_{t}))^T(f(x_{t}) - f(x_{t-1}))}{\|(f(x_{t+1}) - f(x_{t}))\|\|(f(x_{t}) - f(x_{t-1}))\|}
\end{align}
The closer this value is to $1$, the straighter the sequence $f(x_{t-1}), f(x_{t}), f(x_{t+1})$ is locally. One can measure the straightness of the whole trajectory by averaging this local measure of curvature across time. It's worth noting that in high dimensional spaces, straightness is difficult to achieve. A $\cos\alpha$ close to $1$ is a very rare. Under Gaussian distributions, two randomly chosen vectors will be perpendicular to one another, yielding a cosine of $0$. So, for example, straight trajectories have an almost $0$ probability under a high-dimensional Brownian motion or Ornstein–Uhlenbeck (OU) process. As slow-feature analysis can be seen as learning representations that behave roughly like OU processes (Turner and Sahani, 2017), SFA would not, by default, yield straight trajectories. If one would like to recover straighter trajectories, we need stronger priors and constraints, such as $k$-times integrated OU processes. See also (Goroshin et al., 2016) for another approach to recover straight representations.
When we train our own a neural network, we can of course directly evaluate the straightness of learnt representations. But ho we strudy whether the representations learnt by the brain exhibit straight trajectories or not? We can't directly observe the representations learned by the human brain. Instead, Hénaff et al (2019) devised an an indirect way to measure the curvature, i.e. infer it from behavioural/psychophysics data. Their approach relies on ideal observer models for a certain perceptual discrimination task: they infer perceptual curvature from data about how accurately subjects can differentiate between pairs of stimuli.
Here is how to intuitively understand why the approach makes sense: We assume that whether or not a human is able to discriminate between two stimuli, when flashed for a fraction of a second, is related to Euclidean distance in the person's representation space between the reprsentations of the two stimuli. Now, if you take a triplet of stimuli (video frames) $A, B, C$ and measure the pairwise distances $d_{AB}, d_{AC}, d_{BC}$ between them, you can establish the straightness of them as a sequence: if the sequence is straight, you expect that $d_{AC}\approx d_{AB} + d_{AC}$. On the other hand, if the sequence is not straight, you'd expect $d^{2_{AC}\approx d}2_{AB} + d^2_{AC}$ to hold instead. So, if we have a model - called a normative or ideal observer model - that can somehow directly relate Euclidean distances in representation space to perceptual discriminability of stimuli, we have a way to experimentally infer the straightness of representations.
For more detailed description of the psychophysics experiments, data preparation, and the ideal observer model, please see the paper.
The main results of the paper - as expected - is that natural video sequences indeed appear to be mapped to straight trajectories in representation space. Furthermore, the figure below (Fig. 3b in the article) shows that the straightness of trajectories which correspond to natural stimuli is higher than the straightness of unnatural, artificially generated stimuli:
The blue curves correspond to a natural video stimulus in pixel space and in (inferred) representation space. You can see that the sequence is highly curved in pixel space, but in representation space it is straight. On the other hand, the green curves show an artificial video created by linearly interpolating between the first and the last frame of a video. This of course results in a straight trajectory in pixel space. However, the inferred trajectory in representation space is highly curved, which was the author's prediction as this is an unnatural stimulus for the human.
I really liked this paper and it is an excellent example of ideal observer model-based analysis of psychophysics data. It formulates an interesting hypothesis and finds a very nice way to answer it from well-designed experimental data.
That said, I was left wondering just how critically the results depend on assumptions of the ideal observer model, and how robust the findings are to changing some of those assumptions. For one, the paper assumes a Gaussian observation noise in representation space, and I wonder how robust the analysis would be to assuming heavy-tailed noise. Similarly, our very definition of straightness and angles relies on the assumption that the representation space is somehow Euclidean. If the representation space is not in fact Euclidean, our analysis might pick up on that, rather than the straightness of specific trajectories. That said, the most convincing aspect of the paper is contrasting natural vs unnatural stimuli. No matter how precise our inference of straightness is, and how much the actual values depend on changing these assumptions, this is a very reassuring finding.
]]>One of my favourite recent innovations in neural network architectures is Deep Sets. This relatively simple architecture can implement arbitrary set functions: functions over collections of items where
]]>One of my favourite recent innovations in neural network architectures is Deep Sets. This relatively simple architecture can implement arbitrary set functions: functions over collections of items where the order of the items does not matter.
This is a guest post by Fabian Fuchs, Ed Wagstaff and Martin Engelcke, authors of a recent paper on the representational power of such architectures and why the deep sets architecture can represent arbitrary set functions in theory. It's a great paper. Imagine what these guys could achieve if their lab was in Cambridge rather than Oxford!
Here are the links to the original Deep Sets paper, and the more recent paper by the authors of this post:
Over to Fabian, Ed and Martin for the rest of the post. Enjoy.
Most successful deep learning approaches make use of the structure in their inputs: CNNs work well for images, RNNs and temporal convolutions for sequences, etc. The success of convolutional networks boils down to exploiting a key invariance property: translation invariance. This allows CNNs to
But images are far from the only data we want to build neural networks for. Often our inputs are sets: sequences of items, where the ordering of items caries no information for the task in hand. In such a situation, the invariance property we can exploit is permutation invariance.
To give a short, intuitive explanation for permutation invariance, this is what a permutation invariant function with three inputs would look like: $f(a, b, c) = f(a, c, b) = f(b, a, c) = \dots$.
Some practical examples where we want to treat data or different pieces of higher order information as sets (i.e. where we want permutation invariance) are:
We will talk more about applications later in this post.
Note from Ferenc: I would like to jump in here - because it's my blog so I get to do that - to say that I think the killer application for this is actually meta-learning and few-shot learning. By meta-learning, don't think of anything fancy, I consider amortized variational inference, like a VAE, as a form of meta-learning. Consider a conditionally i.i.d model where you have a global parameter $\theta$, and a bunch of observations $x_i$ drawn conditionally i.i.d from a distribution $p_{X\vert \theta}$. Given a set of observations $x_1, \ldots, x_N$ we'd like to approximate the posterior $p(\theta\vert x_1, \ldots, x_N)$ by some parametric $q(\theta\vert x_1, \ldots, x_N; \psi)$, and we want this to work for any number of observations $N$. Clearly, the real posterior $p$ has a permutation invariance with respect to $x_n$, so it would make sense to make the recognition model, $q$, a permutation-invariant architecture. To me, this is the killer application of deep sets, especially in an on-line learning setting, where one wants to update our posterior estimate over some parameters with each new data point we observe.
Having established that there is a need for permutation-invariant neural networks, let's see how to enforce permutation invariance in practice. One approach is to make use of some operation $P$ which is already known to be permutation-invariant. We map each of our inputs separately to some latent representation and apply our $P$ to the set of latents to obtain a latent representation of the set as a whole. $P$ destroys the ordering information, leaving the overall model permutation invariant.
In particular, Deep Sets does this by setting $P$ to be summation in the latent space. Other operations are used as well, e.g. elementwise max. We call the case where the sum is used sum-decomposition via the latent space. The high-level description of the full architecture is now reasonably straightforward - transform your inputs into some latent space, destroy the ordering information in the latent space by applying the sum, and then transform from the latent space to the final output. This is illustrated in the following figure:
If we want to actually implement this architecture, we'll need to choose our latent space (in the guts of the model this will mean something like choosing the size of the output layer of a neural network). As it turns out, the choice of latent space will place a limit on how expressive the model is. In general, neural networks are universal function approximators (in the limit), and we'd like to preserve this property. Zaheer et al. provide a theoretical analysis of the ability of this architecture to represent arbitrary functions - that is, can the architecture, in theory, achieve exact equality with any target function, allowing us to use e.g. neural networks to approximate the necessary mappings? In our paper, we build on and extend this analysis, and discuss what implications it has for the choice of latent space.
Zaheer et al. show that, if we're only interested in sets drawn from a countable domain (e.g. $\mathbb{Z}$ or $\mathbb{Q}$), a 1-dimensional latent space is enough to represent any function. Their proof works by defining an injective mapping from sets to real numbers. Once you have an injective mapping, you can recover all the information about the original set, and can, therefore, represent any function. This sounds like good news -- we can do anything we like with a 1-D latent space! Unfortunately, there's a catch -- the mapping that we rely on is not continuous. The implication of this is that to recover the original set, even approximately, we need to know the exact real number that we mapped to -- knowing to within some tolerance doesn't help us. This is impossible on real hardware.
Above we considered a countable domain, but it's important to consider instead the uncountable domain $\mathbb{R}$, the real numbers. This is because continuity is a much stronger property on $\mathbb{R}$ than on $\mathbb{Q}$, and we need this stronger notion of continuity. The figure below illustrates this, showing a function which is continuous on $\mathbb{Q}$ but not continuous on $\mathbb{R}$ (and certainly not continuous in an intuitive sense). The figure is explained in detail in our paper. Using $\mathbb{R}$ is particularly important if we want to work with neural networks. Neural networks are universal approximators for continuous functions on compact subsets of $\mathbb{R}^M$. Continuity on $\mathbb{Q}$ won't do.
Zaheer et al. go on to provide a proof using continuous functions on $\mathbb{R}$, but it places a limit on the set size for a fixed finite-dimensional latent space. In particular, it shows that with a latent space of $M+1$ dimensions, we can represent any function which takes as input sets of size $M$. If you want to feed the model larger sets, there's no guarantee that it can represent your target function.
As for the countable case, the proof of this statement uses an injective mapping. But the functions we're interested in modeling aren't going to be injective -- we're distilling a large set down into a smaller representation. So maybe we don't need injectivity -- maybe there's some clever lower-dimensional mapping to be found, and we can still get away with a smaller latent space?
No.
As it turns out, you often do need injectivity into the latent space. This is true even for simple functions, e.g. max, which is clearly far from injective. This means that if we want to use continuous mappings, the dimension of the latent space must be at least the maximum set size. We were also able to show that this dimension suffices for universal function representation. That is, we've improved on the result from Zaheer (latent dimension $N \geq M+1$ is sufficient) to obtain both a weaker sufficient condition, and a necessary condition (latent dimension $N \geq M$ is sufficient and necessary). Finally, we've shown that it's possible to be flexible about the input set size. While Zaheer's proof applies to sets of size exactly $M$, we showed that $N=M$ also works if the set size is allowed to vary $\leq M$.
Why do we care about all of this? Sum-decomposition is in fact used in many different contexts - some more obvious than others - and the above findings directly apply in some of these.
Self-attention via {keys, queries, and values} as in the Attention Is All You Need paper by Vaswani et al. 2017 is closely linked to Deep Sets. Self-attention is itself permutation-invariant unless you use positional encoding as often done in language applications. In a way, self-attention "generalises" the summation operation as it performs a weighted summation of different attention vectors. You can show that when setting all keys and queries to 1.0, you effectively end up with the Deep Sets architecture.
Therefore, self-attention inherits all the sufficiency statements ('with $N=M$ you can represent everything'), but not the necessity part: it is not clear that $N=M$ is needed in the self-attention architecture, just because it was proved that it is needed in the Deep Sets architecture.
Point clouds are unordered, variable length lists (aka sets) of $xyz$ coordinates. We can also view them as (sparse) 3D occupancy tensors, but there is no 'natural' 1D ordering because we have three equal spatial dimensions. We could e.g. build a kd-tree but again this imposes a somewhat 'unnatural' ordering.
As a specific example, PointNet by Qi et al. 2017 is an expressive set-based model with some more bells and whistles. It handles interactions between points by (1) computing a permutation-invariant global descriptor, (2) concatenating it with point-wise features, (3) repeating the first two steps several times. They also use transformer modules for translation and rotation invariance --- So. Much. Invariance!
A stochastic process corresponds to a set of random variables. Here we want to model the joint distributions of the values those random variables take. These distributions need to satisfy the condition of exchangeability, i.e. they need to be invariant to the order of the random variables.
Neural Processes and Conditional Neural Processes (both by Garnelo et al. 2018) achieve this by computing a global latent variable via summation. One well-known instance of this is Generative Query Networks by Eslami et al. 2018 which aggregate information from different views via summation to obtain a latent scene representation.
👋 Hi, this is Ferenc again. Thanks to Fabian, Ed and Martin for the great post.
Update: As commenters pointed out, these papers are, of course, not the only ones dealing with permutation invariance and set functions. Here are a couple more things you might want to look at (and there are quite likely many more that I don't mention here - feel free to add more in the comments section below)
As I said before, I think that the coolest application of this type of architecture is in meta-learning situations. When someone mentions meta-learning many people associate to complicated "learning to learn to learn via gradient descent via gradient descent via gradient descent" kind of things. But in reality, simpler variants of meta-learning are a lot closer to being practically useful.
Here is an example of a recommender system developed by (Vartak et al, 2017) for Twitter, using this idea. Here, a user's preferences are summarized by the set of tweets they recently engaged with on the platform. This set is processed by a DeepSets architecture (the sequence in which they engaged with tweets is assumed to carry little information in this application). The output of this set function is then fed into another neural network that scores new tweets the user might find interesting.
Such architectures can prove useful in online learning or streaming data settings, where new datapoints arrive over time, in a sequence. For every new datapoint, one can apply the $\phi$ mapping, and then simply maintain a moving averages of these $\phi$ values. For binary classification, one can have a moving average of $\phi(x)$ for all negative examples, and another moving average for all positive examples. These moving averages then provide a useful, permutation-invariant summary statistics of all the data received so far.
In summary, I'm a big fan of this architecture. I think that the work of Wagstaff et al (2019) provides further valuable intuition on their ability to represent arbitrary set functions.
]]>Counterfactuals are weird. I wasn't going to talk about them in my MLSS lectures on Causal Inference, mainly because wasn't sure I fully understood what they were all about, let alone knowing how to explain it to others. But during the Causality Panel, David Blei made comments about about how
]]>Counterfactuals are weird. I wasn't going to talk about them in my MLSS lectures on Causal Inference, mainly because wasn't sure I fully understood what they were all about, let alone knowing how to explain it to others. But during the Causality Panel, David Blei made comments about about how weird counterfactuals are: how difficult they are to explain and wrap one's head around. So after that panel discussion, and with a grand total of 5 slides finished for my lecture on Thursday, I said "challenge accepted". The figures and story I'm sharing below are what I came up with after that panel discussion as I finally understood how to think about counterfactuals. I'm hoping others will find them illuminating, too.
This is the third in a series of tutorial posts on causal inference. If you're new to causal inferenece, I recommend you start from the earlier posts:
Let me first point out that counterfactual is one of those overloaded words. You can use it, like Judea Pearl, to talk about a very specific definition of counterfactuals: a probablilistic answer to a "what would have happened if" question (I will give concrete examples below). Others use the terms like counterfactual machine learning or counterfactual reasoning more liberally to refer to broad sets of techniques that have anything to do with causal analysis. In this post, I am going to focus on the narrow Pearlian definition of counterfactuals. As promised, I will start with a few examples:
This is an example David brought up during the Causality Panel and I referred back to this in my talk. I'm including it here for the benefit of those who attended my MLSS talk:
Given that Hilary Clinton did not win the 2016 presidential election, and given that she did not visit Michigan 3 days before the election, and given everything else we know about the circumstances of the election, what can we say about the probability of Hilary Clinton winning the election, had she visited Michigan 3 days before the election?
Let's try to unpack this. We are are interested in the probability that:
conditionied on four sets of things:
It's a weird beast: you're simultaneously conditioning on her visiting Michigan and not visiting Michigan. And you're interested in the probability of her winning the election given that she did not. WHAT?
Why would quantifying this probability be useful? Mainly for credit assignment. We want to know why she lost the election, and to what degree the loss can be attributed to her failure to visit Michigan three days before the election. Quantifying this is useful, it can help political advisors make better decisions next time.
Here's a real-world application of counterfactuals: evalueting the efairness of individual decisions. Consider this counterfactual question:
Given that Alice did not get promoted in her job, and given that she is a woman, and given everything else we can observe about her circumstances and performance, what is the probability of her getting a promotion if she was a man instead?
Again, the main reason for asking this question is to establish to what degree being a woman is directly responsible for the observed outcome. Note that this is an individual notion of fairness, unlike the aggregate assessment of whether the promotion process is fair or unfair statistically speaking. It may be that the promotion system is pretty fair overall, but in the particular case of Alice unfair discrimination took place.
A counterfactual question is about a specific datapoint, in this case Alice.
Another weird thing to note about this counterfactual is that the intervention (Alice's gender magically changing to male) is not something we could ever implement or experiment with in practice.
Here's the example I used in my talk, and will use throughout this post: I want to understand to what degree having a beard contributed to getting a PhD:
Given that I have a beard, and that I have a PhD degree, and everything else we know about me, with what probability would I have obtained a PhD degree, had I never grown a beard.
Before I start describing how to express this as a probability, let's first think about what we intuitively expect the answer to be? In the grand scheme of things, my beard probably was not a major contributing factor to getting a PhD. I would have pursued PhD studies, and probably completed my degree, even if something would have prevented me to keep my beard. So
We expect the answer to this counterfactual to be a high probability, something close to 1.
Let's start with the simplest thing one can do to attempt to answer my counterfactual question: collect some data about individuals, whether they have beards, whether they have PhDs, whether they are married, whether they are fit, etc. Here's a cartoon illustration of such dataset:
I am in this dataset, or someone just like me is in the dataset, in row number four: I have a beard, I am married, I am obviously very fit (this was the point where I hoped the audience would get the joke and laugh, and thankfully they did), and I have a PhD degree. I have it all.
If we have this data and do normal statistical machine learning, without causal reasoning, we'd probably attempt to estimate $p(🎓\vert 🧔=0)$, the conditional probability of possessing a PhD degree given the absence of a beard. As I show at the bottom, this is like predicting values in one column from values in another column.
You hopefully know enough about causal inference by now to know that $p(🎓\vert 🧔=0)$ is certainly not the quantity we seek. Without additional knowledge of causal structure, it can't generalize to hypothetical scenarios and interventions we care about her. There might be hidden confounders. Perhaps scoring high on the autism spectrum makes it more likely that you grow a beard, and it may also makes it more likely to obtain a PhD. This conditional isn't aware if the two quantities are causally related or not.
We've learned in the previous two posts that if we want to reason about interventions, we have to express a different conditional distribution, $p(🎓\vert do(🧔=0))$. We also know that in order to reason about this distribution, we need more than just a dataset, we also need a causal diagram. Let's add a few things to our figure:
Here, I drew a cartoon causal diagram on top of the data just for illustration, it was simply copy-pasted from previous figures, and does not represent a grand unifying theory of beards and PhD degrees. But let's assume our causal diagram describes reality.
The causal diagram lets us reason about the distribution of data in an alternative world, a parallel universe if you like, in which everyone is somehow magically prevented to grow a beard. You can imagine sampling a dataset from this distribution, shown in the green table. We can measure the association between PhD degrees and beards in this green distribution, which is precisely what $p(🎓\vert do(🧔=0))$ means. As shown by the arrow below the tables, $p(🎓\vert do(🧔=0))$ is about predicting columns of the green dataset from other columns of the green dataset.
Can $p(🎓\vert do(🧔=0))$ express the counterfactual probability we seek? Well, remember that we expected that I would have obtained a PhD degree with a high probability even without a beard. However, $p(🎓\vert do(🧔=0))$ talks about a the PhD of a random individual after a no-beard intervention. If you take a random person off the street, shave their beard if they have one, it is not very likely that your intervention will cause them to get a PhD with a high probability. Not to mention that your intervention has no effect on most women and men without beards. We intuitively expect $p(🎓\vert do(🧔=0))$ to be close to the base-rate of PhD degrees $p(🎓)$, which is apparently somewhere around 1-3%.
$p(🎓\vert do(🧔=0))$ talks about a randomly sampled individual, while a counterfactual talks about a specific individual
Counterfactuals are "personalized" in the sense that you'd expect the answer to change if you substitute a different person in there. My father has a mustache, (let's classify that as a type of beard for pedagogical purposes), but he does not have a PhD degree. I expect that preventing him to grow a mustache would not have made him any more likely to obtain a PhD. So his counterfactual probability would be a probability close to 0.
The counterfactual probabilities vary from person to person. If you calculate them for random individuals, and average the probabilities, you should expect to to get something like $p(🎓\vert do(🧔=0))$ in expectation. (More on this connection later.) But we not interested in the population mean now, but are interested in calculating the probabilities for each individual.
To finally explain counterfactuals, I have to step beyond causal graphs and introduce another concept: structural equation models.
A causal graph encodes which variables have a direct causal effect on any given node - we call these causal parents of the node. A structural equation model goes one step further to specify this dependence more explicitly: for each variable it has a function which describes the precise relationship between the value of each node the value of its causal parents.
It's easiest to illustrate this through an example: here's a causal graph of an online advertising system, taken from the excellent paper of Bottou et al, (2013):
It doesn't really matter what these variables mean, if interested, read the paper. The dependencies shown by the diagram are equivalently encoded by the following set of equations:
For each node in the graph above we now have a corresponding function $f_i$. The arguments of each function are the causal parents of the variable it instantiates, e.g. $f_1$ computes $x$ from its causal parent $u$, and $f_2$ computes $a$ from its causal parents $x$ and $v$. In order to allow for nondeterministic relationship between the variables, we additionally allow each function $f_i$ to take another input, $\epsilon_i$ which you can think of as a random number. Through the random input $\epsilon_1$, the output of $f_1$ can be random given a fixed value of $u$, hence giving rise to a conditional distribution $p(x \vert u)$.
The structural equation model (SEM) entails the causal graph, in that you can reconstruct the causal graph by looking at the inputs of each function. It also entails the joint distribution, in that you can "sample" from an SEM by evaluating the functions in order, plugging in the random $\epsilon$s where needed.
In a SEM an intervention on a variable, say $q$, can be modelled by deleting the corresponding function, $f_4$, and replacing it with another function. For example $do(Q = q_0)$ would correspond to a simple assignment to a constant $\tilde{f}_4(x, a) = q_0$.
Now that we know what SEMs are we can return to our example of beards and degrees. Let's add a few more things to the figure:
First change is, that instead of just a causal graph, I now assume that we model the world by a fully specified structural equation model. I signify this lazily by labelling the causal graph with the functions $f_1, f_2, f_3$ over the graph. Notice that the SEM of the green situation is the same as the SEM in the blue case, except that I deleted $f_1$ and replaced it with a constant assignment. But $f_2$ and $f_3$ are the same between the blue and the green models.
Secondly, I make the existence of the $\epsilon_i$ noise variables explicit, and show their values (it's all made up of course) in the gray table. If you feed the first row of epsilons to the blue structural equation model, you get the first blue datapoint $0110$. If you feed the same epsilons to the green SEM, you get the first green datapoint $(0110)$. If you feed the second row of epsilons to the models, you get the second rows in the blue and green tables, and so on...
I like to think about the first green datapoint as the parallel twin of the first blue datapoint. To talk about interventions I talked about this making predictions about a parallel universe where nobody has a beard. Now imagine that for every person who lives in our observable universe, there is a corresponding person, their parallel twin, in this parallel universe. Your twin is same in every respect as you, except for the absence of any beard you might have and any downstream consequences of having a beard. If you don't have a beard in this universe, your twin is an exact copy of you in every respect. Indeed, notice that the first blue datapoint is the same as the first green datapoint in my example.
You may know I'm a Star Trek fan and I like to use Star Trek analogies to explain things: In Star Trek, there is this concept called the mirror universe. It's a parallel universe populated by the same people who live in the normal universe, except that everyone who is good in the real universe is evil in the mirror universe. Hilariously, the mirror version of Spock, one of the main protagonists, has a goatie in this mirror universe. This is how you could tell if you're looking at evil mirror-Spock or normal Spock when watching the episode. This explains, sadly, why I'm using beards to explain counterfactuals. Here are Spock and mirror-Spock:
Now that we established the twin datapoint metaphor, we can say that counterfactuals are
making a prediction about features of the unobserved twin datapoint based on features of the observed datapoint.
Crucially, this was possible because we used the same $\epsilon$s in both the blue and the green SEM. This induces a joint distribution between variables in the observable regime, and variables in the unobserved, counterfactual regime. Columns of the green table are no longer independent of columns of the blue table. You can start predicting values in the green table using values in the blue table, as illustrated by the arrows below them.
Mathematically, a counterfactual is the following conditional probability:
$$
p(🎓^\ast \vert 🧔^\ast = 0, 🧔=1, 💍=1, 💪=1, 🎓=1),
$$
where variables with an $^\ast$ are unobserved (and unobservable) variables that live in the counterfactual world, while variables without $^\ast$ are observable.
Looking at the data, it turns out that mirror-Ferenc, who does not have a beard, is married and has PhD, but is not quite as strong as observable Ferenc.
Here is another drawing that some of you might find more appealing, especially those who are are into GANs, VAEs and similar generative models:
A SEM is essentially a generative model of data, which uses some noise variables $\epsilon_1, \epsilon_2, \ldots$ and turns them into observations $(U,Z,X,Y)$ in this example. This is shown in the left-hand branch of the graph above. Now if you want to make counterfactual statements under the intervention $X=\hat{x}$, you can construct a mutilated SEM, which is the same SEM except with $f_3$ deleted and replaced with the constant assignment $x = \hat{x}$. This modified SEM is shown in the right-hand branch. If you feed the $\epsilon$s into the mutilated SEM, you get another set of variables $(U^\ast,Z^\ast,X^\ast,Y^\ast)$, shown in green. These are the features of the twin as it were. This joint generative model over $(U,Z,X,Y)$ and $(U^\ast,Z^\ast,X^\ast,Y^\ast)$ defines a joint distribution over the combined set of variables $(U,Z,X,Y,U^\ast,Z^\ast,X^\ast,Y^\ast)$. Therefore, now you can calculate all sorts of conditionals and marginals of this joint.
Of particular interest are these conditionals:
$$
p(y^\ast \vert X^\ast = \hat{x}, X = x, Y = y, U = u, Z = z),
$$
which is a counterfactual prediction. In reality, since $X^\ast = \hat{x}$ holds with a probability of $1$, we can drop that conditioning.
My notation here is a bit sloppy, there are a lot of things going on implicitly under the hood, which I'm not making explicit in the notation. I'm sorry if it causes any irritation to people, I want to avoid overcomplicating things at this point. Now is a good time to point out that Pearl's notation, including do-notation is often criticized, but people use it because now it's widely adopted.
We can also express the intervention conditional $p(y\vert do(x))$ using this (somewhat sloppy) notation as:
$$
p(y\vert do(X=\hat{x})) = p(y^\ast \vert X^\ast = \hat{x})
$$
We can see that the intervention conditional only contains variables with an $^\ast$ so it does not require the joint distribution of $(U,Z,X,Y,U^\ast,Z^\ast,X^\ast,Y^\ast)$ only the marginal of the $^\ast$ variables $(X^\ast, Z^\ast,X^\ast,Y^\ast)$. As a consequence in order to talk about $p(y\vert do(X=\hat{x}))$ we did not need to introduce SEMs or talk about the epsilons.
Furthermore, notice the following equality:
\begin{align}
p(y\vert do(X=\hat{x})) &= p(y^\ast \vert X^\ast = \hat{x}) \\
&= \int_{x,y,u,z} p(y^\ast \vert X^\ast = \hat{x}, X = x, Y = y, U = u, Z = z) p(x,y,u,z) dx dy du dz \\
&= \mathbb{E}_{p_{X,Y,U,Z}} p(y^\ast \vert X^\ast = \hat{x}, X = x, Y = y, U = u, Z = z),
\end{align}
in other words, the intervention conditional $p(y\vert do(X=\hat{x}))$ is the average of counterfactuals over the obserevable population. This was something that I did did not realize before my MLSS tutorial, and it was pointed out to me by a student in the form of a question. In hindsight, of course this is true!
God, this was a loooong post. If you're still reading, thanks, I hope it was useful. I wanted to close with a few slightly philosophical remarks on counterfactuals.
Counterfactuals are often said to be unscientific, primarily because they are not empirically testable. In normal ML we are used to benchmark datasets, and that the quality of our predictions can always be tested on some test dataset. In causal ML, not everything can be directly tested or empirically benchmarked. For interventions, the best test is to run a randomized controlled trial to directly measure $p(y\vert do(X=x))$ if you can, and then use this experiemental data to evaluate your causal inferences. But some interventions are impossible to carry out in practice. Think about all the work on fairness reasoning about interventions on gender or race. So what to do then?
In the world of counterfactuals this is an even bigger problem, as it is outright impossible to observe the variables you make predictions about. You can't go back in time and rerun history with exactly the same circumstances except for a tiny change. You can't travel to parallel universes (at least not before the 24th century, according to Star Trek). Counterfactual judgments remain hypothetical, subjective, untestable, unfalsifiable. There can be no MNIST or Imagenet for counterfactuals that satisfies everyone, though some good datasets exist, they are for specific scenarios where explicit testing is possible (e.g. offline A/B testing), or make use of simulators instead of "real" data.
Despite it being untestable, and difficult to interpret, humans make use of counterfactual statements all the time, and intuitively it feels like they are pretty useful for intelligent behaviour. Being able to pinpoint the causes that lead to a particular situation or outcome is certainly useful for learning, reasoning and intelligence. So my strategy for now is to ignore the philosophical debate about counterfactuals, and just get on with it, knowing that the tools are there if such predictions have to be made.
]]>Last week I had the honor to lecture at the Machine Learning Summer School in Stellenbosch, South Africa. I chose to talk about Causal Inference, despite being a newcomer to this whole area. In fact, I chose it exactly because I'm a newcomer: causal inference has been a blindspot for me for a long time. I wanted to communicate some of the intuitions and ideas I learned and developed over the past few months, which I wish someone had explained to me earlier.
Now, I'm turning my presentation into a series of posts, starting with this one, building on the previous one I wrote in May. In this one, I will present the toy example I used in my talk to explain interventions. I call this the three scripts toy example. I was not sure if people are going to get it, but I got good feedback on it from the audience, so I'm hoping you will find it useful, too. Here are links to the other posts:
Imagine you teach a programming course and you ask students to write a python script that samples from a 2D Gaussian distribution with a certain mean and covariance. Some of the solutions will be correct, but as there are multiple ways to sample from a Gaussian, you might see very different solutions. For example, here are three scripts that would implement the same, correct sampling behaviour:
Below each of the code snippets I plotted samples drawn by repeatedly executing the scripts. As you can see, all three scripts produce the same joint distribution between $x$ and $y$. You can feed these distributions into a two-sample test, and you will find that they are indeed indistinguishable from each other.
Based on the joint distribution the three scripts are indistinguishable.
But despite the three scripts being equivalent in that they generate the same distribution, they are not exactly the same. For example, they behave differently if we interefere or intervene to the execution.
Consider this thought experiment: I am a hacker, and I can inject code to the python interpreter. For every line of code from the snippet, I can insert a line of code of my choice. Let's say that I really want to set the value of $x$ to $3$, so I use my code injection ability and insert the line x=3
after each line of code of yours. So what actually gets executed is this:
We can now run the scripts in this hacked interpreter and see how the intervention changes the distribution of $x$ and $y$:
Of course, we see that the value of $x$ is no longer random, it's deterministically set to $3$, this results in all samples lining up along the $x=3$ vertical line. But, interestingly, the distribution of $y$ is different for the different scripts. In the blue script, $y$ has a mean around $5$ while the green and red scripts produce a distribution of $y$ centered around a mean of $1$. Here is a better look at the marginal distribution of $y$ under the intervention:
I labelled this plot $p(y\vert do(X=3))$ which semantically means the distribution of $y$ under the intervention where I set the value of $X$ to $3$. This is generally different from the conditional distribution $p(y\vert x=3)$, which of course is the same for all three scripts. Below I show these conditionals - excuse me for the massive estimation errors here, I was lazy creating these plots, but believe me they technically are all the same:
The important point here is that
the scripts behave differently under intervention.
We have a situation where the scripts are indistinguishable when you only look at the joint distribution of the samples they produce, yet they behave differently under intervention.
Consequently,
the joint distribution of data alone is insufficient to predict behaviour under interventions.
If the joint distribution is insufficient, what level of description would allow us to make predictions about how the scripts behave under intervention. If I have the full source code, I can of course execute the modified scripts, i.e. run an experiment and directly observe how the interaction effects the distribution.
However, it turns out, you don't need the full source code. It is sufficient to know the causal diagram corresponding to the source code. The causal diagram encodes causal relationships between variables, with an arrow pointing from causes to effects. Here is what the causal diagrams would look like for these scripts:
We can see that, even though they produce the same joint distribution, the scripts have different causal diagrams. And this additional knowledge of the causal structure allows us to make inferences about intervention without actually running experiments with that intervention. To do this in general setting, we can use do-calculus, explained in a little bit more detail in my earlier post.
Graphically, to simulate the effect of an intervention, you mutilate the graph by removing all edges that point into the variable on which the intervention is applied, in this case $x$.
At the top row you see the three diagrams that describe the three scripts. In the second row are the mutilated graphs where all incoming edges to $x$ have been removed. In the first script, the graph looks the same after mutilation. From this, we can conclude that $p(y\vert do(x)) = p(y\vert x)$, i.e. that the distribution of $y$ under intervention $X=3$ is the same as the conditional distribution of $y$ conditioned on $X=3$. In the second script, after mutilation, $x$ and $y$ become disconnected, therefore independent. From this, we can conclude that $p(y\vert do(X=3)) = p(y)$. Changing the value of $x$ does nothing to change the value of $y$, so whatever you set $X$ to be, $y$ is just going to sample from its marginal distribution. The same argument holds for the third causal diagram.
The significance of this is the following: By only looking at the causal diagram, we are now able to predict how the scripts are going to behave under the intervention $X=3$. We can compute and plot $p(y\vert do(X=3))$ for the three scripts by only using data observed during the normal (non-intervened) condition, without ever having to run the experiment or simulate the intervention.
The causal diagram allows us to predict how the models will behave under intervention, without carrying out the intervention
Here is proof of this, I could estimate the distribution of $y$ observed during the intervention experiment using only samples from the script under the normal (non-intervention) situation. This is called causal infernce from observational data.
The morale of this story is summed up in the following picture:
Let's consider all the questions you would want to be able to answer given some data (i.i.d. samples from a joint distribution). Having access to data, or more generally, the joint distribution it was sampled from, allows you to answer a great many questions, and solve many tasks. For example, you can do supervised machine learning by approximating $p(y \vert x)$ and then use it for many things, such as image labelling. These questions together make up the the blue set.
However, we have already seen that some questions cannot be answered using data/joint distribution alone. Notably, if you want to make predictions about how the system you study would behave under certain interventions or perturbations, you typically won't be able to make such infernces based on the data you have. These types of questions lie outside the blue set.
However, if you complement your data with causal assumptions encoded in a causal diagram - a directed acyclic graph where nodes are your variables - you can exploit these extra assumptions to start answering these questions, shown by the green area.
I am also showing an even larger set of questions still, but I won't tell you what this refers to just yet. I'm leaving that to a future post.
The questions I got most frequently after the lecture were these: what if I don't know what the graph looks like? And what if I get the graph wrong? There are many ways to answer these questions.
To me, a rehabilitated Bayesian, the most appealing answer is that you have to accept that your analysis is conditional on the graph you choose, and your conclusions are valid under the assumptions encoded there. In a way, causal inference from observational data is subjective. When you publish a result, you should caveat it with "under these assumptions, this is true". Readers can then dispute and question your assumptions if they disagree.
As to how to obtain a graph, it varies by application. If you work with online systems such as a recommender system, the causal diagram is actually pretty simple to draw, as it corresponds to how various subsystems are hooked up and feed data to one another. In other applications, notably in healthcare, a little bit more guesswork and thought may involved.
Finally, you can use various causal discovery techniques to try to identify the causal diagram from the data itself. Theoretically, recovering the full causal graph from the data is impossible in general cases. However, if you add certain additional smoothness or independence assumptions to the mix, you may be able to recover the graph from the data with a certain reliability.
We have seen that modeling the joint distribution can only get you so far, and if you want to predict the effect of interventions, i.e. calculate $p(y\vert do(x))$-like quantities, you have to add a causal graph to your analysis.
An optimistic take on this is that, in the i.i.d. setting, drawing a causal diagram and using do-calculus (or something equivalent) can significantly broaden the set of problems you can tackle with machine learning.
The pessimist's take on this is that if you are not aware of all this, you might be trying to address questions in the green bit, without realizing that the data won't be able to give you an answer to such questions.
Whichever way you look at it, causal inference from observational data is an important topic to be aware of.
]]>So as I recently read a paper by Tencent, I was surprised to learn that the online Bayesian deep learning algorithm is apparently deployed in production to power click-through-rate prediction in their ad system. I thought, therefore, that this is worth a mention. Here is the paper:
Though the paper has a few typos and sometimes inconsistent notation (like going back and forth between using $\omega$ and $w$) which can make it tricky to read, I like the paper's layout: starting from desiderata - what the system has to be able to do - arriving at an elegant solution.
The method relies on the approximate Bayesian online-learning technique often referred to as assumed density filtering.
ADF has been independently discovered by the statistics, machine learning and control communities (for citations see (Minka, 2001). It is perhaps most elegantly described by Opper, (1998), and by Minka, (2001) who also extended the idea to develop expectation propagation, a highly influential method for approximate Bayesian inference not just for online settings.
ADF can be explained as recursive algorithm repeating the following steps:
The Tencent paper is based on probabilistic backpropagation (Hernandez-Lobato and Adams, 2015), which uses the ADF idea multiple times to perform the two non-trivial tasks required in a supervised Bayesian deep network: inference and learning.
I find the similarity of this algorithm to gradient calculations in neural network training beautiful: forward propagation to perform inference, forward and backward propagation for learning and parameter updates.
Crucially though, there is no gradient descent, or indeed no gradients or iterative optimization taking place here. After a single forward-backward cycle, information about the new datapoint (or minibatch of data) is - approximately - incorporated into the updated distribution $q_{t}(w)$ of network weights. Once this is done, the mini-batch can be discarded and never revisited. This makes this method a great candidate for online learning in large-scale streaming data situations.
Another advantage of this method is that it can be parallelized in a data-parallel fashion: Multiple workers can update the posterior simultaneously on different minibatches of data. Expectation-propagation provides a meaningful way of combining the parameter updates computed by parallel workers in a meaningful fashion. Indeed, this is what the PBODL algorithm of Tencent does.
Bayesian online learning comes with great advantages. In recommender systems and ads marketplaces, the Bayesian online learning approach handles the user and item cold-start problem gracefully. If there is a new ad in the system that has to be recommended, initially our network might say "I don't really know", and then gradually hone in on a confident prediction as more data is observed.
One can use the uncertainty estimates in a Bayesian network to perform online learning: actively proposing which items to label in order to increase predictive performance in the future. In the context of ads, this may be a useful feature if implemented with care.
Finally, there is the useful principle that Bayes never forgets: if we perform exact Bayesian updates on global parameters of an exchangeable model, the posterior will store information indefinitely long about all datapoints, including very early ones. This advantage is clearly demonstrated in DeepMind's work on catastrophic forgetting (Kirkpatrick et al, 2017). This capacity to keep a long memory of course diminishes the more approximations we have to make, which leads me to drawbacks.
The ADF method is only approximate, and over time, the approximations in each step may accumulate resulting in the network to essentially forget what it learned from earlier datapoints. It is also worth pointing out that in stationary situations, the posterior over parameters is expected to shrink, especially in the last hidden layers of the network where there are often fewer parameters. This posterior compression is often countered by introducing explicit forgetting: without evidence to the contrary, the variance of each parameter is marginally increased in each step. Forgetting and the lack of infinite memory may turn out to be an advantage, not a bug, in non-stationary situations where more recent datapoints are more representative of future test cases.
A practical drawback of a method like this in production environments is that probabilistic forward- and backpropagation require non-trivial custom implementations. Probabilistic backprop does not use reverse mode automatic differentiation, i.e. vanilla backprop as a subroutine. As a consequence, one cannot rely on extensively tested and widely used autodiff packages in tensorflow or pytorch, and performing probabilistic backprop in new layer-types might require significant effort and custom implementations. It is possible that high-quality message passing packages such as the recently open-sourced infer.net will see a renaissance after all.
]]>I scared you, didn't I?
]]>Boo.
I scared you, didn't I?
]]>Happy back-to-school time everyone!
After a long vacation (in blogging terms), I'm returning with a brief paper review which conveniently allows me to continue to explain a few ideas from causal reasoning. Since I wrote this intro to causality, I have read a lot more about it, especially how it
]]>Happy back-to-school time everyone!
After a long vacation (in blogging terms), I'm returning with a brief paper review which conveniently allows me to continue to explain a few ideas from causal reasoning. Since I wrote this intro to causality, I have read a lot more about it, especially how it relates to recommender systems.
Here are the two recent papers this post is based on:
UPDATE: After publishing this post, commenters pointed me to excellent in-depth blog posts by Alex d'Amour on the limitations of these techniques: on non-identifiability and positivity. I recommend you go on and read these posts if you want to learn more about the conditions in which the proposed methods work.
Simpson's paradox is a commonly used example to introduce confounders and how they can cause problems if not handled appropriately. I'm going to piggyback on the specific example from (Bottou et al, 2013), another excellent paper that should be on your reading list if you haven't read it already. We're looking at data about the success rate of two different treatments, A and B, for kidney stone removal:
Look at the first column "Overall" first. The average success rate of treatment A is 78%, while the success rate of treatment B is 83%. You might stop here and conclude that treatment B is better because it's successful in a higher percentage of cases, and continue recommending treatment B to everyone. Not so fast! Look at what happens if we group the data by the size of kidney stone. Now we see that in the group of patients with small stones, treatment A is better, in the group with large stones we also see that treatment A is better. How can treatment A be superior in both group of patients, but treatment B be better overall?
The answer lies in the fact that patients weren't randomly allocated to treatments. Patients with small stones are more likely to be assigned treatment B. These are also patients who are more likely to recover whichever treatment is used. Coversely, when the stone is large, treatment A is more common. In these cases the success rate is lower overall, too. Therefore, when you group data from the two groups, you ignored the fact that the treatments were assigned "unfairly", in the sense that treatment A was on average assigned to many more patients who had a worse outlook to begin with.
In this case, the size of the kidney stone is a confounder variable. A random variable that causally influences both the assignment of treatments, and the outcome. Graphically, it looks like this:
If we use $T, C$ and $Y$ to denote the treatment, confounder and the outcome respectively, and denote treatments A and B as 0 and 1, this is how one could correctly estimate the causal effect of the treatment on the outcome, while controlling for the confounder $C$:
$$
\mathbb{E}_{c\sim p_C}[\mathbb{E}[Y\vert T=1,C=c]] - \mathbb{E}_{c\sim p_C}[\mathbb{E}[Y\vert T=0,C=c]]
$$
Or, alternatively, with the do-calculus notation:
$$
\mathbb{E}[Y\vert do(T=1)] = \mathbb{E}_{c\sim p_C}[\mathbb{E}[Y\vert T=1,C=c]]
$$
Let's decompress this. $\mathbb{E}[Y|T=0,C=c]$ is the average outcome, given that treatment $1$ was given and the confounder took value $c$. Simple. Then, you average this over the marginal distribution of the confounder $C$.
Let's look at how this differs from the non-causal association you would measure between treatment and outcome (i.e. when you don't control for $C$:
$$
\mathbb{E}[Y\vert T=1] = \mathbb{E}_{c\sim p_{C\vert T=1}}[\mathbb{E}[Y\vert T=1,C=c]]
$$
So the two expressions differ in whether you average over $p_{C\vert T=1}$ or $p_{C}$ in the right-hand side. How does ethis expression come about? Remember from the previous post that the effect of an intervention, $do(T=1)$, can be simulated in the causal graph by severing all edges that point into the variable we intervene on. In this case, we sever the edge from the confounder $C$ to the treatment $T$ as that is the only edge pinting to $T$. Therefore, in the alternative model where the link is severed, $C$ and $T$ become independent and thus $p_{C\vert do(T=1)} = p_{C}$. Meanwhile, as no links leading from either $C$ or $T$ to $Y$ have been severed $p(Y\vert T, C)$ remains unchanged by the intervention.
So, if correcting for confounders is so simple, why is it not a solved problem yet? Because you don't always know the confounders, or sometimes you can't measure them.
Controlling for too little: You may not be measuring all confounders. Either because you're not even aware of some, or because you're aware of them but genuinely can't measure their values.
Controlling for too much: You may be controlling for too much. When one wants to make sure all confounders are covered, it's commonplace to simply control for all observable variables, in case they turn out to be confounders. This can go wrong, or very wrong. In ML terms, every time you control for a confounder, you add an extra input to the model to approximate $P_{y\vert T, C}$, increasing complexity of your model, likely reducing its data efficiency. Therefore, if you include a redundant variable in your analysis as a confounder, you're just making your task harder for no additional gain.
But, even worse things can happen if you include the a variable in your analysis which you must not condition on. Mediator variables are ones whose value is causally influenced by the treatment, and they causally influence the outcome. In graph notation, it is a directed chain like this: $T \rightarrow M \rightarrow Y$. If you control for mediators as if they were confounders, you are shielding the causal effect of $T$ on $Y$ via $M$, so you are going to draw incorrect conclusions. The other situation might be conditioning on a collider $W$, which is causally influenced by both the treatment, and the outcome $Y$, i.e. $T \rightarrow W \leftarrow Y$. If you condition on $W$, the phenomenon of explaining away occurs, and you induce spurious association between $T$ and $Y$, and again, you draw incorrect conclusions. A common, often unnoticed, example of conditioning on a collider is sampling bias, when your datapoints are missing not at random.
In general case, given a causal model of all variables in question, do-calculus can identify the minimal set of variables you have to measure and condition on to obtain a correct answer to your causal query. But in many applications, it we simply don't have enough information to name all the possible confounding factors, and often, even if we suspect the existence of a specific confounder, we won't be able to measure or directly observe it.
Wang and Blei (2018) address this problem with the following idea: instead of identifying and then measuring possible confounders, let's try to automatically discover and infer a random variable which can act as a substitute confounder in our analysis.
This is possible in a special setting when:
In a latent factor model (like matrix factorization, PCA, etc) the user-treatment matrix is modeled by two sets of random variables: one for each treatment, one for each user. A general factor model structure is shown below:
$A_{i,j}$ is the assignment of treatment $j$ to data instance $i$. We can imagine this as binary. Its distribution is determined by the instance-specific random variable $Z_i$, and the treatment-specific $\theta_j$. The key observation is, as we look across instances $i$, conditioned on the $Z_i$ vector, the treatment assignments $A_{i,j}$ become conditionally independent between treatments $j$. It is this $Z_i$ variable which will be used as a substitute confounder. Figure 1 of the paper illustrates why:
This figure shows the graphical structure of the problem. For the instance $i$ we have the $m$ different assignments $A_{i1}\ldots A_{im}$ and the outcome $Y_i$. Causal arrows between the assignments $A$ and the outcome $Y$ are not shown on this Figure but it is assumed some of them might exist. We also calculate the latent variable $Z_i$ and call it substitute confounder. As $Z_i$ renders the otherwise dependent $A$s conditionally independent it appears as a common parent node in this graph. The argument is the following: if there was any confounder $U_i$ which effects multiple assignment variables simultaneously, it must either be contained within $Z_i$ already (which is good), or our factor model cannot be perfect. Incorporating $U$ as part of $Z$ would improve our model of $A$s.
Therefore, we can force $Z_i$ to incorporate all multi-cause confounders by simply improving the factor model. Sadly, this argument relies on reasoning about marginal and conditional independence between the causes $A$, so it cannot work for single-cause confounders.
In a way this method, called the deconfounder algorithm, is a mixture of causal discovery (learning about structure of the causal graph) and causal inference (making inferences about variables under interventions). It identifies just enough about the causal structure (the substitute confounder variable) to then be able to make causal inferences of a certain type.
The obvious limitation of this method - and the authors are very transparent about this - is that it cannot deal with single-cause confounders. You have to assume if there is a confounder in your problem, you will be able to find another cause variable causally influenced by the same confounder in order to use this method.
The other limitation is the assumption that the assignments follow a factor model distribution marginally. While this may seem like a relatively OK assumption, it rules a couple of interesting cases out. For example, the assumption implies that the treatments cannot directly causally influence each other. In many applications this is not a reasonable assumption. If I make you watch Lord of the Rings it will definitely causally influence whether you will want to watch Lord of the Rings II. Similarly, some treatments in hospitals may be given or ruled out because of other treatments.
Finally, while the argument guarantees that all multi-cause confounders will be contained in the substitute confounder, there is no guarantee that the method will identify a minimal set of confounders. Indeed, if there are factors which influence multiple $A$s but are unrelated to the outcome, those variables will be covered by $Z$ as well. It is not difficult to create toy examples where the majority of the information captured in $Z$ is redundant.
In the second paper, Wang et al (2018) use this method to the specific problem of training recommender systems. The issue here is that vanilla ML approaches work only when items and users are randomly sampled in the dataset, which is rarely the case. Whether you watch a movie, and whether you like it are confounded by a number of factors, many of which you cannot model or observe. The method does apply to this case, and it does better than their but the results lack the punch I was hoping to see.
Nevertheless, this is a very interesting direction, and also quite a bit different from many other approaches to the problem. UPDATE: As it was evident from commenters' reaction to my post, there is a great deal of discussion taking place about the limitations of these types of 'latent confounder identification' methods. Alex d'Amour's blog posts on non-identifiability and positivity go into a lot more detail on these.
]]>You might have come across Judea Pearl's new book, and a related interview which was widely shared in my social bubble. In the interview, Pearl dismisses most of what we do in ML as curve fitting. While I believe that's an overstatement (conveniently ignores RL for example), it's a nice reminder that most productive debates are often triggered by controversial or outright arrogant comments. Calling machine learning alchemy was a great recent example. After reading the article, I decided to look into his famous do-calculus and the topic causal inference once again.
Again, because this happened to me semi-periodically. I first learned do-calculus in a (very unpopular but advanced) undergraduate course Bayesian networks. Since then, I have re-encountered it every 2-3 years in various contexts, but somehow it never really struck a chord. I always just thought "this stuff is difficult and/or impractical" and eventually forgot about it and moved on. I never realized how fundamental this stuff was, until now.
This time around, I think I fully grasped the significance of causal reasoning and I turned into a full-on believer. I know I'm late to the game but I almost think it's basic hygiene for people working with data and conditional probabilities to understand the basics of this toolkit, and I feel embarrassed for completely ignoring this throughout my career.
In this post I'll try to explain the basics, and convince you why you should think about this, too. If you work on deep learning, that's an even better reason to understand this. Pearl's comments may be unhelpful if interpreted as contrasting deep learning with causal inference. Rather, you should interpret it as highlighting causal inference as a huge, relatively underexplored, application of deep learning. Don't get discouraged by causal diagrams looking a lot like Bayesian networks (not a coincidence seeing they were both pioneered by Pearl) they don't compete with, they complement deep learning.
First of all, causal calculus differentiates between two types of conditional distributions one might want to estimate. tldr: in ML we usually estimate only one of them, but in some applications we should actually try to or have to estimate the other one.
To set things up, let's say we have i.i.d. data sampled from some joint $p(x,y,z,\ldots)$. Let's assume we have lots of data and the best tools (say, deep networks) to fully estimate this joint distribution, or any property, conditional or marginal distribution thereof. In other words, let's assume $p$ is known and tractable. Say we are ultimately interested in how variable $y$ behaves given $x$. At a high level, one can ask this question in two ways:
observational $p(y\vert x)$: What is the distribution of $Y$ given that I observe variable $X$ takes value $x$. This is what we usually estimate in supervised machine learning. It is a conditional distribution which can be calculated from $p(x,y,z,\ldots)$ as a ratio of two of its marginals. $p(y\vert x) = \frac{p(x,y)}{p(x)}$. We're all very familiar with this object and also know how to estimate this from data.
interventional $p(y\vert do(x))$: What is the distribution of $Y$ if I were to set the value of $X$ to $x$. This describes the distribution of $Y$ I would observe if I intervened in the data generating process by artificially forcing the variable $X$ to take value $x$, but otherwise simulating the rest of the variables according to the original process that generated the data. (note that the data generating procedure is NOT the same as the joint distribution $p(x,y,z,\ldots)$ and this is an important detail).
No. $p(y\vert do(x))$ and $p(y\vert x)$ are not generally the same, and you can verify this with several simple thought experiments. Say, $Y$ is the pressure in my espresso machine's boiler which ranges roughly between $0$ and $1.1$ bar depending on how long it's been turned on. Let $X$ be the reading of the built-in barometer. Let's say we jointly observe X and Y at random times. Assuming the barometer functions properly $p(y|x)$ should be a unimodal distribution centered around $x$, with randomness due to measurement noise. However, $p(y|do(x))$ won't actually depend on the value of $x$ and is generally the same as $p(y)$, the marginal distribution of boiler pressure. This is because artificially setting my barometer to a value (say, by moving the needle) won't actually cause the pressure in the tank to go up or down.
In summary, $y$ and $x$ are correlated or statistically dependent and therefore seeing $x$ allows me to predict the value of $y$, but $y$ is not caused by $x$ so setting the value of $x$ won't effect the distribution of $y$. Hence, $p(y\vert x)$ and $p(y\vert do(x))$ behave very differently. This simple example is just the tip of the iceberg. The differences between interventional and observational conditionals can be a lot more nuanced and hard to characterize when there are lots of variables with complex interactions.
Depending on the application you want to solve, you should seek to estimate one of these conditionals. If your ultimate goal is diagnosis or forecasting (i.e. observing a naturally occurring $x$ and inferring the probable values of $y$) you want the observational conditional $p(y\vert x)$. This is what we already do in supervised learning, this is what Judea Pearl called curve fitting. This is all good for a range of important applications such as classification, image segmentation, super-resolution, voice transcription, machine translation, and many more.
In applications where you ultimately want to control or choose $x$ based on the conditional you estimated, you should seek to estimate $p(y\vert do(x))$ instead. For example, if $x$ is a medical treatment and $y$ is the outcome, you are not merely interested in observing a naturally occurring treatment $x$ and predicting the outcome, we want to proactively choose the treatment $x$ given our understanding of how it effects the outcome $y$. Similar situations occur in system identification, control and online recommender systems.
This is perhaps the main concept I haven't grasped before. $p(y\vert do(x))$ is in fact a vanilla conditional distribution, but it's not computed based on $p(x,z,y,\ldots)$, but a different joint $p_{do(X=x)}(x,z,y,\ldots)$ instead. This $p_{do(X=x)}$ is the joint distribution of data which we would observe if we actually carried out the intervention in question. $p(y\vert do(x))$ is the conditional distribution we would learn from data collected in randomized controlled trials or A/B tests where the experimenter controls $x$. Note that actually carrying out the intervention or randomized trials may be impossible or at least impractical or unethical in many situations. You can't do an A/B test forcing half your subjects to smoke weed and the other half to smoke placebo to understand the effect on marijuana on their health. Even if you can't directly estimate $p(y\vert do(x))$ from randomized experiments, the object still exists. The main point of causal inference and do-calculus is:
If I cannot measure $p(y\vert do(x))$ directly in a randomized controlled trial, can I estimate it based on data I observed outside of a controlled experiment?
Let's start with a diagram that shows what's going on if we only care about $p(y\vert x)$, i.e. the simple supervised learning case:
Let's say we observe 3 variables, $x, z, y$, in this order. Data is sampled i.i.d. from some observable joint distribution over 3 variables, denoted by the blue factor graph labelled 'observable joint'. If you don't know what a factor graph is, it's not important, the circles represent random variables, the little square represents a joint distribution of the variables it's connected to. We are interested in predicting $y$ from $x$, and say that $z$ is a third variable which we do not want to infer but we can also measure (I included this for completeness). The observational conditional $p(y\vert x)$ is calculated from this joint via simple conditioning. From the training data we can build a model $q(y\vert x;\theta)$ to approximate this conditional, for example using a deep net minimizing cross-entropy or whatever.
Now, what if we're actually interested in $p(y\vert do(x))$ rather than $p(y\vert x)$? This is what it looks like:
So, we still have the blue observed joint and data is still sampled from this joint. However, the object we wish to estimate is on the bottom right, the red intervention conditional $p(y\vert do(x))$. This is related to the intervention joint which is denoted by the red factor graph above it. It's a joint distribution over the same domain as $p$ but it's a different distribution. If we could sample from this red distribution (e.g. actually run a randomized controlled trial where we get to pick $x$), the problem would be solved by simple supervised learning. We could generate data from the red joint, and estimate a model directly from there. However, we assume this is not possible, and all we have is data sampled from the blue joint. We have to see if we can somehow estimate the red conditional $p(y\vert do(x))$ from the blue joint.
If we want to establish a connection between the blue and the red joints, we must introduce additional assumptions about the causal structure of the data generating mechanism. The only way we can make predictions about how our distribution changes as a consequence of an interaction is if we know how the variables are causally related. This information about causal relationships is not captured in the joint distribution alone. We have to introduce something more expressive than that. Here is how what this looks like:
In addition to the observable joint we now also have a causal model of the world (top left) This causal model contains more detail than the joint distribution: it knows not only that pressure and barometer readings are dependent but also that pressure causes the barometer to go up and not the other way around. The arrows in this model correspond to the assumed direction of causation, and the absence of an arrow represents the absence of direct causal influence between variables. The mapping from causal diagrams to joint distributions is many-to-one: several causal diagrams are compatible with the same joint distribution. Thus, it is generally impossible to conclusively choose between different causal explanations by looking at observed data only.
Coming up with a causal model is a modeling step where we have to consider assumptions about how the world works, what causes what. Once we have a causal diagram, we can emulate the effect of intervention by mutilating the causal network: deleting all edges that lead into nodes in a $do$ operator. This is shown on the middle-top panel. The mutilated causal model then gives rise to a joint distribution denoted by the green factor graph. This joint has a corresponding conditional distribution $\tilde{p}(y\vert do(x))$, which we can use as our approximation of $p(y\vert do(x))$. If we got the causal structure qualitatively right (i.e. there are no missing nodes and we got the direction of arrows all correct), this approximation is exact and $\tilde{p}(y\vert do(x)) = p(y\vert do(x))$. If our causal assumptions are wrong, the approximation may be bogus.
Critically, to get to this green stuff, and thereby to establish the bridge between observational data and interventional distributions, we had to combine data with additional assumptions, prior knowledge if you wish. Data alone would not enable us to do this.
Now the question is, how can we say anything about the green conditional when we only have data from the blue distribution. We are in a better situation than before as we have the causal model relating the two. To cut a long story short, this is what the so-called do-calculus is for. Do-calculus allows us to massage the green conditional distribution until we can express it in terms of various marginals, conditionals and expectations under the blue distribution. Do-calculus extends our toolkit of working with conditional probability distributions with four additional rules we can apply to conditional distributions with the $do$ operators in them. These rules take into account properties of the causal diagram. The details can't be compressed into a single blog post, but here is an introductory paper on them..
Ideally, as a result of a do-calculus derivation you end up with an equivalent formula for $\tilde{p}(y\vert do(x))$ which no longer has any do operators in them, so you estimate it from observational data alone. If this is the case we say that the causal query $\tilde{p}(y\vert do(x))$ is identifiable. Conversely, if this is not possible, no matter how hard we try applying do-calculus, we call the causal query non-identifiable, which means that we won't be able to estimate it from the data we have. The diagram below summarizes this causal inference machinery in its full glory.
The new panel called "estimable formula" shows the equivalent expression for $\tilde{p}(y\vert do(x))$ obtained as a result of the derivation including several do-calculus rules. Notice how the variable $z$ which is completely irrelevant if you only care about $p(y\vert x)$ is now needed to perform causal inference. If we can't observe $z$ we can still do supervised learning, but we won't be able to answer causal inference queries $p(y\vert do(x))$.
You can never fully verify the validity and completeness of your causal diagram based on observed data alone. However, there are certain aspects of the causal model which are empirically testable. In particular, the causal diagram implies certain conditional independence or dependence relationships between sets of variables. These dependencies or independencies can be empirically tested, and if they are not present in the data, that is an indication that your causal model is wrong. Taking this idea forward you can attempt to perform full causal discovery: attempting to infer the causal model or at least aspects of it, from empirical data.
But the bottom line is: a full causal model is a form of prior knowledge that you have to add to your analysis in order to get answers to causal questions without actually carrying out interventions. Reasoning with data alone won't be able to give you this. Unlike priors in Bayesian analysis - which are a nice-to-have and can improve data-efficiency - causal diagrams in causal inference are a must-have. With a few exceptions, all you can do without them is running randomized controlled experiments.
Causal inference is indeed something fundamental. It allows us to answer "what-if-we-did-x" type questions that would normally require controlled experiments and explicit interventions to answer. And I haven't even touched on counterfactuals which are even more powerful.
You can live without this in some cases. Often, you really just want to do normal inference. In other applications such as model-free RL, the ability to explicitly control certain variables may allow you to sidestep answering causal questions explicitly. But there are several situations, and very important applications, where causal inference offers the only method to solve the problem in a principled way.
I wanted to emphasize again that this is not a question of whether you work on deep learning or causal inference. You can, and in many cases you should, do both. Causal inference and do-calculus allows you to understand a problem and establish what needs to be estimated from data based on your assumptions captured in a causal diagram. But once you've done that, you still need powerful tools to actually estimate that thing from data. Here, you can still use deep learning, SGD, variational bounds, etc. It is this cross-section of deep learning applied to causal inference which the recent article with Pearl claimed was under-explored.
UPDATE: In the comments below people actually pointed out some relevant papers (thanks!). If you are aware of any work, please add them there.
]]>Many people working on network pruning observed that, starting from a wide
]]>I wanted to highlight a recent paper I came across, which is also a nice follow-up to my earlier post on pruning neural networks:
Many people working on network pruning observed that, starting from a wide network and pruning it, one obtains better performance than training the slim, pruned architecture from scratch (random initialization). This suggests that the added redundancy of an overly wide network is somehow useful in training.
In this paper, Frankle and Carbin present the lottery ticket hypothesis to explain these observations. According to the hypothesis, good performance results from lucky initialization of a subnetwork or subcomponent of the original network. Since fat networks have exponentially more component subnetworks, starting from a fatter network increases the effective number of lottery tickets, thereby increasing the chances of containing a winning ticket. According to this hypothesis, pruning effectively identifies the subcomponent which is the winning ticket.
It's important to note that good initialization is a somewhat underrated but extremely important component of SGD-based deep learning. Indeed, I often find that if you use non-trivial architectures, training is essentially stuck until you tweak the off-the-shelf initialization schemes so that it somehow starts to work.
To test the hypothesis the authors have designed a set of cool experiments. The experiments basically go like this:
Here are some of the results from Figure 4.
Observe two main things in these results. The winning tickets, without the remaining redundancy of the wide network, train faster than the wide network. In fact, the skinnier they are, the faster they train (within reason). However, if you reinitialize the networks' weights randomly (control), the resulting nets now train slower than full network. Therefore, pruning is not just about finding the right architecture, it's also about identifying the 'winning ticket', which is a particularly luckily initialized subcomponent of the network.
I thought this paper was pretty cool. It underscores the role that (random) initialization plays in successfull training of neural networks, and gives an interesting justification for why redundancy and over-parametrization might actually be a desirable property when combined with SGD from random init.
There are various ways this line of research can be extended and built upon. For a starter, it would be great to see how the results hold up when better pruning schemes and more general classes of initialization approaches are used. It would be interesting to see if the training+pruning method consistently identifies the same winning ticket, or if it finds different winning tickets each time. One could look at various properties of randomly sampled tickets and compare them to the winning tickets. For example, I would be very curious to see what these winning tickets look like on the information plane. Similarly, it would be great to find cheaper ways or heuristics to identify winning tickets without having to train and prune the fat network.
]]>DALI is my favourite meeting
]]>This is a post about my takeaways from the DALI workshop on the Goals and Principles of Representation Learning which we co-organized with DeepMinders Shakir Mohamed and Andriy Mnih and Twitter colleague Lucas Theis. We had an amazing set of presentations, videos are available here.
DALI is my favourite meeting of the year. It's small and and participants are all very experienced and engaged. To make most of the high quality audience we decided to focus on slightly controversial topics, and asked speakers to present their general thoughts, their unique perspectives and opinions rather than talk about their latest research project. Broadly, we asked three types of questions:
From my perspective, the end-goal of any area of ML research should be for each algorithm to have a justification from first principles. Like the chain above starting from the Task all the way down to the precise algorithm we're going to run.
In my machine learning cookbook post I reviewed a number of fairly well understood problem transformations which one could apply to the bottom of this chain, typically to turn an intractable optimisation problem to a tractable one: variational inference, evolution strategies, convex relaxations, etc. In the workshop we focussed on the top of the chain, the computational level understanding: Do we actually have clarity on the ultimate task we solve? If this the task is given, is the principle we're applying actually justified? Is the loss function consistent with the principle?
One thing that most speakers and attendees seemed to agree on is that the most interesting and most important task we use unsupervised representation learning for is to transfer knowledge from large datasets to new tasks. I'll try to capture one possible formulation of what this means:
There are many possible extensions of this basic setup. For example we may have a distribution of tasks, rather than a single task.
One of the principles we discussed was disentanglement. It is somewhat unclear what exactly disentanglement is, but one way to capture this is to say "each coordinate in the representation corresponds to one meaningful factor of variation". Of course this is kind of a circular definition inasmuch as it's not easy to define a meaningful factor of variation. Irina defined disentangled representations as "factorized and interpretable", which, as a definition, has its own issues, too.
In the context of transfer learning though, I can see a vague argument for why seeking disentanglement might be a useful principle. One way to improve data efficiency of a ML algorithm is to endow it with inductive biases or to reduce the complexity of the function class in some sense. If we believe that future downstream tasks we may want to solve can be solved with simple, say linear, models on top of a disentangled representation, then seeking disentanglement makes sense. It may be possible to use this argument itself to formulate a more precise definition of or objective function for disentanglement.
The idea of self-supervision also featured prominently in talks. Self-supervision is the practice of defining auxiliary supervised learning tasks involving the unlabelled data, with hopes that solving these auxiliary tasks will require us to build representations which are also useful in the eventual downstream task we'd like to solve. Autoencoders, pseudolikelihood, denoising, jigsaw-puzzles, temporal-contrastive learning and so on are examples of this approach.
One very interesting aspect Harri talked about was the question of relevance? How can we make sure that an auxiliary learning task is relevant for the primary inference task we need to solve. In a semi-supervised learning setup Harri showed how you can adapt an auxiliary denoising task to be more relevant: you can calculate a sort of saliency map to determine which input dimensions the primary inference network pays attention to, and then use saliency-weighted denoising as the auxiliary task.
After the talks we had an open discussion session, hoping for a heated debate with people representing opposing viewpoints. And the audience really delivered on this. 10 minutes or so into the discussion Zoubin walked in and stated "I don't believe in representation learning". I tried to capture the main points in the argument (dad being anti-representation) below:
Retrospectively, I would summarise Zoubin's (and others') criticism as follows: If we identified transfer learning as the primary task representation learning is supposed to solve, are we actually sure that representation learning is the way to solve it? Indeed, this is not one of the questions I asked in my slides, and it's a very good one. One can argue that there may be many ways to transfer information from some dataset over to a novel task. Learning a representation and transferring that is just one approach. Meta-learning, for example, might provide another approach.
I have really enjoyed this workshop. The talks were great, participants were engaged, and the discussion at the end was a fantastic ending for DALI itself. I am hoping that, just like me, others have left the workshop with many things to talk about, and that the debates we had will inspire future work on representation learning (or anti-representation learning).
]]>I wanted to briefly highlight two recent papers on pruning neural networks (disclaimer, one of them is ours):
What I generally refer to as pruning in the title of this post is reducing or controlling the number of non-zero parameters, or the number of featuremaps actively used in a neural network. At a high level, there are at least three ways one can go about this, pruning is really only one of them:
There are different reasons for pruning your network. The most obvious, perhaps, is to reduce computational cost while keeping the same performance. Removing features which aren't really used in your deep network architecture can speed up inference as well as training. You can think also think of pruning as a form of architecture search: figuring out how many features you need in each layer for best performance.
The second argument is to improve generalization by reducing the number of parameters, and thus the redundancy in the parameter space. As we have seen in recent work on generalization in deep networks, the raw number of parameters ($L_0$ norm) is not actually a sufficient predictor of their generalization ability. That said, we empirically find that pruning a network tends to help generalization. Meanwhile, the community is developing (or, putting my Schmidhuber-hat on: maybe in some cases rediscovering) new parameter-dependent quantities to predict/describe generalization. The Fisher-Rao norm is a great example of these. Interestingly, Fisher pruning (Theis et al, 2018) turns out to have a nice connection to the Fisher-Rao norm, and this may hint at a deeper relationship between pruning, parameter redundancy and generalization.
I found the $L_0$ paper by Louizos et al, (2018) very interesting in that it can be seen as a straightforward application of the machine learning problem transformations I wrote up in the machine learning cookbook a few months ago. It's a good illustration of how you can use these general ideas go from formulating an intractable ML optimization problem to something practical you can run SGD on.
So I will summarize the paper as a series of steps, each changing the optimization problem:
Interestingly, the connection between Eq. (3) and evolution strategies or variational optimization is not mentioned. Instead, a motivation based on a different connection to spike-and-slab priors is given. I recommend reading the paper, perhaps with this connection in mind.
The authors then show that this indeed works, and compares favourably to other methods designed to reduce the number of parameters.
Thinking about the paper in terms of these steps converting from one problem to another allows you to generalize or improve the idea. For example, the REBAR or RELAX gradient estimators provide an unbiased and lower-variance alternative to the concrete relaxation, which may work very well here, too.
The second paper I wanted to talk about is something from our own lab. Rather than being a purely methods paper, (Theis et al, 2018) focusses on the specific application of building speedy neural networks to predict saliency in an image. The pruned network now powers the logic behind cropping photos on Twitter.
Our goal, too, was to reduce computational cost of the network, and specifically in the transfer learning setting: when building on top of a pre-trained neural network, you inherit a lot of complexity required to solve the original source task, much of which may be redundant for solving your target task. There is a difference in our high-level pruning objective: unlike $L_0$ norm or group sparsity, we used a slightly more complicated formula to directly estimates the forward pass runtime of the method. This is a quadratic function of the number of parameters at each layer with interactions between neighbouring layers. Interestingly, this results in architectures which tend to alternate thick and thin layers, like the one below:
We prune the trained network greedily by removing convolutional featuremaps one at a time. A meaningful principle for selecting the next feature map to prune is to minimize the resulting increase in training loss. Starting from this criterion, using a second order Taylor-expansion of the loss, making some more assumptions, we obtain the following pruning signal for keeping a parameter $\theta_i$:
$$
\Delta_i \propto F_i \theta_i^2,
$$
where $F_i$ denotes the $i^{th}$ diagonal entry of the Fisher information matrix. The above formula deals with removing a single parameter, but we can generalize this to removing entire featuremaps. Pruning proceeds by removing the parameter or featuremap with the smallest $\Delta$ in each iteration, and retraining the network between iterations. For more details, please see the paper.
Adding to what's presented in the paper, I wanted to point out a few connections between Fisher pruning to ideas I discussed on this blog before.
The first connection is to the Fisher-Rao norm. Assume for a minute that the Fisher information is diagonal - a big and unreasonable assumption in theory, but a pragmatic simplification resulting in useful algorithms in practice. With this assumption, the Fisher-Rao norm of $\theta$ becomes:
$$
|\theta|_{fr} = \sum_{i=1}^{I} F_i \theta_i^2
$$
Written in this form, you can hopefully see the connection between the FR-norm and the Fisher pruning criterion. Depending on the particular definition of Fisher information used, you can interpret the FR-norm, approximately, as
In the real world, the Fisher info matrix is not diagonal, and this is actually an important aspect of understanding generalization. For one, considering only diagonal entries makes Fisher pruning sensitive to certain reparametrizations (ones with non-diagonal Jacobian) of the network. But maybe there is a deeper connection to be observed here between Fisher-Rao norms and the redundancy of parameters.
Using the diagonal Fisher information values to guide pruning also bears resemblance to elastic weight consolidation by (Kirkpatrick et al, 2017). In EWC, the Fisher information values are used to establish which weights are more or less important for solving previous tasks. There, the algorithm was derived from the perspective Bayesian on-line learning, but you can also motivate it from a Taylor expansion perspective just like Fisher pruning.
The metaphor I use to understand and explain EWC is that of a shared hard drive. (WARNING: like all metaphors, this may be completely missing the point). The parameters of a neural network are like a hard drive or storage volume of some sort. Training the NN on a task involves compressing the training data and saving the information to the hard drive. If you have no mechanism to keep data from being overwritten, it's going to be overwritten: in neural networks, catastrophic forgetting occurs the same way. EWC is like a protocol for sharing the hard-drive between multiple users, without the users overwriting each other's data. The Fisher information values in EWC can be seen as soft do-not-overwrite flags. After training on the first task, we calculate the Fisher information values which say which parameters store crucial information for the task. The ones with low Fisher value are redundant and can be reused to store new info. In this metaphor, it is satisfying to think about the sum of Fisher information values as measuring how full the hard-drive is, and pruning as throwing away parts of the drive not actually used to store anything.
I wrote about two recent methods for automatically learning neural network architecture by figuring out which parameters/features to throw away. In my mind, both methods/papers are interesting in their own right. The $L_0$ approach seems like a simpler optimization algorithm that may be preferable to the iterative, remove-one-feature-at-a-time nature of Fisher pruning. However, Fisher pruning is more applicable to the scenario when you start from a large pretrained model in a transfer learning setting.
]]>This short presentation on generalization by the coauthors Sasha Rakhlin is
]]>After last week's post on the generalization mystery, people have pointed me to recent work connecting the Fisher-Rao norm to generalization (thanks!):
This short presentation on generalization by the coauthors Sasha Rakhlin is also worth looking at - though I have to confess much of the references to learning theory are lost on mesho.
While I can't claim to have understood all the bounding and proofs going on in Section 4, I think I got the big picture so I will try and summarize the main points in the section below. In addition, I wanted to add some figures I did which helped me understand the restricted model class the authors worked with and to understand the "gradient structure" this restriction gives rise to. Feel free to point out if anything I say here is wrong or incomplete.
The main mantra of this paper is along the lines of results by Bartlett (1998) who observed that in neural networks, generalization is about the size of the weights, not the number of weights. This theory underlies the use of techniques such as weight decay and even early stopping, since both can be seen as ways to keep the neural network's weight vector small. Reasoning about a neural network's generalization ability in terms of the size, or norm, of its weight vector is called norm-based capacity control.
The main contribution of Liang et al (2017) is proposing the Fisher-Rao norm as a measure of how big the networks' weights are, and hence as an indicator of a trained network's generalization ability. It is defined as follows:
$$
|\theta|_{fr} = \theta^\top I_\theta \theta
$$
where $I$ is the Fisher information matrix:
$$
I(\theta) = \mathbb{E}_{x,y} \left[ \nabla_\theta \ell(f_\theta(x),y) \nabla_\theta \ell(f_\theta(x),y)^\top \right]
$$
There are various versions of the Fisher information matrix, and therefore of the Fisher-Rao norm, depending on which distribution the expectation is taken under. The empirical form samples both $x$ and $y$ from the empirical data distribution. The model form samples $x$ from the data, but assumes that the loss is a log-loss of a probabilistic model, and we sample $y$ from this model.
Importantly, the Fisher-Rao norm is something which depends on the data distribution (at least the distribution of $x$). It is also invariant under reparametrization, which means that if there are two parameters $\theta_1$ and $\theta_2$ which implement the same function, their FR-norm is the same. Finally, it is a measure related to flatness inasmuch as the Fisher-information matrix approximates the Hessian at a minimum of the loss under certain conditions.
The one thing I wanted to add to this paper, is a little bit more detail on the particular model class - rectified linear networks without bias - that the authors studied here. This restriction turns out to guarantee some very interesting properties, without hurting the empirical performance of the networks (so the authors claim and to some degree demonstrate).
Let's first visualize what the output of a rectified multilayer perceptron with biases looks like. Here I used 3 hidden layers with 15 ReLUs in each and PyTorch-default random initialization. The network's input is 2D, and the output is 1D so I can easily plot contour surfaces:
The left-hand panel shows the function itself. The panels next to it show the gradients with respect to $x_1$ and $x_2$ respectively. The function is piecewise linear (which is hard to see because there are many, many linear pieces), which means that the gradients are piecewise constant (which is more visually apparent).
The piecewise linear structure of $f$ becomes more apparent we superimpose the contour plot of the graidents (black) on top of the contour plot of $f$ itself (red-blue):
These functions are clearly very flexible and by adding more layers, the number of linear pieces grows exponentially.
Importantly, the above plot would look very similar had I plotted the function's output as a function of two components of $\theta$, keeping $x$ fixed. This is significantly more difficult to plot though, so I'm hoping you'll just believe me.
Now let's look at what happens when we remove all biases from the network, keeping only the weight matrices:
Wow, the function looks very different now, doesn't it? At $x=0$ it always takes the value $0$. It is composed of wedge-shaped (or in higher dimensions, generalized pyramid-shaped) regions within which the functino is linear but the slope in each wedge is different. Yet the surface is still continuous. Let's do the superimposed plot again:
It's less clear from these plots why a function like this can model data just as well as the more general piece-wise linear one we get if we enable biases. One thing that helps is dimensionality: in high dimensions, the probability that two randomly sampled datapoints fall into a the same "pyramind", i.e. share the same linear region, is extremely small. Unless your data has some structure that makes this likely to happen for many datapoints at once, you don't really have to worry about it, I guess.
Furthermore, if my network had three input dimensions, but I only use two dimensions $x_1$ and $x_2$ to encode data and fix the third coordinate $x_3=1$, I can implement the same kind of functions over my inputs. This is called using homogeneous coordinates, and a bias-less network with homogeneous coordinates can be nearly as powerful as one with biases in terms of the functions it can model. Below is an example of a function a rectified network with no biases can implement when using homogeneous coordinates.
This is because the third variable $x_3=1$ multiplied by its weights practically becomes a bias for the first hidden layer.
Second observation is that we can consider $f_\theta(x)$ as a function of the weight matrix of a particular layer, keeping all other weights and the input the same, the function behaves exactly the same way as it behaves with respect to the input $x$. The same radial pattern would be observed in $f$ if I plotted it as a function of a weight matrix (though weight matrices are rarely 2-D so I can't really plot that).
The authors note that these functions satisfy the following formula:
$$
f_\theta(x) = \nabla_x f_\theta(x)^\top x
$$
(Moreover I think these are the only continuous functions for which the above equality holds, but I leave this to keen readers to prove or disprove)
Noting the symmetry between the network's inputs and weight matrices, a similar equality can be established with respect to parameters $\theta$:
$$
f_\theta(x) = \frac{1}{L+1}\nabla_\theta f_\theta(x)^\top \theta,
$$
where $L$ is the number of layers.
Here's my explanation which differs slightly from the simple proof the authors give. A general rectified network is piecewise linear with respect to $x$, as discussed. The boundaries of the linear pieces, and the slope, changes as we change $\theta$. Let's fix $\theta$. Now, so long as $x$ and some $x_0$ fall within the same linear region, the function at $x$ equals its Taylor expansion around $x_0$:
\begin{align}
f_\theta(x) &= \nabla_{x} f_\theta(x_0)^\top (x- x_0) + f_{\theta}(x_0) \\
&= \nabla_x f_\theta(x)^\top (x - x_0) + f_{\theta}(x_0)
\end{align}
Now, if we have no biases, all the linear segments are always wedge-shaped, and they all meet at the origin $x=0$. So, we can consider the limit of the above Taylor series in the limit as $x_0\rightarrow 0$. (we have to take a limit only technically as the function is non-differentiable at exactly $x=0$). As $f_{\theta}(0)=0$ we find that
$$
f_\theta(x) = \nabla_x f_\theta(x)^\top x,
$$
just as we wanted. Now, treating layer $l$'s weights $\theta^{(l)}$ as if they were the input to the network consisting of the subsequent layers, and the previous layer's activations as if they were the weight multiplying these inputs, we can derive a similar formula in terms of $\theta^{(l)}$:
$$
f_\theta(x) = \nabla_{\theta^{(l)}} f_\theta(x)^\top \theta^{(l)},
$$
Applying this formula for all layers $l=1\ldots L+1$, and taking the average we obtain:
$$
f_\theta(x) = \frac{1}{L+1}\nabla_\theta f_\theta(x)^\top \theta
$$
We got the $L+1$ from the $L$ hidden layers plus the output layer.
Using the formula above, and the chain rule, we can simplify the expression for the Fisher-Rao norm as follows:
\begin{align}
|\theta|_{fr} &= \mathbb{E} \theta^\top \nabla_\theta \ell(f_\theta(x),y) \nabla_\theta \ell(f_\theta(x),y)^\top \theta \\
&= \mathbb{E} \left( \theta^\top \nabla_\theta \ell(f_\theta(x),y) \right)^2 \\
&= \mathbb{E} \left( \theta^\top \nabla_\theta f_\theta(x) \nabla_f \ell(f,y) \right)^2\\
&= \mathbb{E} \left( f_\theta(x)^\top \nabla_f \ell(f,y)\right)^2
\end{align}
It can be seen very clearly in this form that the Fisher-Rao norm only depends on the output of the function $f_\theta(x)$ and properties of the loss function. This means that if two parameters $\theta_1$ and $\theta_2$ implement the same input-output function $f$, their F-R norm will be the same.
I think this paper presented a very interesting insight into the geometry of rectified linear neural networks, and highlighted some interesting connections between information geometry and norm-based generalization arguments.
What I think is still missing is the kind of insight which would explain why SGD finds solutions with low F-R norm, or how the F-R norm of a solution is effected by the batchsize of SGD, if at all it is. The other thing missing is whether the F-R norm can be an effective regularizer. It seems that for this particular class of networks which don't have any bias parameters, the model F-R norm could be calculated relatively cheaply and added to as a regularizer since we already calculate the forward-pass of the network anyway.
]]>