DEV Community


LSTMs without (too much) maths

aarushikansal profile image Aarushi Kansal ・3 min read

Assumed Knowledge

High level understanding of:

  • Neural networks
  • Linear algebra

Maths refresher

Sigmoid function: squishes values to between 0 and 1
Tan function: squishes values to between -1 and 1

Neural Networks

Traditional neural networks can be useful in a variety of applications such as classification problems. However, one of their major shortcomings is the inability to remember past information.

For example, as you're reading this sentence, you've retained the information from the previous sentence and you know we're talking about neural networks right now. You don't keep forgetting which means you can continue to base your understanding within the context of neural networks and their shortcomings.

Neural networks don't have this capability because they map inputs to outputs in a one to one manner.

Recurrent Neural Networks

Recurrent neural networks are way to address this memory issue. They do this by providing a feedback loop which allows previous information to be persisted. This can be very valuable in applications where we only need to connect recent information. Some of these applications include language translation, image captioning, speech recognition etc.

Unrolled RNN architecture

Here we have an unrolled RNN, where we can see that we process information in one cell and pass relevant information on to the next cell, which allows the cell to 'remember' information from the past.

However, because we're only passing information to the next immediate cell, RNNs, have the tendency to lose information the further along the sequence we go.

Essentially, current implementations of RNNs work well for short dependencies but cannot hold information for long term.

Long Short Term Memory

The short term memory problem is addressed by using LSTMs, which have a more complicated structure, but function more similarly to how a human might read a book or hold a conversation.

In the previous RNN, we can see that our network is able to 'remember' previous information because we pass the previous hidden state (h) into the current cell. Continuing on from this observation, we can also see how hidden states from further back cells become diluted, essentially the information in those states vanishes.

Let's see how LSTMs address this:

Alt Text

One of the core observations here is the top horizontal line, which transfers the vector straight through the cell and through the entire network. This means that information can flow through the sequence, essentially unchanged, meaning our network has the capability to remember information from further behind in the sequence. Kind of like a sushi train, food keeps passing along, we can remove, modify, or leave the sushi as is.

However, we also don't want to just pass information along with no modifications.
The way that humans understand and process information is based on our ability to place more or less emphasis on different parts of a sentence or paragraph, based on context or prior knowledge.

To reproduce this kind of behaviour, LSTMs use gates.

Firstly, we have the forget gate, which is using a sigmoid function to squish all our values - 0 = forget, 1 = remember
This new vector is passed on to our sushi train, where we multiply it with the incoming information. This means we're deciding what incoming information we want to forget.
For example, the incoming information might include a language based on the country we're talking about, but when the context changes to a new country, we want to forget about the previous language.

Next we have the remember or input gate, which determines what new information we want to store. This has two parts to it:

  1. Sigmoid function - so we know which values we need to update (0 - update, 1 - don't update)
  2. Tan function - to regulate our values by squishing to between -1 and 1 Finally we multiply the two and get our new values and push those on to our sushi train, which adds to the incoming state from the forget gate.

Lastly, we need to combine all the gates to decide the final outcome. So again, we run the incoming information through a sigmoid function, pick up the information from the train, run it through a tan function and times the two results. This is the output or hidden state of the cell.

By doing this we're forgetting irrelevant information, updating with new information and coming up with the next prediction which could be the next word in a sentence, translation of a word etc.

Discussion (2)

kushalvala profile image
aarushikansal profile image
Forem Open with the Forem app