Recursive Neuron Networks for sequential data

#RNN #LSTM #Attention

Motivation

Given a current state of object, can you predict its next state?

Without any prior information, any guess would be a random guess

Given information about previous states, we can greatly improve our prediction

Sequential data examples

Text

Audio

DNA

Stochastic forecast

Video

Any data recordings

Thus, to model sequences we need a model that:

  • handle arbitrary length input

  • track long-term dependencies

  • take into account the order

Network Architecture

Forward model:

For simplicity, we collaps items to make simplified forward structure as:

Recursive Neural Network recursively calls itself as follows:

...or in short form

Simple RNN

  • xt : input vector (m 1).

  • ht : hidden layer vector (n 1)

  • ot : output vector (n 1)

  • bh : bias vector (n ⨉ 1)

  • Uh, Vh, Wy : param matrices (nm)

  • σ : activation functions.

Forward-pass

  1. From input vector xt

  2. compute hidden state ht : memory / context

    • current context is a non-linear function of past context ht-1 and current input xt

  3. Output = linear transformation of the current ht

✔️ For RNN, Input size is not fixed and might be any:

  • at every step, the hidden state is updated by current xt

RNN has the same number of weights as forward model.

To predict, RNN model uses same weights at each time step.


Total loss function L is a sum of the losses at every time step.

Learning

In case of RNN, we backpropagate through time (BPTT):

  • Total loss L is backpropagated through every time-step loss Lt

  • AND time-step losses Lt is backpropagated through Lt-1, Lt-2 ... L0

Potential issues

Gradients are too big (aka Gradient Explositon) : we are over-re-learning every single input

Gradients are too small (aka Gradient Vanishing) : we are learning only short-term dependency, forgetting long term context

  • use ReLU : if x>0 it prevents gradient shrinking

  • use gates : control relevant information by adding / removing it

RNN architectures

Simple RNN

short-term memory=Gradient Vanishing

GRU : gated reccurent unit

  • zt : update gate (add or not?)

  • rt : reset gate (skip or not?)

✔️ solve a bit Gradient Vanishing

LSTM: long-short time memory

  • f -forget gate : is context important?

  • i -input gate : add x to context?

  • o -output gate : prediction

✔️ solve a bit Gradient Vanishing


Advanced improvements

Deep RNN

Bi-directional RNN

Deep recurrent networks :

  • level 0: simple RNN (GRU)

  • level 1+: input X = ht lvl0

Limitations of RNN:

  1. Information bottleneck

  2. Slow, no parallelization

  3. Not really long term memory

    • RNN suffers Gradient Vanishing for very long sequences

💡 We would like to have:

  1. Continuous stream of information

  2. Parallelization

  3. Long memory: don't forget what is important

Solution : Attention