Recursive Neuron Networks for sequential data
#RNN #LSTM #Attention
Motivation
Given a current state of object, can you predict its next state?
Without any prior information, any guess would be a random guess
Given information about previous states, we can greatly improve our prediction
Sequential data examples
Text
Audio
DNA
Stochastic forecast
Video
Any data recordings
Thus, to model sequences we need a model that:
handle arbitrary length input
track long-term dependencies
take into account the order
Network Architecture
Forward model:
For simplicity, we collaps items to make simplified forward structure as:
Recursive Neural Network recursively calls itself as follows:
...or in short form
Simple RNN
xt : input vector (m ⨉ 1).
ht : hidden layer vector (n ⨉ 1)
ot : output vector (n ⨉ 1)
bh : bias vector (n ⨉ 1)
Uh, Vh, Wy : param matrices (n⨉m)
σ : activation functions.
Forward-pass
From input vector xt
compute hidden state ht : memory / context
current context is a non-linear function of past context ht-1 and current input xt
Output = linear transformation of the current ht
Learning
In case of RNN, we backpropagate through time (BPTT):
Total loss L is backpropagated through every time-step loss Lt
AND time-step losses Lt is backpropagated through Lt-1, Lt-2 ... L0
Potential issues
Gradients are too big (aka Gradient Explositon) : we are over-re-learning every single input
Gradients are too small (aka Gradient Vanishing) : we are learning only short-term dependency, forgetting long term context
use ReLU : if x>0 it prevents gradient shrinking
use gates : control relevant information by adding / removing it
RNN architectures
GRU : gated reccurent unit
zt : update gate (add or not?)
rt : reset gate (skip or not?)
✔️ solve a bit Gradient Vanishing
LSTM: long-short time memory
f -forget gate : is context important?
i -input gate : add x to context?
o -output gate : prediction
✔️ solve a bit Gradient Vanishing
Advanced improvements
Deep RNN
Bi-directional RNN
Deep recurrent networks :
level 0: simple RNN (GRU)
level 1+: input X = ht lvl0
❌Limitations of RNN:
Information bottleneck
Slow, no parallelization
Not really long term memory
RNN suffers Gradient Vanishing for very long sequences
💡 We would like to have:
Continuous stream of information
Parallelization
Long memory: don't forget what is important
Solution : Attention