The Story That Explains LLMs
They never "know" facts the way a database does. Instead, they've learned that after the phrase "The capital of France is", the word "Paris" follows with overwhelming probability — because they've seen that pattern ten million times.
That is exactly how a Large Language Model works. It is a statistical pattern machine of staggering scale, trained to predict the next token given all previous tokens. Everything else — reasoning, coding, translation — emerges from that single objective.
A Large Language Model (LLM) is a neural network — specifically a Transformer — trained on massive text corpora to model the probability distribution over sequences of tokens. Given a sequence of tokens as context, the model outputs a probability distribution over its vocabulary for the next token. Sampling from this distribution, repeatedly, produces fluent, coherent text.
LLMs are trained with a deceptively simple objective: given tokens [t₁, t₂, …, tₙ], predict tₙ₊₁. The loss is cross-entropy between the predicted distribution and the true next token. Everything — grammar, facts, logic, code generation — is learned as a side-effect of minimising this loss at massive scale. This is called self-supervised learning: the labels (next tokens) come free from the data itself.
From Words to Tokens — Tokenisation
| Text | GPT-4 Tokens | Token Count | Note |
|---|---|---|---|
| "Hello, world!" | ["Hello", ",", " world", "!"] | 4 | Common words → single tokens |
| "unbelievable" | ["un", "believ", "able"] | 3 | Rare word → sub-word split |
| "antidisestablishmentarianism" | ["ant", "idis", "estab", "lishment", "arian", "ism"] | 6 | Very rare → many pieces |
| "🚀" | ["<0xF0>", "<0x9F>", "<0x9A>", "<0x80>"] | 4 | Emoji = raw UTF-8 bytes |
| "Paris" | ["Paris"] | 1 | Very common → single token |
Context windows are measured in tokens, not words. On average, 1 word ≈ 1.3 tokens in English. A 128K token context window holds roughly 96,000 words — about 192 pages of text. Non-English languages are often less efficient: Chinese and Arabic can use 2–3× more tokens per word since the tokenizer was trained on mostly English text.
Embeddings — Turning Tokens into Vectors
After tokenisation, each token ID is mapped to a dense vector called an embedding. This is a lookup table: token 4821 → a vector of 4096 numbers (for a 7B parameter model). These vectors live in a high-dimensional space where semantic similarity maps to geometric proximity.
The famous example: King − Man + Woman ≈ Queen. The vector arithmetic works because the embedding space has learned that the "royalty" direction and the "gender" direction are orthogonal, consistent dimensions.
import torch
import torch.nn as nn
# Simple embedding + sinusoidal position encoding
class TokenEmbedding(nn.Module):
def __init__(self, vocab_size: int, d_model: int):
super().__init__()
self.embed = nn.Embedding(vocab_size, d_model)
self.d_model = d_model
def forward(self, x): # x: [batch, seq_len] token IDs
return self.embed(x) * (self.d_model ** 0.5) # scale by √d_model
# Example usage
vocab_size, d_model, seq_len = 50257, 768, 10
embedder = TokenEmbedding(vocab_size, d_model)
tokens = torch.randint(0, vocab_size, (1, seq_len)) # 1 batch, 10 tokens
embedded = embedder(tokens)
print(embedded.shape) # torch.Size([1, 10, 768])
The Transformer — Architecture Overview
The Transformer (Vaswani et al., 2017, "Attention Is All You Need") is the backbone of every modern LLM. It replaced RNNs and LSTMs by discarding recurrence entirely and relying solely on attention mechanisms. The result: massively parallelisable training and far superior long-range dependency modelling.
That parallel, mutual-awareness process is self-attention. The Transformer runs it across all token positions simultaneously, which is why it parallelises so well on GPUs.
x = x + FFN(Norm(x))
Self-Attention — The Heart of the Transformer
Self-attention is the mechanism that lets every token "look at" every other token in the sequence and decide how much information to gather from each. It computes three vectors for every token — Query (Q), Key (K), and Value (V) — and uses dot products to compute relevance scores.
In self-attention: every token is simultaneously the searcher and a book on the shelf. The word "bank" in "river bank" searches for context and finds "river" and "fish" most relevant — so it blends their information to resolve its meaning. This is how attention resolves ambiguity.
import torch
import torch.nn.functional as F
import math
def scaled_dot_product_attention(Q, K, V, mask=None):
"""
Q: [batch, heads, seq, d_k]
K: [batch, heads, seq, d_k]
V: [batch, heads, seq, d_v]
"""
d_k = Q.shape[-1]
# Step 1: Compute raw scores
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
# scores: [batch, heads, seq, seq]
# Step 2: Apply causal mask (upper triangle = -inf)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
# Step 3: Softmax → attention weights
attn_weights = F.softmax(scores, dim=-1)
# Step 4: Weighted sum of values
output = torch.matmul(attn_weights, V)
return output, attn_weights
# --- Demo ---
batch, heads, seq_len, d_k = 1, 8, 5, 64
Q = torch.randn(batch, heads, seq_len, d_k)
K = torch.randn(batch, heads, seq_len, d_k)
V = torch.randn(batch, heads, seq_len, d_k)
# Causal mask: lower triangle = 1 (allowed), upper = 0 (masked)
mask = torch.tril(torch.ones(seq_len, seq_len)).unsqueeze(0).unsqueeze(0)
out, weights = scaled_dot_product_attention(Q, K, V, mask)
print(f"Output shape: {out.shape}")
print(f"Weights shape: {weights.shape}")
print(f"Weight row sum (should be 1.0): {weights[0,0,2].sum():.4f}")
Multi-Head Attention — Many Perspectives at Once
A single attention head might learn to track grammatical subject-verb agreement. Another might track pronoun references. Another might group semantically related words. Multi-Head Attention runs H parallel attention heads, each with its own Q/K/V projection weights, then concatenates and projects their outputs.
A single head learns one "type" of relationship. Multiple heads in parallel learn many relationship types simultaneously. GPT-4 reportedly uses 96 heads. Each head operates on a d_head = d_model / H sub-space (e.g. 4096/32 = 128 dimensions per head). The total computation is the same as one big head — but the multi-perspective representation is far richer.
class MultiHeadAttention(nn.Module):
def __init__(self, d_model: int, n_heads: int):
super().__init__()
assert d_model % n_heads == 0
self.n_heads = n_heads
self.d_head = d_model // n_heads
# Projection matrices for Q, K, V and output
self.W_q = nn.Linear(d_model, d_model, bias=False)
self.W_k = nn.Linear(d_model, d_model, bias=False)
self.W_v = nn.Linear(d_model, d_model, bias=False)
self.W_o = nn.Linear(d_model, d_model, bias=False)
def forward(self, x, mask=None):
B, T, D = x.shape
# Project and split into heads: [B, T, D] → [B, H, T, d_head]
def split_heads(t):
return t.view(B, T, self.n_heads, self.d_head).transpose(1, 2)
Q = split_heads(self.W_q(x))
K = split_heads(self.W_k(x))
V = split_heads(self.W_v(x))
# Attention for all heads in parallel
attn_out, _ = scaled_dot_product_attention(Q, K, V, mask)
# Recombine heads: [B, H, T, d_head] → [B, T, D]
attn_out = attn_out.transpose(1, 2).contiguous().view(B, T, D)
return self.W_o(attn_out) # final output projection
The Feed-Forward Network — Where Knowledge Lives
After attention gathers context, a Feed-Forward Network (FFN) processes each token independently. It's a 2-layer MLP with a wide hidden layer (4×d_model), applying a non-linear activation. Researchers believe this is where factual knowledge is stored — attention routes information, the FFN transforms it.
| Component | Detail |
|---|---|
| Layer 1 | Linear: d_model → 4·d_model |
| Activation | ReLU (max(0, x)) |
| Layer 2 | Linear: 4·d_model → d_model |
| Used in | GPT-1, GPT-2, BERT |
| Component | Detail |
|---|---|
| Layer 1 | Two parallel linears: gate + value |
| Activation | Swish(gate) · value (gated) |
| Layer 2 | Linear: hidden → d_model |
| Used in | LLaMA, Gemma, Mistral, Claude |
class SwiGLU_FFN(nn.Module):
"""SwiGLU feed-forward — used in LLaMA, Mistral, Gemma."""
def __init__(self, d_model: int, hidden: int):
super().__init__()
self.gate = nn.Linear(d_model, hidden, bias=False)
self.value = nn.Linear(d_model, hidden, bias=False)
self.proj = nn.Linear(hidden, d_model, bias=False)
def forward(self, x):
# Swish(gate(x)) acts as a learned, smooth gating signal
return self.proj(
F.silu(self.gate(x)) * self.value(x)
)
# Sanity check
ffn = SwiGLU_FFN(d_model=4096, hidden=11008) # LLaMA-7B dims
x = torch.randn(1, 10, 4096)
print(ffn(x).shape) # [1, 10, 4096] — same shape in, same shape out
# Count parameters in FFN alone (7B model has 32 layers)
total_ffn_params = sum(p.numel() for p in ffn.parameters())
print(f"FFN params per layer: {total_ffn_params:,}")
print(f"All 32 layers FFN total: {total_ffn_params * 32:,}")
Research (Dai et al., 2022) found that specific neurons in the FFN layers activate for specific factual associations — e.g. a neuron that fires when the model is about to output "Paris" for the query "capital of France". Suppressing these neurons degrades factual accuracy. This suggests the FFN is where factual memories are stored, while attention is the routing mechanism.
Normalisation — Keeping Training Stable
Deep networks suffer from vanishing and exploding gradients. Normalisation layers stabilise the distribution of activations, enabling training of 100+ layer networks. Modern LLMs have largely moved from LayerNorm to RMSNorm.
class RMSNorm(nn.Module):
"""RMS Normalisation — used in LLaMA, Mistral, Gemma."""
def __init__(self, d_model: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.scale = nn.Parameter(torch.ones(d_model)) # learnable γ
def forward(self, x):
# RMS = sqrt(mean(x²) + eps)
rms = x.pow(2).mean(-1, keepdim=True).add(self.eps).sqrt()
return (x / rms) * self.scale
Positional Encoding — Giving Tokens a Sense of Place
Attention is permutation-invariant by design — shuffle the tokens and you get the same attention scores (just reordered). Positional encodings inject sequence order so the model knows "this token is at position 47".
Scaling Laws — Why Bigger Works Better
In 2022, Google DeepMind published the Chinchilla paper (Hoffmann et al.) and showed that GPT-3 (175B params, 300B tokens) was severely undertrained. The compute-optimal recipe: for a model of N parameters, train on approximately 20×N tokens. Chinchilla (70B params, 1.4T tokens) matched GPT-3 at a fraction of the inference cost.
Modern LLMs like LLaMA-3 push further: 8B params on 15T tokens. Overtrain deliberately for a better inference-time model — compute is cheap at inference, training happens once.
| Model | Parameters | Training Tokens | Tokens/Param | Status |
|---|---|---|---|---|
| GPT-3 | 175B | 300B | 1.7× | Undertrained (pre-Chinchilla) |
| Chinchilla | 70B | 1.4T | 20× | Compute-optimal |
| LLaMA-2 7B | 7B | 2T | 286× | Intentionally overtrained |
| LLaMA-3 8B | 8B | 15T | 1,875× | Heavily overtrained for inference |
| Mistral 7B | 7B | ~1T | 143× | Efficient, architecture innovations |
Grouped Query Attention & KV Cache
During inference, the model generates tokens one at a time. Without caching, it would recompute K and V for all previous tokens at every step — O(T²) computation. The KV Cache stores past K and V tensors, reducing generation to O(T) per step.
For a 70B model with 80 layers, 64 heads, 4096 d_model, and 128K context: KV cache = 2 × 80 × 64 × 4096 × 128000 × 2 bytes ≈ ~270 GB — more than the model weights themselves. This is why long-context inference is memory-constrained, not compute-constrained.
| Property | Value |
|---|---|
| KV heads | Equal to Q heads (e.g. 32) |
| KV cache size | 2 × n_heads × d_head × seq_len |
| Memory cost | Very high at long contexts |
| Used in | GPT-2, original GPT-3 |
| Property | Value |
|---|---|
| KV heads | Fraction of Q heads (e.g. 8) |
| KV cache size | 2 × n_kv_heads × d_head × seq_len |
| Memory cost | 4–8× smaller cache |
| Used in | LLaMA-2, LLaMA-3, Mistral, Gemma |
Multi-Query Attention (Shazeer 2019) takes it to the extreme: all Q heads share a single K and V head. Reduces KV cache by n_heads×. Small quality loss, massive memory saving. Used in Falcon, early PaLM. GQA is a compromise: G groups of Q heads share a KV head. LLaMA-3 uses G=8 (32 Q heads, 8 KV heads = 4× cache reduction vs MHA).
Pre-Training — The Foundation Phase
Pre-training is the most expensive phase: billions of parameters, trillions of tokens, thousands of GPUs, months of wall-clock time. The goal is to build a model with rich world knowledge and language understanding. Everything downstream — instruction following, RLHF — fine-tunes this foundation.
Tensor Parallelism: Split attention heads across GPUs within a node.
Pipeline Parallelism: Different layers on different nodes. Micro-batches fill the pipeline.
Modern: ZeRO sharding (DeepSpeed/FSDP) eliminates optimizer state redundancy.
# Simplified pre-training loop (conceptual)
import torch
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
model = LLM(vocab_size=128256, d_model=4096, n_layers=32, n_heads=32)
optimiser = AdamW(model.parameters(), lr=3e-4,
betas=(0.9, 0.95), weight_decay=0.1)
scheduler = CosineAnnealingLR(optimiser, T_max=1_000_000, eta_min=3e-5)
for step, batch in enumerate(dataloader):
input_ids = batch["input_ids"] # [B, T]
labels = input_ids[:, 1:] # next-token labels: shift right
logits = model(input_ids[:, :-1]) # [B, T-1, vocab]
# Cross-entropy loss over all positions
loss = F.cross_entropy(
logits.reshape(-1, logits.shape[-1]),
labels.reshape(-1),
ignore_index=-100 # padding tokens
)
optimiser.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimiser.step()
scheduler.step()
if step % 100 == 0:
print(f"Step {step}: loss={loss.item():.4f}, perplexity={loss.item().exp():.2f}")
Post-Training — Instruction Tuning & Alignment
A pre-trained LLM is a raw text completer, not an assistant. If you prompt it "What is the capital of France?", it might continue "...was debated by historians...". Post-training transforms it into a helpful, safe, instruction-following model.
Mixture of Experts — Efficient Scaling
This is Mixture of Experts (MoE). The FFN in each Transformer layer is replaced with E expert FFNs. A learned router selects the top-K experts for each token. Only K of E experts compute — so a model with 8× the FFN parameters uses roughly the same compute as a dense model.
| Model | Total Params | Active Params | Experts | Top-K |
|---|---|---|---|---|
| Mixtral 8×7B | 47B | 12B | 8 experts | Top 2 per token |
| Mixtral 8×22B | 141B | 39B | 8 experts | Top 2 per token |
| GPT-4 (rumoured) | ~1.8T | ~220B | ~16 experts | Top 2 per token |
| DeepSeek-V3 | 671B | 37B | 256 experts | Top 8 per token |
| Qwen2-57B-A14B | 57B | 14B | 64 experts | Top 8 per token |
class MoELayer(nn.Module):
"""Sparse Mixture of Experts FFN layer."""
def __init__(self, d_model: int, n_experts: int, top_k: int, hidden: int):
super().__init__()
self.n_experts = n_experts
self.top_k = top_k
# Router: maps each token to expert logits
self.router = nn.Linear(d_model, n_experts, bias=False)
# E independent FFN experts
self.experts = nn.ModuleList([
SwiGLU_FFN(d_model, hidden) for _ in range(n_experts)
])
def forward(self, x):
B, T, D = x.shape
x_flat = x.view(-1, D) # [B*T, D]
# Route: pick top-K experts per token
logits = self.router(x_flat) # [B*T, E]
weights = F.softmax(logits, dim=-1)
top_w, top_idx = torch.topk(weights, self.top_k, dim=-1) # [B*T, K]
top_w = top_w / top_w.sum(dim=-1, keepdim=True) # renormalise
# Aggregate expert outputs
out = torch.zeros_like(x_flat)
for k in range(self.top_k):
expert_id = top_idx[:, k] # [B*T]
for e in range(self.n_experts):
mask = (expert_id == e)
if mask.any():
out[mask] += top_w[mask, k:k+1] * self.experts[e](x_flat[mask])
return out.view(B, T, D)
Quantisation — Smaller, Faster, Cheaper
A 70B parameter model in FP32 requires ~280 GB VRAM — four A100s. Quantisation reduces the precision of weights and/or activations, dramatically shrinking memory and speeding up inference. The tradeoff: some accuracy loss.
# Load a 70B model in 4-bit quantisation using bitsandbytes
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
import torch
quantisation_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16, # compute in bf16 for speed
bnb_4bit_quant_type="nf4", # NormalFloat4 — best quality at 4-bit
bnb_4bit_use_double_quant=True, # quantise the quantisation constants too
)
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Meta-Llama-3-70B-Instruct",
quantization_config=quantisation_config,
device_map="auto", # auto-distribute across available GPUs
)
print(f"Model loaded. Memory: {model.get_memory_footprint() / 1e9:.1f} GB")
Fine-Tuning Efficiently — LoRA & PEFT
Full fine-tuning a 70B model requires 70B×4 bytes of weights + gradients + optimiser states ≈ 1–2 TB VRAM. Parameter-Efficient Fine-Tuning (PEFT) methods freeze most of the model and add a tiny number of trainable parameters.
from peft import LoraConfig, get_peft_model, TaskType
from transformers import AutoModelForCausalLM
# LoRA config: rank 16, applied to Q and V projection matrices
lora_config = LoraConfig(
r=16, # rank of the decomposition (A·B where A:[d,r], B:[r,d])
lora_alpha=32, # scaling factor: effective_lr = alpha/r
target_modules=["q_proj", "v_proj"], # which weight matrices to adapt
lora_dropout=0.05,
bias="none",
task_type=TaskType.CAUSAL_LM
)
base_model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B")
model = get_peft_model(base_model, lora_config)
model.print_trainable_parameters()
Inference & Decoding Strategies
After training, the model outputs logits (unnormalised scores) over its vocabulary at each step. How you convert those logits to actual text is a decoding strategy — and the choice dramatically affects the character of the output.
| Strategy | How It Works | When to Use | Risk |
|---|---|---|---|
| Greedy | Always pick the highest probability token | Deterministic tasks, structured outputs | Repetitive, degenerate loops |
| Temperature Sampling | Divide logits by T before softmax. T<1 = sharp, T>1 = flat | Creative writing, dialogue | High T → incoherent text |
| Top-K | Sample only from K most probable tokens (e.g. K=50) | General purpose | Fixed K ignores dynamic prob. mass |
| Top-P (Nucleus) | Sample from the smallest set of tokens whose cumulative prob. ≥ P | Best general-purpose default | Slight overhead |
| Beam Search | Keep B candidate sequences, expand all, prune to B best | Translation, summarisation | Slow, verbose, "safe" outputs |
import torch
import torch.nn.functional as F
def sample_next_token(logits, temperature=0.8, top_p=0.9, top_k=50):
"""
logits: [vocab_size] — raw model output for next position
Returns: sampled token id (int)
"""
# 1. Temperature scaling
logits = logits / temperature
# 2. Top-K filtering: zero out all but top K
if top_k > 0:
top_k_vals = torch.topk(logits, top_k).values[-1]
logits[logits < top_k_vals] = -float("inf")
# 3. Top-P (nucleus) filtering
probs = F.softmax(logits, dim=-1)
sorted_probs, sorted_idx = torch.sort(probs, descending=True)
cumulative = torch.cumsum(sorted_probs, dim=-1)
# Remove tokens once cumulative prob exceeds top_p
remove_mask = cumulative - sorted_probs > top_p
sorted_probs[remove_mask] = 0
probs = torch.zeros_like(logits).scatter_(0, sorted_idx, sorted_probs)
probs = probs / probs.sum() # renormalise
# 4. Sample
return torch.multinomial(probs, num_samples=1).item()
LLM Architecture Comparison — Modern Landscape
| Model | Params | Context | Architecture Innovations | Licence |
|---|---|---|---|---|
| GPT-4 | ~1.8T (rumoured MoE) | 128K | Closed — architecture undisclosed | Proprietary |
| Claude 3.5 | Undisclosed | 200K | Constitutional AI, long context | Proprietary |
| LLaMA-3 70B | 70B | 128K | GQA, RoPE, SwiGLU, 15T tokens | Open weights |
| Mistral 7B | 7B | 32K | Sliding window attention, GQA | Apache 2.0 |
| Mixtral 8×7B | 47B (12B active) | 32K | Sparse MoE, 8 experts, top-2 | Apache 2.0 |
| DeepSeek-V3 | 671B (37B active) | 128K | MoE 256 experts, MLA attention | MIT |
| Gemma 2 27B | 27B | 8K | Alternating sliding/global attn | Open weights |
Evaluation — How We Measure LLM Quality
If benchmark questions appear in training data, scores are inflated and meaningless. Major labs run contamination analyses, but it's nearly impossible to fully verify at 15T token scale. As a rule: high benchmark scores with no methodology → treat with scepticism. Chatbot Arena (real users, blind comparison) is the hardest to contaminate and the most trusted signal.