xrcyz

Toy LSTM in p5js

I made a toy LSTM to help me understand the concepts underlying Long-Short Term Memory recurrent neural networks. It solves a simple task - given a string of characters, predict the next character according to the Reber Grammar.

Reber Grammar diagram

The interesting part of this project is that I derived the weights by hand, without using back-propagation, to see if I could gain an insight into how an LSTM makes its predictions.

If you want to follow along in the animation, you can do it here.

LSTM diagram

The internal logic of this LSTM is remarkably simple. You could implement it in Microsoft Excel.

Input

The input array is a one-hot vector representing the last character of the string. At time zero, we initialise this to the letter B.

let input   = [1,0,0,0,0,0,0]; //current token [B,T,S,X,P,V,E]

Hidden State

Throughout this section I will use to node[i] to refer to the index of the current node in the Reber Grammar graph.

Memory State

The memory array is a one-hot vector representing the state of the Reber Grammar (the index of the current node in the graph).

let memory  = [0,0,0,0,0,0]; //current node in the graph: [0,1,2,3,4,5]

Memory Updates

The eraser, writer, and filter arrays are responsible for updating the memory vector. This is equivalent to crossing an edge in the graph and moving to a new state.

let eraser  = [0,0,0,0,0,0]; //what to erase in memory
let writer  = [0,0,0,0,0,0]; //what to write to memory
let filter  = [1,1,1,1,1,1]; //filter the writer when it returns multiple write values

Whenever the LSTM receives a new input, the eraser, writer, and filter perform a test on each node to see if its preconditions have been met. For example, here is the test for node[0]:

eraser[0] = 0;                  //always reset
writer[0] = Math.tanh(5 * B);   //if we get a B, then 100% we arrived at node 0
filter[0] = 1;                  //never filter

From there we might move to node[1]. This node is interesting for two reasons:

This is where the eraser and filter tests come into play:

eraser[1] = 1 / (1 + exp(-10 * (0.5 - X)));           //reset on X (exit node 1)
writer[1] = Math.tanh(5 * T);                         //if we get a T, then ring the bell for node[1]
filter[1] = 1 / (1 + exp(-30 * (0.75 - memory[5])));  //but only if we are not on node 5

Note that I am taking moderate care to ensure that the firing threshold for each node in memory is approximately 1.0. This becomes important in the reader layer, when we need to test multiple nodes to predict a single character. A node may accumulate evidence when precedent edges are crossed, but is not considered ‘active’ until it crosses the firing threshold.

With that in mind, let us now consider node[2], which may be reached by T(SSS)X or X(TTT)VP. In the writer I use a factor of 0.55 to increment the memory by half if a precedent edge T, X, or P is crossed. This covers the cases for TX (via node 1) and XP (via node 5). Finally, the filter is set to ignore T when it loops on node[5].

eraser[2] = 1 / (1 + exp(-30 * (0.65 - memory[2])));   //reset on exit
writer[2] = Math.tanh(0.55 * (T + X + P));             //breadcrumbs to node 2
filter[2] = 1 / (1 + exp(-30 * (0.65 - memory[5])));   //do not increment from node 5

With node[3], I apply weights to S and V to catch the precedent sequences S and VV. The eraser resets on P to avoid sequence VPS, and the filter blocks the looping S on node[1] to avoid sequence SSS.

eraser[3] = 1 / (1 + exp(-10 * (0.5 - P)));            //reset on P
writer[3] = Math.tanh(3.0 * S + 0.55 * V);             //breadcrumbs to node 3
filter[3] = 1 / (1 + exp(-10 * (0.7 - memory[1])));    //do not increment from node 1

node[4] presents a refreshingly simple case, where we can increment on V and immediately erase/filter on exit.

eraser[4] = 1 / (1 + exp(-10 * (0.6 - memory[4])));     //reset on exit
writer[4] = Math.tanh(5 * V);                           //breadcrumbs to node 4
filter[4] = 1 / (1 + exp(-10 * (0.6 - memory[4])));     //filter V on exit

node[5] filters inputs from node[1], allowing it to trigger on sequences BP and X. The eraser resets the state on exiting across edge V.

eraser[5] = 1 / (1 + exp(-10 * (0.5 - S - V)));         //reset on S,V
writer[5] = Math.tanh(0.55 * B + 0.7 * P + 5 * X);      //breadcrumbs to node 5
filter[5] = 1 / (1 + exp(-30 * (0.65 - memory[1])));    //do not increment from node 1

Finally, we update the memory

for(let i = 0; i < memory.length; i++) 
{ 
    memory[i] = memory[i] * eraser[i] + writer[i] * filter[i]; 
}

Readout

The reader layer outputs a one-hot vector representing the probability of yielding a specified character. Each element of the reader performs a test of the memory vector, to determine if the current state could yield a given character.

reader[0] = 0; //we never yield B
reader[1] = Math.tanh(5 * (memory[0] + memory[5] - 0.7)); //T may yield from 0 or 5
reader[2] = Math.tanh(5 * (memory[1] + memory[2] - 0.7)); //S may yield from 1 or 2
reader[3] = Math.tanh(5 * (memory[1] + memory[2] - 0.7)); //X may yield from 1 or 2
reader[4] = Math.tanh(5 * (memory[0] + memory[4] - 0.7)); //P may yield from 0 or 4 
reader[5] = Math.tanh(5 * (memory[4] + memory[5] - 0.7)); //V may yield from 4 or 5
reader[6] = Math.tanh(5 * (memory[3]             - 0.7)); //E may yield from 3

As mentioned earlier, the choice of weightings in the eraser, writer, and filter come into play when we start adding nodes in the reader. If there is a chance that two nodes are both partially activated, we want their sum to be less than the value of a single fully activated node.

Discussion

This LSTM has some interesting features/flaws that are worthy of discussion.

  1. The internal logic assumes that sequences always start with B.
  2. Long loops on T from node[5] cause false positives on node[1].
  3. The program running on the LSTM is frustratingly obfuscated.
  4. Snapping to vertices may prevent drift in long sequences and help normalize inputs.

Initialising State

I find it interesting to consider that real-world data will start with the system is some unknown hidden state. The LSTM should be robust enough to recover from uncertainty when there are multiple activated nodes in memory, or when presented with an invalid input-memory pair.

Long Loops

When the system loops on T from node[5], a false positive is generated in writer[1]. This gets filtered by filter[1], but the filter returns a slightly positive value - so with enough loops, node[1] will achieve false activation.

This is an interesting error because I can see it popping up in real-world applications. Improbable sequences are less likely to be represented in the training data, and so errors on said sequences are less likely to be caught and corrected during training.

Obfuscation

There were multiple times while writing this blog post that I thought, ‘neural nets would make a great obfuscation tool’.

It has me thinking, maybe instead of training for predictions, I should be training a model to output the parameters of a finite state machine. Then at least it might be possible to reverse engineer the program without having to read off the weights.

Snapping to vertices

Consider the following:

Critically, we don’t want the memory drifting in weird orbits if we go into a long loop (cycling PXV for example). We want to snap to the vertices of the hyperpolygon so that the reader correctly classifies the hidden state. With this in mind, I can see why it may make sense to preprocess the inputs into one-hot or binary vectors, to limit the range of outputs on the eraser, writer, and filter. There might even be a case for doubling the memory size and rounding the memory elements after every state change (one slot per breadcrumb).

Future Work

For the sake of completeness I should compare this solution to a trained LSTM and compare the differences. TBA.