One of my favourite recent innovations in neural network architectures is Deep Sets. This relatively simple architecture can implement arbitrary set functions: functions over collections of items where
]]>One of my favourite recent innovations in neural network architectures is Deep Sets. This relatively simple architecture can implement arbitrary set functions: functions over collections of items where the order of the items does not matter.
This is a guest post by Fabian Fuchs, Ed Wagstaff and Martin Engelcke, authors of a recent paper on the representational power of such architectures and why the deep sets architecture can represent arbitrary set functions in theory. It's a great paper. Imagine what these guys could achieve if their lab was in Cambridge rather than Oxford!
Here are the links to the original Deep Sets paper, and the more recent paper by the authors of this post:
Over to Fabian, Ed and Martin for the rest of the post. Enjoy.
Most successful deep learning approaches make use of the structure in their inputs: CNNs work well for images, RNNs and temporal convolutions for sequences, etc. The success of convolutional networks boils down to exploiting a key invariance property: translation invariance. This allows CNNs to
But images are far from the only data we want to build neural networks for. Often our inputs are sets: sequences of items, where the ordering of items caries no information for the task in hand. In such a situation, the invariance property we can exploit is permutation invariance.
To give a short, intuitive explanation for permutation invariance, this is what a permutation invariant function with three inputs would look like: $f(a, b, c) = f(a, c, b) = f(b, a, c) = \dots$.
Some practical examples where we want to treat data or different pieces of higher order information as sets (i.e. where we want permutation invariance) are:
We will talk more about applications later in this post.
Note from Ferenc: I would like to jump in here - because it's my blog so I get to do that - to say that I think the killer application for this is actually meta-learning and few-shot learning. By meta-learning, don't think of anything fancy, I consider amortized variational inference, like a VAE, as a form of meta-learning. Consider a conditionally i.i.d model where you have a global parameter $\theta$, and a bunch of observations $x_i$ drawn conditionally i.i.d from a distribution $p_{X\vert \theta}$. Given a set of observations $x_1, \ldots, x_N$ we'd like to approximate the posterior $p(\theta\vert x_1, \ldots, x_N)$ by some parametric $q(\theta\vert x_1, \ldots, x_N; \psi)$, and we want this to work for any number of observations $N$. Clearly, the real posterior $p$ has a permutation invariance with respect to $x_n$, so it would make sense to make the recognition model, $q$, a permutation-invariant architecture. To me, this is the killer application of deep sets, especially in an on-line learning setting, where one wants to update our posterior estimate over some parameters with each new data point we observe.
Having established that there is a need for permutation-invariant neural networks, let's see how to enforce permutation invariance in practice. One approach is to make use of some operation $P$ which is already known to be permutation-invariant. We map each of our inputs separately to some latent representation and apply our $P$ to the set of latents to obtain a latent representation of the set as a whole. $P$ destroys the ordering information, leaving the overall model permutation invariant.
In particular, Deep Sets does this by setting $P$ to be summation in the latent space. Other operations are used as well, e.g. elementwise max. We call the case where the sum is used sum-decomposition via the latent space. The high-level description of the full architecture is now reasonably straightforward - transform your inputs into some latent space, destroy the ordering information in the latent space by applying the sum, and then transform from the latent space to the final output. This is illustrated in the following figure:
If we want to actually implement this architecture, we'll need to choose our latent space (in the guts of the model this will mean something like choosing the size of the output layer of a neural network). As it turns out, the choice of latent space will place a limit on how expressive the model is. In general, neural networks are universal function approximators (in the limit), and we'd like to preserve this property. Zaheer et al. provide a theoretical analysis of the ability of this architecture to represent arbitrary functions - that is, can the architecture, in theory, achieve exact equality with any target function, allowing us to use e.g. neural networks to approximate the necessary mappings? In our paper, we build on and extend this analysis, and discuss what implications it has for the choice of latent space.
Zaheer et al. show that, if we're only interested in sets drawn from a countable domain (e.g. $\mathbb{Z}$ or $\mathbb{Q}$), a 1-dimensional latent space is enough to represent any function. Their proof works by defining an injective mapping from sets to real numbers. Once you have an injective mapping, you can recover all the information about the original set, and can, therefore, represent any function. This sounds like good news -- we can do anything we like with a 1-D latent space! Unfortunately, there's a catch -- the mapping that we rely on is not continuous. The implication of this is that to recover the original set, even approximately, we need to know the exact real number that we mapped to -- knowing to within some tolerance doesn't help us. This is impossible on real hardware.
Above we considered a countable domain, but it's important to consider instead the uncountable domain $\mathbb{R}$, the real numbers. This is because continuity is a much stronger property on $\mathbb{R}$ than on $\mathbb{Q}$, and we need this stronger notion of continuity. The figure below illustrates this, showing a function which is continuous on $\mathbb{Q}$ but not continuous on $\mathbb{R}$ (and certainly not continuous in an intuitive sense). The figure is explained in detail in our paper. Using $\mathbb{R}$ is particularly important if we want to work with neural networks. Neural networks are universal approximators for continuous functions on compact subsets of $\mathbb{R}^M$. Continuity on $\mathbb{Q}$ won't do.
Zaheer et al. go on to provide a proof using continuous functions on $\mathbb{R}$, but it places a limit on the set size for a fixed finite-dimensional latent space. In particular, it shows that with a latent space of $M+1$ dimensions, we can represent any function which takes as input sets of size $M$. If you want to feed the model larger sets, there's no guarantee that it can represent your target function.
As for the countable case, the proof of this statement uses an injective mapping. But the functions we're interested in modeling aren't going to be injective -- we're distilling a large set down into a smaller representation. So maybe we don't need injectivity -- maybe there's some clever lower-dimensional mapping to be found, and we can still get away with a smaller latent space?
No.
As it turns out, you often do need injectivity into the latent space. This is true even for simple functions, e.g. max, which is clearly far from injective. This means that if we want to use continuous mappings, the dimension of the latent space must be at least the maximum set size. We were also able to show that this dimension suffices for universal function representation. That is, we've improved on the result from Zaheer (latent dimension $N \geq M+1$ is sufficient) to obtain both a weaker sufficient condition, and a necessary condition (latent dimension $N \geq M$ is sufficient and necessary). Finally, we've shown that it's possible to be flexible about the input set size. While Zaheer's proof applies to sets of size exactly $M$, we showed that $N=M$ also works if the set size is allowed to vary $\leq M$.
Why do we care about all of this? Sum-decomposition is in fact used in many different contexts - some more obvious than others - and the above findings directly apply in some of these.
Self-attention via {keys, queries, and values} as in the Attention Is All You Need paper by Vaswani et al. 2017 is closely linked to Deep Sets. Self-attention is itself permutation-invariant unless you use positional encoding as often done in language applications. In a way, self-attention "generalises" the summation operation as it performs a weighted summation of different attention vectors. You can show that when setting all keys and queries to 1.0, you effectively end up with the Deep Sets architecture.
Therefore, self-attention inherits all the sufficiency statements ('with $N=M$ you can represent everything'), but not the necessity part: it is not clear that $N=M$ is needed in the self-attention architecture, just because it was proved that it is needed in the Deep Sets architecture.
Point clouds are unordered, variable length lists (aka sets) of $xyz$ coordinates. We can also view them as (sparse) 3D occupancy tensors, but there is no 'natural' 1D ordering because we have three equal spatial dimensions. We could e.g. build a kd-tree but again this imposes a somewhat 'unnatural' ordering.
As a specific example, PointNet by Qi et al. 2017 is an expressive set-based model with some more bells and whistles. It handles interactions between points by (1) computing a permutation-invariant global descriptor, (2) concatenating it with point-wise features, (3) repeating the first two steps several times. They also use transformer modules for translation and rotation invariance --- So. Much. Invariance!
A stochastic process corresponds to a set of random variables. Here we want to model the joint distributions of the values those random variables take. These distributions need to satisfy the condition of exchangeability, i.e. they need to be invariant to the order of the random variables.
Neural Processes and Conditional Neural Processes (both by Garnelo et al. 2018) achieve this by computing a global latent variable via summation. One well-known instance of this is Generative Query Networks by Eslami et al. 2018 which aggregate information from different views via summation to obtain a latent scene representation.
π Hi, this is Ferenc again. Thanks to Fabian, Ed and Martin for the great post.
Update: As commenters pointed out, these papers are, of course, not the only ones dealing with permutation invariance and set functions. Here are a couple more things you might want to look at (and there are quite likely many more that I don't mention here - feel free to add more in the comments section below)
As I said before, I think that the coolest application of this type of architecture is in meta-learning situations. When someone mentions meta-learning many people associate to complicated "learning to learn to learn via gradient descent via gradient descent via gradient descent" kind of things. But in reality, simpler variants of meta-learning are a lot closer to being practically useful.
Here is an example of a recommender system developed by (Vartak et al, 2017) for Twitter, using this idea. Here, a user's preferences are summarized by the set of tweets they recently engaged with on the platform. This set is processed by a DeepSets architecture (the sequence in which they engaged with tweets is assumed to carry little information in this application). The output of this set function is then fed into another neural network that scores new tweets the user might find interesting.
Such architectures can prove useful in online learning or streaming data settings, where new datapoints arrive over time, in a sequence. For every new datapoint, one can apply the $\phi$ mapping, and then simply maintain a moving averages of these $\phi$ values. For binary classification, one can have a moving average of $\phi(x)$ for all negative examples, and another moving average for all positive examples. These moving averages then provide a useful, permutation-invariant summary statistics of all the data received so far.
In summary, I'm a big fan of this architecture. I think that the work of Wagstaff et al (2019) provides further valuable intuition on their ability to represent arbitrary set functions.
]]>Counterfactuals are weird. I wasn't going to talk about them in my MLSS lectures on Causal Inference, mainly because wasn't sure I fully understood what they were all about, let alone knowing how to explain it to others. But during the Causality Panel, David Blei made comments about about how weird counterfactuals are: how difficult they are to explain and wrap one's head around. So after that panel discussion, and with a grand total of 5 slides finished for my lecture on Thursday, I said "challenge accepted". The figures and story I'm sharing below are what I came up with after that panel discussion as I finally understood how to think about counterfactuals. I'm hoping others will find them illuminating, too.
This is the third in a series of tutorial posts on causal inference. If you're new to causal inferenece, I recommend you start from the earlier posts:
Let me first point out that counterfactual is one of those overloaded words. You can use it, like Judea Pearl, to talk about a very specific definition of counterfactuals: a probablilistic answer to a "what would have happened if" question (I will give concrete examples below). Others use the terms like counterfactual machine learning or counterfactual reasoning more liberally to refer to broad sets of techniques that have anything to do with causal analysis. In this post, I am going to focus on the narrow Pearlian definition of counterfactuals. As promised, I will start with a few examples:
This is an example David brought up during the Causality Panel and I referred back to this in my talk. I'm including it here for the benefit of those who attended my MLSS talk:
Given that Hilary Clinton did not win the 2016 presidential election, and given that she did not visit Michigan 3 days before the election, and given everything else we know about the circumstances of the election, what can we say about the probability of Hilary Clinton winning the election, had she visited Michigan 3 days before the election?
Let's try to unpack this. We are are interested in the probability that:
conditionied on four sets of things:
It's a weird beast: you're simultaneously conditioning on her visiting Michigan and not visiting Michigan. And you're interested in the probability of her winning the election given that she did not. WHAT?
Why would quantifying this probability be useful? Mainly for credit assignment. We want to know why she lost the election, and to what degree the loss can be attributed to her failure to visit Michigan three days before the election. Quantifying this is useful, it can help political advisors make better decisions next time.
Here's a real-world application of counterfactuals: evalueting the efairness of individual decisions. Consider this counterfactual question:
Given that Alice did not get promoted in her job, and given that she is a woman, and given everything else we can observe about her circumstances and performance, what is the probability of her getting a promotion if she was a man instead?
Again, the main reason for asking this question is to establish to what degree being a woman is directly responsible for the observed outcome. Note that this is an individual notion of fairness, unlike the aggregate assessment of whether the promotion process is fair or unfair statistically speaking. It may be that the promotion system is pretty fair overall, but in the particular case of Alice unfair discrimination took place.
A counterfactual question is about a specific datapoint, in this case Alice.
Another weird thing to note about this counterfactual is that the intervention (Alice's gender magically changing to male) is not something we could ever implement or experiment with in practice.
Here's the example I used in my talk, and will use throughout this post: I want to understand to what degree having a beard contributed to getting a PhD:
Given that I have a beard, and that I have a PhD degree, and everything else we know about me, with what probability would I have obtained a PhD degree, had I never grown a beard.
Before I start describing how to express this as a probability, let's first think about what we intuitively expect the answer to be? In the grand scheme of things, my beard probably was not a major contributing factor to getting a PhD. I would have pursued PhD studies, and probably completed my degree, even if something would have prevented me to keep my beard. So
We expect the answer to this counterfactual to be a high probability, something close to 1.
Let's start with the simplest thing one can do to attempt to answer my counterfactual question: collect some data about individuals, whether they have beards, whether they have PhDs, whether they are married, whether they are fit, etc. Here's a cartoon illustration of such dataset:
I am in this dataset, or someone just like me is in the dataset, in row number four: I have a beard, I am married, I am obviously very fit (this was the point where I hoped the audience would get the joke and laugh, and thankfully they did), and I have a PhD degree. I have it all.
If we have this data and do normal statistical machine learning, without causal reasoning, we'd probably attempt to estimate $p(π\vert π§=0)$, the conditional probability of possessing a PhD degree given the absence of a beard. As I show at the bottom, this is like predicting values in one column from values in another column.
You hopefully know enough about causal inference by now to know that $p(π\vert π§=0)$ is certainly not the quantity we seek. Without additional knowledge of causal structure, it can't generalize to hypothetical scenarios and interventions we care about her. There might be hidden confounders. Perhaps scoring high on the autism spectrum makes it more likely that you grow a beard, and it may also makes it more likely to obtain a PhD. This conditional isn't aware if the two quantities are causally related or not.
We've learned in the previous two posts that if we want to reason about interventions, we have to express a different conditional distribution, $p(π\vert do(π§=0))$. We also know that in order to reason about this distribution, we need more than just a dataset, we also need a causal diagram. Let's add a few things to our figure:
Here, I drew a cartoon causal diagram on top of the data just for illustration, it was simply copy-pasted from previous figures, and does not represent a grand unifying theory of beards and PhD degrees. But let's assume our causal diagram describes reality.
The causal diagram lets us reason about the distribution of data in an alternative world, a parallel universe if you like, in which everyone is somehow magically prevented to grow a beard. You can imagine sampling a dataset from this distribution, shown in the green table. We can measure the association between PhD degrees and beards in this green distribution, which is precisely what $p(π\vert do(π§=0))$ means. As shown by the arrow below the tables, $p(π\vert do(π§=0))$ is about predicting columns of the green dataset from other columns of the green dataset.
Can $p(π\vert do(π§=0))$ express the counterfactual probability we seek? Well, remember that we expected that I would have obtained a PhD degree with a high probability even without a beard. However, $p(π\vert do(π§=0))$ talks about a the PhD of a random individual after a no-beard intervention. If you take a random person off the street, shave their beard if they have one, it is not very likely that your intervention will cause them to get a PhD with a high probability. Not to mention that your intervention has no effect on most women and men without beards. We intuitively expect $p(π\vert do(π§=0))$ to be close to the base-rate of PhD degrees $p(π)$, which is apparently somewhere around 1-3%.
$p(π\vert do(π§=0))$ talks about a randomly sampled individual, while a counterfactual talks about a specific individual
Counterfactuals are "personalized" in the sense that you'd expect the answer to change if you substitute a different person in there. My father has a mustache, (let's classify that as a type of beard for pedagogical purposes), but he does not have a PhD degree. I expect that preventing him to grow a mustache would not have made him any more likely to obtain a PhD. So his counterfactual probability would be a probability close to 0.
The counterfactual probabilities vary from person to person. If you calculate them for random individuals, and average the probabilities, you should expect to to get something like $p(π\vert do(π§=0))$ in expectation. (More on this connection later.) But we not interested in the population mean now, but are interested in calculating the probabilities for each individual.
To finally explain counterfactuals, I have to step beyond causal graphs and introduce another concept: structural equation models.
A causal graph encodes which variables have a direct causal effect on any given node - we call these causal parents of the node. A structural equation model goes one step further to specify this dependence more explicitly: for each variable it has a function which describes the precise relationship between the value of each node the value of its causal parents.
It's easiest to illustrate this through an example: here's a causal graph of an online advertising system, taken from the excellent paper of Bottou et al, (2013):
It doesn't really matter what these variables mean, if interested, read the paper. The dependencies shown by the diagram are equivalently encoded by the following set of equations:
For each node in the graph above we now have a corresponding function $f_i$. The arguments of each function are the causal parents of the variable it instantiates, e.g. $f_1$ computes $x$ from its causal parent $u$, and $f_2$ computes $a$ from its causal parents $x$ and $v$. In order to allow for nondeterministic relationship between the variables, we additionally allow each function $f_i$ to take another input, $\epsilon_i$ which you can think of as a random number. Through the random input $\epsilon_1$, the output of $f_1$ can be random given a fixed value of $u$, hence giving rise to a conditional distribution $p(x \vert u)$.
The structural equation model (SEM) entails the causal graph, in that you can reconstruct the causal graph by looking at the inputs of each function. It also entails the joint distribution, in that you can "sample" from an SEM by evaluating the functions in order, plugging in the random $\epsilon$s where needed.
In a SEM an intervention on a variable, say $q$, can be modelled by deleting the corresponding function, $f_4$, and replacing it with another function. For example $do(Q = q_0)$ would correspond to a simple assignment to a constant $\tilde{f}_4(x, a) = q_0$.
Now that we know what SEMs are we can return to our example of beards and degrees. Let's add a few more things to the figure:
First change is, that instead of just a causal graph, I now assume that we model the world by a fully specified structural equation model. I signify this lazily by labelling the causal graph with the functions $f_1, f_2, f_3$ over the graph. Notice that the SEM of the green situation is the same as the SEM in the blue case, except that I deleted $f_1$ and replaced it with a constant assignment. But $f_2$ and $f_3$ are the same between the blue and the green models.
Secondly, I make the existence of the $\epsilon_i$ noise variables explicit, and show their values (it's all made up of course) in the gray table. If you feed the first row of epsilons to the blue structural equation model, you get the first blue datapoint $0110$. If you feed the same epsilons to the green SEM, you get the first green datapoint $(0110)$. If you feed the second row of epsilons to the models, you get the second rows in the blue and green tables, and so on...
I like to think about the first green datapoint as the parallel twin of the first blue datapoint. To talk about interventions I talked about this making predictions about a parallel universe where nobody has a beard. Now imagine that for every person who lives in our observable universe, there is a corresponding person, their parallel twin, in this parallel universe. Your twin is same in every respect as you, except for the absence of any beard you might have and any downstream consequences of having a beard. If you don't have a beard in this universe, your twin is an exact copy of you in every respect. Indeed, notice that the first blue datapoint is the same as the first green datapoint in my example.
You may know I'm a Star Trek fan and I like to use Star Trek analogies to explain things: In Star Trek, there is this concept called the mirror universe. It's a parallel universe populated by the same people who live in the normal universe, except that everyone who is good in the real universe is evil in the mirror universe. Hilariously, the mirror version of Spock, one of the main protagonists, has a goatie in this mirror universe. This is how you could tell if you're looking at evil mirror-Spock or normal Spock when watching the episode. This explains, sadly, why I'm using beards to explain counterfactuals. Here are Spock and mirror-Spock:
Now that we established the twin datapoint metaphor, we can say that counterfactuals are
making a prediction about features of the unobserved twin datapoint based on features of the observed datapoint.
Crucially, this was possible because we used the same $\epsilon$s in both the blue and the green SEM. This induces a joint distribution between variables in the observable regime, and variables in the unobserved, counterfactual regime. Columns of the green table are no longer independent of columns of the blue table. You can start predicting values in the green table using values in the blue table, as illustrated by the arrows below them.
Mathematically, a counterfactual is the following conditional probability:
$$
p(π^\ast \vert π§^\ast = 0, π§=1, π=1, πͺ=1, π=1),
$$
where variables with an $^\ast$ are unobserved (and unobservable) variables that live in the counterfactual world, while variables without $^\ast$ are observable.
Looking at the data, it turns out that mirror-Ferenc, who does not have a beard, is married and has PhD, but is not quite as strong as observable Ferenc.
Here is another drawing that some of you might find more appealing, especially those who are are into GANs, VAEs and similar generative models:
A SEM is essentially a generative model of data, which uses some noise variables $\epsilon_1, \epsilon_2, \ldots$ and turns them into observations $(U,Z,X,Y)$ in this example. This is shown in the left-hand branch of the graph above. Now if you want to make counterfactual statements under the intervention $X=\hat{x}$, you can construct a mutilated SEM, which is the same SEM except with $f_3$ deleted and replaced with the constant assignment $x = \hat{x}$. This modified SEM is shown in the right-hand branch. If you feed the $\epsilon$s into the mutilated SEM, you get another set of variables $(U^\ast,Z^\ast,X^\ast,Y^\ast)$, shown in green. These are the features of the twin as it were. This joint generative model over $(U,Z,X,Y)$ and $(U^\ast,Z^\ast,X^\ast,Y^\ast)$ defines a joint distribution over the combined set of variables $(U,Z,X,Y,U^\ast,Z^\ast,X^\ast,Y^\ast)$. Therefore, now you can calculate all sorts of conditionals and marginals of this joint.
Of particular interest are these conditionals:
$$
p(y^\ast \vert X^\ast = \hat{x}, X = x, Y = y, U = u, Z = z),
$$
which is a counterfactual prediction. In reality, since $X^\ast = \hat{x}$ holds with a probability of $1$, we can drop that conditioning.
My notation here is a bit sloppy, there are a lot of things going on implicitly under the hood, which I'm not making explicit in the notation. I'm sorry if it causes any irritation to people, I want to avoid overcomplicating things at this point. Now is a good time to point out that Pearl's notation, including do-notation is often criticized, but people use it because now it's widely adopted.
We can also express the intervention conditional $p(y\vert do(x))$ using this (somewhat sloppy) notation as:
$$
p(y\vert do(X=\hat{x})) = p(y^\ast \vert X^\ast = \hat{x})
$$
We can see that the intervention conditional only contains variables with an $^\ast$ so it does not require the joint distribution of $(U,Z,X,Y,U^\ast,Z^\ast,X^\ast,Y^\ast)$ only the marginal of the $^\ast$ variables $(X^\ast, Z^\ast,X^\ast,Y^\ast)$. As a consequence in order to talk about $p(y\vert do(X=\hat{x}))$ we did not need to introduce SEMs or talk about the epsilons.
Furthermore, notice the following equality:
\begin{align}
p(y\vert do(X=\hat{x})) &= p(y^\ast \vert X^\ast = \hat{x}) \\
&= \int_{x,y,u,z} p(y^\ast \vert X^\ast = \hat{x}, X = x, Y = y, U = u, Z = z) p(x,y,u,z) dx dy du dz \\
&= \mathbb{E}_{p_{X,Y,U,Z}} p(y^\ast \vert X^\ast = \hat{x}, X = x, Y = y, U = u, Z = z),
\end{align}
in other words, the intervention conditional $p(y\vert do(X=\hat{x}))$ is the average of counterfactuals over the obserevable population. This was something that I did did not realize before my MLSS tutorial, and it was pointed out to me by a student in the form of a question. In hindsight, of course this is true!
God, this was a loooong post. If you're still reading, thanks, I hope it was useful. I wanted to close with a few slightly philosophical remarks on counterfactuals.
Counterfactuals are often said to be unscientific, primarily because they are not empirically testable. In normal ML we are used to benchmark datasets, and that the quality of our predictions can always be tested on some test dataset. In causal ML, not everything can be directly tested or empirically benchmarked. For interventions, the best test is to run a randomized controlled trial to directly measure $p(y\vert do(X=x))$ if you can, and then use this experiemental data to evaluate your causal inferences. But some interventions are impossible to carry out in practice. Think about all the work on fairness reasoning about interventions on gender or race. So what to do then?
In the world of counterfactuals this is an even bigger problem, as it is outright impossible to observe the variables you make predictions about. You can't go back in time and rerun history with exactly the same circumstances except for a tiny change. You can't travel to parallel universes (at least not before the 24th century, according to Star Trek). Counterfactual judgments remain hypothetical, subjective, untestable, unfalsifiable. There can be no MNIST or Imagenet for counterfactuals that satisfies everyone, though some good datasets exist, they are for specific scenarios where explicit testing is possible (e.g. offline A/B testing), or make use of simulators instead of "real" data.
Despite it being untestable, and difficult to interpret, humans make use of counterfactual statements all the time, and intuitively it feels like they are pretty useful for intelligent behaviour. Being able to pinpoint the causes that lead to a particular situation or outcome is certainly useful for learning, reasoning and intelligence. So my strategy for now is to ignore the philosophical debate about counterfactuals, and just get on with it, knowing that the tools are there if such predictions have to be made.
]]>Last week I had the honor to lecture at the Machine Learning Summer School in Stellenbosch, South Africa. I chose to talk about Causal Inference, despite being a newcomer to this whole area. In fact, I chose it exactly because I'm a newcomer: causal inference has been a blindspot for me for a long time. I wanted to communicate some of the intuitions and ideas I learned and developed over the past few months, which I wish someone had explained to me earlier.
Now, I'm turning my presentation into a series of posts, starting with this one, building on the previous one I wrote in May. In this one, I will present the toy example I used in my talk to explain interventions. I call this the three scripts toy example. I was not sure if people are going to get it, but I got good feedback on it from the audience, so I'm hoping you will find it useful, too. Here are links to the other posts:
Imagine you teach a programming course and you ask students to write a python script that samples from a 2D Gaussian distribution with a certain mean and covariance. Some of the solutions will be correct, but as there are multiple ways to sample from a Gaussian, you might see very different solutions. For example, here are three scripts that would implement the same, correct sampling behaviour:
Below each of the code snippets I plotted samples drawn by repeatedly executing the scripts. As you can see, all three scripts produce the same joint distribution between $x$ and $y$. You can feed these distributions into a two-sample test, and you will find that they are indeed indistinguishable from each other.
Based on the joint distribution the three scripts are indistinguishable.
But despite the three scripts being equivalent in that they generate the same distribution, they are not exactly the same. For example, they behave differently if we interefere or intervene to the execution.
Consider this thought experiment: I am a hacker, and I can inject code to the python interpreter. For every line of code from the snippet, I can insert a line of code of my choice. Let's say that I really want to set the value of $x$ to $3$, so I use my code injection ability and insert the line x=3
after each line of code of yours. So what actually gets executed is this:
We can now run the scripts in this hacked interpreter and see how the intervention changes the distribution of $x$ and $y$:
Of course, we see that the value of $x$ is no longer random, it's deterministically set to $3$, this results in all samples lining up along the $x=3$ vertical line. But, interestingly, the distribution of $y$ is different for the different scripts. In the blue script, $y$ has a mean around $5$ while the green and red scripts produce a distribution of $y$ centered around a mean of $1$. Here is a better look at the marginal distribution of $y$ under the intervention:
I labelled this plot $p(y\vert do(X=3))$ which semantically means the distribution of $y$ under the intervention where I set the value of $X$ to $3$. This is generally different from the conditional distribution $p(y\vert x=3)$, which of course is the same for all three scripts. Below I show these conditionals - excuse me for the massive estimation errors here, I was lazy creating these plots, but believe me they technically are all the same:
The important point here is that
the scripts behave differently under intervention.
We have a situation where the scripts are indistinguishable when you only look at the joint distribution of the samples they produce, yet they behave differently under intervention.
Consequently,
the joint distribution of data alone is insufficient to predict behaviour under interventions.
If the joint distribution is insufficient, what level of description would allow us to make predictions about how the scripts behave under intervention. If I have the full source code, I can of course execute the modified scripts, i.e. run an experiment and directly observe how the interaction effects the distribution.
However, it turns out, you don't need the full source code. It is sufficient to know the causal diagram corresponding to the source code. The causal diagram encodes causal relationships between variables, with an arrow pointing from causes to effects. Here is what the causal diagrams would look like for these scripts:
We can see that, even though they produce the same joint distribution, the scripts have different causal diagrams. And this additional knowledge of the causal structure allows us to make inferences about intervention without actually running experiments with that intervention. To do this in general setting, we can use do-calculus, explained in a little bit more detail in my earlier post.
Graphically, to simulate the effect of an intervention, you mutilate the graph by removing all edges that point into the variable on which the intervention is applied, in this case $x$.
At the top row you see the three diagrams that describe the three scripts. In the second row are the mutilated graphs where all incoming edges to $x$ have been removed. In the first script, the graph looks the same after mutilation. From this, we can conclude that $p(y\vert do(x)) = p(y\vert x)$, i.e. that the distribution of $y$ under intervention $X=3$ is the same as the conditional distribution of $y$ conditioned on $X=3$. In the second script, after mutilation, $x$ and $y$ become disconnected, therefore independent. From this, we can conclude that $p(y\vert do(X=3)) = p(y)$. Changing the value of $x$ does nothing to change the value of $y$, so whatever you set $X$ to be, $y$ is just going to sample from its marginal distribution. The same argument holds for the third causal diagram.
The significance of this is the following: By only looking at the causal diagram, we are now able to predict how the scripts are going to behave under the intervention $X=3$. We can compute and plot $p(y\vert do(X=3))$ for the three scripts by only using data observed during the normal (non-intervened) condition, without ever having to run the experiment or simulate the intervention.
The causal diagram allows us to predict how the models will behave under intervention, without carrying out the intervention
Here is proof of this, I could estimate the distribution of $y$ observed during the intervention experiment using only samples from the script under the normal (non-intervention) situation. This is called causal infernce from observational data.
The morale of this story is summed up in the following picture:
Let's consider all the questions you would want to be able to answer given some data (i.i.d. samples from a joint distribution). Having access to data, or more generally, the joint distribution it was sampled from, allows you to answer a great many questions, and solve many tasks. For example, you can do supervised machine learning by approximating $p(y \vert x)$ and then use it for many things, such as image labelling. These questions together make up the the blue set.
However, we have already seen that some questions cannot be answered using data/joint distribution alone. Notably, if you want to make predictions about how the system you study would behave under certain interventions or perturbations, you typically won't be able to make such infernces based on the data you have. These types of questions lie outside the blue set.
However, if you complement your data with causal assumptions encoded in a causal diagram - a directed acyclic graph where nodes are your variables - you can exploit these extra assumptions to start answering these questions, shown by the green area.
I am also showing an even larger set of questions still, but I won't tell you what this refers to just yet. I'm leaving that to a future post.
The questions I got most frequently after the lecture were these: what if I don't know what the graph looks like? And what if I get the graph wrong? There are many ways to answer these questions.
To me, a rehabilitated Bayesian, the most appealing answer is that you have to accept that your analysis is conditional on the graph you choose, and your conclusions are valid under the assumptions encoded there. In a way, causal inference from observational data is subjective. When you publish a result, you should caveat it with "under these assumptions, this is true". Readers can then dispute and question your assumptions if they disagree.
As to how to obtain a graph, it varies by application. If you work with online systems such as a recommender system, the causal diagram is actually pretty simple to draw, as it corresponds to how various subsystems are hooked up and feed data to one another. In other applications, notably in healthcare, a little bit more guesswork and thought may involved.
Finally, you can use various causal discovery techniques to try to identify the causal diagram from the data itself. Theoretically, recovering the full causal graph from the data is impossible in general cases. However, if you add certain additional smoothness or independence assumptions to the mix, you may be able to recover the graph from the data with a certain reliability.
We have seen that modeling the joint distribution can only get you so far, and if you want to predict the effect of interventions, i.e. calculate $p(y\vert do(x))$-like quantities, you have to add a causal graph to your analysis.
An optimistic take on this is that, in the i.i.d. setting, drawing a causal diagram and using do-calculus (or something equivalent) can significantly broaden the set of problems you can tackle with machine learning.
The pessimist's take on this is that if you are not aware of all this, you might be trying to address questions in the green bit, without realizing that the data won't be able to give you an answer to such questions.
Whichever way you look at it, causal inference from observational data is an important topic to be aware of.
]]>So as I recently read a paper by Tencent, I was surprised to learn that the online Bayesian deep learning algorithm is apparently deployed in production to power click-through-rate prediction in their ad system. I thought, therefore, that this is worth a mention. Here is the paper:
Though the paper has a few typos and sometimes inconsistent notation (like going back and forth between using $\omega$ and $w$) which can make it tricky to read, I like the paper's layout: starting from desiderata - what the system has to be able to do - arriving at an elegant solution.
The method relies on the approximate Bayesian online-learning technique often referred to as assumed density filtering.
ADF has been independently discovered by the statistics, machine learning and control communities (for citations see (Minka, 2001). It is perhaps most elegantly described by Opper, (1998), and by Minka, (2001) who also extended the idea to develop expectation propagation, a highly influential method for approximate Bayesian inference not just for online settings.
ADF can be explained as recursive algorithm repeating the following steps:
The Tencent paper is based on probabilistic backpropagation (Hernandez-Lobato and Adams, 2015), which uses the ADF idea multiple times to perform the two non-trivial tasks required in a supervised Bayesian deep network: inference and learning.
I find the similarity of this algorithm to gradient calculations in neural network training beautiful: forward propagation to perform inference, forward and backward propagation for learning and parameter updates.
Crucially though, there is no gradient descent, or indeed no gradients or iterative optimization taking place here. After a single forward-backward cycle, information about the new datapoint (or minibatch of data) is - approximately - incorporated into the updated distribution $q_{t}(w)$ of network weights. Once this is done, the mini-batch can be discarded and never revisited. This makes this method a great candidate for online learning in large-scale streaming data situations.
Another advantage of this method is that it can be parallelized in a data-parallel fashion: Multiple workers can update the posterior simultaneously on different minibatches of data. Expectation-propagation provides a meaningful way of combining the parameter updates computed by parallel workers in a meaningful fashion. Indeed, this is what the PBODL algorithm of Tencent does.
Bayesian online learning comes with great advantages. In recommender systems and ads marketplaces, the Bayesian online learning approach handles the user and item cold-start problem gracefully. If there is a new ad in the system that has to be recommended, initially our network might say "I don't really know", and then gradually hone in on a confident prediction as more data is observed.
One can use the uncertainty estimates in a Bayesian network to perform online learning: actively proposing which items to label in order to increase predictive performance in the future. In the context of ads, this may be a useful feature if implemented with care.
Finally, there is the useful principle that Bayes never forgets: if we perform exact Bayesian updates on global parameters of an exchangeable model, the posterior will store information indefinitely long about all datapoints, including very early ones. This advantage is clearly demonstrated in DeepMind's work on catastrophic forgetting (Kirkpatrick et al, 2017). This capacity to keep a long memory of course diminishes the more approximations we have to make, which leads me to drawbacks.
The ADF method is only approximate, and over time, the approximations in each step may accumulate resulting in the network to essentially forget what it learned from earlier datapoints. It is also worth pointing out that in stationary situations, the posterior over parameters is expected to shrink, especially in the last hidden layers of the network where there are often fewer parameters. This posterior compression is often countered by introducing explicit forgetting: without evidence to the contrary, the variance of each parameter is marginally increased in each step. Forgetting and the lack of infinite memory may turn out to be an advantage, not a bug, in non-stationary situations where more recent datapoints are more representative of future test cases.
A practical drawback of a method like this in production environments is that probabilistic forward- and backpropagation require non-trivial custom implementations. Probabilistic backprop does not use reverse mode automatic differentiation, i.e. vanilla backprop as a subroutine. As a consequence, one cannot rely on extensively tested and widely used autodiff packages in tensorflow or pytorch, and performing probabilistic backprop in new layer-types might require significant effort and custom implementations. It is possible that high-quality message passing packages such as the recently open-sourced infer.net will see a renaissance after all.
]]>I scared you, didn't I?
]]>Boo.
I scared you, didn't I?
]]>After a long vacation (in blogging terms), I'm returning with a brief paper review which conveniently allows me to continue to explain a few ideas from causal reasoning. Since I wrote this intro to causality, I have read a lot more about it, especially how it
]]>Happy back-to-school time everyone!
After a long vacation (in blogging terms), I'm returning with a brief paper review which conveniently allows me to continue to explain a few ideas from causal reasoning. Since I wrote this intro to causality, I have read a lot more about it, especially how it relates to recommender systems.
Here are the two recent papers this post is based on:
UPDATE: After publishing this post, commenters pointed me to excellent in-depth blog posts by Alex d'Amour on the limitations of these techniques: on non-identifiability and positivity. I recommend you go on and read these posts if you want to learn more about the conditions in which the proposed methods work.
Simpson's paradox is a commonly used example to introduce confounders and how they can cause problems if not handled appropriately. I'm going to piggyback on the specific example from (Bottou et al, 2013), another excellent paper that should be on your reading list if you haven't read it already. We're looking at data about the success rate of two different treatments, A and B, for kidney stone removal:
Look at the first column "Overall" first. The average success rate of treatment A is 78%, while the success rate of treatment B is 83%. You might stop here and conclude that treatment B is better because it's successful in a higher percentage of cases, and continue recommending treatment B to everyone. Not so fast! Look at what happens if we group the data by the size of kidney stone. Now we see that in the group of patients with small stones, treatment A is better, in the group with large stones we also see that treatment A is better. How can treatment A be superior in both group of patients, but treatment B be better overall?
The answer lies in the fact that patients weren't randomly allocated to treatments. Patients with small stones are more likely to be assigned treatment B. These are also patients who are more likely to recover whichever treatment is used. Coversely, when the stone is large, treatment A is more common. In these cases the success rate is lower overall, too. Therefore, when you group data from the two groups, you ignored the fact that the treatments were assigned "unfairly", in the sense that treatment A was on average assigned to many more patients who had a worse outlook to begin with.
In this case, the size of the kidney stone is a confounder variable. A random variable that causally influences both the assignment of treatments, and the outcome. Graphically, it looks like this:
If we use $T, C$ and $Y$ to denote the treatment, confounder and the outcome respectively, and denote treatments A and B as 0 and 1, this is how one could correctly estimate the causal effect of the treatment on the outcome, while controlling for the confounder $C$:
$$
\mathbb{E}_{c\sim p_C}[\mathbb{E}[Y\vert T=1,C=c]] - \mathbb{E}_{c\sim p_C}[\mathbb{E}[Y\vert T=0,C=c]]
$$
Or, alternatively, with the do-calculus notation:
$$
\mathbb{E}[Y\vert do(T=1)] = \mathbb{E}_{c\sim p_C}[\mathbb{E}[Y\vert T=1,C=c]]
$$
Let's decompress this. $\mathbb{E}[Y|T=0,C=c]$ is the average outcome, given that treatment $1$ was given and the confounder took value $c$. Simple. Then, you average this over the marginal distribution of the confounder $C$.
Let's look at how this differs from the non-causal association you would measure between treatment and outcome (i.e. when you don't control for $C$:
$$
\mathbb{E}[Y\vert T=1] = \mathbb{E}_{c\sim p_{C\vert T=1}}[\mathbb{E}[Y\vert T=1,C=c]]
$$
So the two expressions differ in whether you average over $p_{C\vert T=1}$ or $p_{C}$ in the right-hand side. How does ethis expression come about? Remember from the previous post that the effect of an intervention, $do(T=1)$, can be simulated in the causal graph by severing all edges that point into the variable we intervene on. In this case, we sever the edge from the confounder $C$ to the treatment $T$ as that is the only edge pinting to $T$. Therefore, in the alternative model where the link is severed, $C$ and $T$ become independent and thus $p_{C\vert do(T=1)} = p_{C}$. Meanwhile, as no links leading from either $C$ or $T$ to $Y$ have been severed $p(Y\vert T, C)$ remains unchanged by the intervention.
So, if correcting for confounders is so simple, why is it not a solved problem yet? Because you don't always know the confounders, or sometimes you can't measure them.
Controlling for too little: You may not be measuring all confounders. Either because you're not even aware of some, or because you're aware of them but genuinely can't measure their values.
Controlling for too much: You may be controlling for too much. When one wants to make sure all confounders are covered, it's commonplace to simply control for all observable variables, in case they turn out to be confounders. This can go wrong, or very wrong. In ML terms, every time you control for a confounder, you add an extra input to the model to approximate $P_{y\vert T, C}$, increasing complexity of your model, likely reducing its data efficiency. Therefore, if you include a redundant variable in your analysis as a confounder, you're just making your task harder for no additional gain.
But, even worse things can happen if you include the a variable in your analysis which you must not condition on. Mediator variables are ones whose value is causally influenced by the treatment, and they causally influence the outcome. In graph notation, it is a directed chain like this: $T \rightarrow M \rightarrow Y$. If you control for mediators as if they were confounders, you are shielding the causal effect of $T$ on $Y$ via $M$, so you are going to draw incorrect conclusions. The other situation might be conditioning on a collider $W$, which is causally influenced by both the treatment, and the outcome $Y$, i.e. $T \rightarrow W \leftarrow Y$. If you condition on $W$, the phenomenon of explaining away occurs, and you induce spurious association between $T$ and $Y$, and again, you draw incorrect conclusions. A common, often unnoticed, example of conditioning on a collider is sampling bias, when your datapoints are missing not at random.
In general case, given a causal model of all variables in question, do-calculus can identify the minimal set of variables you have to measure and condition on to obtain a correct answer to your causal query. But in many applications, it we simply don't have enough information to name all the possible confounding factors, and often, even if we suspect the existence of a specific confounder, we won't be able to measure or directly observe it.
Wang and Blei (2018) address this problem with the following idea: instead of identifying and then measuring possible confounders, let's try to automatically discover and infer a random variable which can act as a substitute confounder in our analysis.
This is possible in a special setting when:
In a latent factor model (like matrix factorization, PCA, etc) the user-treatment matrix is modeled by two sets of random variables: one for each treatment, one for each user. A general factor model structure is shown below:
$A_{i,j}$ is the assignment of treatment $j$ to data instance $i$. We can imagine this as binary. Its distribution is determined by the instance-specific random variable $Z_i$, and the treatment-specific $\theta_j$. The key observation is, as we look across instances $i$, conditioned on the $Z_i$ vector, the treatment assignments $A_{i,j}$ become conditionally independent between treatments $j$. It is this $Z_i$ variable which will be used as a substitute confounder. Figure 1 of the paper illustrates why:
This figure shows the graphical structure of the problem. For the instance $i$ we have the $m$ different assignments $A_{i1}\ldotsA_{im}$ and the outcome $Y_i$. Causal arrows between the assignments $A$ and the outcome $Y$ are not shown on this Figure but it is assumed some of them might exist. We also calculate the latent variable $Z_i$ and call it substitute confounder. As $Z_i$ renders the otherwise dependent $A$s conditionally independent it appears as a common parent node in this graph. The argument is the following: if there was any confounder $U_i$ which effects multiple assignment variables simultaneously, it must either be contained within $Z_i$ already (which is good), or our factor model cannot be perfect. Incorporating $U$ as part of $Z$ would improve our model of $A$s.
Therefore, we can force $Z_i$ to incorporate all multi-cause confounders by simply improving the factor model. Sadly, this argument relies on reasoning about marginal and conditional independence between the causes $A$, so it cannot work for single-cause confounders.
In a way this method, called the deconfounder algorithm, is a mixture of causal discovery (learning about structure of the causal graph) and causal inference (making inferences about variables under interventions). It identifies just enough about the causal structure (the substitute confounder variable) to then be able to make causal inferences of a certain type.
The obvious limitation of this method - and the authors are very transparent about this - is that it cannot deal with single-cause confounders. You have to assume if there is a confounder in your problem, you will be able to find another cause variable causally influenced by the same confounder in order to use this method.
The other limitation is the assumption that the assignments follow a factor model distribution marginally. While this may seem like a relatively OK assumption, it rules a couple of interesting cases out. For example, the assumption implies that the treatments cannot directly causally influence each other. In many applications this is not a reasonable assumption. If I make you watch Lord of the Rings it will definitely causally influence whether you will want to watch Lord of the Rings II. Similarly, some treatments in hospitals may be given or ruled out because of other treatments.
Finally, while the argument guarantees that all multi-cause confounders will be contained in the substitute confounder, there is no guarantee that the method will identify a minimal set of confounders. Indeed, if there are factors which influence multiple $A$s but are unrelated to the outcome, those variables will be covered by $Z$ as well. It is not difficult to create toy examples where the majority of the information captured in $Z$ is redundant.
In the second paper, Wang et al (2018) use this method to the specific problem of training recommender systems. The issue here is that vanilla ML approaches work only when items and users are randomly sampled in the dataset, which is rarely the case. Whether you watch a movie, and whether you like it are confounded by a number of factors, many of which you cannot model or observe. The method does apply to this case, and it does better than their but the results lack the punch I was hoping to see.
Nevertheless, this is a very interesting direction, and also quite a bit different from many other approaches to the problem. UPDATE: As it was evident from commenters' reaction to my post, there is a great deal of discussion taking place about the limitations of these types of 'latent confounder identification' methods. Alex d'Amour's blog posts on non-identifiability and positivity go into a lot more detail on these.
]]>You might have come across Judea Pearl's new book, and a related interview which was widely shared in my social bubble. In the interview, Pearl dismisses most of what we do in ML as curve fitting. While I believe that's an overstatement (conveniently ignores RL for example), it's a nice reminder that most productive debates are often triggered by controversial or outright arrogant comments. Calling machine learning alchemy was a great recent example. After reading the article, I decided to look into his famous do-calculus and the topic causal inference once again.
Again, because this happened to me semi-periodically. I first learned do-calculus in a (very unpopular but advanced) undergraduate course Bayesian networks. Since then, I have re-encountered it every 2-3 years in various contexts, but somehow it never really struck a chord. I always just thought "this stuff is difficult and/or impractical" and eventually forgot about it and moved on. I never realized how fundamental this stuff was, until now.
This time around, I think I fully grasped the significance of causal reasoning and I turned into a full-on believer. I know I'm late to the game but I almost think it's basic hygiene for people working with data and conditional probabilities to understand the basics of this toolkit, and I feel embarrassed for completely ignoring this throughout my career.
In this post I'll try to explain the basics, and convince you why you should think about this, too. If you work on deep learning, that's an even better reason to understand this. Pearl's comments may be unhelpful if interpreted as contrasting deep learning with causal inference. Rather, you should interpret it as highlighting causal inference as a huge, relatively underexplored, application of deep learning. Don't get discouraged by causal diagrams looking a lot like Bayesian networks (not a coincidence seeing they were both pioneered by Pearl) they don't compete with, they complement deep learning.
First of all, causal calculus differentiates between two types of conditional distributions one might want to estimate. tldr: in ML we usually estimate only one of them, but in some applications we should actually try to or have to estimate the other one.
To set things up, let's say we have i.i.d. data sampled from some joint $p(x,y,z,\ldots)$. Let's assume we have lots of data and the best tools (say, deep networks) to fully estimate this joint distribution, or any property, conditional or marginal distribution thereof. In other words, let's assume $p$ is known and tractable. Say we are ultimately interested in how variable $y$ behaves given $x$. At a high level, one can ask this question in two ways:
observational $p(y\vert x)$: What is the distribution of $Y$ given that I observe variable $X$ takes value $x$. This is what we usually estimate in supervised machine learning. It is a conditional distribution which can be calculated from $p(x,y,z,\ldots)$ as a ratio of two of its marginals. $p(y\vert x) = \frac{p(x,y)}{p(x)}$. We're all very familiar with this object and also know how to estimate this from data.
interventional $p(y\vert do(x))$: What is the distribution of $Y$ if I were to set the value of $X$ to $x$. This describes the distribution of $Y$ I would observe if I intervened in the data generating process by artificially forcing the variable $X$ to take value $x$, but otherwise simulating the rest of the variables according to the original process that generated the data. (note that the data generating procedure is NOT the same as the joint distribution $p(x,y,z,\ldots)$ and this is an important detail).
No. $p(y\vert do(x))$ and $p(y\vert x)$ are not generally the same, and you can verify this with several simple thought experiments. Say, $Y$ is the pressure in my espresso machine's boiler which ranges roughly between $0$ and $1.1$ bar depending on how long it's been turned on. Let $X$ be the reading of the built-in barometer. Let's say we jointly observe X and Y at random times. Assuming the barometer functions properly $p(y|x)$ should be a unimodal distribution centered around $x$, with randomness due to measurement noise. However, $p(y|do(x))$ won't actually depend on the value of $x$ and is generally the same as $p(y)$, the marginal distribution of boiler pressure. This is because artificially setting my barometer to a value (say, by moving the needle) won't actually cause the pressure in the tank to go up or down.
In summary, $y$ and $x$ are correlated or statistically dependent and therefore seeing $x$ allows me to predict the value of $y$, but $y$ is not caused by $x$ so setting the value of $x$ won't effect the distribution of $y$. Hence, $p(y\vert x)$ and $p(y\vert do(x))$ behave very differently. This simple example is just the tip of the iceberg. The differences between interventional and observational conditionals can be a lot more nuanced and hard to characterize when there are lots of variables with complex interactions.
Depending on the application you want to solve, you should seek to estimate one of these conditionals. If your ultimate goal is diagnosis or forecasting (i.e. observing a naturally occurring $x$ and inferring the probable values of $y$) you want the observational conditional $p(y\vert x)$. This is what we already do in supervised learning, this is what Judea Pearl called curve fitting. This is all good for a range of important applications such as classification, image segmentation, super-resolution, voice transcription, machine translation, and many more.
In applications where you ultimately want to control or choose $x$ based on the conditional you estimated, you should seek to estimate $p(y\vert do(x))$ instead. For example, if $x$ is a medical treatment and $y$ is the outcome, you are not merely interested in observing a naturally occurring treatment $x$ and predicting the outcome, we want to proactively choose the treatment $x$ given our understanding of how it effects the outcome $y$. Similar situations occur in system identification, control and online recommender systems.
This is perhaps the main concept I haven't grasped before. $p(y\vert do(x))$ is in fact a vanilla conditional distribution, but it's not computed based on $p(x,z,y,\ldots)$, but a different joint $p_{do(X=x)}(x,z,y,\ldots)$ instead. This $p_{do(X=x)}$ is the joint distribution of data which we would observe if we actually carried out the intervention in question. $p(y\vert do(x))$ is the conditional distribution we would learn from data collected in randomized controlled trials or A/B tests where the experimenter controls $x$. Note that actually carrying out the intervention or randomized trials may be impossible or at least impractical or unethical in many situations. You can't do an A/B test forcing half your subjects to smoke weed and the other half to smoke placebo to understand the effect on marijuana on their health. Even if you can't directly estimate $p(y\vert do(x))$ from randomized experiments, the object still exists. The main point of causal inference and do-calculus is:
If I cannot measure $p(y\vert do(x))$ directly in a randomized controlled trial, can I estimate it based on data I observed outside of a controlled experiment?
Let's start with a diagram that shows what's going on if we only care about $p(y\vert x)$, i.e. the simple supervised learning case:
Let's say we observe 3 variables, $x, z, y$, in this order. Data is sampled i.i.d. from some observable joint distribution over 3 variables, denoted by the blue factor graph labelled 'observable joint'. If you don't know what a factor graph is, it's not important, the circles represent random variables, the little square represents a joint distribution of the variables it's connected to. We are interested in predicting $y$ from $x$, and say that $z$ is a third variable which we do not want to infer but we can also measure (I included this for completeness). The observational conditional $p(y\vert x)$ is calculated from this joint via simple conditioning. From the training data we can build a model $q(y\vert x;\theta)$ to approximate this conditional, for example using a deep net minimizing cross-entropy or whatever.
Now, what if we're actually interested in $p(y\vert do(x))$ rather than $p(y\vert x)$? This is what it looks like:
So, we still have the blue observed joint and data is still sampled from this joint. However, the object we wish to estimate is on the bottom right, the red intervention conditional $p(y\vert do(x))$. This is related to the intervention joint which is denoted by the red factor graph above it. It's a joint distribution over the same domain as $p$ but it's a different distribution. If we could sample from this red distribution (e.g. actually run a randomized controlled trial where we get to pick $x$), the problem would be solved by simple supervised learning. We could generate data from the red joint, and estimate a model directly from there. However, we assume this is not possible, and all we have is data sampled from the blue joint. We have to see if we can somehow estimate the red conditional $p(y\vert do(x))$ from the blue joint.
If we want to establish a connection between the blue and the red joints, we must introduce additional assumptions about the causal structure of the data generating mechanism. The only way we can make predictions about how our distribution changes as a consequence of an interaction is if we know how the variables are causally related. This information about causal relationships is not captured in the joint distribution alone. We have to introduce something more expressive than that. Here is how what this looks like:
In addition to the observable joint we now also have a causal model of the world (top left) This causal model contains more detail than the joint distribution: it knows not only that pressure and barometer readings are dependent but also that pressure causes the barometer to go up and not the other way around. The arrows in this model correspond to the assumed direction of causation, and the absence of an arrow represents the absence of direct causal influence between variables. The mapping from causal diagrams to joint distributions is many-to-one: several causal diagrams are compatible with the same joint distribution. Thus, it is generally impossible to conclusively choose between different causal explanations by looking at observed data only.
Coming up with a causal model is a modeling step where we have to consider assumptions about how the world works, what causes what. Once we have a causal diagram, we can emulate the effect of intervention by mutilating the causal network: deleting all edges that lead into nodes in a $do$ operator. This is shown on the middle-top panel. The mutilated causal model then gives rise to a joint distribution denoted by the green factor graph. This joint has a corresponding conditional distribution $\tilde{p}(y\vert do(x))$, which we can use as our approximation of $p(y\vert do(x))$. If we got the causal structure qualitatively right (i.e. there are no missing nodes and we got the direction of arrows all correct), this approximation is exact and $\tilde{p}(y\vert do(x)) = p(y\vert do(x))$. If our causal assumptions are wrong, the approximation may be bogus.
Critically, to get to this green stuff, and thereby to establish the bridge between observational data and interventional distributions, we had to combine data with additional assumptions, prior knowledge if you wish. Data alone would not enable us to do this.
Now the question is, how can we say anything about the green conditional when we only have data from the blue distribution. We are in a better situation than before as we have the causal model relating the two. To cut a long story short, this is what the so-called do-calculus is for. Do-calculus allows us to massage the green conditional distribution until we can express it in terms of various marginals, conditionals and expectations under the blue distribution. Do-calculus extends our toolkit of working with conditional probability distributions with four additional rules we can apply to conditional distributions with the $do$ operators in them. These rules take into account properties of the causal diagram. The details can't be compressed into a single blog post, but here is an introductory paper on them..
Ideally, as a result of a do-calculus derivation you end up with an equivalent formula for $\tilde{p}(y\vert do(x))$ which no longer has any do operators in them, so you estimate it from observational data alone. If this is the case we say that the causal query $\tilde{p}(y\vert do(x))$ is identifiable. Conversely, if this is not possible, no matter how hard we try applying do-calculus, we call the causal query non-identifiable, which means that we won't be able to estimate it from the data we have. The diagram below summarizes this causal inference machinery in its full glory.
The new panel called "estimable formula" shows the equivalent expression for $\tilde{p}(y\vert do(x))$ obtained as a result of the derivation including several do-calculus rules. Notice how the variable $z$ which is completely irrelevant if you only care about $p(y\vert x)$ is now needed to perform causal inference. If we can't observe $z$ we can still do supervised learning, but we won't be able to answer causal inference queries $p(y\vert do(x))$.
You can never fully verify the validity and completeness of your causal diagram based on observed data alone. However, there are certain aspects of the causal model which are empirically testable. In particular, the causal diagram implies certain conditional independence or dependence relationships between sets of variables. These dependencies or independencies can be empirically tested, and if they are not present in the data, that is an indication that your causal model is wrong. Taking this idea forward you can attempt to perform full causal discovery: attempting to infer the causal model or at least aspects of it, from empirical data.
But the bottom line is: a full causal model is a form of prior knowledge that you have to add to your analysis in order to get answers to causal questions without actually carrying out interventions. Reasoning with data alone won't be able to give you this. Unlike priors in Bayesian analysis - which are a nice-to-have and can improve data-efficiency - causal diagrams in causal inference are a must-have. With a few exceptions, all you can do without them is running randomized controlled experiments.
Causal inference is indeed something fundamental. It allows us to answer "what-if-we-did-x" type questions that would normally require controlled experiments and explicit interventions to answer. And I haven't even touched on counterfactuals which are even more powerful.
You can live without this in some cases. Often, you really just want to do normal inference. In other applications such as model-free RL, the ability to explicitly control certain variables may allow you to sidestep answering causal questions explicitly. But there are several situations, and very important applications, where causal inference offers the only method to solve the problem in a principled way.
I wanted to emphasize again that this is not a question of whether you work on deep learning or causal inference. You can, and in many cases you should, do both. Causal inference and do-calculus allows you to understand a problem and establish what needs to be estimated from data based on your assumptions captured in a causal diagram. But once you've done that, you still need powerful tools to actually estimate that thing from data. Here, you can still use deep learning, SGD, variational bounds, etc. It is this cross-section of deep learning applied to causal inference which the recent article with Pearl claimed was under-explored.
UPDATE: In the comments below people actually pointed out some relevant papers (thanks!). If you are aware of any work, please add them there.
]]>Many people working on network pruning observed that, starting from a wide
]]>I wanted to highlight a recent paper I came across, which is also a nice follow-up to my earlier post on pruning neural networks:
Many people working on network pruning observed that, starting from a wide network and pruning it, one obtains better performance than training the slim, pruned architecture from scratch (random initialization). This suggests that the added redundancy of an overly wide network is somehow useful in training.
In this paper, Frankle and Carbin present the lottery ticket hypothesis to explain these observations. According to the hypothesis, good performance results from lucky initialization of a subnetwork or subcomponent of the original network. Since fat networks have exponentially more component subnetworks, starting from a fatter network increases the effective number of lottery tickets, thereby increasing the chances of containing a winning ticket. According to this hypothesis, pruning effectively identifies the subcomponent which is the winning ticket.
It's important to note that good initialization is a somewhat underrated but extremely important component of SGD-based deep learning. Indeed, I often find that if you use non-trivial architectures, training is essentially stuck until you tweak the off-the-shelf initialization schemes so that it somehow starts to work.
To test the hypothesis the authors have designed a set of cool experiments. The experiments basically go like this:
Here are some of the results from Figure 4.
Observe two main things in these results. The winning tickets, without the remaining redundancy of the wide network, train faster than the wide network. In fact, the skinnier they are, the faster they train (within reason). However, if you reinitialize the networks' weights randomly (control), the resulting nets now train slower than full network. Therefore, pruning is not just about finding the right architecture, it's also about identifying the 'winning ticket', which is a particularly luckily initialized subcomponent of the network.
I thought this paper was pretty cool. It underscores the role that (random) initialization plays in successfull training of neural networks, and gives an interesting justification for why redundancy and over-parametrization might actually be a desirable property when combined with SGD from random init.
There are various ways this line of research can be extended and built upon. For a starter, it would be great to see how the results hold up when better pruning schemes and more general classes of initialization approaches are used. It would be interesting to see if the training+pruning method consistently identifies the same winning ticket, or if it finds different winning tickets each time. One could look at various properties of randomly sampled tickets and compare them to the winning tickets. For example, I would be very curious to see what these winning tickets look like on the information plane. Similarly, it would be great to find cheaper ways or heuristics to identify winning tickets without having to train and prune the fat network.
]]>DALI is my favourite meeting
]]>This is a post about my takeaways from the DALI workshop on the Goals and Principles of Representation Learning which we co-organized with DeepMinders Shakir Mohamed and Andriy Mnih and Twitter colleague Lucas Theis. We had an amazing set of presentations, videos are available here.
DALI is my favourite meeting of the year. It's small and and participants are all very experienced and engaged. To make most of the high quality audience we decided to focus on slightly controversial topics, and asked speakers to present their general thoughts, their unique perspectives and opinions rather than talk about their latest research project. Broadly, we asked three types of questions:
From my perspective, the end-goal of any area of ML research should be for each algorithm to have a justification from first principles. Like the chain above starting from the Task all the way down to the precise algorithm we're going to run.
In my machine learning cookbook post I reviewed a number of fairly well understood problem transformations which one could apply to the bottom of this chain, typically to turn an intractable optimisation problem to a tractable one: variational inference, evolution strategies, convex relaxations, etc. In the workshop we focussed on the top of the chain, the computational level understanding: Do we actually have clarity on the ultimate task we solve? If this the task is given, is the principle we're applying actually justified? Is the loss function consistent with the principle?
One thing that most speakers and attendees seemed to agree on is that the most interesting and most important task we use unsupervised representation learning for is to transfer knowledge from large datasets to new tasks. I'll try to capture one possible formulation of what this means:
There are many possible extensions of this basic setup. For example we may have a distribution of tasks, rather than a single task.
One of the principles we discussed was disentanglement. It is somewhat unclear what exactly disentanglement is, but one way to capture this is to say "each coordinate in the representation corresponds to one meaningful factor of variation". Of course this is kind of a circular definition inasmuch as it's not easy to define a meaningful factor of variation. Irina defined disentangled representations as "factorized and interpretable", which, as a definition, has its own issues, too.
In the context of transfer learning though, I can see a vague argument for why seeking disentanglement might be a useful principle. One way to improve data efficiency of a ML algorithm is to endow it with inductive biases or to reduce the complexity of the function class in some sense. If we believe that future downstream tasks we may want to solve can be solved with simple, say linear, models on top of a disentangled representation, then seeking disentanglement makes sense. It may be possible to use this argument itself to formulate a more precise definition of or objective function for disentanglement.
The idea of self-supervision also featured prominently in talks. Self-supervision is the practice of defining auxiliary supervised learning tasks involving the unlabelled data, with hopes that solving these auxiliary tasks will require us to build representations which are also useful in the eventual downstream task we'd like to solve. Autoencoders, pseudolikelihood, denoising, jigsaw-puzzles, temporal-contrastive learning and so on are examples of this approach.
One very interesting aspect Harri talked about was the question of relevance? How can we make sure that an auxiliary learning task is relevant for the primary inference task we need to solve. In a semi-supervised learning setup Harri showed how you can adapt an auxiliary denoising task to be more relevant: you can calculate a sort of saliency map to determine which input dimensions the primary inference network pays attention to, and then use saliency-weighted denoising as the auxiliary task.
After the talks we had an open discussion session, hoping for a heated debate with people representing opposing viewpoints. And the audience really delivered on this. 10 minutes or so into the discussion Zoubin walked in and stated "I don't believe in representation learning". I tried to capture the main points in the argument (dad being anti-representation) below:
Retrospectively, I would summarise Zoubin's (and others') criticism as follows: If we identified transfer learning as the primary task representation learning is supposed to solve, are we actually sure that representation learning is the way to solve it? Indeed, this is not one of the questions I asked in my slides, and it's a very good one. One can argue that there may be many ways to transfer information from some dataset over to a novel task. Learning a representation and transferring that is just one approach. Meta-learning, for example, might provide another approach.
I have really enjoyed this workshop. The talks were great, participants were engaged, and the discussion at the end was a fantastic ending for DALI itself. I am hoping that, just like me, others have left the workshop with many things to talk about, and that the debates we had will inspire future work on representation learning (or anti-representation learning).
]]>I wanted to briefly highlight two recent papers on pruning neural networks (disclaimer, one of them is ours):
What I generally refer to as pruning in the title of this post is reducing or controlling the number of non-zero parameters, or the number of featuremaps actively used in a neural network. At a high level, there are at least three ways one can go about this, pruning is really only one of them:
There are different reasons for pruning your network. The most obvious, perhaps, is to reduce computational cost while keeping the same performance. Removing features which aren't really used in your deep network architecture can speed up inference as well as training. You can think also think of pruning as a form of architecture search: figuring out how many features you need in each layer for best performance.
The second argument is to improve generalization by reducing the number of parameters, and thus the redundancy in the parameter space. As we have seen in recent work on generalization in deep networks, the raw number of parameters ($L_0$ norm) is not actually a sufficient predictor of their generalization ability. That said, we empirically find that pruning a network tends to help generalization. Meanwhile, the community is developing (or, putting my Schmidhuber-hat on: maybe in some cases rediscovering) new parameter-dependent quantities to predict/describe generalization. The Fisher-Rao norm is a great example of these. Interestingly, Fisher pruning (Theis et al, 2018) turns out to have a nice connection to the Fisher-Rao norm, and this may hint at a deeper relationship between pruning, parameter redundancy and generalization.
I found the $L_0$ paper by Louizos et al, (2018) very interesting in that it can be seen as a straightforward application of the machine learning problem transformations I wrote up in the machine learning cookbook a few months ago. It's a good illustration of how you can use these general ideas go from formulating an intractable ML optimization problem to something practical you can run SGD on.
So I will summarize the paper as a series of steps, each changing the optimization problem:
Interestingly, the connection between Eq. (3) and evolution strategies or variational optimization is not mentioned. Instead, a motivation based on a different connection to spike-and-slab priors is given. I recommend reading the paper, perhaps with this connection in mind.
The authors then show that this indeed works, and compares favourably to other methods designed to reduce the number of parameters.
Thinking about the paper in terms of these steps converting from one problem to another allows you to generalize or improve the idea. For example, the REBAR or RELAX gradient estimators provide an unbiased and lower-variance alternative to the concrete relaxation, which may work very well here, too.
The second paper I wanted to talk about is something from our own lab. Rather than being a purely methods paper, (Theis et al, 2018) focusses on the specific application of building speedy neural networks to predict saliency in an image. The pruned network now powers the logic behind cropping photos on Twitter.
Our goal, too, was to reduce computational cost of the network, and specifically in the transfer learning setting: when building on top of a pre-trained neural network, you inherit a lot of complexity required to solve the original source task, much of which may be redundant for solving your target task. There is a difference in our high-level pruning objective: unlike $L_0$ norm or group sparsity, we used a slightly more complicated formula to directly estimates the forward pass runtime of the method. This is a quadratic function of the number of parameters at each layer with interactions between neighbouring layers. Interestingly, this results in architectures which tend to alternate thick and thin layers, like the one below:
We prune the trained network greedily by removing convolutional featuremaps one at a time. A meaningful principle for selecting the next feature map to prune is to minimize the resulting increase in training loss. Starting from this criterion, using a second order Taylor-expansion of the loss, making some more assumptions, we obtain the following pruning signal for keeping a parameter $\theta_i$:
$$
\Delta_i \propto F_i \theta_i^2,
$$
where $F_i$ denotes the $i^{th}$ diagonal entry of the Fisher information matrix. The above formula deals with removing a single parameter, but we can generalize this to removing entire featuremaps. Pruning proceeds by removing the parameter or featuremap with the smallest $\Delta$ in each iteration, and retraining the network between iterations. For more details, please see the paper.
Adding to what's presented in the paper, I wanted to point out a few connections between Fisher pruning to ideas I discussed on this blog before.
The first connection is to the Fisher-Rao norm. Assume for a minute that the Fisher information is diagonal - a big and unreasonable assumption in theory, but a pragmatic simplification resulting in useful algorithms in practice. With this assumption, the Fisher-Rao norm of $\theta$ becomes:
$$
|\theta|_{fr} = \sum_{i=1}^{I} F_i \theta_i^2
$$
Written in this form, you can hopefully see the connection between the FR-norm and the Fisher pruning criterion. Depending on the particular definition of Fisher information used, you can interpret the FR-norm, approximately, as
In the real world, the Fisher info matrix is not diagonal, and this is actually an important aspect of understanding generalization. For one, considering only diagonal entries makes Fisher pruning sensitive to certain reparametrizations (ones with non-diagonal Jacobian) of the network. But maybe there is a deeper connection to be observed here between Fisher-Rao norms and the redundancy of parameters.
Using the diagonal Fisher information values to guide pruning also bears resemblance to elastic weight consolidation by (Kirkpatrick et al, 2017). In EWC, the Fisher information values are used to establish which weights are more or less important for solving previous tasks. There, the algorithm was derived from the perspective Bayesian on-line learning, but you can also motivate it from a Taylor expansion perspective just like Fisher pruning.
The metaphor I use to understand and explain EWC is that of a shared hard drive. (WARNING: like all metaphors, this may be completely missing the point). The parameters of a neural network are like a hard drive or storage volume of some sort. Training the NN on a task involves compressing the training data and saving the information to the hard drive. If you have no mechanism to keep data from being overwritten, it's going to be overwritten: in neural networks, catastrophic forgetting occurs the same way. EWC is like a protocol for sharing the hard-drive between multiple users, without the users overwriting each other's data. The Fisher information values in EWC can be seen as soft do-not-overwrite flags. After training on the first task, we calculate the Fisher information values which say which parameters store crucial information for the task. The ones with low Fisher value are redundant and can be reused to store new info. In this metaphor, it is satisfying to think about the sum of Fisher information values as measuring how full the hard-drive is, and pruning as throwing away parts of the drive not actually used to store anything.
I wrote about two recent methods for automatically learning neural network architecture by figuring out which parameters/features to throw away. In my mind, both methods/papers are interesting in their own right. The $L_0$ approach seems like a simpler optimization algorithm that may be preferable to the iterative, remove-one-feature-at-a-time nature of Fisher pruning. However, Fisher pruning is more applicable to the scenario when you start from a large pretrained model in a transfer learning setting.
]]>This short presentation on generalization by the coauthors Sasha Rakhlin is
]]>After last week's post on the generalization mystery, people have pointed me to recent work connecting the Fisher-Rao norm to generalization (thanks!):
This short presentation on generalization by the coauthors Sasha Rakhlin is also worth looking at - though I have to confess much of the references to learning theory are lost on mesho.
While I can't claim to have understood all the bounding and proofs going on in Section 4, I think I got the big picture so I will try and summarize the main points in the section below. In addition, I wanted to add some figures I did which helped me understand the restricted model class the authors worked with and to understand the "gradient structure" this restriction gives rise to. Feel free to point out if anything I say here is wrong or incomplete.
The main mantra of this paper is along the lines of results by Bartlett (1998) who observed that in neural networks, generalization is about the size of the weights, not the number of weights. This theory underlies the use of techniques such as weight decay and even early stopping, since both can be seen as ways to keep the neural network's weight vector small. Reasoning about a neural network's generalization ability in terms of the size, or norm, of its weight vector is called norm-based capacity control.
The main contribution of Liang et al (2017) is proposing the Fisher-Rao norm as a measure of how big the networks' weights are, and hence as an indicator of a trained network's generalization ability. It is defined as follows:
$$
|\theta|_{fr} = \theta^\top I_\theta \theta
$$
where $I$ is the Fisher information matrix:
$$
I(\theta) = \mathbb{E}_{x,y} \left[ \nabla_\theta \ell(f_\theta(x),y) \nabla_\theta \ell(f_\theta(x),y)^\top \right]
$$
There are various versions of the Fisher information matrix, and therefore of the Fisher-Rao norm, depending on which distribution the expectation is taken under. The empirical form samples both $x$ and $y$ from the empirical data distribution. The model form samples $x$ from the data, but assumes that the loss is a log-loss of a probabilistic model, and we sample $y$ from this model.
Importantly, the Fisher-Rao norm is something which depends on the data distribution (at least the distribution of $x$). It is also invariant under reparametrization, which means that if there are two parameters $\theta_1$ and $\theta_2$ which implement the same function, their FR-norm is the same. Finally, it is a measure related to flatness inasmuch as the Fisher-information matrix approximates the Hessian at a minimum of the loss under certain conditions.
The one thing I wanted to add to this paper, is a little bit more detail on the particular model class - rectified linear networks without bias - that the authors studied here. This restriction turns out to guarantee some very interesting properties, without hurting the empirical performance of the networks (so the authors claim and to some degree demonstrate).
Let's first visualize what the output of a rectified multilayer perceptron with biases looks like. Here I used 3 hidden layers with 15 ReLUs in each and PyTorch-default random initialization. The network's input is 2D, and the output is 1D so I can easily plot contour surfaces:
The left-hand panel shows the function itself. The panels next to it show the gradients with respect to $x_1$ and $x_2$ respectively. The function is piecewise linear (which is hard to see because there are many, many linear pieces), which means that the gradients are piecewise constant (which is more visually apparent).
The piecewise linear structure of $f$ becomes more apparent we superimpose the contour plot of the graidents (black) on top of the contour plot of $f$ itself (red-blue):
These functions are clearly very flexible and by adding more layers, the number of linear pieces grows exponentially.
Importantly, the above plot would look very similar had I plotted the function's output as a function of two components of $\theta$, keeping $x$ fixed. This is significantly more difficult to plot though, so I'm hoping you'll just believe me.
Now let's look at what happens when we remove all biases from the network, keeping only the weight matrices:
Wow, the function looks very different now, doesn't it? At $x=0$ it always takes the value $0$. It is composed of wedge-shaped (or in higher dimensions, generalized pyramid-shaped) regions within which the functino is linear but the slope in each wedge is different. Yet the surface is still continuous. Let's do the superimposed plot again:
It's less clear from these plots why a function like this can model data just as well as the more general piece-wise linear one we get if we enable biases. One thing that helps is dimensionality: in high dimensions, the probability that two randomly sampled datapoints fall into a the same "pyramind", i.e. share the same linear region, is extremely small. Unless your data has some structure that makes this likely to happen for many datapoints at once, you don't really have to worry about it, I guess.
Furthermore, if my network had three input dimensions, but I only use two dimensions $x_1$ and $x_2$ to encode data and fix the third coordinate $x_3=1$, I can implement the same kind of functions over my inputs. This is called using homogeneous coordinates, and a bias-less network with homogeneous coordinates can be nearly as powerful as one with biases in terms of the functions it can model. Below is an example of a function a rectified network with no biases can implement when using homogeneous coordinates.
This is because the third variable $x_3=1$ multiplied by its weights practically becomes a bias for the first hidden layer.
Second observation is that we can consider $f_\theta(x)$ as a function of the weight matrix of a particular layer, keeping all other weights and the input the same, the function behaves exactly the same way as it behaves with respect to the input $x$. The same radial pattern would be observed in $f$ if I plotted it as a function of a weight matrix (though weight matrices are rarely 2-D so I can't really plot that).
The authors note that these functions satisfy the following formula:
$$
f_\theta(x) = \nabla_x f_\theta(x)^\top x
$$
(Moreover I think these are the only continuous functions for which the above equality holds, but I leave this to keen readers to prove or disprove)
Noting the symmetry between the network's inputs and weight matrices, a similar equality can be established with respect to parameters $\theta$:
$$
f_\theta(x) = \frac{1}{L+1}\nabla_\theta f_\theta(x)^\top \theta,
$$
where $L$ is the number of layers.
Here's my explanation which differs slightly from the simple proof the authors give. A general rectified network is piecewise linear with respect to $x$, as discussed. The boundaries of the linear pieces, and the slope, changes as we change $\theta$. Let's fix $\theta$. Now, so long as $x$ and some $x_0$ fall within the same linear region, the function at $x$ equals its Taylor expansion around $x_0$:
\begin{align}
f_\theta(x) &= \nabla_{x} f_\theta(x_0)^\top (x- x_0) + f_{\theta}(x_0) \\
&= \nabla_x f_\theta(x)^\top (x - x_0) + f_{\theta}(x_0)
\end{align}
Now, if we have no biases, all the linear segments are always wedge-shaped, and they all meet at the origin $x=0$. So, we can consider the limit of the above Taylor series in the limit as $x_0\rightarrow 0$. (we have to take a limit only technically as the function is non-differentiable at exactly $x=0$). As $f_{\theta}(0)=0$ we find that
$$
f_\theta(x) = \nabla_x f_\theta(x)^\top x,
$$
just as we wanted. Now, treating layer $l$'s weights $\theta^{(l)}$ as if they were the input to the network consisting of the subsequent layers, and the previous layer's activations as if they were the weight multiplying these inputs, we can derive a similar formula in terms of $\theta^{(l)}$:
$$
f_\theta(x) = \nabla_{\theta^{(l)}} f_\theta(x)^\top \theta^{(l)},
$$
Applying this formula for all layers $l=1\ldots L+1$, and taking the average we obtain:
$$
f_\theta(x) = \frac{1}{L+1}\nabla_\theta f_\theta(x)^\top \theta
$$
We got the $L+1$ from the $L$ hidden layers plus the output layer.
Using the formula above, and the chain rule, we can simplify the expression for the Fisher-Rao norm as follows:
\begin{align}
|\theta|_{fr} &= \mathbb{E} \theta^\top \nabla_\theta \ell(f_\theta(x),y) \nabla_\theta \ell(f_\theta(x),y)^\top \theta \\
&= \mathbb{E} \left( \theta^\top \nabla_\theta \ell(f_\theta(x),y) \right)^2 \\
&= \mathbb{E} \left( \theta^\top \nabla_\theta f_\theta(x) \nabla_f \ell(f,y) \right)^2\\
&= \mathbb{E} \left( f_\theta(x)^\top \nabla_f \ell(f,y)\right)^2
\end{align}
It can be seen very clearly in this form that the Fisher-Rao norm only depends on the output of the function $f_\theta(x)$ and properties of the loss function. This means that if two parameters $\theta_1$ and $\theta_2$ implement the same input-output function $f$, their F-R norm will be the same.
I think this paper presented a very interesting insight into the geometry of rectified linear neural networks, and highlighted some interesting connections between information geometry and norm-based generalization arguments.
What I think is still missing is the kind of insight which would explain why SGD finds solutions with low F-R norm, or how the F-R norm of a solution is effected by the batchsize of SGD, if at all it is. The other thing missing is whether the F-R norm can be an effective regularizer. It seems that for this particular class of networks which don't have any bias parameters, the model F-R norm could be calculated relatively cheaply and added to as a regularizer since we already calculate the forward-pass of the network anyway.
]]>It's related to this pretty insightful paper:
I set out to write about the following paper I saw people talk about on twitter and reddit:
It's related to this pretty insightful paper:
Inevitably, I started thinking more generally about flat and sharp minima and generalization, so rather than describing these papers in details, I ended up dumping some thoughts of my own. Feedback and pointers to literature are welcome, as always
The loss surface of deep nets tends to have many local minima. Many of these might be equally good in terms of training error, but they may have widely different generalization performance, i.e. an network with minimal training loss might perform very well, or very poorly on a held-out training set. Interestingly, stochastic gradient descent (SGD) with small batchsizes appears to locate minima with better generalization properties than large-batch SGD. So the big question is: what measurable property of a local minimum can we use to predict generalization properties? And how does this relate to SGD?
There is speculation dating back to at least Hochreiter and Schmidhuber (1997) that the flatness of the minimum is a good measure to look at. However, as Dinh et al (2017) pointed out, flatness is sensitive to reparametrizations of the neural network: we can reparametrize a neural network without changing its outputs while making sharp minima look arbitrarily flat and vice versa. As a consequence the flatness alone cannot explain or predict good generalization.
Li et al (2017) proposed a normalization scheme which scales the space around a minimum in such a way that the apparent flatness in 1D and 2D plots is kind of invariant to the type of reparametrization Dinh et al used. This, they say, allows us to produce more faithful visualizations of the loss surfaces around a minimum. They even use 1D and 2D plots to illustrate differences between different architectures, such as a VGG and a ResNet. I personally do not buy the conclusions of this paper, and it seems the reviewers of the ICLR submission largely agreed on this. The proposed method is weakly motivated and only addresses one possible type of reparametrization.
Following the thinking by Dinh et al, if generalization is a property which is invariant under reparametrization, the quantity we use to predict generalization should also be invariant. My intuition is that a good way to achieve invariance is to consider the ratio between two quantities - maybe two flatness measures - which are effected by reparametrization in the same way.
One thing I think would make sense to look at is the average flatness of the loss in a single minibatch vs the flatness of the average loss. Why would this makes sense? The average loss can be flat around a minimum in different ways: it can be flat because it is the average of flat functions which all look very similar and whose minimum is very close to the same location; or it can be flat because it is the average of many sharp functions with minima at locations scattered around the minimum of the average.
Intuitively, the former solution is more stable with respect to subsampling of data, therefore it should be more favourable from a generalization viewpont. The latter solution is very sensitive to which particular minibatch we are looking at, so presumably it may give rise to worse generalization.
As a conclusion of this section, I don't think it makes sense to look at only the flatness of the average loss, looking at how that flatness is effected by subsampling the data somehow feels more key to generalization.
After Jorge Nocedal's ICLR talk on large-batch SGD Leon Buttou had a comment which I think hit the nail on its head. The process of sampling minibatches from training data kind of simulates the effect of sampling the training set and the test set from some underlying data distribution. Therefore, you might think of generalization from one minibatch to another as a proxy to how well a method would generalize from a training set to a test set.
How can we use this insight to come up with some sort of measure of generalization based on minibatches, especially along the lines of sharpness or local derivatives?
First of all, let's consider the stochastic process $f(\theta)$ which we obtain by evaluating the loss function on a random minibatch. The randomness comes from subsampling the data. This is a probability distribution over loss functions over $\theta$. I think it's useful to seek an indicator of generalization ability as a local property of this stochastic process at any given $\theta$ value.
Let's pretend for a minute that each draw $f(\theta)$ from this process is a convex or at least has a unique global minimum. How would one describe a model's generalization from one minibatch to another in terms of this stochastic process?
Let's draw two functions $f_1(\theta)$ and $f_2(\theta)$ independently (i.e. evaluate the loss on two separate minibatches). I propose that the following would be a meaningful measure:
$$
R = f_2 (\operatorname{argmin}_\theta f_1(\theta)) - \min_\theta f_2(\theta)
$$
Basically: you care about finding low error according to $f_2$ but all you have access to is $f_1$. You therefore look at what the value of $f_2$ is at the location of the minimum of $f_1$ and compare that to the global minimal value of $f_2$. This is a sort of regret expression, hence my use of $R$ to denote it.
Now, in deep learning the loss functions $f_1$ and $f_2$ are not convex, have many local minima, so this definition is not particularly useful in general. However, it makes sense to calculate this value locally, in a small neighbourhood of a particular parameter value $\theta$. Let's consider fitting a restricted neural network model, where only parameters within a certain $\epsilon$ distance from $\theta$ are allowed. If $\epsilon$ is small enough, we can assume the loss functions have a unique global minimum within this $\epsilon$-ball. Furthermore, if $\epsilon$ is small enough, one can use a first-order Taylor-approximation to $f_1$ and $f_2$ to analytically find approximate minima within the $\epsilon$-ball. To do this, we just need to evaluate gradient at $\theta$. this is illustrated in the figure below:
The left-hand panel shows an imaginary loss function evaluated on some minibatch $f_1$, restricted to the $\epsilon$-ball around $\theta$. We can assume $\epsilon$ is small enough so $f_1$ is linear within this local region. Unless the gradient is exactly $0$, the minimum will fall on the surface of the $\epsilon$-ball, exactly at $\theta - \epsilon \frac{g_1}{|g_1|}$ where $g_1$ is the gradient of $f_1$ at $\theta$. This is shown by the yellow star. On the right-hand panel I show $f_2$. This is also locally linear, but its gradient $g_2$ might be different. The minimum of $f_2$ within the $\epsilon$-ball is at $\theta - \epsilon \frac{g_2}{|g_2|}$, shown by the red star. We can consider the regret-type expression as above, by evaluating $f_2$ at the yellow star, and substracting its value at the red star. This can be expressed as follows (I divided by $\epsilon$):
$$
\frac{R(\theta, f_1, f_2)}{\epsilon} \rightarrow - \frac{g_2^\top g_1}{|g_1|} + \frac{g_2^\top g_2}{|g_2|} = |g_2| - \frac{g_2^\top g_1}{|g_1|} = |g_2|(1 - cos(g_1, g_2))
$$
In practice one would consider taking an expectation with respect to the two minibatches to obtain an expression that depends on $\theta$. So, we have just come up with a local measure of generalization ability, which is expressed in terms of expectations involving gradients over different minibatches. The measure is local as it is specific for each value of $\theta$. It is data-dependent in that it depends on the distribution $p_\mathcal{D}$ from which we sample minibatches.
This measure depends on two things:
When we take the expectation over this, assuming that the cosine similarity term is mostly $1$ we end up with the expression $\mathbb{E}_g \sqrt{\operatorname{trace}\left(g g_2^\top\right)}$ where the expectation is taken over minibatches. Note that the trace-norm of the empirical Fisher information matrix $\sqrt{ \operatorname{trace} \mathbb{E}_g \left(g g_2^\top\right)}$ can be used as a measure of flatness of the average loss around minima, so there may be some interesting connections there. However, due to Jensen's inequality the two things are not actually the same.
Update - thanks for reddit user bbsome for pointing this out:
Note that R is not invariant under reparametrization either. The source of this sensitivity is the fact that I considered an $\epsilon$-ball in Euclidean norm around $\theta$. The right way to get rid of this is to consider an $\epsilon$-ball using the symmetrized KL divergence as instead of the Euclidean norm, similarly to how natural gradient methods can be derived. If we do this, the formula becomes dependent only on the functions the neural network implements, not on the particular choice of parametrization. I leave it as homework for people to work out how this would change the formulae.
This post started out as a paper review, but in the end I didn't find the paper too interesting and instead resorted to sharing ideas about tackling the generalization puzzle a bit differently. It's entirely possible that people have done this analysis before, or that it's completely useless. In any case, I welcome feedback.
The first observation here was that a good indicator may involve not just the flatness of the average loss around the minimum, but a ratio between two flatness indicators. Such metrics may end up invariant under reparametrization by construction.
Taking this idea further I attempted to develop a local indicator of generalization performance which goes beyond flatness. It also includes terms that measure the sensitivity of gradients to data subsampling.
Because data subsampling is something that occurs both in generalization (training vs test set) and in minibatch-SGD, it may be possible that these kind of measures might shed some light on how SGD enables better generalization.
]]>I also read Yann LeCun's rebuttal to Ali's talk. He says what Ali calls
]]>Like many of you, I thoroughly enjoyed Ali Rahimi's NIPS talk in response to winning the test-of time award for their work on random kitchen sinks. I recommend everyone to watch it if you haven't already.
I also read Yann LeCun's rebuttal to Ali's talk. He says what Ali calls alchemy is in fact engineering. Although I think Yann ends up arguing against points Ali didn't make in his talk, he raises important and good points about the different roles that tricks, empirical evidence and theory can play in engineering.
I wanted to add my own thoughts and experiences to the mix.
We can think about machine learning knowledge as a graph, where methods are nodes and edges represent connections or analogies between the methods.
Innovation means growing this graph, which one can do in different ways:
It's easy to forget what the original GAN paper's results looked like. These were really good back then and would be laughable today:
GANs were, arguably, a highly influential, great new idea. Today, that same paper would be unpublishable because the pictures don't look pretty enough. Wasserstein GANs were a great idea, and frankly, to recognize it is a good idea and I don't need to look at experimental results at all.
The same way Yann argues neural nets were unfairly abandoned in the 90s because of the lack of convergence guarantees convex optimization methods enjoy, today we are unfairly dismissing any method or idea that does not produce state-of-the-art or near-SOTA results. I once reviewed a paper where another reviewer wrote "The work can be novel if the method turns out to be the most efficient method of [...] compared to existing methods". This is at least as wrong as dismissing novel ideas for lack of theory (which is, BTW, not what Ali suggested).
As far as I'm concerned, I'm now comfortable with using methods that are non-rigorous or for which the theoretical framework is underdeveloped or does not exist. However,
nobody should feel good about any paper where the evaluation is non-rigorous
In GAN papers we show pretty pictures, but we have absolutely no rigorous way to assess the diversity of samples, or whether any form of overfitting has occurred, at least not as far as I'm aware. I know from experience that getting a sufficiently novel deep learning idea to work is a fragile process: it starts with nothing working, then maybe it starts to work but doesn't converge, then it converges but to the wrong thing. The whole thing just doesn't work until it suddenly starts to, and very often it's unclear what specifically made it work. This process is akin to multiple hypothesis testing. You ran countless experiments and report the result that looks best and behaves the way you expected it to behave. The underlying problem is, we conflate the software development process required to implement a method with manual hyperparameter search and cherry-picking of results. As a result, our reported "empirical evidence" may be more biased and less reliable than one would hope.
I agree with Yann that there is merit in starting to adopt techniques before theory or rigorous analysis becomes available. However, once theoretical insight becomes available, reasoning about empirical performance often continues to trump rigour.
Let me tell you about a pattern I encountered a few times (you might say I'm just bitter about this). The pattern: Someone comes up with an idea, they demonstrate it works very well on some large, complicated problem using a very large neural network and loads of tricks and presumably months of manual tweaking and hyper-parameter search. I find a theoretical problem with the method. People say: it still works well in practice, so I don't see a problem.
This was the kind of response I got to my critique of scheduled sampling and my critique of elastic weight consolidation. In both of these cases reviewers pointed out the methods work just fine on "real-world problems", and in the case of scheduled sampling people commented "after all the method came first in a benchmark competition so it must be correct". No, if a method works, but works for the wrong reasons, or for different reasons the authors gave, we have a problem.
You can think of "making a a deep learning method work on a dataset" as a statistical test. I would argue that the statistical power of experiments is very weak. We do a lot of things like early stopping, manual tweaking of hyperparameters, running multiple experiments and only reporting the best results. We probably all know we should not be doing these things when testing hypotheses. Yet, these practices are considered fine when reporting empirical results in ML papers. Many go on and consider these reported empirical results as "strong empirical evidence" in favour of a method.
I want to thank Ali for giving this talk. Yes, it was confrontational. Insulting? I think provocative is a better word. It was at least somewhat controversial, apparently. It contains a few points I disagree with. But I do not think it was wrong.
It touched upon a lot of problems which I think should be recognized and appreciated by the community. Rigour is not about learning theory, convergence guarantees, bounds or theorem proving. Intellectual rigour applies and can be applied to all of machine learning whether or not we have fully developed mathematical tools for analysis.
Rigour means being thorough, exhaustive, meticulous. It includes good practices like honestly describing the potential weaknesses of a method; thinking about what might go wrong; designing experiments which highlight and analyze these weaknesses; making predictions about your algorithm's behaviour in certain toy cases and demonstrating empirically that it indeed behaves as expected; refusing to use unjustified evaluation methods; accepting and addressing criticism. All of these should apply to machine learning, deep or not, and indeed they apply to engineering as a whole.
]]>]]>Theorem 2 of my moment averaging paper is one of the most surprising mathematical results I've seen in machine learning. The proof is very short, but I still have no intuition for why it's
I wrote this post in response to a challenge by Roger Grosse:
Theorem 2 of my moment averaging paper is one of the most surprising mathematical results I've seen in machine learning. The proof is very short, but I still have no intuition for why it's true. $500 reward for a clear & convincing intuitive explanation.https://t.co/WjAMg3qIRm
— Roger Grosse (@RogerGrosse) November 27, 2017
As a consequence of responding to a question, this will be a bit of an obscure post, and it may not be very meaningful outside the context of annealed importance sampling. Nevertheless, the post gives me a chance to talk about some interesting aspects of Bregman divergences, exponential families, and hopefully teach some you some new things you may not have come across before.
If you look further down the thread below Roger's email, it seems @AIActorCritic started answering the question in terms of duality of parametrizations - my post also generalizes things along the lines of duality but goes into a bit more detail.
Context: the authors analyze the bias of an Annealed Importance Sampling (AIS) estimator, whose goal is to estimate the normalizing constant of an intractable distribution $p$. AIS constructs a path in the space of probability distributions starting from a tractable distribution $q$ and ending at $p$ via a sequence of intermediate distributions $q=p_0,p_1, \ldots, p_{K-1}, p_K=p$. This path is then used to construct an estimator to the log-partition function (normalizing constant) of $p$. Crucially, the bias of this estimator depends on, among other things, the particular path we take between the two distributions: there are multiple paths between $q$ and $p$, and these paths can result in lower or higher bias. Under some simplifying assumptions the bias is given by:
$$
\delta = \sum_{k=0}^{K-1} D_{KL}[p_k|p_{k+1}].
$$
It is not hard to show that as $K\rightarrow \infty$, this bias vanishes, but what the paper looks at is its precise asymptotic behaviour, with the main insight being Eqn (4):
$$
K\delta \rightarrow \int_0^1 \dot{\theta}(\beta)^\top G_\theta(\beta)\dot{\theta}(\beta) d\beta =:\mathcal{F}(\gamma),
$$
where $\gamma$ denotes the path as a whole, $\theta(\beta), \beta\in[0,1]$ is the path we trace in parameter space $\theta$, and $G_\theta(\beta)$ is the Fisher information matrix with respect to $\theta$ evaluated at $\theta(\beta)$.
Intriguingly, if you calculate $\mathcal{F}(\gamma)$ along two very different paths in distribution space - a straight line in natural parameter space or a straight line in moments space - you get exactly the same value. This is despite the fact that these paths clearly path through very different distributions as illustrated very nicely in Figure 1 of the paper.
So the question is: why are these two paths equivalent in this sense? The answer might lie in convex conjugate duality.
The first observation we can make is that Eqns. (18) and (19) do not only show that the two paths are equivalent, they also shows that the cost can be expressed in terms of the symetrized KL divergence (sometimes called the Jeffreys divergence) between the endpoints:
$$
\mathcal{F}(\gamma_{MA}) = \mathcal{F}(\gamma_{GA}) = \frac{1}{2}\left( D_{KL}[p_{\beta=0}|p_{\beta=1}] + D_{KL}[p_{\beta=1}|p_{\beta=0}] \right).
$$
To prove this, one can simply use the following formula for the $KL$ divergence between two exponential family distributions with natural parameters $\eta'$ and $\eta$:
$$
D_{KL}[p_{\eta'}|p_{\eta}] = A(\eta') - A(\eta) - s^\top(\eta' - \eta),
$$
where $A(\eta) = \log\mathcal{Z}(\eta)$ is the log-partition function, and $s$ denotes the moments of $p_{\eta}$ just like in the paper. This observation is interesting as we may be looking at special cases of a more general result.
Any convex functional $H$ over a convex domain induces a Bregman divergence, defined between two points $p$ and $q$ as:
$$
D_H(p|q) = H(p) - H(q) - \langle\nabla H(q), p-q \rangle,
$$
where $\langle\rangle$ denotes inner product. I used this notation rather than vector product to emphasize the fact that $p$ and $q$ might be functions or probability distributions not necessarily just finite dimensional vectors. In what follows I switch back to usual vector notation.
Examples of Bregman divergences include the KL divergence between distributions, the squared Euclidean distance between points in a Euclidean space, the maximum mean discrepancy between probability distributions, and many more. To read more about them, I recommend Mark Reid's great summary.
We can replace the KL divergence in $\delta$ by an arbitrary Bregman divergence, which yields the following generalization of Eqn. (4):
\begin{align}
K \sum_{k=0}^{K-1}D_H(p_k|p_{k+1}) &\rightarrow \frac{1}{2}\int_0^1 \dot{p}(\beta)^\top \nabla^2 H(\beta) \dot{p}(\beta) d\beta\\
&= \frac{1}{2}\int_0^1 \dot{\nabla H}(\beta)^\top \dot{p}(\beta) d\beta,
\end{align}
where $\dot{\nabla H}(\beta)$ is the derivative of $\nabla H(p(\beta))$ with respect to $\beta$, and the second line was obtained by applying the chain rule $\dot{\nabla H}(\beta) = \dot{p}(\beta)\nabla^2 H(\beta)$. $\nabla^2 H(\beta)$ is called the metric tensor and it is a generalization of the Fisher information matrix $G_\theta$ before.
It turns out, if we interpolate in $p$-space linearly, that is, trace a path where $p(\beta) = \beta p_1 + (1 - \beta) p_0$, we can express the above integral analytically in terms of the endpoints of the path $p_0$ and $p_1$:
\begin{align}
\frac{1}{2}\int_0^1 \dot{\nabla H}(\beta)\dot{p}(\beta) d\beta &= \frac{1}{2}\int_0^1 \dot{\nabla H}(\beta)^\top d\beta (p_1 - p_0) \\
&= \frac{1}{2} \left( \nabla H(p_1) - \nabla H(p_0) \right)^\top \left( p_1 - p_0 \right)\\
&= \frac{1}{2}\left( D_H(p_1|p_0) + D_H(p_0|p_1) \right).
\end{align}
This final expression is the symmetrized Bregman divergence between the endpoints of the path $p_1$ and $p_0$, a generalization of the symmetrised KL that we found before.
Before showing how this general result is useful to explain the equivalence of Eqns. (18) and (19), I want to take some time to draw some pretty pictures visualizing Bregman divergences, the quantities we are dealing with are.
Below is a graphical illustration of a Bregman divergence between two points $p$ and $q$ with respect to some convex potential $H$.
Here's the story: Poppy and Quentin are about to get married. It is bad luck to see each other all dressed up before the wedding. Therefore, they decide to spend the morning on a convex hill (convex from the top). Poppy is at coordinates $p$ and Quentin at $q\neq p$ . The surface of the hill is described by $-H$ where $H$ is a convex function. Precisely because the hill is convex, Quentin can't see Poppy, not unless she starts jumping. The Bregman divergence describes the safe height Poppy is allowed to jump up without Quentin seeing her. The more hilly the hill - higher the curvature - between them, the higher Poppy can jump without being seen by Quentin.
For this figure I used a boring parabolic hill, $-H(p) = p(1-p)$. The resulting divergence actually ends up symmetric and is simply proportional to the squared Euclidean distance $(p-q)^2$:
But this is an exception, rather than the rule: most Bregman divergences are asymmetric. For the convex function $H(p) = p(1-p^3)$, this is what the picture looks like:
OK, so what did we just prove before? Consider we have a sequence of bridesmaids equally placed between Poppy and Quentin... Actually, let's stop the wedding analogy before it's too far.
Consider a sequence of points $p_0,\ldots,p_K$ equally placed between $p=p_0$ and $q=p_K$. Now look at the sum of divergences between subsequent points $\sum_k D[p_k,p_{k+1}]$. The divergences say how high each of the $p_k$ has to jump to see $p_{k+1}$. For $K=3$ we are interested in the sum of the red line segments
For $K=5$:
For $K=12$:
It is pretty clear that the line segments get shorter and shorter, and indeed their sum converges to $0$. What we have proved is that asymptotically this sum behaves like $\frac{1}{2K}\left(D[p|q] + D[q|p]\right)$. Why this is the case is still a mistery to me, and I don't know how to even visualise this. But I wouldn't be surprised if there turned out to be a good reason for this.
After a bit an excursion into visualising Bregman divergences in one dimension, let's go back to the question of why linear interpolation in natural parameter-space or moment space gives the $\mathcal{F}(\gamma)$. The answer lies in duality.
Any convex function $H$ defines a Bregman divergence on a convex domain. Any convex function $H$ also has a convex conjugate $H^\ast$. This convex conjugate, $H^\ast$, also defines a Bregman divergence on its own domain, which is generally different from the domain of $H$. Furthermore, this dual divergence is equivalent to the original in the following sense:
$$
D_{H^{\ast}}[p^{\ast}|q^{\ast}] = D_{H}[p|q],
$$
where $p^{\ast} = \nabla H(p)$ and $q_{ast} = \nabla H(q)$ are called the dual parameters corresponding to $p$ and $q$, respectively. The mapping between parameters and dual parameters is one-to-one, thanks to convexity. $p$ and $q$ are sometimes called the primal parameters, but this distinction is rather arbitrary as conjugate duality is a symmetric relationship.
With an understanding of duality, let's look at the formula that we obtained before:
\begin{align}
K \sum_{k=0}^{K-1}D_H(p_k|p_{k+1}) &\rightarrow \frac{1}{2}\int_0^1 \dot{\nabla H}(\beta)^\top \dot{p}(\beta) d\beta,
\\
&= \frac{1}{2}\int_0^1 \dot{p^\ast}(\beta)^\top \dot{p}(\beta) d\beta\\
&= \frac{1}{2}\left(p^\ast(1) - p^\ast(0)\right)^\top\left(p(1) - p(0)\right),
\end{align}
observe the bilinearity of the formula with respect to the primal and dual parameters.
Let's now consider an exponential family of the form:
$$
p(x\vert \eta) = h(x)\exp(\eta^\top g(x) - A(\eta))
$$
where $A(\eta)=\log\mathcal{Z}(\eta)$ is the log partition function. As $A$ is a convex function of $\eta$, we can define a Bregman divergence in the coordinate system of natural parameters $\eta$ induced by the (convex) log-partition function $A$:
$$
D_A[\eta'|\eta] = A(\eta') - A(\eta) - \nabla A(\eta)^T\left( \eta' - \eta\right).
$$
(notice how similar this is to the KL divergence expression I used before)
The dual parameter corresponding to $\eta$ turns out to be the moments parameters $s$:
$$
\eta^\ast = \nabla A (\eta) = \mathbb{E}_\eta[g(x)] = s,
$$
The equality in the middle is a well-known property of exponential families, the proof is similar to how you would obtain the REINFORCE gradient estimator, for example. Using this equality we can rewrite the divergence above as:
$$
D_A[\eta'|\eta] = A(\eta') - A(\eta) - s^T\left( \eta' - \eta\right).
$$
If we consider the natural parameters $\eta$ the primal parameters, the moments $s$ are the dual parameters. As always, there is a one-to-one mapping between primal and dual parameters. Using the convex conjugate $A^{\ast}$, we can therefore define another Bregman divergence in the coordinate system of moments $s$ as follows:
\begin{align}
D_{A^\ast}[s'|s] &= A^\ast(s') - A^\ast(s) - \nabla A^\ast(s)^T\left(s' - s\right)\\
&= A^\ast(s') - A^\ast(s) - \eta^T\left(s' - s\right)
\end{align}
As it turns out this convex conjugate $A^\ast$ is the negative Shannon's entropy of the distribution parametrized by its moments $s$ (it's beyond this post to show this, see e.g. this book chapter.
So now we have two Bregman divergences, one in the primal space of $\eta$ and one in the dual space $s$. These two divergences are equivalent, and in this case they are also both equivalent to the Kulback-Leibler divergence between the corresponding distributions:
$$
D_{A}[\eta'|\eta] = D_{A^\ast}[s'|s] = D_{KL}[p'|p]
$$
So, putting these things together, we can now understand why Eqns. (18) and (19) give the same result. If we interpolate linearly in primal space between $\eta(0)$ and $\eta(1)$ we get that the loss of the path is:
$$
\mathcal{F}(\gamma_{GA}) = \frac{1}{2}\left( D_A[\eta(1)|\eta(0)] + D_A[\eta(0)|\eta(1)]\right)
$$
Similarly, interpolating between $s(0)$ and $s(1)$ in moment space we obtain:
$$
\mathcal{F}(\gamma_{MA}) = \frac{1}{2}\left( D_{A^\ast}[s(1)|s(0)] + D_{A^\ast}[s(0)|s(1)]\right).
$$
And by the duality of the $A$ and $A^\ast$ we get that the two are equal, and also equal to the symetrized KL divergence between the distributions.
Interestingly, this equivalence also suggests that linear interpolation in the distribution space, $p(\beta) = \beta p_1 + (1 - \beta) p_0$, would also give the same result. This would interpolate between $p_0$ and $p_1$ via mixtures between the two distributions. This option may not be of practical relevance though.
This framework might allow us to study linear interpolation in parameter spaces which are related to natural parameters by a convex operation. So long as $f$ is invertible and $f^{-1}$ has a positive definite Jacobian, we can define a Bregman divergence in $\theta$-space using the convex function $f^{-1}\circ A$. I'm not sure if this would actually work out or whether this would yield any insights.
This was my attempt at providing more insight into Theorem 2 of Roger's paper he found very surprising. My post probably does little to make the result look less surprising, it merely shifts the surprise elsewhere. The story goes as follows:
This post is basically a follow-up my earlier post, "GANs are broken in more than one
]]>In the spirit of thanksgiving, let me start by thanking all the active commenters on my blog: you are always very quick to point out typos, flaws and references to literature I overlooked.
This post is basically a follow-up my earlier post, "GANs are broken in more than one way". In this one, I review and recommend some additional references people pointed me to after the post was published. It turns out, unsurprisingly, that a lot more work on these questions than I was aware of. Although GANs are kind of broken, they are also actively being fixed, in more than one way.
In my post I focussed on potential problems that arise when we apply simultaneous gradient descent on a non-conservative vector field, or equivalently, a non-cooperative game.
These problems have been pointed out before by a number of authors before. As one example, (Salimans et al, 2017) talked about this and in fact used the same pathological example I used in my plot: the one which gives rise to the constant curl field:
Pfau and Vinyals (2016) pointed out connections between GANs and actor critic methods. GANs and actor-critic methods are both notoriously hard to optimize, and the authors argue that these difficulties arise because of the same reasons. This means that two sets of people have been struggling to solve the same underlying problems, so it is reasonable to hope that we can learn from each other's innovations.
Many people pointed out that when using different variants of GANs, we don't actually use simultaneous gradient descent. Notably, in WGANs, and also often in normal GANs when using instance noise, we always optimize the discriminator until convergence before taking a step with the generator. This nested-loop algorithm has very different behaviour compared to simultaneous gradient descent. Indeed, you can aruge it converges inasmuch as the outer loop which optimizes the generator can be interpreted as vanilla gradient descent on a lower bound. Of course this convergence only happens in practice if the optimal discriminator is unique with respect to the generator, or at least if the solution is stable enough to assume that the outer loop indeed evaluates gradients of a deterministic scalar function.
(Heusel et al, 2017) propose another optimization strategy relying on updating the discriminator and generator at different time scales/learning rates. The authors prove that this algorithm always converges to a local Nash equilibrium, and this property holds even when using the Adam optimizer on top. This algorithm was later used successfully for training Coulomb GANs.
In another great NIPS paper (Nagarajan and Kolter, 2017) study convergence properties of simultaneous gradient descent for various GAN settings. Using similar apparatus to the Numerics of GAN paper, they show that simultaneous GD is actually stable and convergent in the local neighborhood of the Nash equilibrium. On the other hand, it is also shown that WGANs have non-convergent limit cycles around the equilibrium (these are vector fields that make you go around in circles, like the example above). So, in WGANs we have a real problem with gradient descent. Based on these observations the authors go on to define regularization strategy that stabilizes gradient descent. Highly recommended paper.
]]>