Understanding Attention in AI (for Word Prediction) — In Simple Words
If you’ve ever wondered how models like ChatGPT or GPT-4 “understand” context and predict the next word so accurately — the secret sauce is attention. In this post, we’ll break it down step-by-step, with real examples and math, using layman language.
🧠 What’s the Big Problem Attention Solves?
When a model is reading a sentence and trying to predict the next word, it has to figure out:
“Which words matter the most for this prediction?”
Take this sentence:
“The cat sat on the mat, and then it…”
What does “it” refer to? Probably “cat”. The model needs to pay more attention to “cat” than “mat”. Attention helps with exactly that.
🧱 Step 1: Each Word Starts with an Embedding
Words are first converted into embeddings — fixed-length lists of numbers that represent their meaning.
| Word | Embedding (example) |
|---|---|
| The | [1, 0] |
| Cat | [0, 1] |
| Sat | [1, 1] |
These are learned from data and fixed during inference.
🔄 Step 2: Creating Query, Key, and Value Vectors
Each word’s embedding is passed through 3 different matrices to produce:
| Vector | What it Represents |
| Query | What this word is looking for |
| Key | What this word offers |
| Value | What this word contains |
These are calculated like this:
Query = Wq × Embedding
Key = Wk × Embedding
Value = Wv × Embedding
Where Wq, Wk, and Wv are learned matrices during training.
🧭 Step 3: Attention Scores (Who Should I Listen To?)
To figure out how much attention a word like “sat” gives to others:
- Use its Query
- Compare it with Keys of other words
This is done using a dot product:
score = dot(Query_sat, Key_other)
Then pass these scores through softmax to get weights:
attention weights = softmax(scores)
Example:
Scores: [2.13, 0.71, 2.83]
Softmax: [0.2, 0.6, 0.2]
📦 Step 4: Value Comes In — Blending the Information
Once the attention weights are known, we use them to combine the Value vectors of all words:
final_output_for_sat =
0.2 × Value(The) +
0.6 × Value(Cat) +
0.2 × Value(Sat)
✅ Note: You do not separately add the original embedding or value of “sat”. It’s already included through the weighted sum.
🧠 Why Not Just Use Embeddings Directly?
Because:
- Query/Key allow us to compute dynamic importance per context
- Value lets us blend useful information from others
This way, “bank” in:
“I went to the bank to deposit money.”
…has a different meaning than in:
“We sat on the river bank.”
Even though both started with the same word embedding, the attention mechanism updated its context.
🧾 Real-World Analogy
Imagine each word is a person in a meeting:
- Query = What I want to know
- Key = What I represent
- Value = What I can share
When it’s “sat”’s turn to understand its role, it asks:
“Who can help me?” (uses Query)
Others answer with their Key. Based on how relevant they are, “sat” blends their Values to update its understanding.
🔁 How This Works in Practice (Recap)
Let’s say our sentence is: “The cat sat”
You compute a new vector for each word like this:
| Word | Uses this Query | Compares with Keys of | Blends these Values | Gets new vector for |
| The | Q_The | [K_The, K_Cat, K_Sat] | [V_The, V_Cat, V_Sat] | new “The” embedding |
| Cat | Q_Cat | [K_The, K_Cat, K_Sat] | [V_The, V_Cat, V_Sat] | new “Cat” embedding |
| Sat | Q_Sat | [K_The, K_Cat, K_Sat] | [V_The, V_Cat, V_Sat] | new “Sat” embedding |
Each word is updated based on the full sentence context.
🔄 The model doesn’t calculate how much sat influences others when computing sat’s output. Instead, it computes how much others influence sat.
🔄 Residual Connections
After attention, the model often adds back the original embedding:
output = LayerNorm(original_embedding + attention_output)
This is called a residual connection, and it helps preserve the original identity of the word while adding context.
🧠 TL;DR
- Each word starts with a fixed embedding.
- From this, it produces a Query, Key, and Value.
- A word uses its Query to score all Keys.
- It uses the scores to blend all Values.
- The result is a new, contextual embedding.
- This repeats for every word in parallel.
🧱 What Happens Across 96 Layers?
In large models (like GPT-4 with 96 layers), the attention process is repeated at every layer. Here’s what happens:
- Every layer has its own set of Q, K, V matrices (learned separately)
- Each word’s vector from the previous layer is used to create new Q/K/V
- New attention scores and weighted sums are computed
- This produces a new version of the word vector, which is passed to the next layer
- The vector size (e.g., 384, 1024) stays the same across layers
So in each layer, we’re refining the word’s representation:
| Layer | What the vector captures |
| 1 | Word and its local context |
| 10 | Grammar patterns, chunked phrases |
| 30 | Deeper relationships, semantic alignment |
| 96 | Task-specific reasoning, deep meaning |
Each of the 96 layers performs attention using its own Q/K/V and its own feedforward block.
🔘 What Are Those Circles in Diagrams?
Those circles or nodes in neural network diagrams do not represent Q/K/V matrices. Instead, they usually represent:
The individual dimensions (neurons) of the vector being passed through the model.
For example:
- If the embedding dimension is 384 → you’ll see 384 neurons/nodes per layer
- These represent the vector features learned by the model
So:
- The Q, K, V matrices are each 384×384 in size (or higher)
- Each transforms the entire 384-dim vector of a word
You do this for every word at every layer. So if your input is a 10-word sentence:
- You compute Q/K/V for all 10 words in parallel, at each layer
🔂 Full Example: 2 Layers with Real Numbers
Let’s walk through a mini version using:
- Sentence: “The cat sat”
- Embedding size: 3 (just to keep math easy)
- 2 transformer layers
Embeddings:
The = [1, 0, 1]
Cat = [0, 1, 1]
Sat = [1, 1, 0]
Layer 1 Matrices:
Wq = [[1, 0, 1], [0, 1, 1], [1, 1, 0]]
Wk = Identity matrix
Wv = Identity matrix
Q for “sat” = Wq × [1, 1, 0] = [1, 1, 2]
K for all words = same as embeddings (since Wk = identity)
Attention scores for “sat”:
Q • K_The = 3
Q • K_Cat = 3
Q • K_Sat = 2
Softmax = [0.41, 0.41, 0.18]
Weighted sum of V:
= 0.41×[1,0,1] + 0.41×[0,1,1] + 0.18×[1,1,0]
= [0.59, 0.59, 0.82] → new vector for "sat"
Layer 2: Feed this into new Q/K/V and repeat the same steps!
✅ Final Summary
- You don’t have just 3 Q/K/V matrices total — you have them per layer
- Each circle in diagrams = a dimension of a word vector
- Attention happens in every layer, and word meaning is updated step by step
- What you saw above is exactly how “cat” or “sat” gains context like subject/object roles, references, and relationships