从 GPT-2 到现代模型
MiniGPT 能跑了,但它用的是「教学版零件」:Post-Norm、ReLU FFN、正弦位置编码。你打开 LLaMA 3 的代码会发现——几乎每个零件都换过了。
这一节从教学版 Block 出发,逐步升级到现代 LLM 的真实组件,看看每个改进到底解决了什么问题。
先澄清一个容易混淆的事实:GPT-2 并不是简单照搬原始 Transformer 的 Post-Norm + ReLU 配方。GPT-2 已经使用 Pre-LN 风格的 Block、GELU 激活函数和可学习的位置嵌入。
所以这一节的对比对象不是「真实 GPT-2 的每一行代码」,而是我们前面为了教学写出的朴素 MiniGPT。它 代表了原始 Transformer 里最容易理解的一组零件:LayerNorm、ReLU FFN、Post-Norm、正弦位置编码。
为什么现代 LLM 多数选择 Decoder-Only?
Transformer 原论文里明明有 Encoder 和 Decoder 两部分,现代 LLM 却大多只保留了 Decoder。
先给一个直觉:Decoder-Only 把所有任务都统一成了「继续往后写」。
翻译可以写成:英文:I love you\n中文:,让模型继续写「我爱你」。问答可以写成:问题:太阳为什么会发光?\n回答:,让模型继续写答案。写代码、总结、推理、对话,本质上也都能变成「给一段前缀,让模型预测后面的 token」。
这个统一非常重要。Encoder-Only(比如 BERT)擅长看完整句话做理解,但它不是天然的生成器;Encoder-Decoder(比如 T5)当然能生成,但结构更复杂,训练和推理时要同时维护 Encoder 端和 Decoder 端。Decoder-Only 更像一台单向打字机:它只学一件事——根据已经看到的内容预测下一个 token。
在大规模训练中,这一点尤其划算——互联网文本天然就是一长串 token。对于一段长度为 1000 的文本,Decoder-Only 可以在几乎每个位置都产生训练信号:第 1 个 token 预测第 2 个,第 2 个预测第 3 个,一直到最后。数据格式简单,loss 定义简单,模型目标也简单。
推理阶段也有天然优势。自回归生成时,模型每次只新增一个 token,历史 token 的 K、V 可以缓存在 KV Cache 里,下次只算新 token 对历史的 Attention。这种增量计算的模式让工程优化 变得很自然。
所以不是 Encoder 没价值,而是现代通用 LLM 的核心需求变成了:既要理解输入,又要继续生成输出。Decoder-Only 通过大规模 next-token prediction 学会了在同一个框架里同时做这两件事。理解能力不是单独接一个 Encoder 得来的,而是在海量「预测下一个 token」中长出来的。
从教学版到现代版
Decoder-Only 架构确定之后,Block 内部的每个零件就成了优化的重点。从 2017 年到今天,几乎每个组件都经历了替换:
| 教学版零件 | 现代替代 | 解决的问题 |
|---|---|---|
| Post-Norm | Pre-Norm(+ RMSNorm) | 深层训练梯度不稳定 |
| ReLU | GELU / SwiGLU | 负数区梯度消失 |
| 正弦位置编码 | RoPE | 长度外推能力差 |
| 标准 MHA | GQA / MLA | 长上下文 KV Cache 显存过大 |
每一行对应一个具体的工程问题。这一节就从教学版 Block 出发,依次升级归一化、激活函数、位置编码和注意力机制,看看每次替换到底解决了什么。
0. 回顾 Transformer Block
先把你已经熟悉的 教学版 Block 贴出来,我们就在它上面一个一个改。
# === 这就是 教学版 MiniGPT 的版本,我们接下来的「改造对象」 ===
import torch.nn as nn
import torch.nn.functional as F
class FeedForward_Old(nn.Module):
"""原始 FFN:两层 Linear,中间 ReLU"""
def __init__(self, d_model, d_ff=None):
super().__init__()
d_ff = d_ff or 4 * d_model
self.fc1 = nn.Linear(d_model, d_ff)
self.fc2 = nn.Linear(d_ff, d_model)
def forward(self, x):
return self.fc2(F.relu(self.fc1(x)))
class TransformerBlock_Old(nn.Module):
"""Post-Norm + LayerNorm + ReLU FFN(教学版)"""
def __init__(self, d_model, num_heads, d_ff=None):
super().__init__()
self.attention = nn.MultiheadAttention(d_model, num_heads, batch_first=True)
self.ffn = FeedForward_Old(d_model, d_ff)
self.norm1 = nn.LayerNorm(d_model) # ← 普通 LayerNorm
self.norm2 = nn.LayerNorm(d_model)
def forward(self, x, mask=None):
# Post-Norm: 先算子层,再 +残差,最后 Norm
x = self.norm1(x + self.attention(x, x, x, need_weights=False)[0])
x = self.norm2(x + self.ffn(x))
return x
print("✅ 这是教学版 MiniGPT 的版本。接下来我们逐个升级。")
print("升级路线: LayerNorm→RMSNorm → ReLU→SwiGLU → Post-Norm→Pre-Norm")
1. 改进一:LayerNorm → RMSNorm
1.1 LayerNorm 到底在算什么?— 用 4 个数字手算一遍
假设一个 token 的向量是 [1, 3, 5, 7](4 维,为了手算方便,这里也取 4 维)。
LayerNorm 做的事:把这个向量的均值变成 0,标准差变成 1。
就像给一群学生调分——不管原始分数多高多低,调完之后平均分 0 分,分数分散程度统一。
具体步骤:
1. 算均值 μ = (x₁ + x₂ + x₃ + x₄) / 4
2. 算方差 σ² = ((x₁-μ)² + (x₂-μ)² + (x₃-μ)² + (x₄-μ)²) / 4
3. 算标准差 σ = √σ²
4. 归一化 x' = (x - μ) / σ
5. 缩放 y = γ × x' + β
γ 和 β 是可学习的参数(让模型自己决定调完之后「放多大」「往哪偏」)。
下面用 [1, 3, 5, 7] 手算一遍:
# === LayerNorm 手工计算 ===
import torch
print("=== LayerNorm 手算:输入 x = [1, 3, 5, 7] ===")
print()
x = torch.tensor([1.0, 3.0, 5.0, 7.0])
# Step 1: 均值
mu = x.mean()
print(f"Step 1 — 均值 μ = (1+3+5+7)/4 = {mu:.1f}")
# Step 2: 方差(这里除以 N,不是 N-1)
var = torch.mean((x - mu) ** 2)
print(f"Step 2 — 方差 σ² = ((1-4)²+(3-4)²+(5-4)²+(7-4)²)/4")
print(f" = (9 + 1 + 1 + 9)/4 = {var:.1f}")
# Step 3: 标准差
sigma = torch.sqrt(var)
print(f"Step 3 — 标准差 σ = √{var:.1f} = {sigma:.4f}")
# Step 4: 归一化
x_norm = (x - mu) / sigma
print(f"Step 4 — 归一化: (x - 4)/{sigma:.4f}")
for i, (xi, xni) in enumerate(zip(x.tolist(), x_norm.tolist())):
print(f" x[{i}] = ({xi:.1f} - 4) / {sigma:.4f} = {xni: .4f}")
print(f" 归一化后: {[f'{v:.4f}' for v in x_norm.tolist()]}")
print(f" 均值: {x_norm.mean():.4f} (=0), 标准差: {x_norm.std(unbiased=False):.4f} (=1)")
# Step 5: 缩放(假设 γ=[1,1,1,1], β=[0,0,0,0])
# 初始时 γ 全 1,β 全 0,所以输出就是 x_norm
print(f"Step 5 — 缩放: γ=1, β=0 时输出 = 归一化结果")
print(f" γ 和 β 是随着训练学出来的,让模型自己决定怎么调")
1.2 LayerNorm 的问题:多算了一个不需要的东西
注意 LayerNorm 算了两个统计量:
- μ(均值):把数据中心移到 0
- σ(标准差):把数据分散程度统一
RMSNorm 的想法是问一个更务实的问题:如果把「去均值」这一步去掉,只保留缩放,模型还能不能训练好?