Before You Read the Transformer Paper — The Math You Actually Need

#ai#ml#math#prerequisites

I wrote this because I kept bumping into words in the transformer article — “matrix multiplication”, “loss function”, “backpropagation”, “neuron” — and either skipping over them or half-understanding them. This is me going back and filling every single gap. No assumed knowledge. No skipping steps.

Read this before the transformer article. Every concept here will be used there.


0. What even is a number in this context?

Before anything else — what does a number represent in ML?

In normal programming, a number is a count, a price, an ID. In ML, numbers represent meaning — properties of things.

When a model sees the word “cat”, it doesn’t see the letters C-A-T. It sees something like:

[0.2, 0.9, 0.1, 0.7, 0.3, ...]   ← 512 numbers

Each number encodes some aspect of what “cat” means. The whole field of ML is about learning which numbers best capture meaning, and doing math on those numbers to produce useful outputs.

Everything in this article is just explaining what kinds of math, and why.


1. What is a neural network — and what is “learning”?

This is the question nobody stops to answer directly before throwing you into matrices and gradients. Let’s answer it properly.

The problem that neural networks solve

Traditional programming works like this:

you write the rules

rules + input → output

You decide the logic. You write the if statements. The computer executes them.

But some problems don’t have rules you can write. How do you write a rule for “is this image a cat or a dog”? Or “what’s the right translation for this sentence”? Or “what’s the next word someone would say”?

These problems have patterns — but the patterns are too complex, too numerous, too fuzzy to express as hand-written rules.

Neural networks solve this with the opposite approach:

you give examples (input → correct output)

the network figures out the rules itself

You don’t write the logic. You show the network thousands or millions of examples, and it discovers its own internal representation of the patterns. The “rules” end up encoded as numbers — weights — inside the network.

What “learning” means mathematically

Learning = finding the set of weights that makes the network produce the right output for any given input.

More precisely: finding weights that minimize the error across all your training examples.

That’s it. The entire field of deep learning — billions of dollars of compute, thousands of researchers — is fundamentally about solving this one optimization problem: find the numbers (weights) that minimize the error (loss).

Everything else — layers, attention, transformers — is about building the right structure for those weights to live in, so they can represent the patterns you care about.


2. What is a neuron?

Before layers and matrices, understand what a single neuron does. Everything else builds on this.

A neuron takes multiple inputs, multiplies each by a weight, sums them up, adds a bias, then passes the result through an activation function.

inputs: x₁, x₂, x₃
weights: w₁, w₂, w₃
bias: b

neuron output = activation( w₁x₁ + w₂x₂ + w₃x₃ + b )

Concretely:

x₁ = 0.5  (maybe: "how royal is this word?")
x₂ = 0.9  (maybe: "how human is this word?")
x₃ = 0.2  (maybe: "how animate is this word?")

w₁ = 1.2, w₂ = 0.8, w₃ = -0.3, b = 0.1

sum = (1.2×0.5) + (0.8×0.9) + (-0.3×0.2) + 0.1
    = 0.6 + 0.72 - 0.06 + 0.1
    = 1.36

output = ReLU(1.36) = 1.36   ← positive, passes through

What does a neuron detect?

The weights determine what combination of inputs the neuron responds to. A neuron with high weights on “royal” and “human” fires strongly for words like “king” and “queen”. A neuron with a negative weight on “animate” fires less for objects.

Nobody programs this. The weights are learned. The network discovers on its own which combinations of inputs are useful to detect.

A layer is just many neurons running in parallel.

A layer with 512 neurons takes the same input vector and runs 512 of these computations simultaneously — each with its own set of weights — producing 512 outputs. That’s the entire “layer” concept.


3. Vectors — a list of numbers with meaning

A vector is just a list of numbers.

[3, 7]          ← 2-dimensional vector
[1, 4, 9]       ← 3-dimensional vector
[0.2, 0.9, 0.1] ← 3-dimensional vector

“Dimensional” just means how many numbers are in the list. A 512-dimensional vector is a list of 512 numbers.

Vectors as points in space

The intuition: a vector is a point in space.

A 2D vector [3, 7] is a point on a flat map — 3 right, 7 up. A 3D vector [1, 4, 9] is a point in a room — 1 right, 4 forward, 9 up.

A 512-dimensional vector is a point in 512-dimensional space. You can’t visualize it, but the math works exactly the same.

Why this matters

Words in a transformer are 512-dimensional vectors. Similar words end up as points that are close together in that space. “Cat” and “dog” are nearby. “Cat” and “democracy” are far apart.

This means you can do geometry on words:

  • king - man + woman ≈ queen — vector arithmetic preserves semantic relationships
  • Distance between vectors = how different two words are
  • Direction of a vector = what “kind” of thing it is

Why 512 dimensions and not 10 or 10,000?

More dimensions = the model can encode more distinct properties of a word — more nuance, richer meaning.

But more dimensions also means:

  • More weights to learn (more compute, more data needed)
  • More memory at inference time
  • Diminishing returns past a certain point

512 was the sweet spot the original paper found. Modern models use 4096+ dimensions because they have far more compute and data.


4. Matrices — a grid of numbers

A matrix is a rectangular grid of numbers. Rows and columns.

┌ 1  2  3 ┐
│ 4  5  6 │  ← a 2×3 matrix (2 rows, 3 columns)
└         ┘

Matrices store transformations — rules for converting one vector into another vector.

Every weight matrix in a transformer (Wq, Wk, Wv, W1, W2…) is a matrix. A big grid of numbers the model learned.

Matrix multiplication — the most important operation in ML

Matrix multiplication takes a vector and a matrix, produces a new vector. It’s the core operation that runs billions of times during training and inference.

For each output value: take the corresponding row of the matrix, multiply element-by-element with the input, sum it all up.

Input:  [2, 3]
Matrix: ┌ 1  0 ┐
        └ 0  2 ┘

output[0] = (2×1) + (3×0) = 2
output[1] = (2×0) + (3×2) = 6

Output: [2, 6]

Interactive: edit the numbers, watch the output change

Every output is a weighted sum of all inputs — the same operation billions of times in a transformer. Edit any value to see the output update live:

Weight matrix W
(2 × 2)
×
Input vector
(2 × 1)
=
Output vector
2.00
6.00
(2 × 1)
Step by step — how each output is computed:
row 0:output[0]=(1×2)+(0×3)=2.00
row 1:output[1]=(0×2)+(2×3)=6.00

Hover a row to highlight which weights combine to produce that output. Real transformer matrices are 512×512 — same operation, 512 dimensions instead of 2.

Why matrices are “learned”

When we say “Wq is a learned weight matrix” — during training, the model adjusts every number in that grid until the transformation it performs produces useful outputs. The matrix starts as random noise and becomes a meaningful transformation.

In the transformer:

  • Wq transforms a word’s embedding into its “Query” — what it’s looking for
  • Wk transforms it into its “Key” — what it contains
  • Wv transforms it into its “Value” — its actual information

Same input embedding, three different matrices, three completely different outputs.


5. Dot product — measuring similarity

The dot product of two vectors: multiply corresponding elements, sum everything up.

a = [1, 2, 3]
b = [4, 5, 6]

a · b = (1×4) + (2×5) + (3×6) = 4 + 10 + 18 = 32

What the number means

  • High positive → vectors point in the same direction → similar
  • Near zero → vectors are perpendicular → unrelated
  • Negative → vectors point opposite directions → dissimilar
[1, 0] · [1, 0]  =  1   ← identical direction
[1, 0] · [0, 1]  =  0   ← perpendicular
[1, 0] · [-1, 0] = -1   ← opposite

In attention: the transformer takes the Query vector of one word and dots it with the Key vector of every other word. That dot product IS the attention score — how relevant those two words are to each other.


6. What is a loss function?

This is the concept that makes training possible. Without understanding it, backpropagation and gradient descent are magic.

The definition

A loss function takes the model’s prediction and the correct answer, and produces a single number — how wrong the model was.

  • Loss = 0 → perfect prediction
  • Loss = large → very wrong

Example for a simple classification:

model predicts: [0.7, 0.2, 0.1]   ← probabilities for [cat, dog, bird]
correct answer: [1.0, 0.0, 0.0]   ← it was actually a cat

loss = some formula measuring distance between these two distributions
     = a single number, e.g. 0.36

For language models, the loss is typically cross-entropy — how surprised was the model by the correct next token? If the model predicted “the” with 90% probability and “the” was indeed next, loss is low. If it predicted “banana” was most likely and the correct answer was “the”, loss is high.

Loss creates a landscape

Imagine every possible weight setting as a point on a map. The loss function gives every point an elevation — how wrong the model is with those weights.

Training = find the lowest point on this map.

Interactive: watch gradient descent find the minimum

The loss function creates a landscape of "how wrong the model is" for every possible weight value. Training = rolling a ball downhill to find the lowest point (minimum loss):

0246lossweight value →minimum
weight: -0.500loss: 6.051gradient: -4.936
drag manually:

The pink dot is the current weight. The pink arrow shows the gradient direction. Training moves the weight opposite the gradient (downhill) on every step. In a real model, this happens for billions of weights simultaneously.

The key insight: the loss function is differentiable — you can compute the slope (gradient) at any point, which tells you which direction is downhill. That’s how the model knows which way to update the weights.


7. What is a derivative? What is a gradient?

A derivative — the slope at a point

The derivative of a function at a point tells you: if I shift the input slightly, how much does the output change?

f(x) = x²
f'(x) = 2x    ← the derivative

At x=3:  f'(3) = 6
         "shift x by 0.001 → output changes by ≈0.006"

A large derivative = output is very sensitive to input changes. A near-zero derivative = output barely changes. The weight is effectively stuck — it can’t learn.

A gradient — derivatives for many variables at once

A real neural network has millions of weights, not just one x. The gradient is just a collection of derivatives — one for each weight:

gradient = [∂L/∂w₁, ∂L/∂w₂, ∂L/∂w₃, ...]

Each entry answers: “if I increase this weight slightly, does the loss go up or down — and by how much?”

The gradient is a vector that points in the direction of steepest increase in loss. So to reduce the loss, you move opposite the gradient:

new_weight = old_weight - learning_rate × gradient

Do this for every weight, billions of times. That’s training.


8. Backpropagation — how gradients flow

You have a loss. You need a gradient for every single weight in the network — there could be billions. Computing each one individually would be impossibly slow.

Backpropagation is the algorithm that computes all gradients efficiently in one backwards pass through the network.

The intuition

Forward pass: input flows forward through every layer, producing a prediction.

input → layer 1 → layer 2 → layer 3 → prediction → loss

Backward pass: the error signal flows backwards, accumulating how much each layer contributed to the loss.

loss → ∂loss/∂layer3 → ∂loss/∂layer2 → ∂loss/∂layer1

The key mathematical tool is the chain rule: if z = f(g(x)), then dz/dx = f'(g(x)) × g'(x). In plain English — the gradient at an early layer = (gradient at the next layer) × (derivative of this layer’s own operation).

This means:

  1. Compute loss
  2. Walk backwards one layer at a time
  3. At each layer, multiply the incoming gradient by the local derivative
  4. That gives the gradient for the weights in this layer

The whole network’s gradients — potentially trillions of numbers — can be computed in exactly two passes: one forward, one backward.

Why this matters for understanding the transformer

Every design decision in the transformer exists partly to make backpropagation work well:

  • Residual connections — give gradients a direct highway backwards, preventing vanishing
  • LayerNorm — keeps activations in a range where derivatives are nonzero
  • The √dₖ scaling in attention — prevents softmax saturation where derivatives collapse to zero

None of these make sense without understanding that gradients need to flow cleanly backwards.

Gradient vanishing — the failure mode

If any layer has a derivative near zero, the gradient gets multiplied by ≈0 at that point and effectively disappears. Layers before it receive no signal. Their weights can’t update. They’re permanently stuck.

This is gradient vanishing. It’s not just slow training — it’s entire layers becoming completely useless, regardless of how much more data you feed in.


9. What is a layer?

A layer is one transformation step. Input vector in, math, output vector out.

input vector (512 numbers)

  × weight matrix W    ← matrix multiplication

  + bias vector b      ← small learned offset

  activation function  ← introduces non-linearity

output vector (512 numbers)

Why stack layers?

One layer = one transformation = can only detect simple patterns.

Stack 6 layers: each builds on what the previous found. By the final layer, you’ve applied 6 transformations, each refining the representation further.

Think of it like editing a document:

  • Layer 1: catches spelling errors
  • Layer 2: fixes grammar
  • Layer 3: improves clarity
  • Layer 4: restructures arguments
  • Layer 5: tightens reasoning
  • Layer 6: polishes the whole thing

Nobody tells layer 1 what “spelling errors” are. The network discovers what’s useful to detect at each depth because it minimizes the loss.

Interactive: watch the forward pass through one layer

One layer — four operations. A token vector goes in, a transformed vector comes out. Watch each operation happen:

Input vector
word embedding for "cat"
× Weight matrix W
matrix multiplication
+ Bias vector b
shift the result
ReLU activation
max(0, x) — kill negatives
Output vector
richer representation
Hit play to watch the forward pass, or click any step to jump to it

How many layers does a transformer have?

Original paper: 6 encoder + 6 decoder layers. Modern models:

  • GPT-3: 96 layers
  • GPT-4: estimated 120+ layers
  • Claude: unknown

Each additional layer = more processing steps, more capacity for complex patterns. But also more weights, more compute, more data needed.


10. Activation functions — introducing non-linearity

Here’s the problem: chain multiple matrix multiplications together and the whole thing collapses to one matrix multiplication.

Layer 1: y = W1 × x
Layer 2: z = W2 × y = (W2 × W1) × x = W_combined × x

One layer or a hundred — mathematically identical. Depth becomes meaningless.

Activation functions break this by adding non-linearity after each matrix multiply — a step the chain rule can’t simplify away.

Interactive: see the curves

Activation functions are applied after each matrix multiplication. Their shape determines what the layer can learn. Select one to see its curve:

-4-2024-101
ReLUmax(0, x) — kills negatives, passes positives unchanged. Fast, simple, default choice for hidden layers.

Hover over the chart to see exact input → output values

ReLU — the workhorse

ReLU(x) = max(0, x)

Negative inputs → 0. Positive inputs → unchanged. The “kink” at zero is the non-linearity.

Why it’s everywhere:

  • Computationally trivial (if x > 0: return x)
  • Doesn’t saturate for positive values — gradients flow cleanly
  • In practice, works better than more complex functions

Softmax — scores to probabilities

softmax(xᵢ) = eˣⁱ / Σeˣʲ

Takes a list of raw scores, outputs probabilities that sum to 1.0. Used in two places in the transformer:

  1. Attention weights — turns similarity scores into “how much to attend”
  2. Final output — turns logits into “what’s the next token”

Sigmoid — binary decisions

σ(x) = 1 / (1 + e⁻ˣ)

Squishes any number into 0–1. Useful for yes/no decisions. Saturates at extremes (gradient vanishing risk) which is why ReLU replaced it in hidden layers.


11. LayerNorm — keeping numbers stable

As vectors flow through many layers, the numbers drift — some dimensions explode to huge values, others shrink to near zero.

LayerNorm fixes this after every sublayer. It normalizes each vector to have mean ≈ 0 and std ≈ 1:

1. mean μ = average of all values in the vector
2. std  σ = square root of average squared deviation from mean
3. normalize: x̂ᵢ = (xᵢ - μ) / σ
4. scale + shift: yᵢ = γ × x̂ᵢ + β   ← γ, β are learned

Why mean=0, std=1 specifically? That’s the “Goldilocks zone” where:

  • Softmax doesn’t collapse (no extreme values)
  • ReLU doesn’t kill everything (not all negative)
  • Derivatives stay nonzero (gradients flow)

Steps 1–3 are fixed math. Step 4 lets the model learn to rescale if needed — so LayerNorm doesn’t permanently constrain expressiveness.


12. Loss functions — types you’ll encounter

You already know what a loss function is. Here are the specific ones used in transformers:

Cross-entropy loss — for language models

Language models predict the next token from a vocabulary of ~50,000. Cross-entropy measures how surprised the model was by the correct answer:

loss = -log(probability assigned to the correct token)

If model gave “the” a 90% probability and it was correct: loss = -log(0.9) = 0.105 (low) If model gave “the” a 2% probability: loss = -log(0.02) = 3.91 (high)

Training minimizes this across billions of tokens. Low cross-entropy = model has learned the statistical patterns of language.

Why -log?

-log has useful properties:

  • When probability = 1.0 (perfect), loss = 0
  • When probability → 0 (totally wrong), loss → ∞
  • It’s differentiable everywhere — gradients flow cleanly
  • It heavily penalizes confident wrong answers

13. Overfitting — why architecture choices exist

Overfitting is when a model memorizes the training data instead of learning general patterns.

Imagine training a model to recognize cats from 1,000 photos. An overfitted model might memorize “photo #47 is a cat” rather than learning “cats have pointed ears and whiskers”. It scores perfectly on training data but fails on any new photo it hasn’t seen.

How you know it’s happening

training loss:    0.02   ← very low, model "knows" the training data
validation loss:  4.8    ← high, model fails on unseen data

Validation loss = loss on data the model never trained on. If training loss is low but validation loss is high, you’re overfitting.

Why this matters for the transformer

Many transformer design choices exist specifically to prevent overfitting:

  • Dropout — randomly zero out neurons during training, forcing the model not to rely on any single pathway
  • Weight decay — penalizes large weight values, keeping the model from memorizing
  • The specific layer count and dimension sizes — chosen to have enough capacity to learn patterns but not so much that it memorizes

When you see d_model=512 or num_heads=8 in the paper — those aren’t arbitrary. They’re the sweet spot between underfitting (too small to learn) and overfitting (too large, memorizes training data).


14. Training vs inference — two completely different modes

The transformer behaves differently during training vs when you’re actually using it.

Training mode

  • You have input and the correct output
  • The model makes predictions, computes loss against the correct answers
  • Backpropagation runs, weights update
  • Happens billions of times on massive datasets
  • Slow — hours to months on thousands of GPUs
input: "The cat sat"
target: "on the mat"
model predicts token by token, loss computed, weights updated
  • Masking in the decoder is a training trick — you feed the full target sequence but mask future positions so the model can’t cheat

Inference mode

  • You have input only — no correct output
  • The model generates output one token at a time
  • No backpropagation, no weight updates
  • Fast — milliseconds per response
input: "The cat sat"
model generates: "on" → "the" → "mat" → <end>
  • Masking is now structural — the model genuinely can’t see future tokens because they haven’t been generated yet

Why this distinction matters

A lot of confusion in transformer explanations comes from mixing these up. When you read “the decoder generates tokens autoregressively” — that’s inference. When you read “teacher forcing” or “shifted right inputs” — that’s training. They’re different mechanisms even though they share the same weights.


15. Tokens — not words

The transformer doesn’t operate on words. It operates on tokens.

A token is a chunk of text — but not necessarily a whole word. The tokenizer (BPE — Byte Pair Encoding) splits text into subword units:

"transformers"  → ["transform", "ers"]       ← 2 tokens
"unbelievable"  → ["un", "believ", "able"]   ← 3 tokens
"cat"           → ["cat"]                    ← 1 token
"the"           → ["the"]                    ← 1 token

Why not just words?

  1. Vocabulary size — English has hundreds of thousands of words. With subwords, you cover all of English with ~50,000 tokens.
  2. Unknown words — whole-word tokenizers break on names, slang, code. BPE handles anything by splitting into known pieces.

Context window = token limit

Claude: 200k tokens. GPT-4: 128k tokens. These are token limits, not word limits. ~1 word ≈ 1.3 tokens on average.

Why a limit at all? Attention computes N² comparisons for N tokens. At 100k tokens that’s 10 billion attention scores. Memory and compute become the bottleneck.


16. Parameters vs hyperparameters

Parameters — what the model learns. Every number in every weight matrix.

Wq: 512 × 64 = 32,768 parameters
Wk: 512 × 64 = 32,768 parameters
...
GPT-4: ~1.8 trillion parameters total

“A 7B model” = 7 billion parameters.

Hyperparameters — what you set before training. The model never changes these.

HyperparameterOriginal paperWhat happens if you change it
d_model512Bigger = richer representations, more compute
num_heads8More heads = more “perspectives”, more compute
num_layers6More layers = deeper reasoning, more compute
learning_rate0.0001Too high = overshoots, diverges. Too low = never converges
batch_size64Bigger = more stable gradients, more memory

Finding good hyperparameters is expensive trial and error. Most production models have been tuned over hundreds of training runs.


17. The encoder-decoder split — BERT vs GPT vs the original paper

Encoder — bidirectional understanding

Reads the full input simultaneously. Every token can attend to every other token in both directions. Produces rich contextual representations.

Best for: Understanding tasks — classification, search, question answering. BERT is encoder-only.

Decoder — autoregressive generation

Generates one token at a time. When generating token N, can only see tokens 1 through N-1 (masked). Can’t see the future.

Best for: Generation tasks. GPT, Claude, Gemini, LLaMA are all decoder-only.

Encoder-decoder — sequence-to-sequence

Original paper architecture. Encoder reads source sequence (full bidirectional attention). Decoder generates target sequence (masked self-attention + cross-attention back to encoder).

Best for: Translation, summarization — tasks where input and output are different sequences. T5 is encoder-decoder.


18. How the decoder produces a word

After all layers, the decoder produces a 512-dimensional vector. How does that become “the next word is ‘Paris’”?

Step 1 — Linear projection

output_vector (512d)
      ↓ × W_final (512 × 50,000)
logits (50,000 numbers)   ← one score per token in vocabulary

Step 2 — Softmax

logits → softmax → probabilities (sum = 1.0)

Step 3 — Sample or argmax

  • Greedy — pick highest probability token every time
  • Sampling — randomly sample from the distribution (adds variety)
  • Temperature — controls how “peaked” the distribution is. Temperature=1.0 = normal sampling. Temperature→0 = always pick the max. Temperature>1 = more random

Step 4 — Repeat

Append the chosen token, feed back in, generate the next one. Repeat until <end> or max length.

This sequential generation is why inference can’t be parallelized — each token depends on all previous tokens.


19. “Shifted right” — the training trick explained

In training, you have a full source-target pair:

Encoder input:  "The cat sat"
Decoder target: "Le chat s'est assis"

The decoder input is the target sequence shifted one position right:

Decoder input:  [<start>, "Le", "chat", "s'est"]
Decoder target: ["Le",    "chat", "s'est", "assis"]

Position 1 sees <start>, predicts “Le”. Position 2 sees “Le”, predicts “chat”. And so on.

Combined with masking (can’t see future positions), this trains the decoder to predict each token from only previous tokens — without cheating.

At inference, there’s no target sequence — you genuinely generate one token at a time. “Shifted right” is purely a training efficiency trick.


20. Test yourself

Before reading the transformer article, make sure you can answer all of these:

The foundation:

  • What problem do neural networks solve that regular programming can’t?
  • What does “learning” mean mathematically?
  • What does a single neuron compute?

The math:

  • What is a vector? What does 512-dimensional mean?
  • What does matrix multiplication produce? Step through a 2×2 example.
  • What is a dot product measuring?
  • What is a derivative telling you at a point on a curve?
  • What is a gradient? How is it different from a single derivative?

Loss and training:

  • What is a loss function? What does it output?
  • What is gradient descent doing geometrically?
  • What is backpropagation? Why does it matter?
  • What is gradient vanishing? Why does it happen?

Layers and activations:

  • What does one layer actually do, step by step?
  • Why can’t you just use one big layer?
  • What does ReLU do to negative numbers? Positive numbers?
  • Why do activation functions exist at all?
  • What does LayerNorm do and why is it necessary?

Training vs inference:

  • What’s the difference between training mode and inference mode?
  • Why is generation sequential at inference but not at training?

Overfitting:

  • What is overfitting?
  • How do you detect it?
  • Name two mechanisms in the transformer that help prevent it.

Tokens and architecture:

  • Why are tokens not the same as words?
  • What is a context window and why does it have a limit?
  • What’s the difference between a parameter and a hyperparameter?
  • What does “a 7B model” mean?
  • What’s the difference between encoder-only, decoder-only, and encoder-decoder?

If you can answer all of these, you’re ready. Go read the transformer article.

Transformers: From Zero to the Paper


Last updated: March 2026.