AlphaGo Zero: Minimal Policy Improvement, Expectation Propagation and other Connections
This is a post about the new reinforcement learning technique that enables AlphaGo Zero to learn Go from scratch via self-play. The paper has been out for a week I guess it's now considered old - sorry for the latency.
- D Silver, J Schrittwieser, K Simonyan, I Antonoglou, A Huang, Arthur Guez, T Hubert, L Baker, M Lai, Adrian Bolton, Y Chen, T Lillicrap, F Hui, L Sifre, G van den Driessche, T Graepel & D Hassabis (2017) Mastering the Game of Go without Human Knowledge
I'm no expert in RL, so I'm pretty sure many of you are going to come at me with pitchforks shouting "this is all trivial" or "this has been done before" or "this is no different from X". Please do, I'm here to learn.
Summary of this post
- the key idea behind AlphaGo Zero is interpreting Monte Carlo Tree Search (MCTS) as policy improvement operator: You can think of MCTS as an operator that takes your current policy and turns it into a better policy.
- you can formulate the goal of AlphaGo Zero as seeking a fixed point with respect to the policy improvement operator: the best policy is one that cannot be further improved by MCTS.
- once such policy is found, you no longer need MCTS when you play a game. As a result, AlphaGo Zero is a lot more efficient to run, once trained.
- I look at this idea a bit more generally, and connect it to the previous post on non-conservative vector fields and the Numerics of GANs.
- Finally, I highlight connections to other learning or inference algorithms that are similar in nature: expectation-propagation and contrastive divergence.
The Minimal Policy Improvement Technique
Background: The original AlphaGo used a combination of two neural networks - the policy and value networks - and a Monte Carlo Tree Search (MCTS) algorithm to play Go. For each move, the policy network is first evaluated to give an initial strategy $\pmb{p}$. This is then used to prime MCTS, which then performs rollouts - simulations of possible future game trajectories - evaluating the policy network and the value network multiple times in these imagined future scenarios. MCTS refines AlphaGo's initial intuitive strategy $\pmb{p}$, and outputs a better strategy $\pmb{\pi}$ which is usually concentrated on stronger moves in the current situation.
The key idea of AlphaGo Zero is this: you can think of MCTS generally as a policy improvement operator, $\pmb{\pi}(\pmb{p})$, which takes $\pmb{p}$ as input and outputs a better policy $\pmb{\pi}$. If you have such a policy improvement operator, it makes sense to define the goal of learning as finding a policy $\pmb{p}^{\ast}$ which the policy improvement operator cannot improve any further, in other words, we seek the following fix point:
$$
\pmb{\pi}(\pmb{p}^{\ast}) = \pmb{p}^{\ast}.
$$
For now, let's assume two things:
- the policy improvement operator $\pmb{\pi}$ depends only on the output of the policy $\pmb{p}$, and
- the state of the game $s$ is fixed, in other words I am concerned with what the network does in a particular fixed state, say what it chooses as the starting move. I only assume this to avoid having to spell out expectations over states, etc.
Under these assumptions, we can write the policy improvement algorithm as seeking stationary points of the following vector field $\pmb{v}$:
$$
\pmb{v}(\pmb{p}) = \pmb{\pi}(\pmb{p}) - \pmb{p}
$$
Here we go again, looking for stationary points $\pmb{v}(\pmb{p}) = \pmb{0}$ of a vector field, just like in the previous post about GANs.
The most natural approach to find such stationary points is Euler's method - which would be gradient descent if the vector field happened to correspond to the gradients of a scalar loss. Of course, there is no guarantee that $\mathbb{v}$ is conservative, and therefore that Euler's method will converge. But hey, it's worth a try. What would Euler's method do on this vector field?
\begin{align}
\theta_{t+1} &\leftarrow \theta_{t} + h\pmb{v}(\pmb{p})\frac{\partial \pmb{p}}{\partial \theta}\\
&= \theta_{t} + h \left(\pmb{\pi} - \pmb{p}\right)\frac{\partial \pmb{p}}{\partial \theta}
\end{align}
This is almost, but not quite, what AlphaGo Zero does. It differs in two significant ways:
Dependence of MCTS on parameters
There is a catch, which only becomes clear once we make the dependence on parameters explicit:
$$
\pmb{v}(\pmb{p}_\theta, \theta) = \pmb{\pi}(\pmb{p}_\theta, \theta) - \pmb{p}_\theta
$$
MCTS itself depends not only on the initial guess $\pmb{p_\theta}$ but also on the parameters $\theta$ directly. This is because the policy-value network $f_\theta$ is evaluated multiple times within the MCTS algorithm. What this means is that the vector field $\pmb{v}$ also depends on the model parameters $\theta$ beyond its dependency of $\pmb{p}_\theta$.
Thus, our interpretation as seeking the stationary point of a single vector field in $\pmb{p}$-space isn't quite accurate. The vector field basically changes each time we change the parameter $\theta$, as we change the parameters $\theta$. In fact we are searching for $\theta^{\ast}$ such that
$$
\pmb{\pi}(\pmb{p}_{\theta^{\ast}}, \theta^{\ast}) = \pmb{p}_{\theta^{\ast}}
$$
if we want to solve this directly with gradient descent, then we would have to take a gradient of MCTS with respect to $\theta$. Good luck with that.
The way AlphaGo sidesteps this problem is by removing/ignoring the dependence of $\pmb{\pi}$ on $\pmb{p}$ and $\theta$. We simply calculate MCTS with the current value of the parameters, and then note down the value of the resulting $\pmb{\pi}$ and treat it as a regression target. AGZ's learning converges when
$$
\pmb{\pi}(\pmb{p}_{\theta^{\text{recent}}}, \theta^{\text{recent}}) = \pmb{p}_{\theta},
$$
where $\theta_{\text{recent}}$ is some recent value of the parameter $\theta$ from an earlier iteration of training.
Cross-entropy loss
AlphaGo differs from the Euler method on the vector field in another way. Interestingly, this, too, is similar to something from the last post. Instead of following the improvement gradient $\pmb{\pi} - \pmb{p}$, the method instead performs gradient descent on a scalar loss function
$$
\mathcal{L}(\theta) = KL(\pmb{\pi}_\theta | \pmb{p}_\theta)
$$
This is a generalization of trying to minimize the norm of the $\pmb{\pi} - \pmb{p}$, $L(\theta) = | v(\theta)|^2_2$, which is something that the Numerics of GANs paper recommended in the context of GANs. But, it replaces the Euclidean distance with KL divergence which makes more sense between probability distributions.
Again, the algorithm ignores the dependence of $\pmb{\pi}_\theta$ on $\theta$, and instead minimize
$$
\mathcal{L}(\theta) = KL(\pmb{\pi} | \pmb{p}_\theta) = \pmb{\pi}^T\log \pmb{p}_\theta - c,
$$
where $\pmb{\pi}$ was obtained by a previous value of the parameter $\theta_{\text{recent}}$, and $c$ is constant with respect to $\theta$. This is what AlpgaGo Zero optimizes in Eqn. (1) of the paper.
Connections to Expectation-Propagation
The Expectation-Propagation(EP) algorithm can be understood from the Minimal Policy Improvement viewpoint. There, instead of a RL policy, we wish to optimize an approximate posterior $q$ which is factorized as $q(\pmb{x}) = \prod q_i(\pmb{x}_{S_i})$. The real posterior $p(\pmb{x}) \alpha \prod p_i(\pmb{x}_{S_i})$ shares the same factorized structure, but its normalizing constant is intractable.
For EP, there are in fact multiple different improvement operators, one for each factor. For factor $j$, the improvement operator replaces the approximate factor $q_j$ with the real factor $p_j$ as follows:
$$
\pmb{\pi}^{j}(q) \propto \frac{p_j(\pmb{x}_{S_i})}{q_j(\pmb{x}_{S_j})}\prod q_i(\pmb{x}_{S_i})
$$
Within the class of distributions $q$ which share the same factor-graph structure as $p$, the following is true:
$$
\forall j: \pmb{\pi}^{j}(q) = p \iff p=q
$$
Therefore, it makes sense to seek stationary points with respect to the improvement operators, or in other words, to seek an approximate posterior which cannot be further improved by the above improvement operators.
When updating parameters of the $j^{\text{th}}$ factor $q_j$, we use the KL divergence to measure the size of the improvement - just like in AlphaGo Zero:
$$
\operatorname{KL}\left[\pmb{\pi}^{j}(q)|q\right]
$$
Of course there is more to EP than this, but I think it's an interesting view on it. Thinking about it this way already opens up a whole new set of questions: what if we consider other improvement operators, for example, Markov-Chain Monte Carlo kernels?
Which brings us to the next connection I wanted to highlight: contrastive divergence.
Contrastive divergence
Contrastive divergence also uses something like this minimal improvement principle, but somehow upside-down, in the context of generative modelling.
Let's say we have some training data $x_i$ drawn i.i.d. from some true data distribution $p$, and we want to fit to it an intractable generative model $q_\theta$. We can't evaluate the likelihood of the model $q_\theta$, but we can derive MCMC update, such that that irrespective of the starting state, applying the MCMC kernel several times, the distribution of samples converges to $q_\theta$.
We can therefore interpret each step of an MCMC sampling algorithm, too, as an improvement operator: Given samples from an arbitrary distribution $p$, if we apply one step of MCMC to the samples, the resulting distribution will look a bit more like samples drawn from the stationary distribution of the MCMC algorithm. Contrastive divergence exploits this and minimizes a kind of divergence between the data distribution $p$ and the distribution we obtain by running a fixed number of MCMC steps on the data samples.
For more details (and also an excellent hand-drawn Euler-method illustration) I recommend Rich Turner's excellent notes from, well, probably at least ten years ago.
Summary
I really liked the idea of thinking about MCTS as an improvement operator, and to train the policy network to minimize the improvement achievabel by MCTS. Although the authors demonstrate other ideas - like using ResNets or using a joint network for policy and value - are important for good performance, I think this is the key idea that enables learning from pure self-play.
Sadly, I can't see this method being easy to apply beyond the scope of self-play though. Firstly, in Go the states are fully observed, which is not usually the case for RL. Secondly, MCTS only works so well because we can use our policy as a model of the opponent's policy, so we can do pretty accurate rollouts.
I'm not sure what an improvement operator would be for other RL tasks, even, say, for Atari games. If you have ides, please share.