Training a simple bigram character level model on tiny stories
Building a Bigram Language Model: A Step-by-Step Guide to Character-Level Text Generation
I wrote this small snippet as part of my learning process from Andrej’s video (link).
What is a Bigram Language Model?
A bigram language model predicts the next character in a sequence based solely on the current character. It’s called “bigram” because it considers pairs of characters (bi = two).
The model learns a probability distribution over all possible next characters given the current character, essentially building a lookup table that says “when I see character X, what’s the most likely next character?”
Dataset Preparation and Text Loading
Our journey begins with loading and examining our text data:
import pandas as pd
import numpy as np
with open('stories.text','r',encoding='utf-8') as f:
text=f.read()
print(text[:500])
The Tiny Stories dataset contains simple, child-friendly stories that are perfect for training language models. We load the entire text file into memory as a single string. This approach works well for smaller datasets, though larger datasets would require more sophisticated data loading strategies.
Character-Level Tokenization
Unlike word-based models, our character-level approach treats each individual character as a token. This has several advantages:
- Simplicity: No need for complex word segmentation
- Robustness: Can handle any text, including typos and rare words
- Fine-grained control: Learns spelling patterns and character relationships
Let’s build our character vocabulary:
chars=sorted(list(set(text)))
vocab_size=len(chars)
print(chars)
print(vocab_size)
print(''.join(chars))
['\n', ' ', '!', '"', '#', '$', '&', "'", '(', ')', '*', '+', ',', '-', '.', '/', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', ';', '<', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '¦', '©', '\xad', '±', '´', 'Â', 'Ã', 'â', 'ð', 'œ', 'Š', 'Ÿ', 'Ž', '˜', '“', '”', '‹', '€', '™']
101
!"#$&'()*+,-./0123456789:;<?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz¦©±´ÂÃâðœŠŸŽ˜“”‹€™
This code extracts all unique characters from our text and sorts them alphabetically. The vocabulary size tells us how many different characters our model needs to handle. Typically, this includes letters (both cases), numbers, punctuation, and whitespace characters.
Building the Tokenizer
Tokenization is the process of converting text into numerical representations that neural networks can process. We create two essential mappings:
stoi={ch:i for i,ch in enumerate(chars)} # string to integer
itos={i:ch for i,ch in enumerate(chars)} # integer to string
encoder = lambda s: [stoi[c] for c in s]
decoder = lambda l: ''.join([itos[i] for i in l])
The stoi
(string-to-integer) dictionary maps each character to a unique integer ID, while itos
(integer-to-string) provides the reverse mapping. Our encoder and decoder functions handle the conversion between text and numerical sequences.
Testing our tokenizer:
exa="My name is Harshwardhan"
output=encoder(exa)
print(output)
print(decoder(output))
[42, 80, 1, 69, 56, 68, 60, 1, 64, 74, 1, 37, 56, 73, 74, 63, 78, 56, 73, 59, 63, 56, 69]
My name is Harshwardhan
This verification step ensures our encoding and decoding process is lossless - we can convert text to numbers and back to the original text perfectly.
Converting to PyTorch Tensors
Neural networks work with tensors, so we convert our encoded text into a PyTorch tensor:
import torch
data = torch.tensor(encoder(text))
print(data.shape)
print(data[:20])
torch.Size([19212308])
tensor([48, 71, 70, 75, 14, 1, 48, 71, 70, 75, 1, 74, 56, 78, 1, 75, 63, 60,
1, 74])
The resulting tensor contains integer indices representing each character in our text. The shape tells us the total length of our dataset, while examining the first 100 elements helps us verify the conversion worked correctly.
Dataset Splitting
Machine learning requires separate training and validation sets to properly evaluate model performance:
n=int(0.9*len(data))
train=data[:n]
validate=data[n:]
We use a 90-10 split, dedicating 90% of our data to training and 10% to validation. The validation set helps us monitor whether our model is learning genuine patterns or simply memorizing the training data (overfitting).
Understanding Context Windows
Language models don’t process entire texts at once. Instead, they work with fixed-size context windows.
To give a context of what I am trying to say, here’s a snippet you can run to get an idea
block_size=8
x = train[:block_size]
y = train[1:block_size+1]
for t in range(block_size):
context = x[:t+1]
target = y[t]
print(f"when input is {context} the target: {target}")
when input is tensor([48]) the target: 71
when input is tensor([48, 71]) the target: 70
when input is tensor([48, 71, 70]) the target: 75
when input is tensor([48, 71, 70, 75]) the target: 14
when input is tensor([48, 71, 70, 75, 14]) the target: 1
when input is tensor([48, 71, 70, 75, 14, 1]) the target: 48
when input is tensor([48, 71, 70, 75, 14, 1, 48]) the target: 71
when input is tensor([48, 71, 70, 75, 14, 1, 48, 71]) the target: 70
This code demonstrates a crucial concept: from a single sequence of length 8, we can create 8 different training examples. Each example uses a progressively longer context to predict the next character:
- Given just the first character, predict the second
- Given the first two characters, predict the third
- And so on…
This approach maximizes the learning opportunities from our data and teaches the model to work with contexts of varying lengths.
Batch Processing for Efficient Training
Neural networks train more efficiently when processing multiple examples simultaneously. Our batch generation function creates random samples:
batch_size=4
block_size=8
def get_batch(split):
data=train if split=='train' else validate
ix=torch.randint(len(data)-block_size,(batch_size,))
x=torch.stack([data[i:i+block_size] for i in ix])
y = torch.stack([data[i+1:i+block_size+1] for i in ix])
return x,y
This function randomly selects starting positions in our dataset and extracts sequences of length block_size
. The result is two tensors:
-
x
: Input sequences (what the model sees) -
y
: Target sequences (what the model should predict)
The random sampling ensures our model sees different parts of the text in each batch, promoting better generalization.
The Bigram Language Model Architecture
Now we build our neural network. Despite its simplicity, this model embodies key language modeling concepts:
class BigramLanguageModel(nn.Module):
def __init__(self, vocab_size):
super().__init__()
self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)
The core of our model is an embedding table - essentially a learned lookup table where each character is associated with a vector of probabilities for the next character. The embedding dimension equals our vocabulary size, creating a direct mapping from current character to next character probabilities.
Forward Pass and Loss Calculation
The forward pass transforms input sequences into predictions and calculates the training loss:
def forward(self, idx, targets=None):
logits = self.token_embedding_table(idx) # (B,T,C)
if targets is None:
loss = None
else:
B, T, C = logits.shape
logits = logits.view(B*T, C)
targets = targets.view(B*T)
loss = F.cross_entropy(logits, targets)
return logits, loss
The embedding table produces “logits” - raw prediction scores for each possible next character. When we have targets (during training), we calculate cross-entropy loss, which measures how well our predictions match the actual next characters.
The reshaping operations (view
) are necessary because PyTorch’s cross-entropy function expects 2D inputs, but our model produces 3D tensors (batch, time, characters).
Text Generation
The generation function is where our trained model becomes useful:
def generate(self, idx, max_new_tokens):
for _ in range(max_new_tokens):
logits, loss = self(idx)
logits = logits[:, -1, :] # focus on last time step
probs = F.softmax(logits, dim=-1) # convert to probabilities
idx_next = torch.multinomial(probs, num_samples=1) # sample
idx = torch.cat((idx, idx_next), dim=1) # append
return idx
This function implements autoregressive generation:
- Get predictions for the current sequence
- Focus only on the last position (most recent character)
- Convert logits to probabilities using softmax
- Sample a character based on these probabilities
- Add the sampled character to our sequence
- Repeat
The sampling step is crucial - rather than always picking the most likely character (which would be deterministic and repetitive), we sample according to the probability distribution, introducing controlled randomness that makes the generated text more interesting and varied.
Training Loop
Training a neural network involves repeatedly showing it examples and adjusting its parameters to reduce prediction errors:
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)
batch_size = 32
for steps in range(10000):
xb, yb = get_batch('train')
logits, loss = m(xb, yb)
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
Each training step follows a standard pattern:
- Forward pass: Feed data through the model to get predictions
- Loss calculation: Compare predictions to actual targets
- Backward pass: Calculate gradients showing how to improve
- Parameter update: Adjust model weights to reduce loss
We use the AdamW optimizer, which adapts the learning rate for each parameter individually, leading to more stable and efficient training than basic gradient descent.
Monitoring Progress
Before training, our model generates mostly gibberish:
print(decoder(m.generate(idx = torch.zeros((1, 1), dtype=torch.long), max_new_tokens=100)[0].tolist()))
Sha inth ge jonin out, peroamy aveppedan s lld het
After 10,000 training steps, the same generation call produces much more coherent text. The loss value also decreases significantly, indicating that our model is learning the character patterns in our dataset.
Key Insights and Limitations
Our bigram model, while simple, demonstrates several important concepts:
Strengths:
- Simplicity: Easy to understand and implement
- Speed: Fast training and inference
- Foundational: Introduces core language modeling concepts
Limitations:
- Limited context: Only considers the immediately previous character
- No long-range dependencies: Cannot capture relationships between distant characters
- Basic patterns: Learns simple character transitions but misses complex linguistic structures
Here is the link to the Google Colab Notebook -
Enjoy Reading This Article?
Here are some more articles you might like to read next: