Ordinary Neural Networks don’t perform well in cases where sequence of data is important. For example: language translation, sentiment-analysis, time-series and more. To overcome this failure, RNNs were invented. RNN stands for “Recurrent Neural Network”. An RNN cell not only considers its present input but also the output of RNN cells preceding it, for it’s present output.
Simple form of Vanilla RNN’s present state could be represented as :
RNNs performed very well on sequential data and performed well on tasks where sequence was important.
Ordinary RNNs comes with some problems
Vanishing gradients problem:
Hyperbolic tangent(tanh) is mostly used as activation function in RNNs which lies in [-1,1] and derivative of tanh lies in [0,1]. During backpropagation, as gradient is calculated by chain rule, it has an effect of multiplying these small numbers n (number of times tanh used in rnn architecture) times which squeezes the final gradient to almost zero and hence subtracting gradient from weights doesn’t make any change to them which stops the training of model.
Exploding gradients problem:
Opposite to vanishing gradient problem, while following chain rule we multiply with the weight matrix(transposed W )too at each step, and if the values are larger than 1, multiplying a large number to itself many times leads to a very large number leading to explosion of gradient.
Long-Term Dependencies problem
RNNs are good in handling sequential data but they run into problem when the context is far away. Example: I live France and I know ____. The answer must be ‘French’ here but if the there are some more words in between ‘I live in France’ & ‘I know ____’. It’ll be difficult for RNNs to predict ‘French’. This is the problem of Long-Term Dependencies. Hence we come to LSTMs.
Long Short Term Memory Networks
LSTMs are special kind of RNNs with capability of handling Long-Term dependencies. They also provide solution to Vanishing/Exploding Gradient problem. We’ll discuss later in this article.
A simple LSTM cell looks like this:
At start, we need to initialize the weight matrices and bias terms as shown below.
LSTM cell structure
A simple LSTM cell consists of 4 gates:
Let’s discuss the gates:
•Forget Gate: After getting the output of previous state, h(t-1), Forget gate helps us to take decisions about what must be removed from h(t-1) state and thus keeping only relevant stuff. It is surrounded by a sigmoid function which helps to crush the input between [0,1].It is represented as:
We multiply forget gate with previous cell state to forget the unnecessary stuff from previous state which is not needed anymore, as shown below:
•Input Gate: In the input gate, we decide to add new stuff from the present input to our present cell state scaled by how much we wish to add them.
In the above photo, sigmoid layer decides which values to be updated and tanh layer creates a vector for new candidates to added to present cell state. The code is shown below.
To calculate the present cell state, we add the output of ( (input_gate*gate_gate) and forget gate) as shown below.
Output Gate: Finally we’ll decide what to output from our cell state which will be done by our sigmoid function.
We multiply the input with tanh to crush the values between (-1,1) and then multiply it with the output of sigmoid function so that we only output what we want to.
An overall view of what we did with code.
LSTM responds to vanishing and exploding gradient problem in the following way. LSTM has much cleaner backprop compared to Vanilla RNNs
- There is no multiplication with matrix W during backprop. It’s element wise multiplication with f(forget gate). So it’s time complexity is less.
- During backprop through each LSTM cell, it’s multiplied by different values of forget fate, which makes it less prone to vanishing/exploding gradient. Though, if values of all forget gates are less than 1, it may suffer from vanishing gradient but in practice people tend to initialise the bias terms with some positive number so in the beginning of training f(forget gate) is very close to 1 and as time passes the model can learn these bias terms.
- Still, the model may suffer with vanishing gradient problem but chances are very less.
- This article was limited to architecture of LSTM cell but you can see the complete code HERE. The code also implements an example of generating simple sequence from random inputs using LSTMs.
This is an updated version of my previous article.