tero.co.uk

Gated Recurrent Units

Published 11 April 2016

This article covers the mathematics of a neural network with gated recurrent units. It's sort of a contination of the RNN article. There are lots of great articles covering LSTMs and GRUs from a graphical perspective. But these articles usually only give the maths and programming for the feed forward stage. I struggled to find the details of the equations for the back propagation, so I decided to write my own. Not for the faint hearted!

Layer types

Basic feed forward networks have one set of weights between layers. Simple recurrent neural networks (RNNs) extend this to have two sets of weights, with an extra set of recurrent weights connecting a hidden layer to itself.

Gated Recurrent Units (GRUs) replace each of these sets of weights, with three sets of weights. So instead of one set connecting the previous layer, there are three sets. And instead of one recurrent set connecting a layer to itself, there are three of them. Which makes six sets of weights in total.

Imagine a feed forward network with 10 nodes in the input layer, and 15 nodes in the next hidden layer. There is one set of 10*15=150 weights connecting these two layers. First the network forms a weighted sum of the 10 inputs and 150 weights, and then passes this through tanh or sigmoid to produce 15 activation values, which as passed onto the next layer.

A simple RNN with the same size layers would have 10*15=150 input-to-hidden and 15*15=225 hidden-to-hidden weights, 375 weights in total. Just like feed forward networks, it will combine these to produce 15 activation values

A GRU with the same sized layers has all that in triplicate, so 3*10*15=450 input-to-hidden and 3*15*15=775 hidden-to-hidden weights, 1225 weights in total. From which it also produces just 15 activation values.

In fact, the terms "feed forward", "simple recurrent" and "GRU" refer to the type of layer, rather than the network. Because it's possible to have all three in a single network. Furthermore, GRU is just a more complicated type of recurrent layer. And the basic function of all three layer types is still the same: in this case 10 values in and 15 values out.

Labelling the weights

Before presenting the equations for GRUs, we first need some labels. Previously, I labelled the weights between the input layer as: \(w^{xh}\), and the recurrent weights as \(w^{hh}\). For instance, \(w^{xh}\) was for the weights connecting the input layer x with the hidden layer h.

GRUs have two "gates" - one for updating and one for resetting. They also have the thing they may or may not remember. These three things (updating, resetting, memory) each have their own weights. So I'll refer to them as \(w^{xu}\) and \(w^{hu}\) for updating, \(w^{xr}\) and \(w^{hr}\) for resetting and \(w^{xm}\) and \(w^{hm}\) for the memory. So for example \(w^{xr}\) is the set of weights between the input x and the hidden GRU layer h. And \(w^{hm}\) represents the weights between this GRU hidden layer h and the memory.

Finally, I used x and h to represent the actual values. So there is \(x_1\), \(x_1\) ... \(x_10\) for the 10 input node values and \(h_1\) ... \(h_15\) for the 15 hidden node values. Because of the recurrent connections, we also have to consider the time step. Previously I represented this time step with a superscript. So \(h_3^{t=4}\) stores the value of the 3rd hidden node in the 4th time step. In the equation below, the generic labels will be \(x^t\) for the input node values in the tth time step, and \(h^{t-1}\) for the hidden node values in the t-1st time step.

Equations

Now I can build up to the equation which computes the hidden node values in a GRU. First of all, here is the equation for a simple feed forward network. This one doesn't involve any time steps. It multiplies the inputs x times the weights (as a vector dot product) and then performs the tanh activation function:

\(h = \tanh\ (x w^{xh})\)

For simple RNNs, we have to introduce the time step t and the hidden-to-hidden weights:

\(h^t = \tanh\ (x^t w^{xh} + h^{t-1} w^{hh}) \)

And now I can present the whole equation as borrowed from this excellent RNN tutorial. As with the equation above, this "simply" turns \(x^t\) and \(h^{t-1}\) into the new hidden node values \(h^t\). This equation includes the 6 sets of weights. The Greek lower case sigma \(\sigma\) is for the sigmoid activation function, and the symbol \(\odot\) means an element-wise vector multiplication:

\(h^t = (1 - \sigma\ (x^t w^{xu} + h^{t-1} w^{hu})) \odot \tanh\ (x^t w^{xm} + (\sigma\ (x^t w^{xr} + h^{t-1} w^{hr}) \odot h^{t-1}) w^{hm})\) \( + \sigma\ (x^t w^{xu} + h^{t-1} w^{hu}) \odot h^{t-1}\)

You can see that it is signifcantly more complex than a simple RNN! But it can be broken down into a few distint steps. First it computes the reset gate which we'll store in a variabled called r. The sigmoid function \(\sigma\) produces a value between 0 and 1:

\(r = \sigma\ (x^t w^{xr} + h^{t-1} w^{hr})\)

The value of this reset gate goes from 0 to 1 and determines how much of the previous hidden node value should be remembered in a variable called m. For example, if the reset value for the second node \(r_2\) was 0, then none of the second hidden node's previous value \(h_2^{t-1}\) will be used to compute \(m_2\). But if \(r_2\) equals 1, then all of \(h_2^{t-1}\) will be passed through and then multiplied by the weight \(w^{hm}\) to contribute to \(m_2\). The equation which does this is:

\(m = \tanh\ (x^t w^{xm} + (r \odot h^{t-1}) w^{hm})\)

Finally it computes the update gate which is also a value from 0 to 1:

\(u = \sigma\ (x^t w^{xu} + h^{t-1} w^{hu})\)

The update gate determines whether the new hidden node value \(h^t\) should retain the old value (if u=0) or use the new value computed above (if u=1), or somewhere in between if u is between 0 and 1. Note that some articles show this the other way around with u=1 for the new value and u=0 for the old value. It doesn't really matter as the network as long as the equations are consistent, but doing it this way round will save having to compute/save 1-u lot later:

\(h^t = u \odot m + (1 - u) \odot h^{t-1} \)

If you combine these 4 mini-equations, you get the bigger equation above. That is all that's needed for forward propagation.

Back propagation

Whenever a neural network is run, it produces an "answer" in its output nodes. We compare this to the expected answer and compute the loss or cost - eg how wrong the network was. We then try to minimise this loss. This involves figuring out how each and every weight (all 1225 of them) affects the loss. And this computing the derivatives of the loss with respect to each weight.

For feed forward networks, the derivatives of the input-to-hidden weights were computed by a big sum:

\(\frac {\partial J} {\partial w^{xh}_{mi}} = \sum_{j=1}^{n_y} (\frac {\partial J} {\partial p_j} \frac {\partial p_j} {\partial z_{yj}} \frac {\partial z_{yj}} {\partial h_i}) \frac {\partial h_i} {\partial z_{hi}} \frac {\partial z_{hi}} {\partial w^{xh}_{mi}} \)

This shows the sum in all its gory detail, using the chain rule to expand everything out. For our purposes, we can simplify it a bit, unchaining and recombining some of the derivatives. It is still the same equation:

\(\frac {\partial J} {\partial w^{xh}_{mi}} = \sum_{j=1}^{n_y} (\frac {\partial J} {\partial h_i}) \frac {\partial h_i} {\partial w^{xh}_{mi}} \)

For RNNs the same sum involves time steps. For example the loss at time step t due to a hidden-to-hidden weight is:

\(\frac {\partial J^t} {\partial w^{hh}_{mi}} = \sum_{j=1}^{n_y} (\frac {\partial J^t} {\partial p^t_j} \frac {\partial p^t_j} {\partial z^t_{yj}} \frac {\partial z^t_{yj}} {\partial h^t_i}) \frac {\partial h^t_i} {\partial z^t_{hi}} \frac {\partial z^t_{hi}} {\partial w^{hh}_{mi}} \)

We can do the same unchaining and recombining act here too:

\(\frac {\partial J^t} {\partial w^{hh}_{mi}} = \sum_{j=1}^{n_y} (\frac {\partial J^t} {\partial h^t_i}) \frac {\partial h^t_i} {\partial w^{hh}_{mi}} \)

And this sum can keep growing and growing as we go farther back in time:

\(\frac {\partial J^{t-1}} {\partial w^{hh}_{mk}} = \sum_{i=1}^{n_h} (\sum_{j=1}^{n_y} (\frac {\partial J^t} {\partial h^t_i}) \frac {\partial h^t_i} {\partial h^{t-1}_k}) \frac {\partial h^{t-1}_k} {\partial w^{hh}_{mk}} \)

And so on and so on...

\(\frac {\partial J^{t-2}} {\partial w^{hh}_{ml}} = \sum_{k=1}^{n_h} (\sum_{i=1}^{n_h} (\sum_{j=1}^{n_y} (\frac {\partial J^t} {\partial h^t_i}) \frac {\partial h^t_i} {\partial h^{t-1}_k}) \frac {\partial h^{t-1}_k} {\partial h^{t-2}_l}) \frac {\partial h^{t-2}_l} {\partial w^{hh}_{ml}} \)

Notice that for every step back, the second-to-last term is pretty much duplicated. This is basis of the back propagation of the error. We can rewrite each of those steps as an error value e which keeps getting updated:

\( e^{t-1}_{hi} = \sum_{k=1}^{n_h} (e^t_{hi} \frac {\partial h^t_k} {\partial h^{t-1}_k}) \)

And then the derivative is just this error multiplied by the very last derivative from the long sum:

\(\frac {\partial J^t} {\partial w^{hh}_{mi}} = e^t_{hi} \frac {\partial h^t_i} {\partial w^{hh}_{mi}} \)

We can drop the sums and indices on these equations. The sums are absorbed by a matrix multiplication:

\( e^{t-1}_h = e^t_h \frac {\partial h^t} {\partial h^{t-1}}\)

\(\frac {\partial J^t} {\partial w^{hh}} = e^t_h \frac {\partial h^t} {\partial w^{hh}} \)

For an RNN, both of these equations are fairly easy to work out:

\( e^{t-1}_h = (e^t_h \odot (1 - {h^t}^2)) w^{hh} \)

\(\frac {\partial J^t} {\partial w^{hh}} = (e^t_h \odot (1 - {h^t_i}^2)) h^{t-1} \)

For a GRU they are not. Those derivatives are really involved and nasty. The next section will figure them out.

Derivatives

For a GRU network there are six sets of weights instead of just one. So we need to compute how the hidden node values \(h\) change with respect to each of those six sets of weights. Since the GRUs are recurrent, we also compute for each time step separately, so we use \(h^t\) instead of just \(h\).

Let's start with \(w^{xu}\) the input-to-update-gate weights as they are the easiest. We need to apply the chain rule. For this I will use the letter z to represent the sums before the activation function is applied. So \(z_u\) is the sum inside the \(\sigma\) for the update gate. This one can be broken down into three parts, where \(u - u^2\) is the derivative of the sigmoid function:

\(\frac {\partial J^t} {\partial w^{xu}} = \frac {\partial J^t} {\partial h^t} \frac {\partial h^t} {\partial u} \frac {\partial u} {\partial z_u} \frac {\partial z_u} {\partial w^{xu}} = (e^t_h \odot (m - h^{t-1}) \odot (u - u^2))\ x^t\)

The corresponding hidden-to-update-gate is quite similar:

\(\frac {\partial J^t} {\partial w^{hu}} = \frac {\partial J^t} {\partial h^t} \frac {\partial h^t} {\partial u} \frac {\partial u} {\partial z_u} \frac {\partial z_u} {\partial w^{hu}} = (e^t_h \odot (m - h^{t-1}) \odot (u - u^2))\ h^{t-1}\)

The input-to-memory derivatives are computed as follow, where \(1 - m^2\) is the derivative of the tanh function:

\(\frac {\partial h^t} {\partial w^{xm}} = \frac {\partial h^t} {\partial m} \frac {\partial m} {\partial z_m} \frac {\partial z_m} {\partial w^{xm}} = (u \odot (1 - m^2))\ x^t\)

And the hidden-to-memory derivatives are similar:

\(\frac {\partial J^t} {\partial w^{hm}} = \frac {\partial J^t} {\partial h^t} \frac {\partial h^t} {\partial m} \frac {\partial m} {\partial z_m} \frac {\partial z_m} {\partial w^{hm}} = (e^t_h \odot u \odot (1 - m^2))\ (r \odot h^{t-1}) \)

Continuing with that theme, you might think that the input-to-reset-gate could be worked out the same way:

\(\frac {\partial J^t} {\partial w^{xr}} = \frac {\partial J^t} {\partial h^t} \frac {\partial h^t} {\partial m} \frac {\partial m} {\partial z_m} \frac {\partial z_m} {\partial r} \frac {\partial r} {\partial z_r} \frac {\partial z_r} {\partial w^{xr}}\)

However, it's more subtle than that, because it lays bare an assumption which I didn't tell you I made. The assumption was that the i th update gate \(u_i\) only depends upon the i th weighted sum \(z_{ui}\). That assumption has worked for us so far, but no longer! Because the reset gates are much more complicated since they are collectively multiplied by the hidden-to-memory weights, and so it all gets jumbled up. It means that each weight doesn't have just one nice clear path to the i th node value through a few \(\odot\) multiplications. Instead there are many paths. We'll have to expand the derivative into all it's goriness. So let's first look at the input-to-reset-gate derivative for a single node with respect to a single weight. I have carefully worked this out for the 3rd hidden node with respect to one of the weights:

\(\frac {\partial h^t_3} {\partial w^{xr}_{12}} = u_3\ (1 - m_3^2)\ (r_2 - r_2^2)\ h^{t-1}_2\ w^{hm}_{23}\ x^t_1\)

We can make that more generic by replacing the numbers with indices:

\(\frac {\partial h^t_i} {\partial w^{xr}_{jk}} = u_i\ (1 - m_i^2)\ (r_k - r_k^2)\ h^{t-1}_k\ w^{hm}_{ki}\ x^t_j\)

Now, the total effect of that one weight is the accumulation of how it effects all the hidden nodes, so we have to add that up:

\(\frac {\partial h^t} {\partial w^{xr}_{jk}} = \sum_{i=1}^{n_h} (u_i\ (1 - m_i^2)\ (r_k - r_k^2)\ h^{t-1}_k\ w^{hm}_{ki}\ x^t_j) \)

We can take some of the terms out of the sum as they don't involve the index i:

\(\frac {\partial h^t} {\partial w^{xr}_{jk}} = \sum_{i=1}^{n_h} (u_i\ (1 - m_i^2)\ w^{hm}_{ki})\ (r_k - r_k^2)\ h^{t-1}_k\ x^t_j \)

The important thing to realise here is that the sum is actually a vector-times-matrix multiplication, because that's what matrix multiplication does, it multiplies a bunch of components and adds up the result. However, the hidden-to-memory weights are being multiplied the other way around so we have to transpose that matrix. This is the final result. And please don't think I got this right the first time. It took me a lot of trial and error and testing (using numerical gradient checking) to reach this equation:

\(\frac {\partial J^t} {\partial w^{xr}} = ((e^t_h \odot u \odot (1 - m^2)) {w^{hm}}^T) \odot (r - r^2) \odot h^{t-1} \odot x^t \)

And the hidden-to-reset-gate derivatives are similar:

\(\frac {\partial J^t} {\partial w^{hr}} = (e^t_h \odot (u \odot (1 - m^2)) {w^{hm}}^T) \odot (r - r^2) \odot h^{t-1} \odot h^{t-1} \)

We have now computed the six partial derivatives that will help us compute how the loss changes with respect to all 1225 of the weights. We now just need to compute one last thing from this equation, which is the error e used above.

Back propagating the error

One of the great properties of neural networks is that each layer can be considered separately when computing the derivative. To enable this, we just have to run the network in reverse, back propagating the error through the network according to the following equation from above:

\( e^{t-1}_h = e^t_h \frac {\partial h^t} {\partial h^{t-1}}\)

Which means that given the error at one time step \(e^t\), we have to figure out the error at the previous time step \(e^{t-1}\) by computing the partial derivative of thie node values in one time step \(h_t\) with respect to the previous time step \(h^{t-1}\). To do this notice that the term \(h^{t-1}\) appears four times in the forward propagation equations. It appears once next to each of the three sets of recurrent weights, and once next to the term 1-u. It therefore ends up as a four part sum, and we can use the derivatives computed above, replace the multiplication by \(h^{t-1}\) with a multiplication by a transposed weight matrix:

\( e^{t-1}_h = (e^t_h \odot (m - h^{t-1}) \odot (u - u^2)) {w^{hu}}^T + \) \(((e^t_h \odot u \odot (1 - m^2)) {w^{hm}}^T) \odot r + \) \((((e^t_h \odot u \odot (1 - m^2)) {w^{hm}}^T) \odot (r - r^2) \odot h^{t-1}) {w^{hr}}^T) + e^t_h \odot (1-u)\)

That back propagates the error through time steps, from time t to t-1. At the same time, we also need to back propagate the error from this GRU layer to the previous layer. This can be done for all time steps at once (after the error has been back propagated above):

\( e_x = e_h \frac {\partial h} {\partial x}\)

But in practice we'll also do this for each time step separately:

\( e^{t}_x = e^t_h \frac {\partial h^t} {\partial x^t}\)

This equation is similar to the one above, except doesn't include the 1-u:

\( e^{t}_x = (e^t_h \odot (m - h^{t-1}) \odot (u - u^2)) {w^{xu}}^T + \) \(((e^t_h \odot u \odot (1 - m^2)) {w^{xm}}^T) \odot r + \) \((((e^t_h \odot u \odot (1 - m^2)) {w^{hm}}^T) \odot (r - r^2) \odot h^{t-1}) {w^{xr}}^T)\)

Programming tips

This has been a significant intellectual and programming challenge. The most difficult part is that it won't work unless all the equations are correct. And if just one thing is wrong, it will all fail completely.

So if you are attempting this yourself, first implement a numerical gradient check. This is a function which varies each weight by a small amount and measures the effect on the overall cost J. It allows you to check if you are getting the derivatives right. The only other option is to program the network in full and then running it to see if it produces sensible answers. But it's very hard to determine what is and isn't a sensible answer or not. Numerical gradient checking provides clear-cut evidence.

Secondly, add the equations one a time. So first program a simple feed forward network using only one set of weights such as \(w^{xu}\). The node values in the next time step are then:

\(h^t = \sigma (x^t w^{xu})\)

Implement the feed forward, weight updates and back propagation for this network first. Then change it into a simple RNN and check everything works:

\(h^t = \sigma (x^t w^{xu} + h^{t-1} w^{hu} )\)

When testing the RNN, first test it with just 1 time step and the hidden nodes initialised to zero. That way you can verify the weight updates are working without worrying about back propagating the error. Then initialise the hidden nodes to random values (which adds a bit more complexity as you can no longer assume that \(h^{t-1} = 0\)). And then test it over multiple time steps to make sure back propagation works.

Then implement a feed forward network which combines two different functions, for example:

\(u = \sigma\ (x^t w^{xu})\)
\(m = \tanh\ (x^t w^{xm}) \)
\(h^t = u \odot m = \sigma\ (x^t w^{xu}) \odot \tanh\ (x^t w^{xm}) \)

The derivatives and back propagation for this will be significantly more involved and above. For example:

\(\frac {\partial J^t} {\partial w^{xu}} = (e^t_h \odot m \odot (u - u^2))\ x^t\)
\(\frac {\partial J^t} {\partial w^{xm}} = (e^t_h \odot u \odot (1 - m^2))\ x^t\)
\( e^{t}_x = (e^t_h \odot m \odot (u - u^2))\ {w^{xu}}^T + (e^t_h \odot u \odot (1 - m^2))\ {w^{xm}}^T \)

Then add the two sets of recurrent weights to this network. Only then introduce the reset gates into the m equation (which makes it much more complex because of that internal hidden-to-memory weight set), and finally add in the \((1-u) \odot h^{t-1}\). Run the numerical gradient check each time to check it's all working.

Also notice that there are some bits of maths which are used repeatedly such as \(e^t_h \odot u \odot (1 - m^2)\). You can of course keep these inside internal temporary variables so you don't have to keep recomputing them.

Conclusion

That covers the maths behind back propagation for a GRU network. I implemented this in code at the same time to verify that the equations were correct, using numerical gradient checking. It took a long time to get it all working and correct. I hope this article helps.