Attention 机制与 Transformer Block
上一节,我们给每个 token 的向量加上了位置信息。至此,输入层的工作完成:每个 token 既有语义信息("这个词是什么"),也有位置信息("这个词在第几个位置")。但这里有一个问题——这些向量是独立计算的,彼此之间没有交流。换句话说,模型此时不知道 token 之间有什么关系。
这一节从这个问题出发,引入 Attention 机制,让每个 token 能够从上下文中按需提取信息。我们先把 Attention 的计算过程拆开手算一遍,再逐步加入因果遮蔽、多头机制、前馈网络、残差连接和 LayerNorm,最终组装成一个 Transformer Block。
以 "the cat sat on the mat" 为例。如果不看上下文,sat 只是一个表示"坐"的向量。但有了上下文,模型就能知道谁在坐(cat),坐在哪里(mat)。再看 "bank":在 "river bank" 里指河岸,在 "bank account" 里指银行——四个字母完全相同,含义由周围的词决定。
这些例子指向一个核心需求:token 之间需要交换信息。Transformer 的整体结构因此分为三段:
输入层:Tokenizer → Embedding → Position Encoding ← 前三节已完成
核心层:N 个 Transformer Block ← 本节重点
输出层:Linear → logits ← 下一节
本节要搭建的是中间的 Transformer Block。Block 的核心机制叫 Attention——让每个 token 先判断上下文中谁更相关,再按相关程度加权混合信息。具体来说,我们会按以下顺序逐步搭建:
- Scaled Dot-Product Attention:用 Q、K、V 三个向量算出注意力权重,再加权混合——这是最核心的计算
- Causal Mask:遮住未来位置,让 GPT 只能看前面的 token
- Multi-Head Attention:多组 Q/K/V 并行工作,从不同角度提取上下文信息
- Transformer Block:把 Multi-Head Attention 和前馈网络(FFN)组装在一起,加上残差连接和 LayerNorm
每一步只在前一步的基础上加一个新组件,最后拼成完整的 Block。
本节要点
通过这一节的学习,以下问题应该能够回答:
- Attention 在解决什么问题?
- Q、K、V 分别代表什么?
- Scaled Dot-Product Attention 的四步计算是什么?
- GPT 为什么需要 Causal Mask?
- Multi-Head Attention 和 Transformer Block 是怎样组装起来的?
1. Attention 的直觉
Attention 的输出可以用一组权重来描述。假设模型处理 "the cat sat on the mat",当它处理 sat 时,可能给出的关注比例是:
the cat sat on the mat
sat: 0.05 0.35 0.10 0.05 0.05 0.40
这些数字加起来等于 1。cat 和 mat 权重大,说明 sat 会从它们那里读入更多信息。最终得到的 sat 新向量,就融合了"猫坐在垫子上"这个上下文。
计算这些关注比例是 Attention 要解决的核心问题。它引入三个向量——Q(Query)、K(Key)、V(Value),分别扮演"提问"、"贴标签"和"提供内容"三个角色:
- Q(Query):我想找什么?比如 sat 在问:"谁在做这个动作?动作发生在哪里?"
- K(Key):我是什么标签?比如 cat 的 Key 可能让别人看出:"我是一个动作的发出者。"
- V(Value):我能提供什么内容?如果 cat 被关注,真正被混合进输出的是它的 Value。
一个简单的类比是查资料:Query 是想查的问题,Key 是每份资料的标题或标签,Value 是资料正文。先用 Query 和 Key 判断哪份资料相关,再把相关资料的 Value 按比例读进来。
在 Self-Attention 中,Q、K、V 都从同一个输入矩阵 X 计算而来,只是乘了三套不同的权重矩阵:
Q = X @ W_Q
K = X @ W_K
V = X @ W_V
同一份输入经过三种不同的线性投影,变成了三种角色。下面用代码把 X 构造出来,再一步步手算 Attention。
# === 词表 → 句子 → Token IDs → Embedding 查表 → X ===
import torch
import torch.nn as nn
_ = torch.manual_seed(42)
vocab = {"the": 0, "a": 1, "cat": 2, "dog": 3, "sat": 4,
"ran": 5, "on": 6, "mat": 7, "[PAD]": 8, "[UNK]": 9}
vocab_size = len(vocab)
id2word = {v: k for k, v in vocab.items()}
# 句子 → token ids
sentence = ["the", "cat", "sat"]
token_ids = [vocab[w] for w in sentence] # [0, 2, 4]
# token ids → Embedding 查表 → X
d_model = 4
embedding = nn.Embedding(vocab_size, d_model)
token_ids_tensor = torch.tensor(token_ids) # [3]
X = embedding(token_ids_tensor) # [3, 4]
print(f"词表大小: {vocab_size}, 句子: {' '.join(sentence)}")
print(f"Token IDs: {token_ids}")
print(f"X 形状: {list(X.shape)} ← [seq_len=3, d_model=4]")
print(f"\nX = (来自 Embedding 查表,不是 randn):\n{X}")
print(f"\n解释: X[0]='the' → {X[0].tolist()}")
print(f" X[1]='cat' → {X[1].tolist()}")
print(f" X[2]='sat' → {X[2].tolist()}")
# === 从上一步的 X 出发,计算 Q/K/V ===
import torch.nn as nn
seq_len = X.shape[0] # = 3
d_k = 4
# Q/K/V 都来自 X,但各自乘不同的权重矩阵(用 nn.Linear,和后面 MHA 一致)
W_Q = nn.Linear(d_model, d_k, bias=False)
W_K = nn.Linear(d_model, d_k, bias=False)
W_V = nn.Linear(d_model, d_k, bias=False)
Q = W_Q(X) # [3, 4] — 每个 token 的「查询」
K = W_K(X) # [3, 4] — 每个 token 的「标签」
V = W_V(X) # [3, 4] — 每个 token 的「内容」
print(f"X 形状: {X.shape} → Q/K/V 形状: {Q.shape}")
print(f"→ Q、K、V 都来自同一个 X,乘了不同的矩阵")
2. Scaled Dot-Product Attention
Attention 的计算分四步。第一步是用 Q 和 K 算相关度分数——第 i 行、第 j 列表示 token i 对 token j 有多感兴趣,点积越大匹配越强。
# Step 1: 注意力分数 = Q × K^T
# 第 i 行第 j 列 = token i 对 token j 的原始相关度
attention_scores = Q @ K.T # [3, 4] @ [4, 3] = [3, 3]
print(f"注意力分数矩阵 {list(attention_scores.shape)} = [{seq_len}×{seq_len}]:")
print(attention_scores)
print(f"\n第 i 行 = token {list(range(seq_len))} 对各 token 的分数")
缩放
点积结果可能很大。数值太大时,softmax 会变得过于自信,训练不稳定。
所以要除以 √d_k,让分数更稳。
为什么是 √d_k,而不是别的数?可以从方差的角度理解。假设 Q 和 K 的每个元素都是独立随机变量,均值为 0、方差为 1。那么 的方差也是 1,点积 的方差就是 。
也就是说,维度越高,点积的绝对值通常越大。当 时,点积的典型量级大约在 8 附近(),而 softmax 输入在 ±8 这个范围内已经接近饱和,梯度会非常小。
除以 相当于把点积的方差拉回 1,让 softmax 工作在梯度充足的区间。
# Step 2: 缩放 / √d_k → 防止 d_k 大时点积过大,softmax 梯度消失
import math
d_k = Q.shape[-1]
scaled_scores = attention_scores / math.sqrt(d_k)
print(f"缩放因子: √{d_k} = {math.sqrt(d_k):.2f}")
print(f"缩放前 token 0: {attention_scores[0].tolist()}")
print(f"缩放后 token 0: {scaled_scores[0].tolist()} ← 值变小,相对大小不变")
第三步,softmax 把一行分数变成概率——每一行加起来等于 1,表示这个 token 对其他 token 的关注比例。
第四步,用这些权重去混合 V。谁权重大,谁的信息就进来得多。这样每个 token 的输出就融合了它关注到的上下文。
# Step 3: Softmax → 把分数变成概率(每行加起来 = 1)
import torch.nn.functional as F
attention_weights = F.softmax(scaled_scores, dim=-1)
print(f"注意力权重矩阵 {list(attention_weights.shape)}:")
print(attention_weights)
# 验证每行和为 1
print(f"\n每行和: {attention_weights.sum(dim=-1).tolist()} ← 都是 1.0")
# Step 4: 加权求和 — 用注意力权重混合 V
output = attention_weights @ V # [3, 3] @ [3, 4] = [3, 4]
print(f"输出形状: {list(output.shape)} = [{seq_len}, {d_model}]")
print(f"\n输 入 token 0: {X[0].tolist()}")
print(f"输出 token 0: {output[0].tolist()}")
print(f"→ 不一样!因为 token 0 按权重融合了 token 1、2 的信息")
四步总结
Step 1 — 线性投影生成 Q/K/V:
同一份输入 乘三套不同权重,得到"提问""标签""内容"三个角色。
Step 2 — 算相关度分数:
第 行第 列是 token 对 token 的原始匹配分数,点积越大表示越相关。
Step 3 — 缩放 + Softmax:
除以 防止点积过大导致 softmax 饱和。Softmax 之后每行加起来等于 1,得到一组注意力权重。
Step 4 — 加权求和:
谁的权重大,谁的信息就进来得多。每个 token 的输出向量融合了它关注到的上下文 。
合并成一条公式:
3. 因果遮蔽
上一节实现的 Scaled Dot-Product Attention 里,每个 token 可以看到序列中的所有 token,包括自己后面的。在理解整句话的语义时这没有问题。但 GPT 是生成模型,它的工作方式是一个字一个字往后写。训练时也是这个逻辑:给它一句话,让它练习"看到前面的词,猜下一个词"。
GPT 为什么要练习"猜下一个词"
先退一步想:GPT 最终要做什么?给它一段开头,它接着往后写。比如输入 "从前有座山,山里有座",它应该接着写 "庙"。也就是说,GPT 的全部能力可以归结为一件事:给定已经写好的文字,判断下一个词最可能是什么。
如果它总是能猜对下一个词,就能一直写下去,写出连贯的长文。那怎么练出这个能力?就像学写作文:老师给你一句话的开头 "春 天来了,冰雪",让你填下一个词。你填 "融化",对了;填 "消融",也对;填 "键盘",就不太对。通过大量这样的练习,你逐渐学会了什么样的上文之后应该跟什么样的词。
GPT 的训练完全一样。给它一句话 "the cat sat on the mat",让它反复练习:
看到 "the" → 猜下一个词(答案是 cat)
看到 "the cat" → 猜下一个词(答案是 sat)
看到 "the cat sat" → 猜下一个词(答案是 on)
看到 "the cat sat on" → 猜下一个词(答案是 the)
看到 "the cat sat on the" → 猜下一个词(答案是 mat)
一句话就能练 5 次。训练数据里有几十亿句话,每句话都这样拆成很多个"看前文、猜后文"的练习题。练得多了,模型就学会了:the 后面通常跟名词,cat sat 后面通常跟介词,on 后面通常跟 the……
问题:Attention 让每个 token 都能看到所有 token
这些猜词任务是同时完成的。模型把 6 个 token 一起送进去,一次前向传播算出 6 个输出向量,每个输出向量各自拿去猜自己位置的下一个词。
回忆 Scaled Dot-Product Attention 的计算。第一步算出分数矩阵 。对于 3 个 token(the, cat, sat),这个矩阵是 的:
the cat sat
the: [0.2, 0.3, 0.1]
cat: [0.3, 0.5, 0.8]
sat: [0.1, 0.2, 0.6]