Guest blog by Yoel Zeldes.
This post describes:
In this post you will learn what the Gumbel-softmax trick is. Using this trick, you can sample from a discrete distribution and let the gradients propagate to the weights that affect the distribution's parameters. This trick opens doors to many interesting applications. For start, you can find an example of text generation in the paper GANS for Sequences of Discrete Elements with the Gumbel-softmax Dis....
Training deep neural networks usually boils down to defining your model's architecture and a loss function, and watching the gradients propagate.
However, sometimes it's not that simple: some architectures incorporate a random component. The forward pass is no longer a deterministic function of the input and weights. The random component introduces stochasticity, by means of sampling from it.
When would that happen, you ask? Whenever we want to approximate an intractable sum or integral. Then, we can form a Monte Carlo estimate. A good example is the variational autoencoder. Basically, it's an autoencoder on steroids: the encoder's job is to learn a distribution over the latent space. The loss function contains an intractable expectation over that distribution, so we sample from it.
As with any architecture, the gradients need to propagate to the weights of the model. Some of the weights are responsible for transforming the input into the parameters of the distribution from which we sample. Here we face a problem: the gradients can't propagate through random nodes! Hence, these weights won't be updated.
One solution to the problem is the reparameterization trick: you substitute the sampled random variable with a deterministic parameterized transformation of a parameterless random variable.
If you don't know this trick I highly encourage you to read about it. I'll demonstrate it with the Gaussian case. For many types of continuous distributions you can do the reparameterization trick. But what do you do if you need the distribution to be over a discrete set of values?
Read the full article with source code, here.