The Story That Explains LSTMs
Now imagine a different kind of reader β one who only remembers the last sentence they read. When they reach the lake scene, they have no idea who Marcus is or why he's trembling. This is the problem that plagued early neural networks called Vanilla RNNs.
Long Short-Term Memory (LSTM) networks were invented to solve exactly this. They are neural networks that can selectively remember, selectively forget, and selectively output β carrying relevant context across hundreds or thousands of time steps, just like a great detective who knows which clues to file away and which to ignore.
An LSTM is a special type of Recurrent Neural Network (RNN) designed to learn long-range dependencies in sequential data. Introduced by Hochreiter & Schmidhuber in 1997, it solved the infamous vanishing gradient problem that made vanilla RNNs useless for long sequences. Today, LSTMs power speech recognition, machine translation, time-series forecasting, and music generation.
An LSTM maintains two separate memory channels: a cell state (long-term memory β the novel's plot) and a hidden state (working memory β what the reader is currently focused on). Three learned gates control what flows in, what flows out, and what gets erased. The elegance is that these gates are differentiable, so they train end-to-end with backpropagation.
Why Vanilla RNNs Fail β The Vanishing Gradient
To understand why LSTMs exist, you must feel the pain of what came before. A Vanilla RNN processes sequences step by step, passing a hidden state from one time step to the next. The hidden state is the only memory it has.
During backpropagation through time (BPTT), gradients are multiplied together at every time step. If the recurrent weight matrix has eigenvalues < 1, these products shrink exponentially β vanishing to zero before reaching distant steps. The network literally cannot learn that "cat" (step 1) causes "sat" (step 20). Conversely, eigenvalues > 1 cause exploding gradients β unstable training that diverges. Vanilla RNNs sit on a knife's edge between these two catastrophes.
LSTM Cell Anatomy β The Animated Working Diagram
An LSTM cell takes three inputs β the current input xβ, the previous hidden state hβββββ, and the previous cell state Cβββββ β and produces two outputs: a new hidden state hβ and a new cell state Cβ. Inside, four learned transformations govern everything.
π±οΈ Click the diagram to pause/resume the animation. The blue highway at the top is the cell state β information can flow across the entire sequence largely untouched.
The Four Equations β What Actually Happens Inside
Every LSTM step runs exactly four vector-valued equations. Everything else is just wiring those equations together. Learn these four, and you understand the entire model.
The cell state update Cβ = fβ β Cβββ + iβ β CΜβ uses element-wise multiplication (β), not matrix multiplication. This is what prevents vanishing gradients β the gradient flows through these additions, not through saturating multiplications. It is like a highway with on-ramps (iβ β CΜβ) and off-ramps (1 β fβ), rather than a roundabout where information must fully interact at every step.
Gate Intuition β The Three Security Guards
Guard 1 β The Archivist (Forget Gate): Every morning, she reviews the files and stamps "DESTROY" on anything outdated. If a citizen moved cities, the old address gets shredded. Her stamp strength (0β1) decides how much of each record to keep.
Guard 2 β The Intake Officer (Input Gate Γ Candidate): New documents arrive at the door. He first decides which documents are worth filing (input gate, 0β1), then decides what those documents actually say (candidate values, -1 to +1). Only approved content in the right amount gets filed.
Guard 3 β The Spokesperson (Output Gate): A reporter asks for information. The spokesperson decides which filed records to share publicly β even if the room holds sensitive long-term data, only selected parts get released as the "hidden state" that downstream layers read.
The filing room itself is the cell state β a long-term memory that can hold information for hundreds of "days" (time steps) without degradation.
Output near 1: preserve it unchanged.
Example: when parsing "he left the bank" after many finance words β forget gate resets the "bank = finance" memory.
Example: when "bank" in context of "river" appears β write "bank = geography" strongly.
Near 0: keep the fact private. Near 1: broadcast it.
LSTM Unrolled Through Time
An LSTM doesn't just process one step β it processes a sequence. The same cell (same weights) is applied repeatedly, with the hidden state and cell state passed forward at each step. This is called unrolling.
Crucially, all three LSTM cells in the diagram above share the exact same weights (Wf, Wi, Wc, Wo and their biases). The "same cell applied repeatedly" means the total number of parameters does not grow with sequence length. A sequence of 1,000 words uses the same parameter count as a sequence of 10 words. This is what makes RNNs and LSTMs fundamentally different from transformers (which grow with input length via attention).
LSTM Variants β The Family Tree
Key Hyperparameters β What You Actually Tune
| Hyperparameter | Typical Range | Effect | Tip |
|---|---|---|---|
hidden_size / units |
64 β 512 | Dimensions of h and C vectors. Bigger = more capacity, more compute | Start at 128. Double if model underfits. |
num_layers |
1 β 4 | Depth of stacked LSTM. More layers = more hierarchical abstraction | 2 is usually enough; 3+ needs strong regularisation |
dropout |
0.1 β 0.5 | Applied between LSTM layers (not inside recurrent connections) | Use 0.2β0.3 first. Variational dropout if overfitting persists |
sequence_length |
task-dependent | How many time steps the LSTM sees at once (BPTT window) | Keep to β€ 200 for stability. Truncated BPTT for longer sequences |
learning_rate |
1e-4 β 1e-2 | LSTMs are sensitive to LR. Too high = divergence; too low = slow | Use Adam with lr=1e-3. Add scheduler to reduce on plateau |
gradient_clipping |
0.5 β 5.0 | Prevents exploding gradients. Clips gradient norm to this threshold | Always use with LSTMs. Value of 1.0 is safe default |
bidirectional |
True / False | Doubles parameters and computation; allows future-context awareness | Use for classification/NER. Avoid for generation or real-time prediction |
Python Implementation β From Scratch to Production (PyTorch)
We'll build an LSTM for stock price forecasting β a classic time-series task. The model predicts the next day's closing price given the last 60 days of price data.
Step 1 β Imports
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import MinMaxScaler
import matplotlib.pyplot as plt
# Reproducibility
torch.manual_seed(42)
np.random.seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
Step 2 β Data Preparation
# Load CSV with columns: Date, Close
df = pd.read_csv('stock_prices.csv', parse_dates=['Date'], index_col='Date')
prices = df[['Close']].values # shape: (N, 1)
# Normalise to [0, 1] β critical for stable LSTM training
scaler = MinMaxScaler(feature_range=(0, 1))
prices_scaled = scaler.fit_transform(prices)
# Train / test split (80% / 20%) β NO shuffling for time series!
split = int(len(prices_scaled) * 0.8)
train_data = prices_scaled[:split]
test_data = prices_scaled[split:]
print(f"Train samples: {len(train_data)} | Test samples: {len(test_data)}")
Step 3 β Sliding Window Dataset
class StockDataset(Dataset):
def __init__(self, data, seq_len=60):
self.seq_len = seq_len
X, y = [], []
for i in range(len(data) - seq_len):
X.append(data[i : i + seq_len]) # shape: (60, 1)
y.append(data[i + seq_len]) # next value
self.X = torch.FloatTensor(np.array(X)) # (N, 60, 1)
self.y = torch.FloatTensor(np.array(y)) # (N, 1)
def __len__(self):
return len(self.X)
def __getitem__(self, idx):
return self.X[idx], self.y[idx]
SEQ_LEN = 60
BATCH = 32
train_ds = StockDataset(train_data, SEQ_LEN)
test_ds = StockDataset(test_data, SEQ_LEN)
train_dl = DataLoader(train_ds, batch_size=BATCH, shuffle=True)
test_dl = DataLoader(test_ds, batch_size=BATCH, shuffle=False)
Step 4 β The LSTM Model
class LSTMForecaster(nn.Module):
def __init__(self, input_size=1, hidden_size=128,
num_layers=2, dropout=0.2, output_size=1):
super().__init__()
self.hidden_size = hidden_size
self.num_layers = num_layers
# Core LSTM β 2 stacked layers with dropout between them
self.lstm = nn.LSTM(
input_size = input_size,
hidden_size = hidden_size,
num_layers = num_layers,
dropout = dropout, # between layers only
batch_first = True # input: (batch, seq, features)
)
# Prediction head
self.fc = nn.Sequential(
nn.Dropout(dropout),
nn.Linear(hidden_size, 64),
nn.ReLU(),
nn.Linear(64, output_size)
)
def forward(self, x):
# x: (batch, seq_len, input_size)
# Initialise hβ and Cβ to zeros
h0 = torch.zeros(self.num_layers, x.size(0),
self.hidden_size).to(x.device)
c0 = torch.zeros(self.num_layers, x.size(0),
self.hidden_size).to(x.device)
# out: (batch, seq_len, hidden_size)
# hn, cn: (num_layers, batch, hidden_size)
out, (hn, cn) = self.lstm(x, (h0, c0))
# Take ONLY the last time step's hidden state
last_hidden = out[:, -1, :] # (batch, hidden_size)
return self.fc(last_hidden) # (batch, 1)
model = LSTMForecaster(
input_size = 1,
hidden_size = 128,
num_layers = 2,
dropout = 0.2,
output_size = 1
).to(device)
total_params = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {total_params:,}")
Step 5 β Training Loop with Gradient Clipping
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, patience=5, factor=0.5, verbose=True
)
EPOCHS = 50
CLIP_NORM = 1.0 # gradient clipping threshold
train_losses = []
val_losses = []
for epoch in range(1, EPOCHS + 1):
# ββ Training phase ββββββββββββββββββββββββββββββββββββββ
model.train()
epoch_loss = 0.0
for X_batch, y_batch in train_dl:
X_batch = X_batch.to(device)
y_batch = y_batch.to(device)
optimizer.zero_grad()
preds = model(X_batch) # forward pass
loss = criterion(preds, y_batch)
loss.backward() # BPTT
nn.utils.clip_grad_norm_(model.parameters(),
CLIP_NORM) # β critical
optimizer.step()
epoch_loss += loss.item() * len(X_batch)
train_loss = epoch_loss / len(train_ds)
train_losses.append(train_loss)
# ββ Validation phase βββββββββββββββββββββββββββββββββββββ
model.eval()
val_loss = 0.0
with torch.no_grad():
for X_val, y_val in test_dl:
X_val = X_val.to(device)
y_val = y_val.to(device)
val_preds = model(X_val)
val_loss += criterion(val_preds, y_val).item() * len(X_val)
val_loss /= len(test_ds)
val_losses.append(val_loss)
scheduler.step(val_loss)
if epoch % 10 == 0:
print(f"Epoch {epoch:3d}/{EPOCHS} | Train MSE: {train_loss:.6f} | Val MSE: {val_loss:.6f}")
Step 6 β Inference and Inverse Transform
model.eval()
all_preds, all_true = [], []
with torch.no_grad():
for X_batch, y_batch in test_dl:
X_batch = X_batch.to(device)
pred = model(X_batch).cpu().numpy()
all_preds.append(pred)
all_true.append(y_batch.numpy())
preds_scaled = np.concatenate(all_preds) # (N, 1)
true_scaled = np.concatenate(all_true)
# Inverse MinMax transform β real price values
preds_price = scaler.inverse_transform(preds_scaled)
true_price = scaler.inverse_transform(true_scaled)
mae = np.mean(np.abs(preds_price - true_price))
rmse = np.sqrt(np.mean((preds_price - true_price) ** 2))
print(f"MAE: ${mae:.2f}")
print(f"RMSE: ${rmse:.2f}")
# Save model
torch.save(model.state_dict(), 'lstm_forecaster.pth')
print("Model saved.")
Keras / TensorFlow Implementation
The same model in Keras β fewer lines, identical concept. Keras is often preferable for rapid prototyping and research iteration.
import numpy as np
from tensorflow import keras
from tensorflow.keras import layers, callbacks
# Build model β functional API for clarity
inputs = keras.Input(shape=(60, 1)) # (seq_len, features)
x = layers.LSTM(128, return_sequences=True,
dropout=0.2)(inputs) # layer 1 β passes full sequence
x = layers.LSTM(64, return_sequences=False,
dropout=0.2)(x) # layer 2 β returns last step
x = layers.Dense(32, activation='relu')(x)
outputs = layers.Dense(1)(x)
model = keras.Model(inputs, outputs)
model.compile(
optimizer = keras.optimizers.Adam(lr=1e-3, clipnorm=1.0), # β gradient clip
loss = 'mse',
metrics = ['mae']
)
model.summary()
# Callbacks
cbs = [
callbacks.EarlyStopping(patience=10, restore_best_weights=True),
callbacks.ReduceLROnPlateau(patience=5, factor=0.5),
callbacks.ModelCheckpoint('best_lstm.keras', save_best_only=True)
]
history = model.fit(
X_train, y_train,
validation_data = (X_test, y_test),
epochs = 100,
batch_size = 32,
callbacks = cbs,
verbose = 1
)
In a stacked LSTM, all layers except the last must use
return_sequences=True so they pass the full sequence
of hidden states to the next layer. The final LSTM layer uses
return_sequences=False to return only the last
hidden state β which then feeds into the Dense prediction head.
Getting this wrong is one of the most common LSTM bugs.
LSTM vs Transformer β When to Use Which
| Property | LSTM | Transformer |
|---|---|---|
| Memory mechanism | Recurrent cell state | Self-attention over all tokens |
| Long-range dependency | Good (60β200 steps) | Excellent (1000s of tokens) |
| Parallelisation during training | Sequential β cannot parallelise | Fully parallel β GPU-friendly |
| Real-time / streaming inference | Excellent β O(1) per step | Expensive β needs full context window |
| Parameter count vs performance | Very efficient at small scale | Needs scale to shine |
| Time-series forecasting | Still competitive | Patching approaches catching up |
| On-device / edge deployment | Lightweight, fast | Often too large |
| When to choose | Streaming data, IoT, low-latency systems, moderate-length sequences | Large NLP tasks, long context, when compute is abundant |
Despite transformers dominating NLP headlines, LSTMs remain the go-to choice for real-time time-series tasks: ECG monitoring, financial tick data, sensor fusion in robotics, and speech synthesis on embedded devices. Their O(1) per-step inference cost and small memory footprint make them irreplaceable in production edge systems where a 70B parameter transformer is simply not an option.
Real-World Applications β Where LSTMs Live in Production
Golden Rules β Non-Negotiable LSTM Practices
clip_grad_norm_(..., 1.0) in PyTorch
or clipnorm=1.0 in Keras. Exploding gradients will silently destroy training
and produce NaN losses. This is the single most common LSTM training bug.
return_sequences=True for all
layers except the last. A common mistake is forgetting this and getting a
shape mismatch error β or silently feeding only the last hidden state to a layer
that expects a full sequence.
h0, c0 = None, None
or explicitly zero them between batches. Carrying stale state from a different
sequence is a subtle but devastating bug.
model.eval() and torch.no_grad() during inference.
LSTMs have dropout which behaves differently during training vs inference. Forgetting
model.eval() means dropout randomly zeroes predictions β your model
will appear unreliable when it isn't.
h.detach(). Full BPTT over 1,000 steps is numerically unstable and
uses prohibitive memory.
Quick Reference β LSTM Cheat Sheet
| Task | Architecture | Loss | Activation |
|---|---|---|---|
| Regression / Forecasting | Stacked LSTM β Dense(1) | MSE | Linear output |
| Binary Classification | LSTM β Dense(1) | BCE | Sigmoid output |
| Multi-class Classification | LSTM β Dense(n_classes) | CrossEntropy | Softmax output |
| NER / Sequence Labelling | BiLSTM β CRF | CRF NLL | Per-token softmax |
| Text Generation | Stacked LSTM β Dense(vocab) | CrossEntropy | Temperature softmax |
| Anomaly Detection | LSTM Autoencoder | Reconstruction MSE | Linear decoder |
You have gone from the filing room analogy to the four gate equations to a full PyTorch training loop with gradient clipping. The key insight to carry forward: an LSTM works because its cell state is a gradient highway β additions rather than multiplications allow errors to propagate backwards across hundreds of time steps without vanishing. The three gates are all learned, differentiable valves that the optimizer adjusts to control information flow. That is the entire secret.