Word Embeddings¶
I've recently taken on a project that could benefit from word similarity checking. So I would like to make some word-vectorization system. That's been done before, so I'm essentially just going to steal the old skip-gram, and implement it from "scratch" (I use that term lightly). There's really only two constraints I have here:
- The model has to be able to do a forwards and backwards pass on my laptop.
- The embedding cannot be too large- because I would like to use it on my website.
The only thing I need this to actually do is check the similarity of words, but it would be fun if we got a full vectorspace out of it.
from functools import cache
from random import randint
import pickle
import jax
import jax.numpy as jnp
from jax import jit
import numpy as np
from matplotlib import pyplot as plt
Loading Training Data¶
Like a lot of word-vectorization projects, we're going to be making use of text8
, because the pre-processing is already done and it's more than big enough for this. The data comes as all-lower-case and with no punctuation, so the only thing we have to do is build a vocabulary out of it. I'll do that the slow way with Python's built in set
class.
def make_vocab(text):
vocab = list(set(text))
return vocab
We'll be using matrix multiplication later to select the word embeddings. It's often much faster just to use the index of the token, i.e word_embeddings[token_id]
, but that seems to mess with the JAX autograd implementation, and will throw off our gradient. So we'll need this:
def token_to_one_hot(token: int, vocab_size: int):
z = np.zeros([vocab_size])
z[token] = 1.0
return jnp.array(z)
def load_text(path: str = "./text8"):
with open(path, 'r') as opn:
l = opn.read()
return l.split()
sample_data = load_text()
sample_data = sample_data[:int(0.1*len(sample_data))]
sample_vocab = make_vocab(sample_data)
I was trying not to use too many globals so that you could just copy-paste the functions out of this notebook to make a decent library, but you can't use the @cache
decorator on non-hashable types (like the vocab list), and without the cache this function would be worst-case $O(n)$, where $n$ is the size of our vocab.
@cache
def tokenize_word(word):
return sample_vocab.index(word)
def tokenize_data(input_data):
for i in range(len(input_data)):
if i % 1700520 == 0:
print(f"{i}/1700520")
input_data[i] = tokenize_word(input_data[i])
return input_data
tokenized_text = tokenize_data(sample_data)
Let's just make sure our tokenization/detokenization is working properly...
for i in sample_data[0:20]:
print(sample_vocab[i], end=' ')
anarchism originated as a term of abuse first used against early working class radicals including the diggers of the english
Pre-Processing¶
This is actually where most of the work is. We're going to train the network to take in some word and predict the probabilities of all the other words being next to it. That means for each word we need to go through the dataset and count the occurence of other words nearby it. Because we'll be passing this through the network with batch training, we'll want it in the format of a n_vocab
by n_vocab
matrix. Unfortunately n_vocab
is very big, and if we made one of those my Fujitsu laptop would run out of RAM (ask me how I know). Instead we'll store them into a hashmap of the form hashmap[token] = [# times that token was next to the word]
(and make a list of those). Then we'll pass it through a different function to convert it back into a vector (right before we use it).
This hashmap is actually very useful in it's own right, and I could use it to check the similarity of words for my project. Unforunately it's also very big and so doesn't adhere to the second requirement.
def create_word_frequency_count(input_data, vocab):
# How far we'll look to the left/right of the word for other words
window_size = 5
# Let's start by making an empty list of dictionaries. If we used a n_vocab by n_vocab matrix
# we'd run out of memory instantly.
word_frequencies = []
for _ in range(len(vocab)):
word_frequencies.append({})
for i in range(window_size+1, len(input_data) - window_size-1):
token = input_data[i]
for j in range(-window_size, window_size+1):
if j == 0: continue
# Get the token of the words in our window
neighboring_word = input_data[i + j]
# In Python, it's faster to try and crash-out than to check first
# isn't that funny?
try:
word_frequencies[token][neighboring_word] += 1.0
except KeyError as e:
word_frequencies[token][neighboring_word] = 1.0
return word_frequencies
def softmax(x):
e_x = jnp.exp(x - jnp.max(x))
return e_x / e_x.sum(axis=0)
And this is that function that turns the hashmap back into a vector (representing a probability distribution) so we can put it through our network.
def dict_to_arr(d: dict, vocab_size: int):
z = np.zeros([vocab_size])
for i in d:
z[i] = d[i]
z = softmax(z)
return z
def batch_data(word_frequencies: dict, vocab_size: int, batch_size=10):
batches = []
batch = []
batch_tokens = []
for i in range(1, len(word_frequencies)):
batch.append(dict_to_arr(word_frequencies[i], vocab_size))
batch_tokens.append(token_to_one_hot(i, vocab_size))
if i % batch_size == 0:
batches.append([
jnp.vstack(batch_tokens),
jnp.vstack(batch)
])
batch = []
batch_tokens = []
return batches
word_frequency_count = create_word_frequency_count(sample_data, sample_vocab)
Now that we've made the frequency count dictionary, we can deallocate the original sample data (which is good because it's pretty big)
# RECLAIM WHAT WAS ONCE OURS (RAM mostly)
del sample_data
batched_sample_data = batch_data(word_frequency_count, len(sample_vocab))
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
batched_sample_data[0][1].shape
(10, 70889)
Forwards Pass¶
The forwards pass of this network is also pretty simple. All we have to do is:
- Take some token.
- One-hot encode it (done outside the function to make differentiation easier).
- Use the one-hot encoding to look up the embedding.
- Project the embeddings up to the size of the original vocabulary.
- Use softmax to turn that projection into a probability distribution.
def forward(one_hot_encoded_inputs: jnp.array, embeddings: jnp.array, weights: jnp.array):
x = jnp.dot(one_hot_encoded_inputs, embeddings)
x = jnp.dot(x, weights)
return softmax(x)
Parameters¶
All we'll need is two matrices, one for holding the embeddings of each token (think of it less like a matrix, and more like a list of vectors), and one for projecting those embeddings up to the size of vocabulary. We won't add a separate matrix for biasing, because we want any biasing to be done in our embeddings (since thats the only part of the network that will make it out of this notebook).
def gen_params(vocab_size: int, embedding_size: int):
embeddings = np.random.uniform(-1.0, 1.0, size=[vocab_size, embedding_size])
weights = np.random.uniform(-1.0, 1.0, size=[embedding_size, vocab_size])
return {"weights": weights, "embeddings": embeddings}
def update_params(p1, grad_p1, lr=0.01):
for param in p1:
p1[param] -= lr * grad_p1[param]
return p1
forward(batched_sample_data[0][0], test_params["embeddings"], test_params["weights"]).shape
(10, 70889)
Loss¶
Technically, this is something of a probability distribution in discrete classes, so crossentropy would be a better loss function. Unfortunately, JAX handles it somewhat poorly (or I've been using it wrong), and the jnp.log2
function shows zero regard for floating points numbers. Because of that we'll use MSE, which frankly works just fine on a network of this size and simplicity. Honestly I get the feeling I could use almost any loss functions as long as it expresses some kind of distance between the prediction distribution and the actual distribution.
def batch_loss(params: dict, batch_x: jnp.array, batch_y: jnp.array):
prediction = forward(batch_x, params["embeddings"], params["weights"])
se = jnp.sum((batch_y - prediction)**2, axis=1)
mse = jnp.mean(se)
return mse
batch_loss(test_params, batched_sample_data[0][0], batched_sample_data[0][1])
Array(3982.1328, dtype=float32)
batch_loss_grad = jax.grad(batch_loss)
test_g = batch_loss_grad(test_params, batched_sample_data[0][0], batched_sample_data[0][1])
This graph is pretty much useless, but it looks cool- and it's always fun to see what's going on inside the ol' blackbox.
plt.imshow(np.array(test_g['embeddings'][1].reshape([10, 10])))
plt.show()
Train¶
See this is why we jumped through a couple hoops earlier. Now we can just use jax.grad
to differentiate the loss function. Meaning all we have to worry about is calling it, and applying the gradient. Handy right?
def batch_train(params: dict, batches: list, epochs=10):
for _ in range(epochs):
for i, batch in enumerate(batches):
batch_tokens = batch[0]
batch_prob_dist = batch[1]
g = batch_loss_grad(params, batch_tokens, batch_prob_dist)
update_params(params, g, lr=0.01)
batch_train(test_params, batched_sample_data, epochs=20)
batch_loss(test_params, batched_sample_data[0][0], batched_sample_data[0][1])
Array(722.1161, dtype=float32)
Playing Around¶
What use is making word vectors if we don't use them for anything? Let's start by writing some quick utilities to vectorize words more easily, I'm starting to get carpel tunnel from all these notes you know?
def embed_word(word):
return test_params["embeddings"][sample_vocab.index(word)]
def unembed_word(vector):
return sample_vocab[np.argmax(np.dot(vector, test_params["embeddings"].T))]
man = embed_word("man")
apple = embed_word("apple")
woman = embed_word("woman")
dog = embed_word("dog")
cat = embed_word("cat")
computer = embed_word("computer")
science = embed_word("science")
chemistry = embed_word("chemistry")
physics = embed_word("physics")
Then let's define a simple test to see how similar two words are. The classic case is of course cosine similarity, so I'll use that, then give it a little tweak to make it easier for me to read.
def cosine_similarity(A, B):
return np.dot(A,B)/(np.linalg.norm(A)*np.linalg.norm(B))
def test(a, b):
return cosine_similarity(a, b) * 100
Okay, let's see how related certain words are:
test(chemistry, physics)
21.926777064800262
test(apple, computer)
-1.7913002520799637
The dataset is pretty old, so that correlation might not come accross just yet. Let's try something more universal, it was always my opinion that women were more likely to have cats, and men were more likely to have dogs. Let's see if the network agrees:
test(man, cat)
-0.5196941550821066
test(woman, cat)
-0.5642959382385015
test(man, dog)
-11.501768976449966
test(dog, woman)
-4.738159850239754
test(woman, man)
-2.9907846823334694
Of course, the network looks for word assosciations, so really what this means is that people talk about cats and dogs in the same context as woman more often than they talk about cats and dogs in the context of men, and accross the board they talk about cats in relation to the word "woman" or "man" more than they talk about dogs.
test(cat, dog)
3.7614338099956512
test(apple, man)
-0.04189200117252767
These correlations aren't entirely perfect of course, but you can see that they do have some genuine sense to them! Considering the size and simplicity of the network I'm actually surprised. I figure with so little data on such a small embedding size we wouldn't get much of a end-user-experience, but this can actually be applied to something. Which is good, because I cooked it up with a certain something in mind.
Let's test it for some mathematical purity. If it's really a vector space, then if we normalize two opposing vectors and add them, we should get a vector of length zero.
w1, w2 = "big", "small"
v1 = embed_word(w1)/np.linalg.norm(embed_word(w1))
v2 = embed_word(w2)/np.linalg.norm(embed_word(w2))
np.linalg.norm(v1 - v2)
1.427628
Well that might have been a long shot- not that I had any practical use for it.