In [1]:
import math
import random
import pickle
import requests
In [2]:
import numpy as np
In [3]:
import jax.numpy as jnp
import jax

GPT From Scratch¶

This is a continuation of this blog by Jay Mody. In Jay Mody's blog he bootstraps the forwards pass of a GPT (using "only numpy"), then loads in the weights for GPT2 and tries it out. It's a very good blog, but I have two major gripes with it:

  • I think it's longer than it should be, and a bit convoluted.
  • He didn't include any functional training code. Which is like giving a toy without batteries. In an attempt to fix this injustice, I'll try my hand at shortening his explanations and adding some training code.

Generating Parameters¶

You can load the original parameters in from GPT2 and play around with them. To avoid forcing the user to download tensorflow to convert the weights I have included them as an (insecure) pickle file. If you're running my notebook anyway it's not much of a security difference- but you should still be wary of these things.

In [192]:
# Load in GPT parameters, we won't be training this guy though
f = open("gpt.pickle", 'rb')
encoder, hparams, params = pickle.loads(f.read())
In [5]:
# Show what those look like
def show_params(head, indent_level=0):
    if isinstance(head, dict):
        for i in head:
            print("  "*indent_level, end='')
            print(i)
            show_params(head[i], indent_level=indent_level+1)
    elif isinstance(head, list):
        i = head[11]
        print("  "*indent_level, end='')
        show_params(i, indent_level=indent_level+1)
    else:
        print("  "*indent_level, end='')
        print(head.shape)
In [6]:
# Or we could make our own
def generate_random_parameters(embedding_length, vocab_size, context_size, num_blocks, proj_width=4):
    # What sort of random are we feeling today?
    initializer = np.random.normal

    # These ones are nice and simple aren't they?
    wpe = initializer(size=[context_size, embedding_length])
    wte = initializer(size=[vocab_size, embedding_length])

    # A little trickyer
    ln_f = {'b': np.random.normal(size=[embedding_length]), 'g': np.random.normal(size=[embedding_length])}

    # Okay here we go
    blocks = []
    for _ in range(num_blocks):
        # Build the multi-head attention
        c_attn = {'b': initializer(size=[3*embedding_length]), 'w': initializer(size=[embedding_length, 3*embedding_length])}
        c_proj = {'b': initializer(size=[embedding_length]), 'w': initializer(size=[embedding_length, embedding_length])}
        attn = {'c_attn': c_attn, 'c_proj': c_proj}

        # Build a multilayer perceptron
        ln_1 = {'b': initializer(size=[embedding_length]), 'g': initializer(size=[embedding_length])}
        ln_2 = {'b': initializer(size=[embedding_length]), 'g': initializer(size=[embedding_length])}
        
        mlp_c_fc = {'b': initializer(size=[proj_width*embedding_length]), 'w': initializer(size=[embedding_length, proj_width*embedding_length])}
        mlp_c_proj = {'b': initializer(size=[embedding_length]), 'w': initializer(size=[proj_width*embedding_length, embedding_length]) }
        
        mlp = {'c_fc': mlp_c_fc, 'c_proj': mlp_c_proj }

        # Finally build the block
        block = {'attn': attn, 'ln_1': ln_1, 'ln_2': ln_2, 'mlp' : mlp }

        # And add it to our list of blocks
        blocks.append(block)

    # Finally we have everything we need to make a new set of parameters
    new_params = { 'blocks': blocks, 'ln_f': ln_f, 'wpe': wpe, 'wte': wte }
    return new_params
In [7]:
random_params = generate_random_parameters(hparams['n_embd'], hparams['n_vocab'], hparams['n_ctx'], hparams['n_layer'], proj_width=4)

Updating Parameters¶

Once we calculate the gradient we'll need to update the parameters somehow. The storage system we're using for that makes it a bit difficult... but nothing too bad. And yes- I did write this out as a recursive function first but it was very slow.

In [8]:
# I want it on record that I *did not* design this param dict storage system...
def update_attention_block(og_attn_block, gd_attn_block, lr):
    for param1 in og_attn_block:
        for param2 in og_attn_block[param1]:
            if isinstance(og_attn_block[param1][param2], dict):
                for param3 in og_attn_block[param1][param2]:
                    og_attn_block[param1][param2][param3] -= gd_attn_block[param1][param2][param3]*lr
            else:
                og_attn_block[param1][param2] -= gd_attn_block[param1][param2]*lr
In [9]:
def update_parameters(og_params, gd_params, lr=0.1):
    for param1 in og_params:
        if isinstance(og_params[param1], list):
            for i in range(len(og_params[param1])):
                update_attention_block(og_params[param1][i], gd_params[param1][i], lr)
                
        elif isinstance(og_params[param1], dict):
            for param2 in og_params[param1]:
                og_params[param1][param2] -= gd_params[param1][param2]*lr
        else:
            og_params[param1] -= gd_params[param1]*lr

Utility Functions¶

GELU¶

This one is explaind best in this paper which I might cover one day. But for now, you can look at it like any other activation function.

In [ ]:
def gelu(x):
    return 0.5 * x * (1 + jnp.tanh(jnp.sqrt(2 / jnp.pi) * (x + 0.044715 * x**3)))

Softmax¶

I hear softmax described as "A function that turns arbitrary real values into a probability distribution." but all that really means is "Make sure all the numbers are between 0-1, and make them all add up to 1."

In [ ]:
def softmax(x):
    exp_x = jnp.exp(x - jnp.max(x, axis=-1, keepdims=True))
    return exp_x / jnp.sum(exp_x, axis=-1, keepdims=True)

Layer Normalization¶

This one's a bit more complicated, we take in some gaussian distributed dataset and perform a transformation on it to try and get the mean=0 and variance=1, so it looks like the normal distribution.

In [ ]:
def layer_norm(x, g, b, eps: float = 1e-5):
    mean = jnp.mean(x, axis=-1, keepdims=True)
    variance = jnp.var(x, axis=-1, keepdims=True)
    return g * (x - mean) / jnp.sqrt(variance + eps) + b

Linear Combination¶

This is just matrix multiplication + a bias. If you've worked with neural networks at all you've seen this guy.

In [ ]:
def linear(x, w, b):
    return x @ w + b

FFN¶

Now we have everything we need to create a simple 2 layer neural network, with GELU as our activation function.

In [ ]:
def ffn(x, c_fc, c_proj):
    return linear(gelu(linear(x, **c_fc)), **c_proj)

Attention¶

This can get pretty complicated despite being pretty small. We'll start with the basics

Simple Attention¶

I'll follow Jay Mody's "intuition" for this, but I'll be keeping it a lot shorter.

  • Assume we have a list of vector containg word vecotorization, we call it $K$

  • Assume we also have a vectorized word $q$.

  • Then we can get the similarity of $q$ to some row in $K$ by multiplying by $K$'s transpose $$qK^T$$

  • We want those as a "probablity distribution" so we'll shove them through $softmax$ like earlier, so now we have $$softmax(qK^T)$$

  • We're going to want to take a weighted sum of that because... well neural networks am I right? $$softmax(qK^T)V$$

  • That actually works on it's own, but the values can get a bit big, and floating points are still shitty (please e-mail me if they stop being shitty). So we'll scale it with the square-root of the length of $q$. $$ attention(q, K, V) = softmax(\frac{qK^T}{\sqrt{d_k}})V$$

  • If we have multple queries we can do them all at once by just swapping the vector for matrices $$ attention(Q, K, V) = softmax(\frac{QK^T}{\sqrt{d_k}})V$$

In the code you'll see that because $Q, K, V$ are all matrices, we can use a single matrix-multiplication to do all of them. You'll also see we apply a simple mask to them, that'll be useful in the next section.

In [ ]:
def attention(q, k, v, mask):
    return softmax(q @ k.T / jnp.sqrt(q.shape[-1]) + mask) @ v

Multi-Head Attention¶

We're gonna want to run this attention function more than once per transformer block. We call that "multi-head attention". It goes something like this:

  • First pass everything through a linear layer because why not.
  • Split X up into $Q, K, V$
  • Hide future tokens from previous tokens using that mask.
  • Apply attention for ever $[Q, K, V]$ to every $[Q, K, V]$
  • Stack them on top of eachother.
  • Another linear layer lmao.
In [ ]:
def mha(x, c_attn, c_proj, n_head):
    x = linear(x, **c_attn)
    qkv_heads = list(map(lambda x: jnp.split(x, n_head, axis=-1), jnp.split(x, 3, axis=-1)))
    causal_mask = (1 - jnp.tri(x.shape[0], dtype=x.dtype)) * -1e10
    out_heads = [attention(q, k, v, causal_mask) for q, k, v in zip(*qkv_heads)]
    x = linear(jnp.hstack(out_heads), **c_proj)
    return x

Transformer Block¶

And you can see for yourself, that a "transformer block" is just multi-head attention with a normal neural network slapped on top./

In [ ]:
def transformer_block(x, mlp, attn, ln_1, ln_2, n_head):
    x = x + mha(layer_norm(x, **ln_1), **attn, n_head=n_head)
    x = x + ffn(layer_norm(x, **ln_2), **mlp)
    return x

GPT¶

Here all we have to do is lookup the tokens, shove them through the transformer blocks, project it back to the input shape, normalize it, and call it a day.

In [12]:
def gpt2(inputs, wte, wpe, blocks, ln_f, n_head):
    # I know this line is ugly (and probably slow) but it's needed to make
    # this whole function differentiable, which is needed to, y'know
    # differentiate it
    x = jnp.array([ wte[i] + 0.0 for i in inputs ]) + wpe[0:len(inputs)]

    for block in blocks:
        x = transformer_block(x, **block, n_head=n_head)
        
    return layer_norm(x, **ln_f) @ wte.T

Generation¶

In [13]:
def _generate(inputs, params, n_head, n_tokens_to_generate):
    from tqdm import tqdm
    for _ in tqdm(range(n_tokens_to_generate), "generating"):
        logits = gpt2(inputs, **params, n_head=n_head)
        next_id = jnp.argmax(logits[-1])
        inputs.append(int(next_id))
    return inputs[len(inputs) - n_tokens_to_generate :]
In [14]:
def generate(encoder, hparams, params, prompt: str, n_tokens_to_generate: int = 40):
    input_ids = encoder.encode(prompt)
    assert len(input_ids) + n_tokens_to_generate < hparams["n_ctx"]
    output_ids = _generate(input_ids, params, hparams["n_head"], n_tokens_to_generate)
    output_text = encoder.decode(output_ids)
    return output_text
In [15]:
generate(encoder, hparams, params, "who are you")
generating:   0%|          | 0/40 [00:00<?, ?it/s]An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
generating: 100%|██████████| 40/40 [00:29<00:00,  1.36it/s]
Out[15]:
'?"\n\n"I\'m not sure. I\'m not sure if I\'m going to be able to do it. I\'m not sure if I\'m going to be able to do it. I'

Training Data¶

I'll use this random copy of Neuromancer I found on GitHub, because it's topical and because it's short.

In [16]:
neuromancer_raw = requests.get("https://gist.githubusercontent.com/m-242/ecb3e130b76a3b12f7ef41b04f486405/raw/8a3e992841f55f33b9836631b62ac0250b5fe7f8/neuromancer.txt")
neuromancer_raw = neuromancer_raw.content
neuromancer_raw[700:800]
Out[16]:
b'was tending bar, his prosthetic arm jerking monoto-\nnously as he filled a tray of glasses with draft'
In [17]:
# Fix up the encoding
neuromancer = str(neuromancer_raw, encoding='ascii')
neuromancer[500:600]
Out[17]:
's massive drug defi-\nciency." It was a Sprawl voice and a Sprawl joke. The Chatsubo\nwas a bar for pr'
In [18]:
# We'll take it nice and slow
# because we're gonna do a lot of string shit and that's slow
def preprocess_text(text):
    # Yeah I'm an anti-capitalist how could you tell?
    text = text.lower().replace("\n", " ").replace("- ", "").replace("'", " ' ")

    # I love my puter all my friends are in it 
    friends = "abcdefghijklmnopqrstuvwxyz  .'"
    
    # I HATE MY PUTER all my ENEMIES are in it
    clean = ""
    for i in text:
        if i in friends:
            clean += i

    # Period will be his own token, I'm not THAT insane
    clean = clean.replace(". ", " . ")
    
    # I remain ambivalent about my puter, it contains multitudes, friends, enemies, love, hate
    return clean

And I'll keep using the GPT2 encoder, because it's well supported already and because "encoding" (or tokenizing) could honestly be it's own notebook.

In [19]:
raw_training_data = preprocess_text(neuromancer)
tokenized_training_data = encoder.encode(raw_training_data)

Training¶

We're gonna keep it simple. None of this log-stuff, no batching, nothing complicated. We'll start by just saying "get the next token right". That's to say error=-output[next tokens id]. Of course this isn't the "right" way to do it, the right way is with crossentropy and batch processing and GPU support- but baby steps.

In [164]:
def lm_loss(p, inputs: list[int], n_heads=12) -> float:       
    x = list(inputs[:-1])
    y = inputs[-1]

    if len(x) < 9:
        print(x) 
        raise ValueError("FUCK") # you don't wanna know
             
    output = gpt2(x, **p, n_head=n_heads)
    loss = -output[-1][y]
    return loss
In [193]:
lm_loss(params, tokenized_training_data[500:550], n_heads=6)
Out[193]:
Array(69.26208, dtype=float32)

Then we'll use jax.grad to differentiate it. Jax is fucking great.

In [166]:
lm_loss_grad = jax.grad(lm_loss)
In [167]:
f = lm_loss_grad(params, tokenized_training_data[500:550], n_heads=6)

And we finally have everything we need to create at training function, a very simple and not very optimized one of course.

In [168]:
def train(p: dict, inputs: list[int], samples=20, min_sample_size=10, max_sample_size=20, lr=1e-5, n_heads=12) -> dict:
    print("training", end='')
    for _ in range(samples):
        # make sure the user known we're still alive. and that when they are dead
        # we will be still alive.
        print(".", end='')
        
        # Take a random sample
        i = random.randint(0, len(inputs) - max_sample_size -1)
        sample_size = random.randint(min_sample_size, max_sample_size)
        x = inputs[i:i+sample_size]
        
        # Compute it's gradient
        gd = lm_loss_grad(p, x, n_heads=n_heads)

        # Update it
        update_parameters(p, gd, lr=lr)
        
    print("\n")

Testing It Out¶

We'll start by making a really tiny network. No train-test split, I just want to see if it's learning anything at all. So I'll try and get it to optimize to a 10-token dataset.

In [181]:
n_hparams = {'n_vocab': 50257, 'n_ctx': 128, 'n_embd': 120, 'n_head': 6, 'n_layer': 3}
#n_hparams = hparams

params = generate_random_parameters(
    embedding_length=n_hparams["n_embd"],
    vocab_size=n_hparams["n_vocab"],
    context_size=n_hparams['n_ctx'],
    num_blocks=n_hparams['n_layer'],
    proj_width=4
)
                                    
generate(encoder, n_hparams, params, "who are you", n_tokens_to_generate=10)
generating: 100%|██████████| 10/10 [00:00<00:00, 40.80it/s]
Out[181]:
' Sith 77 teachwrap Alphwrap millennia Alph Alphwrap'

Let's check how well random values do

In [182]:
lm_loss(params, tokenized_training_data[500:520], n_heads=6)
Out[182]:
Array(-1.8531859, dtype=float32)

Then we'll train it on our 10-token-sentence.

In [189]:
train(params, tokenized_training_data[500:520], samples=20, min_sample_size=10, max_sample_size=10, lr=0.1, n_heads=6)
training....................

And see if that loss went down any

In [190]:
lm_loss(params, tokenized_training_data[500:520], n_heads=6)
Out[190]:
Array(-2230.628, dtype=float32)

Well the loss went down! Let's see what he has to say about things:

In [191]:
generate(encoder, n_hparams, params, "he scratched his overhang of ", n_tokens_to_generate=10)
generating: 100%|██████████| 10/10 [00:00<00:00, 57.01it/s]
Out[191]:
' his whites his whites his whites his whites his whites'

For reference the original sentence was this

In [194]:
encoder.decode(tokenized_training_data[500:520])
Out[194]:
' ratz grunted the sound served him as laughter . he scratched his overhang of whiteshirted'

So it's clear the network saw his and whites and is trying to optimize towards it. That means I didn't mess up my parameter update function, yay!

Looking At The Vocabulary¶

In [131]:
encoder.decode(tokenized_training_data[500:520])
Out[131]:
' ratz grunted the sound served him as laughter . he scratched his overhang of whiteshirted'
In [31]:
hist = {}
for i in tokenized_training_data:
    try:
        hist[i] += 1
    except KeyError as e:
        hist[i] = 1
histy = zip(hist.keys(), hist.values())
histy = list(histy)
histy.sort(key = lambda x: x[1], reverse=True)
histy[:20]
Out[31]:
[(764, 7183),
 (262, 5349),
 (705, 2446),
 (286, 2327),
 (257, 2215),
 (290, 1638),
 (339, 1594),
 (284, 1328),
 (287, 1212),
 (345, 1185),
 (340, 1110),
 (264, 1082),
 (465, 1066),
 (220, 1032),
 (373, 971),
 (1312, 864),
 (1339, 845),
 (326, 720),
 (531, 665),
 (607, 641)]
In [ ]: