Composite Objective Functions for Training Autoregressive Models

Setup: autoregressive models

We want to learn join distribution of sequences of variables $x_{1:N}$. We can do these via autoregressive models, where we learn the probability of the next symbol given previous symbols:

$$ q(x_{1:N}; \theta) = \prod_{n=1}^{N} q(x_n\vert x_{1:n-1}; \theta)$$

Expressing the joint distribution via the chain rule typically allows us to build models whose log-likelihood is easy to evaluate. This is because we only have to define distributions over single symbols, so it's easier to choose complicated distributions whose normalisation constant is still tractable. Blah blah.

Example 1: sequence modelling, RNNS for text


Example 2: spatial LSTMs, MoGSMs


Example 3: NADE and MADE


I'm going to use RNNs for character-by-character text modelling as it's intuitive to talk about, but my general argument applies to all examples of autoregressive models.

What's wrong with maximum likelihood training?

The scheduled sampling paper highlighted important issues with maximum likelihood training for autoregressive text models. In a nutshell, RNNs trained via maximum likelihood tend to go astray. After training, when one starts generating sentences with them, every once in a while they generate prefixes that may not occur with high frequency in the training data. Once you have a prefix like that, the RNN will not know how to complete the sentence, because it has never seen training examples with this prefix. Therefore, it can't recover from the mistake it made early on and will continue generating an arbitrary postfix.

In a way, this behaviour is not surprising. As I pointed out in the previous post, this is partially explained by log-likelihood being a local scoring rule. But there is another explanation: Asking a generative model $q$ to not generate samples that have low probability under the true data distribution $p$ is in a way asking the model to minimise $KL[q|p]$. But maximum likelihood training minimises $KL[p|q]$, which is not the same. It is known that $\operatorname{argmin}_q KL[p|q]$ leads to a behaviour where $q$ will aim to cover all modes of $p$, at the expense of also putting mass where $p$ has no mass. On the other hand $\operatorname{argmin}_q KL[q|p]$ will push $q$ to model the most prominent mode of $p$ well, and ignore the rest of $p$. In a way, optimising $KL[q|p]$ would be much more conservative in distributing probability mass where $p$ has none. Unfortunately, minimising $KL[q|p]$ has theoretical and practical problems, so that's not an option.

So, perhaps we should come up with an objective function that, while preserves the nice computational and theoretical properties of maximum likelihood, works a bit more like $KL[q|p]$ in terms of being conservative about covering low-probability areas.

How to fix it?

So we've established that minimising likelihood, and hence $KL[p|q]$ will introduce undesired behaviour. A solution would be to use an alternative objective function. Examples of objective functions that would work well in theory are maximum mean discrepancy or adversarial training. These objective functions ensure consistency, and at the same time can emulate behaviour that is more similar to $KL[q|p]$. However, these objective functions are intractable to calculate for most models and are typically based on sampling, which scales very poorly with the dimensionality of the data, in this case the length of the sentence $N$. So using a full sampling-based adversarial objective or MMD seems impractical, particularly as we are talking about a case where the likelihood is in fact tractable.

So what can we do, when full MMD training is not an option? We can use a composite scoring function: we will use MMD for some parts of the objective, but keep KL-divergence for the rest for its nice computational properties.

Let's look at the traninig objective for $q$ in relation to the prefix and postfix of a sentence. In this shorthand notation I use prefix to denote $x_{1:k}$ and postfix to denote $x_{k+1:N}$.

$$ KL[p_{x_{1:N}} | q_{x_{1:N};\theta}] = KL[p_{prefix} | q_{prefix}] + \mathbb{E}_{p_{prefix}} KL[p_{postfix\vert prefix} | q_{postfix \vert prefix}] $$

Looking at the objective above, it is clear that the undesirable behaviour is mainly caused by the first term:

  1. generally, we can expect that $p_{prefix}$ is highly multimodal, has high entropy and high complexity: you can start a sentence in a lot of different ways. On the other hand $p_{postfix\vert prefix}$ is probably lower entropy and has far fewer modes: once you started the sentence, there are a lot fewer possible postfixes that complete it well.
  2. $q_{postfix \vert prefix}$ is trained to mimic $p_{postfix\vert prefix}$. Because we can assume $p_{postfix\vert prefix}$ is less complicated, the KL divergence probably works fine and the objective will find a good model. But, $q_{postfix \vert prefix}$ only recieves training signal where $p_{prefix}$ is high, so we have to make sure $q_{prefix}$ does not sample unlikely prefixes where $q_{postfix \vert prefix}$ would not know what to do.

So, if we can make sure that $q_{prefix}$ is trained conservatively - in a way that resembles $KL[q|p]$-like behaviour - we can probably be sure that the full model will learn how to complete sentences.

So my suggestion is a composite training objective:

$$ \mathcal{L}[p_{x_{1:N}} | q_{x_{1:N};\theta}] = MMD[p_{prefix} | q_{prefix}] + \mathbb{E}_{p_{prefix}} KL[p_{postfix\vert prefix} | q_{postfix \vert prefix}] $$

This works because as long as the prefix is chosen to be short enough, $MMD$ can be approximated accurately via sampling. It also tends to be more conservative than $KL[p|q]$, and will try to avoid placing mass where $p_{prefix}$ has none.

Possible problem with objective function:

If $q_{prefix}$ shares parameters with $q_{postfix \vert prefix}$, we have effectively coupled the $MMD$ part of the objective with a $KL$ divergence. As these objectives live on a different scale, this might push an RNN to put disproportionately more expressive power into modelling the prefix or the postfix. This could be mitigated by introducing a hyperparameter that controls the relative weight of the two terms in the objective function.

Alternatively, one could decouple $q_{prefix}$ and $q_{postfix \vert prefix}$, and use completely different models and parameters to learn each. For example. $q_{prefix}$ could be a fully connected generative stochastic network, or just a simple regularised histogram.

Key takeaway points