LSTM cell: Understanding architecture from scratch with code

Ordinary Neural Networks don’t perform well in cases where sequence of data is important. For example: language translation, sentiment-analysis, time-series and more. To overcome this failure, RNNs were invented. RNN stands for “Recurrent Neural Network”. An RNN cell not only considers its present input but also the output of RNN cells preceding it, for it’s present output.

Simple form of Vanilla RNN’s present state could be represented as :

Simple Vanilla RNN
Representation of simple RNN cell,source: stanford

RNNs performed very well on sequential data and performed well on tasks where sequence was important.

Ordinary RNNs comes with some problems

Vanishing gradients problem:

Vanishing gradient
Vanishing Gradient problem 1.tanh 2.derivative of tanh

Hyperbolic tangent(tanh) is mostly used as activation function in RNNs which lies in [-1,1] and derivative of tanh lies in [0,1]. During backpropagation, as gradient is calculated by chain rule, it has an effect of multiplying these small numbers n (number of times tanh used in rnn architecture) times which squeezes the final gradient to almost zero and hence subtracting gradient from weights doesn’t make any change to them which stops the training of model.

Exploding gradients problem:

Opposite to vanishing gradient problem, while following chain rule we multiply with the weight matrix(transposed W )too at each step, and if the values are larger than 1, multiplying a large number to itself many times leads to a very large number leading to explosion of gradient.

vanilla RNN gradient flow
exploding and vanishing gradients, source: CS231N stanford

Long-Term Dependencies problem

long term dependency problem
Long-term dependency problem, each node represents an rnn cell.source:Google

RNNs are good in handling sequential data but they run into problem when the context is far away. Example: I live France and I know ____. The answer must be ‘French’ here but if the there are some more words in between ‘I live in France’ & ‘I know ____’. It’ll be difficult for RNNs to predict ‘French’. This is the problem of Long-Term Dependencies. Hence we come to LSTMs.

Want to know the mathematics behind it? Have a quick peek here.

Long Short Term Memory Networks

LSTMs are special kind of RNNs with capability of handling Long-Term dependencies. They also provide solution to Vanishing/Exploding Gradient problem. We’ll discuss later in this article.

A simple LSTM cell looks like this:

LSTM cell
RNN vs LSTM cell representation, source: stanford

At start, we need to initialize the weight matrices and bias terms as shown below.

LSTM cell structure

A simple LSTM cell consists of 4 gates:

Simple lstm
3 LSTM cells connected to each other. source:Google
Image for post
LSTM cell visual representation, source: Google
Image for post
handy information about gates, source: Stanford CS231N

Let’s discuss the gates:

•Forget Gate: After getting the output of previous stateh(t-1), Forget gate helps us to take decisions about what must be removed from h(t-1) state and thus keeping only relevant stuff. It is surrounded by a sigmoid function which helps to crush the input between [0,1].It is represented as:

Image for post
Forget Gate, src: Google

We multiply forget gate with previous cell state to forget the unnecessary stuff from previous state which is not needed anymore, as shown below:

•Input Gate: In the input gate, we decide to add new stuff from the present input to our present cell state scaled by how much we wish to add them.

Image for post
Input Gate+Gate_gate,photo credits: Christopher Olah

In the above photo, sigmoid layer decides which values to be updated and tanh layer creates a vector for new candidates to added to present cell state. The code is shown below.

To calculate the present cell state, we add the output of ( (input_gate*gate_gate) and forget gate) as shown below.

Image for post

Output Gate: Finally we’ll decide what to output from our cell state which will be done by our sigmoid function.

We multiply the input with tanh to crush the values between (-1,1) and then multiply it with the output of sigmoid function so that we only output what we want to.

Image for post
output Gate, source:Google

An overall view of what we did with code.

LSTM responds to vanishing and exploding gradient problem in the following way. LSTM has much cleaner backprop compared to Vanilla RNNs

Image for post
Gradient flows smoothly during Backprop,source: CS231N stanford
  1. There is no multiplication with matrix W during backprop. It’s element wise multiplication with f(forget gate). So it’s time complexity is less.
  2. During backprop through each LSTM cell, it’s multiplied by different values of forget fate, which makes it less prone to vanishing/exploding gradient. Though, if values of all forget gates are less than 1, it may suffer from vanishing gradient but in practice people tend to initialise the bias terms with some positive number so in the beginning of training f(forget gate) is very close to 1 and as time passes the model can learn these bias terms.
  • Still, the model may suffer with vanishing gradient problem but chances are very less.
  • This article was limited to architecture of LSTM cell but you can see the complete code HERE. The code also implements an example of generating simple sequence from random inputs using LSTMs.

This is an updated version of my previous article.

Default image
Manik Soni
Articles: 5

One comment

  1. Great Content


Leave a Reply