Training modern deepnets can take an inordinate amount of time even with the best GPU hardware available. Inception-3 on ImageNet 1000 using 8 NVIDIA Tesla K40s takes about 2 weeks (Google Research Blog).
Even when a large network is trained successfully, the memory footprint and the prediction latency (due to the number of its parameters) can make it challenging to put it into production.
One way to keep the predictive accuracy of a large network but reduce the number of its parameters, is a training paradigm called "distillation". The concept of distillation, introduced by Hinton, Vinyals, Dean, tries to transfer knowledge from a larger network to a much smaller network for deployment.
So what exactly is distillation? In supervised learning techniques, you train the input data on hard targets. If you’re training a network to recognize images of hand written digits, the hard targets will look like the following:
0 1 2 3 4 5 6 7 8 9 0 1 0 0 0 0 0 0 0 0
When the cumbersome model is a large ensemble of simpler models, we can use an arithmetic or geometric mean of their individual predictive distributions as the soft targets.
Hinton, et al. suggest that the class probabilities of of the incorrect categories are still valuable information to assist smaller networks.
Training a much smaller network with less data on these soft targets shows promise. This means that once we have our large model trained for months, we can distill the knowledge using the soft targets into a smaller model for deployment in a reasonable time and obtain respectable results.
Just as a review, softmax, is an activation function you can apply to the final layer of your network to turn scores into categorical probabilities. More precisely, given the logits, zi, the inputs to the softmax function, the softmax is defined as follows:
It’s easiest to think of the softmax as the normalized exponential function; the outputs, qi add up to 1, hence turning scores, zi into categorical probabilities.
For distillation, we think of the generalized softmax function which has a temperature parameter T:
When T = 1, this is just the regular softmax function described above. By raising the temperature one can dampen the probability distribution of the classes, in essence “softening” the target categories. You can see this in the example of classifying animal images where the soft targets have a less extreme probability distribution than the ensemble targets or the hard targets.
When transferring the knowledge via distillation from a larger network to a smaller network, you do so by creating soft targets from the larger network with some high temperature, and use the same temperature to train the smaller network. So let vi be the logits to the softmax for the larger network, with the softmax function p. The learning happens by minimizing the cross entropy loss between p and q:
In other words, to minimize the gradient:
If the temperature is high in comparison to the logits, then using a property of exponential functions that e^\epsilon ~ 1 + \epsilon for small enough \epsilon, the above expression can be approximated as
where N is the number of categories. Let’s assume without loss of generality that the logits for the larger network and the smaller network both have mean 0. Then the above is equivalent to
Note that in training, you scale the gradient of the cross entropy by the square inverse of the temperature.
How does using soft targets do in practice? Let’s start with the first problem of distilling a large cumbersome model’s knowledge into a smaller one for easier deployment. Of the examples discussed in the distillation paper, we focus on MNIST. Turns out, using soft targets is actually so good that the smaller model can be trained and generalized to predict classes that it hadn’t even seen before.
As a baseline, using all 60,000 training cases on a large model with two hidden 1600 rectified linear hidden units (ReLU) with dropout, jittering inputs and weight constraints gives 67 test errors. Now using a smaller network with no regularization, specifically just using vanilla backprop in a 784->800->800->10 network with ReLU as activation, gives 146 test errors. Turns out, adding simply the soft targets with a temperature of 20, but with no jittering of inputs or dropouts, results in minimizing the test errors to 74.
Distillation is a technique that tries to simulate the output from a large cumbersome model via a simpler model. Traditionally, the large cumbersome model is a several layered deep net with thousands of units and the simpler model contains, an order of magnitude smaller, number of layers and neurons. While this technique allows us to deploy simpler models in production systems, it usually has a higher error than the cumbersome model. Focusing on the classification case distillation is achieved by training the simpler model on the class probabilities outputted from the cumbersome model.