Before You Read the Transformer Paper — The Math You Actually Need
I wrote this because I kept bumping into words in the transformer article — “matrix multiplication”, “loss function”, “backpropagation”, “neuron” — and either skipping over them or half-understanding them. This is me going back and filling every single gap. No assumed knowledge. No skipping steps.
Read this before the transformer article. Every concept here will be used there.
0. What even is a number in this context?
Before anything else — what does a number represent in ML?
In normal programming, a number is a count, a price, an ID. In ML, numbers represent meaning — properties of things.
When a model sees the word “cat”, it doesn’t see the letters C-A-T. It sees something like:
[0.2, 0.9, 0.1, 0.7, 0.3, ...] ← 512 numbers
Each number encodes some aspect of what “cat” means. The whole field of ML is about learning which numbers best capture meaning, and doing math on those numbers to produce useful outputs.
Everything in this article is just explaining what kinds of math, and why.
1. What is a neural network — and what is “learning”?
This is the question nobody stops to answer directly before throwing you into matrices and gradients. Let’s answer it properly.
The problem that neural networks solve
Traditional programming works like this:
you write the rules
↓
rules + input → output
You decide the logic. You write the if statements. The computer executes them.
But some problems don’t have rules you can write. How do you write a rule for “is this image a cat or a dog”? Or “what’s the right translation for this sentence”? Or “what’s the next word someone would say”?
These problems have patterns — but the patterns are too complex, too numerous, too fuzzy to express as hand-written rules.
Neural networks solve this with the opposite approach:
you give examples (input → correct output)
↓
the network figures out the rules itself
You don’t write the logic. You show the network thousands or millions of examples, and it discovers its own internal representation of the patterns. The “rules” end up encoded as numbers — weights — inside the network.
What “learning” means mathematically
Learning = finding the set of weights that makes the network produce the right output for any given input.
More precisely: finding weights that minimize the error across all your training examples.
That’s it. The entire field of deep learning — billions of dollars of compute, thousands of researchers — is fundamentally about solving this one optimization problem: find the numbers (weights) that minimize the error (loss).
Everything else — layers, attention, transformers — is about building the right structure for those weights to live in, so they can represent the patterns you care about.
2. What is a neuron?
Before layers and matrices, understand what a single neuron does. Everything else builds on this.
A neuron takes multiple inputs, multiplies each by a weight, sums them up, adds a bias, then passes the result through an activation function.
inputs: x₁, x₂, x₃
weights: w₁, w₂, w₃
bias: b
neuron output = activation( w₁x₁ + w₂x₂ + w₃x₃ + b )
Concretely:
x₁ = 0.5 (maybe: "how royal is this word?")
x₂ = 0.9 (maybe: "how human is this word?")
x₃ = 0.2 (maybe: "how animate is this word?")
w₁ = 1.2, w₂ = 0.8, w₃ = -0.3, b = 0.1
sum = (1.2×0.5) + (0.8×0.9) + (-0.3×0.2) + 0.1
= 0.6 + 0.72 - 0.06 + 0.1
= 1.36
output = ReLU(1.36) = 1.36 ← positive, passes through
What does a neuron detect?
The weights determine what combination of inputs the neuron responds to. A neuron with high weights on “royal” and “human” fires strongly for words like “king” and “queen”. A neuron with a negative weight on “animate” fires less for objects.
Nobody programs this. The weights are learned. The network discovers on its own which combinations of inputs are useful to detect.
A layer is just many neurons running in parallel.
A layer with 512 neurons takes the same input vector and runs 512 of these computations simultaneously — each with its own set of weights — producing 512 outputs. That’s the entire “layer” concept.
3. Vectors — a list of numbers with meaning
A vector is just a list of numbers.
[3, 7] ← 2-dimensional vector
[1, 4, 9] ← 3-dimensional vector
[0.2, 0.9, 0.1] ← 3-dimensional vector
“Dimensional” just means how many numbers are in the list. A 512-dimensional vector is a list of 512 numbers.
Vectors as points in space
The intuition: a vector is a point in space.
A 2D vector [3, 7] is a point on a flat map — 3 right, 7 up. A 3D vector [1, 4, 9] is a point in a room — 1 right, 4 forward, 9 up.
A 512-dimensional vector is a point in 512-dimensional space. You can’t visualize it, but the math works exactly the same.
Why this matters
Words in a transformer are 512-dimensional vectors. Similar words end up as points that are close together in that space. “Cat” and “dog” are nearby. “Cat” and “democracy” are far apart.
This means you can do geometry on words:
king - man + woman ≈ queen— vector arithmetic preserves semantic relationships- Distance between vectors = how different two words are
- Direction of a vector = what “kind” of thing it is
Why 512 dimensions and not 10 or 10,000?
More dimensions = the model can encode more distinct properties of a word — more nuance, richer meaning.
But more dimensions also means:
- More weights to learn (more compute, more data needed)
- More memory at inference time
- Diminishing returns past a certain point
512 was the sweet spot the original paper found. Modern models use 4096+ dimensions because they have far more compute and data.
4. Matrices — a grid of numbers
A matrix is a rectangular grid of numbers. Rows and columns.
┌ 1 2 3 ┐
│ 4 5 6 │ ← a 2×3 matrix (2 rows, 3 columns)
└ ┘
Matrices store transformations — rules for converting one vector into another vector.
Every weight matrix in a transformer (Wq, Wk, Wv, W1, W2…) is a matrix. A big grid of numbers the model learned.
Matrix multiplication — the most important operation in ML
Matrix multiplication takes a vector and a matrix, produces a new vector. It’s the core operation that runs billions of times during training and inference.
For each output value: take the corresponding row of the matrix, multiply element-by-element with the input, sum it all up.
Input: [2, 3]
Matrix: ┌ 1 0 ┐
└ 0 2 ┘
output[0] = (2×1) + (3×0) = 2
output[1] = (2×0) + (3×2) = 6
Output: [2, 6]
Interactive: edit the numbers, watch the output change
Every output is a weighted sum of all inputs — the same operation billions of times in a transformer. Edit any value to see the output update live:
Hover a row to highlight which weights combine to produce that output. Real transformer matrices are 512×512 — same operation, 512 dimensions instead of 2.
Why matrices are “learned”
When we say “Wq is a learned weight matrix” — during training, the model adjusts every number in that grid until the transformation it performs produces useful outputs. The matrix starts as random noise and becomes a meaningful transformation.
In the transformer:
- Wq transforms a word’s embedding into its “Query” — what it’s looking for
- Wk transforms it into its “Key” — what it contains
- Wv transforms it into its “Value” — its actual information
Same input embedding, three different matrices, three completely different outputs.
5. Dot product — measuring similarity
The dot product of two vectors: multiply corresponding elements, sum everything up.
a = [1, 2, 3]
b = [4, 5, 6]
a · b = (1×4) + (2×5) + (3×6) = 4 + 10 + 18 = 32
What the number means
- High positive → vectors point in the same direction → similar
- Near zero → vectors are perpendicular → unrelated
- Negative → vectors point opposite directions → dissimilar
[1, 0] · [1, 0] = 1 ← identical direction
[1, 0] · [0, 1] = 0 ← perpendicular
[1, 0] · [-1, 0] = -1 ← opposite
In attention: the transformer takes the Query vector of one word and dots it with the Key vector of every other word. That dot product IS the attention score — how relevant those two words are to each other.
6. What is a loss function?
This is the concept that makes training possible. Without understanding it, backpropagation and gradient descent are magic.
The definition
A loss function takes the model’s prediction and the correct answer, and produces a single number — how wrong the model was.
- Loss = 0 → perfect prediction
- Loss = large → very wrong
Example for a simple classification:
model predicts: [0.7, 0.2, 0.1] ← probabilities for [cat, dog, bird]
correct answer: [1.0, 0.0, 0.0] ← it was actually a cat
loss = some formula measuring distance between these two distributions
= a single number, e.g. 0.36
For language models, the loss is typically cross-entropy — how surprised was the model by the correct next token? If the model predicted “the” with 90% probability and “the” was indeed next, loss is low. If it predicted “banana” was most likely and the correct answer was “the”, loss is high.
Loss creates a landscape
Imagine every possible weight setting as a point on a map. The loss function gives every point an elevation — how wrong the model is with those weights.
Training = find the lowest point on this map.
Interactive: watch gradient descent find the minimum
The loss function creates a landscape of "how wrong the model is" for every possible weight value. Training = rolling a ball downhill to find the lowest point (minimum loss):
The pink dot is the current weight. The pink arrow shows the gradient direction. Training moves the weight opposite the gradient (downhill) on every step. In a real model, this happens for billions of weights simultaneously.
The key insight: the loss function is differentiable — you can compute the slope (gradient) at any point, which tells you which direction is downhill. That’s how the model knows which way to update the weights.
7. What is a derivative? What is a gradient?
A derivative — the slope at a point
The derivative of a function at a point tells you: if I shift the input slightly, how much does the output change?
f(x) = x²
f'(x) = 2x ← the derivative
At x=3: f'(3) = 6
"shift x by 0.001 → output changes by ≈0.006"
A large derivative = output is very sensitive to input changes. A near-zero derivative = output barely changes. The weight is effectively stuck — it can’t learn.
A gradient — derivatives for many variables at once
A real neural network has millions of weights, not just one x. The gradient is just a collection of derivatives — one for each weight:
gradient = [∂L/∂w₁, ∂L/∂w₂, ∂L/∂w₃, ...]
Each entry answers: “if I increase this weight slightly, does the loss go up or down — and by how much?”
The gradient is a vector that points in the direction of steepest increase in loss. So to reduce the loss, you move opposite the gradient:
new_weight = old_weight - learning_rate × gradient
Do this for every weight, billions of times. That’s training.
8. Backpropagation — how gradients flow
You have a loss. You need a gradient for every single weight in the network — there could be billions. Computing each one individually would be impossibly slow.
Backpropagation is the algorithm that computes all gradients efficiently in one backwards pass through the network.
The intuition
Forward pass: input flows forward through every layer, producing a prediction.
input → layer 1 → layer 2 → layer 3 → prediction → loss
Backward pass: the error signal flows backwards, accumulating how much each layer contributed to the loss.
loss → ∂loss/∂layer3 → ∂loss/∂layer2 → ∂loss/∂layer1
The key mathematical tool is the chain rule: if z = f(g(x)), then dz/dx = f'(g(x)) × g'(x). In plain English — the gradient at an early layer = (gradient at the next layer) × (derivative of this layer’s own operation).
This means:
- Compute loss
- Walk backwards one layer at a time
- At each layer, multiply the incoming gradient by the local derivative
- That gives the gradient for the weights in this layer
The whole network’s gradients — potentially trillions of numbers — can be computed in exactly two passes: one forward, one backward.
Why this matters for understanding the transformer
Every design decision in the transformer exists partly to make backpropagation work well:
- Residual connections — give gradients a direct highway backwards, preventing vanishing
- LayerNorm — keeps activations in a range where derivatives are nonzero
- The √dₖ scaling in attention — prevents softmax saturation where derivatives collapse to zero
None of these make sense without understanding that gradients need to flow cleanly backwards.
Gradient vanishing — the failure mode
If any layer has a derivative near zero, the gradient gets multiplied by ≈0 at that point and effectively disappears. Layers before it receive no signal. Their weights can’t update. They’re permanently stuck.
This is gradient vanishing. It’s not just slow training — it’s entire layers becoming completely useless, regardless of how much more data you feed in.
9. What is a layer?
A layer is one transformation step. Input vector in, math, output vector out.
input vector (512 numbers)
↓
× weight matrix W ← matrix multiplication
↓
+ bias vector b ← small learned offset
↓
activation function ← introduces non-linearity
↓
output vector (512 numbers)
Why stack layers?
One layer = one transformation = can only detect simple patterns.
Stack 6 layers: each builds on what the previous found. By the final layer, you’ve applied 6 transformations, each refining the representation further.
Think of it like editing a document:
- Layer 1: catches spelling errors
- Layer 2: fixes grammar
- Layer 3: improves clarity
- Layer 4: restructures arguments
- Layer 5: tightens reasoning
- Layer 6: polishes the whole thing
Nobody tells layer 1 what “spelling errors” are. The network discovers what’s useful to detect at each depth because it minimizes the loss.
Interactive: watch the forward pass through one layer
One layer — four operations. A token vector goes in, a transformed vector comes out. Watch each operation happen:
How many layers does a transformer have?
Original paper: 6 encoder + 6 decoder layers. Modern models:
- GPT-3: 96 layers
- GPT-4: estimated 120+ layers
- Claude: unknown
Each additional layer = more processing steps, more capacity for complex patterns. But also more weights, more compute, more data needed.
10. Activation functions — introducing non-linearity
Here’s the problem: chain multiple matrix multiplications together and the whole thing collapses to one matrix multiplication.
Layer 1: y = W1 × x
Layer 2: z = W2 × y = (W2 × W1) × x = W_combined × x
One layer or a hundred — mathematically identical. Depth becomes meaningless.
Activation functions break this by adding non-linearity after each matrix multiply — a step the chain rule can’t simplify away.
Interactive: see the curves
Activation functions are applied after each matrix multiplication. Their shape determines what the layer can learn. Select one to see its curve:
Hover over the chart to see exact input → output values
ReLU — the workhorse
ReLU(x) = max(0, x)
Negative inputs → 0. Positive inputs → unchanged. The “kink” at zero is the non-linearity.
Why it’s everywhere:
- Computationally trivial (
if x > 0: return x) - Doesn’t saturate for positive values — gradients flow cleanly
- In practice, works better than more complex functions
Softmax — scores to probabilities
softmax(xᵢ) = eˣⁱ / Σeˣʲ
Takes a list of raw scores, outputs probabilities that sum to 1.0. Used in two places in the transformer:
- Attention weights — turns similarity scores into “how much to attend”
- Final output — turns logits into “what’s the next token”
Sigmoid — binary decisions
σ(x) = 1 / (1 + e⁻ˣ)
Squishes any number into 0–1. Useful for yes/no decisions. Saturates at extremes (gradient vanishing risk) which is why ReLU replaced it in hidden layers.
11. LayerNorm — keeping numbers stable
As vectors flow through many layers, the numbers drift — some dimensions explode to huge values, others shrink to near zero.
LayerNorm fixes this after every sublayer. It normalizes each vector to have mean ≈ 0 and std ≈ 1:
1. mean μ = average of all values in the vector
2. std σ = square root of average squared deviation from mean
3. normalize: x̂ᵢ = (xᵢ - μ) / σ
4. scale + shift: yᵢ = γ × x̂ᵢ + β ← γ, β are learned
Why mean=0, std=1 specifically? That’s the “Goldilocks zone” where:
- Softmax doesn’t collapse (no extreme values)
- ReLU doesn’t kill everything (not all negative)
- Derivatives stay nonzero (gradients flow)
Steps 1–3 are fixed math. Step 4 lets the model learn to rescale if needed — so LayerNorm doesn’t permanently constrain expressiveness.
12. Loss functions — types you’ll encounter
You already know what a loss function is. Here are the specific ones used in transformers:
Cross-entropy loss — for language models
Language models predict the next token from a vocabulary of ~50,000. Cross-entropy measures how surprised the model was by the correct answer:
loss = -log(probability assigned to the correct token)
If model gave “the” a 90% probability and it was correct: loss = -log(0.9) = 0.105 (low) If model gave “the” a 2% probability: loss = -log(0.02) = 3.91 (high)
Training minimizes this across billions of tokens. Low cross-entropy = model has learned the statistical patterns of language.
Why -log?
-log has useful properties:
- When probability = 1.0 (perfect), loss = 0
- When probability → 0 (totally wrong), loss → ∞
- It’s differentiable everywhere — gradients flow cleanly
- It heavily penalizes confident wrong answers
13. Overfitting — why architecture choices exist
Overfitting is when a model memorizes the training data instead of learning general patterns.
Imagine training a model to recognize cats from 1,000 photos. An overfitted model might memorize “photo #47 is a cat” rather than learning “cats have pointed ears and whiskers”. It scores perfectly on training data but fails on any new photo it hasn’t seen.
How you know it’s happening
training loss: 0.02 ← very low, model "knows" the training data
validation loss: 4.8 ← high, model fails on unseen data
Validation loss = loss on data the model never trained on. If training loss is low but validation loss is high, you’re overfitting.
Why this matters for the transformer
Many transformer design choices exist specifically to prevent overfitting:
- Dropout — randomly zero out neurons during training, forcing the model not to rely on any single pathway
- Weight decay — penalizes large weight values, keeping the model from memorizing
- The specific layer count and dimension sizes — chosen to have enough capacity to learn patterns but not so much that it memorizes
When you see d_model=512 or num_heads=8 in the paper — those aren’t arbitrary. They’re the sweet spot between underfitting (too small to learn) and overfitting (too large, memorizes training data).
14. Training vs inference — two completely different modes
The transformer behaves differently during training vs when you’re actually using it.
Training mode
- You have input and the correct output
- The model makes predictions, computes loss against the correct answers
- Backpropagation runs, weights update
- Happens billions of times on massive datasets
- Slow — hours to months on thousands of GPUs
input: "The cat sat"
target: "on the mat"
model predicts token by token, loss computed, weights updated
- Masking in the decoder is a training trick — you feed the full target sequence but mask future positions so the model can’t cheat
Inference mode
- You have input only — no correct output
- The model generates output one token at a time
- No backpropagation, no weight updates
- Fast — milliseconds per response
input: "The cat sat"
model generates: "on" → "the" → "mat" → <end>
- Masking is now structural — the model genuinely can’t see future tokens because they haven’t been generated yet
Why this distinction matters
A lot of confusion in transformer explanations comes from mixing these up. When you read “the decoder generates tokens autoregressively” — that’s inference. When you read “teacher forcing” or “shifted right inputs” — that’s training. They’re different mechanisms even though they share the same weights.
15. Tokens — not words
The transformer doesn’t operate on words. It operates on tokens.
A token is a chunk of text — but not necessarily a whole word. The tokenizer (BPE — Byte Pair Encoding) splits text into subword units:
"transformers" → ["transform", "ers"] ← 2 tokens
"unbelievable" → ["un", "believ", "able"] ← 3 tokens
"cat" → ["cat"] ← 1 token
"the" → ["the"] ← 1 token
Why not just words?
- Vocabulary size — English has hundreds of thousands of words. With subwords, you cover all of English with ~50,000 tokens.
- Unknown words — whole-word tokenizers break on names, slang, code. BPE handles anything by splitting into known pieces.
Context window = token limit
Claude: 200k tokens. GPT-4: 128k tokens. These are token limits, not word limits. ~1 word ≈ 1.3 tokens on average.
Why a limit at all? Attention computes N² comparisons for N tokens. At 100k tokens that’s 10 billion attention scores. Memory and compute become the bottleneck.
16. Parameters vs hyperparameters
Parameters — what the model learns. Every number in every weight matrix.
Wq: 512 × 64 = 32,768 parameters
Wk: 512 × 64 = 32,768 parameters
...
GPT-4: ~1.8 trillion parameters total
“A 7B model” = 7 billion parameters.
Hyperparameters — what you set before training. The model never changes these.
| Hyperparameter | Original paper | What happens if you change it |
|---|---|---|
| d_model | 512 | Bigger = richer representations, more compute |
| num_heads | 8 | More heads = more “perspectives”, more compute |
| num_layers | 6 | More layers = deeper reasoning, more compute |
| learning_rate | 0.0001 | Too high = overshoots, diverges. Too low = never converges |
| batch_size | 64 | Bigger = more stable gradients, more memory |
Finding good hyperparameters is expensive trial and error. Most production models have been tuned over hundreds of training runs.
17. The encoder-decoder split — BERT vs GPT vs the original paper
Encoder — bidirectional understanding
Reads the full input simultaneously. Every token can attend to every other token in both directions. Produces rich contextual representations.
Best for: Understanding tasks — classification, search, question answering. BERT is encoder-only.
Decoder — autoregressive generation
Generates one token at a time. When generating token N, can only see tokens 1 through N-1 (masked). Can’t see the future.
Best for: Generation tasks. GPT, Claude, Gemini, LLaMA are all decoder-only.
Encoder-decoder — sequence-to-sequence
Original paper architecture. Encoder reads source sequence (full bidirectional attention). Decoder generates target sequence (masked self-attention + cross-attention back to encoder).
Best for: Translation, summarization — tasks where input and output are different sequences. T5 is encoder-decoder.
18. How the decoder produces a word
After all layers, the decoder produces a 512-dimensional vector. How does that become “the next word is ‘Paris’”?
Step 1 — Linear projection
output_vector (512d)
↓ × W_final (512 × 50,000)
logits (50,000 numbers) ← one score per token in vocabulary
Step 2 — Softmax
logits → softmax → probabilities (sum = 1.0)
Step 3 — Sample or argmax
- Greedy — pick highest probability token every time
- Sampling — randomly sample from the distribution (adds variety)
- Temperature — controls how “peaked” the distribution is. Temperature=1.0 = normal sampling. Temperature→0 = always pick the max. Temperature>1 = more random
Step 4 — Repeat
Append the chosen token, feed back in, generate the next one. Repeat until <end> or max length.
This sequential generation is why inference can’t be parallelized — each token depends on all previous tokens.
19. “Shifted right” — the training trick explained
In training, you have a full source-target pair:
Encoder input: "The cat sat"
Decoder target: "Le chat s'est assis"
The decoder input is the target sequence shifted one position right:
Decoder input: [<start>, "Le", "chat", "s'est"]
Decoder target: ["Le", "chat", "s'est", "assis"]
Position 1 sees <start>, predicts “Le”. Position 2 sees “Le”, predicts “chat”. And so on.
Combined with masking (can’t see future positions), this trains the decoder to predict each token from only previous tokens — without cheating.
At inference, there’s no target sequence — you genuinely generate one token at a time. “Shifted right” is purely a training efficiency trick.
20. Test yourself
Before reading the transformer article, make sure you can answer all of these:
The foundation:
- What problem do neural networks solve that regular programming can’t?
- What does “learning” mean mathematically?
- What does a single neuron compute?
The math:
- What is a vector? What does 512-dimensional mean?
- What does matrix multiplication produce? Step through a 2×2 example.
- What is a dot product measuring?
- What is a derivative telling you at a point on a curve?
- What is a gradient? How is it different from a single derivative?
Loss and training:
- What is a loss function? What does it output?
- What is gradient descent doing geometrically?
- What is backpropagation? Why does it matter?
- What is gradient vanishing? Why does it happen?
Layers and activations:
- What does one layer actually do, step by step?
- Why can’t you just use one big layer?
- What does ReLU do to negative numbers? Positive numbers?
- Why do activation functions exist at all?
- What does LayerNorm do and why is it necessary?
Training vs inference:
- What’s the difference between training mode and inference mode?
- Why is generation sequential at inference but not at training?
Overfitting:
- What is overfitting?
- How do you detect it?
- Name two mechanisms in the transformer that help prevent it.
Tokens and architecture:
- Why are tokens not the same as words?
- What is a context window and why does it have a limit?
- What’s the difference between a parameter and a hyperparameter?
- What does “a 7B model” mean?
- What’s the difference between encoder-only, decoder-only, and encoder-decoder?
If you can answer all of these, you’re ready. Go read the transformer article.
→ Transformers: From Zero to the Paper
Last updated: March 2026.