• Uncategorised

Understanding Attention in AI (for Word Prediction) — In Simple Words

If you’ve ever wondered how models like ChatGPT or GPT-4 “understand” context and predict the next word so accurately — the secret sauce is attention. In this post, we’ll break it down step-by-step, with real examples and math, using layman language.


🧠 What’s the Big Problem Attention Solves?

When a model is reading a sentence and trying to predict the next word, it has to figure out:

“Which words matter the most for this prediction?”

Take this sentence:

“The cat sat on the mat, and then it…”

What does “it” refer to? Probably “cat”. The model needs to pay more attention to “cat” than “mat”. Attention helps with exactly that.


🧱 Step 1: Each Word Starts with an Embedding

Words are first converted into embeddings — fixed-length lists of numbers that represent their meaning.

WordEmbedding (example)
The[1, 0]
Cat[0, 1]
Sat[1, 1]

These are learned from data and fixed during inference.


🔄 Step 2: Creating Query, Key, and Value Vectors

Each word’s embedding is passed through 3 different matrices to produce:

VectorWhat it Represents
QueryWhat this word is looking for
KeyWhat this word offers
ValueWhat this word contains

These are calculated like this:

Query = Wq × Embedding
Key   = Wk × Embedding
Value = Wv × Embedding

Where Wq, Wk, and Wv are learned matrices during training.


🧭 Step 3: Attention Scores (Who Should I Listen To?)

To figure out how much attention a word like “sat” gives to others:

  • Use its Query
  • Compare it with Keys of other words

This is done using a dot product:

score = dot(Query_sat, Key_other)

Then pass these scores through softmax to get weights:

attention weights = softmax(scores)

Example:

Scores:  [2.13, 0.71, 2.83]
Softmax: [0.2, 0.6, 0.2]

📦 Step 4: Value Comes In — Blending the Information

Once the attention weights are known, we use them to combine the Value vectors of all words:

final_output_for_sat =
  0.2 × Value(The) +
  0.6 × Value(Cat) +
  0.2 × Value(Sat)

✅ Note: You do not separately add the original embedding or value of “sat”. It’s already included through the weighted sum.


🧠 Why Not Just Use Embeddings Directly?

Because:

  • Query/Key allow us to compute dynamic importance per context
  • Value lets us blend useful information from others

This way, “bank” in:

“I went to the bank to deposit money.”

…has a different meaning than in:

“We sat on the river bank.”

Even though both started with the same word embedding, the attention mechanism updated its context.


🧾 Real-World Analogy

Imagine each word is a person in a meeting:

  • Query = What I want to know
  • Key = What I represent
  • Value = What I can share

When it’s “sat”’s turn to understand its role, it asks:

“Who can help me?” (uses Query)

Others answer with their Key. Based on how relevant they are, “sat” blends their Values to update its understanding.


🔁 How This Works in Practice (Recap)

Let’s say our sentence is: “The cat sat”

You compute a new vector for each word like this:

WordUses this QueryCompares with Keys ofBlends these ValuesGets new vector for
TheQ_The[K_The, K_Cat, K_Sat][V_The, V_Cat, V_Sat]new “The” embedding
CatQ_Cat[K_The, K_Cat, K_Sat][V_The, V_Cat, V_Sat]new “Cat” embedding
SatQ_Sat[K_The, K_Cat, K_Sat][V_The, V_Cat, V_Sat]new “Sat” embedding

Each word is updated based on the full sentence context.

🔄 The model doesn’t calculate how much sat influences others when computing sat’s output. Instead, it computes how much others influence sat.


🔄 Residual Connections

After attention, the model often adds back the original embedding:

output = LayerNorm(original_embedding + attention_output)

This is called a residual connection, and it helps preserve the original identity of the word while adding context.


🧠 TL;DR

  • Each word starts with a fixed embedding.
  • From this, it produces a Query, Key, and Value.
  • A word uses its Query to score all Keys.
  • It uses the scores to blend all Values.
  • The result is a new, contextual embedding.
  • This repeats for every word in parallel.

🧱 What Happens Across 96 Layers?

In large models (like GPT-4 with 96 layers), the attention process is repeated at every layer. Here’s what happens:

  • Every layer has its own set of Q, K, V matrices (learned separately)
  • Each word’s vector from the previous layer is used to create new Q/K/V
  • New attention scores and weighted sums are computed
  • This produces a new version of the word vector, which is passed to the next layer
  • The vector size (e.g., 384, 1024) stays the same across layers

So in each layer, we’re refining the word’s representation:

LayerWhat the vector captures
1Word and its local context
10Grammar patterns, chunked phrases
30Deeper relationships, semantic alignment
96Task-specific reasoning, deep meaning

Each of the 96 layers performs attention using its own Q/K/V and its own feedforward block.


🔘 What Are Those Circles in Diagrams?

Those circles or nodes in neural network diagrams do not represent Q/K/V matrices. Instead, they usually represent:

The individual dimensions (neurons) of the vector being passed through the model.

For example:

  • If the embedding dimension is 384 → you’ll see 384 neurons/nodes per layer
  • These represent the vector features learned by the model

So:

  • The Q, K, V matrices are each 384×384 in size (or higher)
  • Each transforms the entire 384-dim vector of a word

You do this for every word at every layer. So if your input is a 10-word sentence:

  • You compute Q/K/V for all 10 words in parallel, at each layer

🔂 Full Example: 2 Layers with Real Numbers

Let’s walk through a mini version using:

  • Sentence: “The cat sat”
  • Embedding size: 3 (just to keep math easy)
  • 2 transformer layers

Embeddings:

The = [1, 0, 1]
Cat = [0, 1, 1]
Sat = [1, 1, 0]

Layer 1 Matrices:

Wq = [[1, 0, 1], [0, 1, 1], [1, 1, 0]]
Wk = Identity matrix
Wv = Identity matrix

Q for “sat” = Wq × [1, 1, 0] = [1, 1, 2]

K for all words = same as embeddings (since Wk = identity)

Attention scores for “sat”:

Q • K_The = 3
Q • K_Cat = 3
Q • K_Sat = 2
Softmax = [0.41, 0.41, 0.18]

Weighted sum of V:

= 0.41×[1,0,1] + 0.41×[0,1,1] + 0.18×[1,1,0]
= [0.59, 0.59, 0.82] → new vector for "sat"

Layer 2: Feed this into new Q/K/V and repeat the same steps!


✅ Final Summary

  • You don’t have just 3 Q/K/V matrices total — you have them per layer
  • Each circle in diagrams = a dimension of a word vector
  • Attention happens in every layer, and word meaning is updated step by step
  • What you saw above is exactly how “cat” or “sat” gains context like subject/object roles, references, and relationships

You may also like...