November 14, 2019

Meta-Learning Millions of Hyper-parameters using the Implicit Function Theorem

Last night on the train I read this nice paper by David Duvenaud and colleagues. Around midnight I got a calendar notification "it's David Duvenaud's birthday". So I thought it's time for a David Duvenaud birthday special (don't get too excited David, I won't make it an annual tradition...)


I recently covered iMAML: the meta-learning algorithm that makes use of implicit gradients to sidestep backpropagating through the inner loop optimization in meta-learning/hyperparameter tuning. The method presented in (Lorraine et al, 2019) uses the same high-level idea, but introduces a different - on the surface less fiddly - approximation to the crucial inverse Hessian. I won't spend a lot of time introducing the whole meta-learning setup from scratch, you can use the previous post as a starting point.

Implicit Function Theorem

Many - though not all - meta-learning or hyperparameter optimization problems can be stated as nested optimization problems. If we have some hyperparameters $\lambda$ and some parameters $\theta$ we are interested in

\operatorname{argmin}_\lambda \mathcal{L}_V (\operatorname{argmin}_\theta \mathcal{L}_T(\theta, \lambda)),

Where $\mathcal{L}_T$ is some training loss and $\mathcal{L}_V$ a validation loss. The optimal parameter to the training problem, $\theta^\ast$ implicitly depends on the hyperparameters $\lambda$:

\theta^\ast(\lambda) = \operatorname{argmin} f(\theta, \lambda)

If this implicit function mapping $\lambda$ to $\theta^\ast$ is differentiable, and subject to some other conditions, the implicit function theorem states that its derivative is

\left.\frac{\partial\theta^{\ast}}{\partial\lambda}\right\vert_{\lambda_0} = \left.-\left[\frac{\partial^2 \mathcal{L}_T}{\partial \theta \partial \theta}\right]^{-1}\frac{\partial^2\mathcal{L}_T}{\partial \theta \partial \lambda}\right\vert_{\lambda_0, \theta^\ast(\lambda_0)}

The formula we obtained for iMAML is a special case of this where the $\frac{\partial^2\mathcal{L}_T}{\partial \theta \partial \lambda}$ is the identity This is because there, the hyperparameter controls a quadratic regularizer $\frac{1}{2}\|\theta - \lambda\|^2$, and indeed if you differentiate this with respect to both $\lambda$ and $\theta$ you are left with a constant times identity.

The primary difficulty of course is approximating the inverse Hessian, or indeed matrix-vector products involving this inverse Hessian. This is where iMAML and the method proposed by Lorraine et al, (2019) differ. iMAML uses a conjugate gradient method to iteratively approximate the gradient. In this work, they use a Neumann series approximation, which, for a matrix $U$ looks as follows:

U^{-1} = \sum_{i=0}^{\infty}(I - U)^i

This is basically a generalization of the better known sum of a geometric series: if you have a scalar $\vert u \vert<1$ then

\sum_{i=0}^\infty q^i = \frac{1}{1-q}.

Using a finite truncation of the Neumann series one can approximate the inverse Hessian in the following way:

\left[\frac{\partial^2 \mathcal{L}_T}{\partial \theta \partial \theta}\right]^{-1} \approx \sum_{i=1}^j \left(I - \frac{\partial^2 \mathcal{L}_T}{\partial \theta \partial \theta}\right)^i.

This Neumann series approximation, at least on the surface, seems significantly less hassle to implement than running a conjugate gradient optimization step.


One of the fun bits of this paper is the interesting set of experiments the authors used to demonstrate the versatility of this approach. For example, in this framework, one can treat the training dataset as a hyperparameter. Optimizing pixel values in a small training dataset, one image per class, allowed the authors to "distill" a dataset into a set of prototypical examples. If you train your neural net on this distilled dataset, you get relatively good validation performance. The results are not quite as image-like as one would imagine, but for some classes, like bikes, you even get recognisable shapes:

In another experiment the authors trained a network to perform data augmentation, treating parameters of this network as a hyperparameter of a learning task. In both of these cases, the number of hyperparameters optimized were in the hundreds of thousands, way beyond the number we usually consider as hyperparameters.


This method inherits some of the limitations I already discussed with iMAML. Please also see the comments where various people gave pointers to work that overcomes some of these limitations.

Most crucially, methods based on implicit gradients assume that your learning algorithm (inner loop) finds a unique, optimal parameter that minimises some loss function. This is simply not a valid assumption for SGD where different random seeds might produce very different and differently behaving optima.

Secondly, this assumption only allows for hyperparameters that control the loss function, but not for ones that control other aspects of the optimization algorithm, such as learning rates, batch sizes or initialization. For those kind of situations, explicit differentiation may still be the most competitive solution. On that note, I also recommend reading this recent paper on generalized inner-loop meta-learning and the associated pytorch package higher.


Happy birthday David. Nice work!

  • LinkedIn
  • Tumblr
  • Reddit
  • Google+
  • Pinterest
  • Pocket