长上下文外推
LLaMA 2 训练时只见过 4096 个 token 长度的文本。但今天的模型能处理 128K 甚至 1M——同样的架构,训练时从没见过的长度,推理时居然能用。
答案在位置编码的外推技术里。这一节从 Attention 的长度瓶颈出发,逐步理解 RoPE 为什么能外推、PI 做了什么、NTK 和 YaRN 又改进了什么,最后学会测试长上下文能力。
长上下文外推(Length Extrapolation)要解决的问题是:模型在训练时只见过位置 0 到 4095 的编码,推理时却要处理位置 10000——这个位置的编码对模型来说是完全陌生的。
解决方案不是重新训练,而是让训练时学到的位置编码规律在更长的序列上继续生效。具体方法取决于位置编码本身的设计:RoPE 用旋转矩阵编码相对位置,天然具备一定的外推能力;PI 压缩位置编号;NTK 调整频率基;YaRN 对不同频率维度做差异化处理。这四种方法构成一条递进的优化路径,也是这一节的主线。
1. 什么是外推
外推(Extrapolation)= 用已知范围的规律,推测范围之外的情况。
举个生活中的例子:
- 你测过水温 0°C → 冰,50°C → 液体,100°C → 沸腾
- 现在问你:200°C 的水会怎样?你虽然没测过,但根据规律能推断 → 还是气体
- 这就是「外推」
LLM 面临同样的问题:
- 训练时:模型学习了位置 0 到 4095 之间的 attention 规律
- 推理时:用户给了一篇 10000 token 的文章
- 问题:位置 4096~9999 这些位置的 token,模型训练时从没 见过,它能正确处理它们吗?
答案取决于你用什么位置编码。
2. 位置编码回顾
(如果你对 Part 3 的 Embedding + Position 已经很熟,可以跳过这一节。但后面的 RoPE 依赖这个基础,不确定的话最好看一眼。)
Attention 本身是不关心顺序的。你把 "猫 坐 垫子" 和 "垫子 坐 猫" 喂给 attention,它算出来的 attention 分数完全相同——因为 attention 只看 token 之间「有多相关」,不看「谁在前谁在后」。
但顺序显然很重要。「我爱你」和「你爱我」意思完全不同。
位置编码就是给每个 token 贴上一个「我是第几个」的标签,让 attention 在计算相关性时能用到这个信息。
贴标签有三种方式:
# 直观感受:顺序对 attention 的影响
print("句子 A: 我 爱 你")
print("句子 B: 你 爱 我")
print()
print("没有位置编码时:")
print(" '我' 和 '你' 的 attention 分数完全一样——不管谁在前面")
print(" 模型分不清 '我爱你' 和 '你爱我'")
print()
print("有位置编码时:")
print(" '我' 在位置 0 和位置 2 有不同的向量 → attention 可以区分")
print()
print("问题来了:训练时最长只见过 4096 个位置,")
print("推理时来了 10000 个位置 → 第 4097 个位置的「标签」长什么样?")
句子 A: 我 爱 你
句子 B: 你 爱 我
没有位置编码时:
'我' 和 '你' 的 attention 分数完全一样——不管谁在前面
模型分不清 '我爱你' 和 '你爱我'
有位置编码时:
'我' 在位置 0 和位置 2 有不同的向量 → attention 可以区分
问题来了:训练时最长只见过 4096 个位置,
推理时来了 10000 个位置 → 第 4097 个位置的「标签」长什么样?
3. 三种位置编码的外推能力
| 方案 | 怎么做 | 代表模型 | 能外推吗? | 为什么? |
|---|---|---|---|---|
| 学出来的位置 | 训练时给每个位置随机初始化一个向量,训练过程中调整 | GPT-2 | ❌ 完全不能 | 只学了 0~1023 位置的向量,1024 位置的向量根本没存在过 |
| 正弦位置编码 | 用 sin/cos 函数手工算每个位置的值,不用学 | 原始 Transformer | 理论上能,实际很差 | 函数本身连续,但模型没学会利用连续性 |
| RoPE(旋转位置编码) | 用「旋转」来编码位置,位置差 = 旋转角度差 | LLaMA、Qwen、Mistral | ✅ 可以! | 相对位置天然具有外推性,且频域特性可以利用 |
RoPE 现在是几乎所有开源 LLM 的标配。我们下面搞懂它。
4. RoPE:用旋转编码相对位置
第 2 节回顾了位置编码的基本问题:Attention 本身不区分"我爱你"和"你爱我"。正弦位置编码的解法是在 Input Embedding 上加一个位置向量,让每个 token 的输入表示自带位置信息。
RoPE(Rotary Position Embedding,旋转位置编码)的思路不同。它不修改输入端的 Embedding,而是直接介入 Q 和 K 的点积计算——用旋转矩阵把相对位置信息"写进"点积结果里。
4.1 目标:让 Q 和 K 的点积依赖于相对位置
回忆 Attention 的核心操作:位置 m 的 query 向量 q_m 和位置 n 的 key 向量 k_n 做点积,点积值决定了两个 token 之间的 Attention 权重。
没有位置编码时,q_m 和 k_n 的点积只取决于两个 token 的语义内容。同一个词对,不论出现在句首还是句尾,点积完全相同——因为 Embedding 查出来的向量只跟 token ID 有关,跟位置无关。
RoPE 的目标是设计一个操作 f,把位置信息注入到 q 和 k 中,使得点积只依赖于相对位置 m-n:
为什么追求"只依赖于 m-n"?考虑一个具体的例子。句子"我 爱 你"里,"爱"在位置 1,"你"在位置 2,距离是 1。句子"昨天 我 爱 你"里,"爱"移到了位置 2,"你"移到了位置 3——但"爱"和"你"之间的距离仍然是 1。在这两个句子里,"爱"和"你"之间的语义关系应该相同。如果点积只依赖于相对距离,距离相同的词对 Attention 就相同,模型就学到了语言中最基本的平移不变性。
正弦位置编码通过加法注入位置,但加法之后的信息是混在一起的——q_m 和 k_n 各自包含绝对位置 m 和 n,它们的点积无法干净地只依赖 m-n。RoPE 通过旋转来实现这个干净的依赖关系。
4.2 二维旋转:从一个向量对的旋转入手
先看最简单的情况:向量只有 2 维。平面上一个点 (x, y) 逆时针旋转角度 θ,用旋转矩阵表示:
旋转后的结果是:
旋转不改变向量的长度:。这意味着旋转不会扭曲 token 的语义信息——向量的模长(代表语义强度)保持不变,改变的只是方向。
现在把这套到 Attention 上。把位置 m 的查询向量 q_m 旋转一个与 m 成正比的角度 mθ,把位置 n 的键向量 k_n 旋转 nθ。然后算点积:
关键一步。利用旋转矩阵的两个性质:
性质一:。旋转矩阵的转置等于反向旋转——把角度取负即可。
性质二:。先反向转 α°,再正向转 β°,等价于直接转 (β-α)°。
代入得到:
等号右边只出现了 ,不出现单独的 或 。m 和 n 各自的绝对位置抵消了,只剩下它们的差。
这意味着:旋转后 q_m 和 k_n 的点积,只依赖于它们的相对位置 (n-m),不依赖于它们各自在哪个位置。 相对位置信息被天然编码进了 Q 和 K 的点积之中。
下面用具体数字手动验证这个性质。设两个向量 q = (1, 1) 和 k = (1, 1),单位旋转角 θ = 15°,位置 m=2,n=5(距离为 3)。同时再算一组 m=10,n=13(距离也是 3)——如果两组的点积相同,就验证了"只依赖 m-n"。
import math
import torch
# === 手动验证:旋转后点积只依赖于相对位置 ===
# 设 q = (1, 1), k = (1, 1), θ = 15° = π/12
theta = math.pi / 12 # 15 度
def rotate_2d(v, angle):
"""对一个二维向量施加旋转矩阵"""
x, y = v[0], v[1]
cos_a, sin_a = math.cos(angle), math.sin(angle)
x_new = x * cos_a - y * sin_a
y_new = x * sin_a + y * cos_a
return torch.tensor([x_new, y_new])
q = torch.tensor([1.0, 1.0])
k = torch.tensor([1.0, 1.0])
# 情况 A:m=2, n=5 → 距离 = 3
q_rot_A = rotate_2d(q, 2 * theta)
k_rot_A = rotate_2d(k, 5 * theta)
dot_A = torch.dot(q_rot_A, k_rot_A)
# 情况 B:m=10, n=13 → 距离也是 3
q_rot_B = rotate_2d(q, 10 * theta)
k_rot_B = rotate_2d(k, 13 * theta)
dot_B = torch.dot(q_rot_B, k_rot_B)
# 情况 C:m=0, n=3 → 距离也是 3(更极端的对比)
q_rot_C = rotate_2d(q, 0 * theta)
k_rot_C = rotate_2d(k, 3 * theta)
dot_C = torch.dot(q_rot_C, k_rot_C)
# 情况 D:m=2, n=6 → 距离 = 4(不同距离,对照)
q_rot_D = rotate_2d(q, 2 * theta)
k_rot_D = rotate_2d(k, 6 * theta)
dot_D = torch.dot(q_rot_D, k_rot_D)
print("=== 手算验证:旋转后点积只依赖于 m-n ===")
print()
print("q = (1, 1), k = (1, 1), 单位角 θ = 15°")
print()
print("情况 A: m=2, n=5 (距离=3) → 点积 =", f"{dot_A:.6f}")
print("情况 B: m=10, n=13 (距离=3) → 点积 =", f"{dot_B:.6f}")
print("情况 C: m=0, n=3 (距离=3) → 点积 =", f"{dot_C:.6f}")
print("情况 D: m=2, n=6 (距离=4) → 点积 =", f"{dot_D:.6f}")
print()
print("关键观察:")
print(" 1. A、B、C 的 m,n 各自不同,但距离都是 3 → 点积完全相同")
print(" 2. D 的距离是 4 → 点积与前三组不同")
print(" 3. 验证完毕:点积只取决于 n-m,不取决于 m 和 n 各自的绝对位置")
print()
print("这就是 RoPE 最核心的数学性质。")
# 补充:旋转前后点积的变化
dot_original = torch.dot(q, k)
print(f"\n旋转前:q·k = {dot_original:.4f}")
print(f"旋转后(距离=3):点积 ≈ {dot_A:.4f}")
print(f"相对位置的差异被编码进了点积的数值变化中")
5. 从二维推广到 d 维
上一节在二维向量上验证了旋转编码相对位置的原理。实际的 q 和 k 通常是 64 或 128 维。怎么把二维旋转推广到 d 维?
5.1 d 维向量的分组旋转
方案比想象的直接:把 d 维拆成 d/2 个互不重叠的二维对:
对 0: (dim₀, dim₁) → 在自己的 2D 平面上旋转
对 1: (dim₂, dim₃) → 在自己的 2D 平面上旋转
...
对 d/2-1: (dim_{d-2}, dim_{d-1}) → 在自己的 2D 平面上旋转
每一对独立在自己所属的二维平面上旋转,各对的旋转互不干扰。写成矩阵形式,d 维的旋转矩阵是一个分块对角矩阵——对角线上有 d/2 个 2×2 的小旋转矩阵,其余位置全是 0:
R = [R_{θ₀} 0 ... 0 ]
[ 0 R_{θ₁} ... 0 ]
[ ... ... ... ... ]
[ 0 0 ... R_{θ_{d/2-1}}]
在实际代码里不会真的构造这个 d×d 的大矩阵(大部分是 0,浪费计算和显存),而是直接用向量化操作:把相邻维度两两分组,每对独立计算旋转。
5.2 每对维度有不同的旋转速度
所有维度对都用同样的转速是不行的——那样所有维度对携带相同的位置信息,浪费了 d 维的表达能力。RoPE 给不同的维度对分配不同的旋转速度。第 i 对的单位旋转角是:
- i=0(第一对,dim₀ 和 dim₁):θ₀ = 1.0,每个位置转 1 弧度(≈ 57°),转得飞快。训练窗口 4096 个位置相当于转了 652 个完整圈,sin/cos 所有可能的值都见过了。它的职责是区分紧挨着的两个位置。
- i=31(最后一对,dim₆₂ 和 dim₆₃,d=64 时):θ₃₁ ≈ 0.00013,每个位置只转约 0.008°,极慢。在 4096 个训练位置内只走了约 0.55 弧度(31°),远不到一圈。它的职责是承载远距离的位置关系。
这和正弦位置编码的高频/低频设计思路相同,但作用的机制不同:正弦方案是不同频率波形叠加出一个位置向量加在 embedding 上;RoPE 是不同频率的旋转直接作用于 Q 和 K,通过点积编码相对距离。
5.3 RoPE 施加在 Q 和 K 上的完整流程
现在可以把 RoPE 的完整计算串起来了。给定一个 token 序列:
-
Embedding + 投影: 经过 Embedding 和 Q/K/V 投影矩阵,得到 ,,
-
分组:每个 q 和 k 按相邻维度两两分组,(dim₀, dim₁) 是第一对,(dim₂, dim₃) 是第二对,以此类推
-
逐对旋转:位置 m 的 q 的第 i 对旋转角度 ,位置 n 的 k 的第 i 对旋转角度 。用 4.2 节的公式逐对计算旋转后的坐标
-
点积:旋转后的 q 和 k 做点积。根据 4.2 节的推导,第 i 对维度对点积的贡献包含 ,整条向量的点积是各对贡献之和——结果只依赖于相对位置
-
后续不变:softmax + 乘以 V 的步骤和标准 Attention 完全相同
和正弦位置编码的结构性区别。正弦方案:位置向量加在 Embedding 上,后续 Q/K 投影会混合 token 语义和位置信息。RoPE:位置信息绕过 Embedding,直接作用在 Q 和 K 上,通过旋转矩阵精确控制点积中的位置依赖。V(Value)不参与旋转——因为 Attention 输出中不需要通过 V 来传递位置信息 ,V 只需要提供被"加权聚合"的内容即可。
下面用代码实现完整的 RoPE,并验证它的相对位置性质。
# 直接看:不同维度对的旋转速度差多少
import torch
import math
d_k = 64 # 总共 64 维,两两结对 → 32 对
base = 10000 # RoPE 默认 base
pair_indices = torch.arange(0, d_k, 2).float() # [0, 2, 4, ..., 62]
freqs = 1.0 / (base ** (pair_indices / d_k))
print(f"共 {len(freqs)} 对维度")
print(f"第 0 对(最快)频率: {freqs[0]:.4f} → 每位置转 {math.degrees(freqs[0]):.1f}°")
print(f"第 16 对(中等)频率: {freqs[16]:.6f} → 每位置转 {math.degrees(freqs[16]):.4f}°")
print(f"第 31 对(最慢)频率: {freqs[31]:.8f} → 每位置转 {math.degrees(freqs[31]):.6f}°")
slowest_period = 2 * math.pi / freqs[31]
print(f"\n最慢那对走完一圈需要 {slowest_period:.0f} 个位置")
# → 训练窗口 4096,最慢的指针连一圈 都没走完 → 这就是外推的瓶颈
共 32 对维度
第 0 对(最快)频率: 1.0000 → 每位置转 57.3°
第 16 对(中等)频率: 0.010000 → 每位置转 0.5730°
第 31 对(最慢)频率: 0.00013335 → 每位置转 0.007641°
最慢那对走完一圈需要 47117 个位置
# 画出来:不同维度的指针随位置移动(cos 值)
import torch
import matplotlib.pyplot as plt
seq_len = 200
positions = torch.arange(seq_len).float()
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
for ax_idx, (pair_idx, label) in enumerate([
(0, "快(秒针)"),
(16, "中(分针)"),
(31, "慢(时针)")
]):
theta = positions * freqs[pair_idx]
ax = axes[ax_idx]
ax.plot(positions.numpy(), theta.cos().numpy(), linewidth=1)
ax.set_xlabel('Position'); ax.set_ylabel('cos(angle)')
ax.set_title(f'Pair {pair_idx} — {label}\n{math.degrees(freqs[pair_idx]):.2f} deg per step')
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
# 左边转了好几圈(密集波形)→ 高频 → 区分邻居
# 右边不到半圈(缓慢曲线)→ 低频 → 承载远距离信息
/var/folders/fv/xkn6r25n41j9fm98mh1l73hm0000gn/T/ipykernel_49974/1492510485.py:19: UserWarning: Glyph 24555 (\N{CJK UNIFIED IDEOGRAPH-5FEB}) missing from font(s) DejaVu Sans.
plt.tight_layout()
/var/folders/fv/xkn6r25n41j9fm98mh1l73hm0000gn/T/ipykernel_49974/1492510485.py:19: UserWarning: Glyph 65288 (\N{FULLWIDTH LEFT PARENTHESIS}) missing from font(s) DejaVu Sans.
plt.tight_layout()
/var/folders/fv/xkn6r25n41j9fm98mh1l73hm0000gn/T/ipykernel_49974/1492510485.py:19: UserWarning: Glyph 31186 (\N{CJK UNIFIED IDEOGRAPH-79D2}) missing from font(s) DejaVu Sans.
plt.tight_layout()
/var/folders/fv/xkn6r25n41j9fm98mh1l73hm0000gn/T/ipykernel_49974/1492510485.py:19: UserWarning: Glyph 38024 (\N{CJK UNIFIED IDEOGRAPH-9488}) missing from font(s) DejaVu Sans.
plt.tight_layout()
/var/folders/fv/xkn6r25n41j9fm98mh1l73hm0000gn/T/ipykernel_49974/1492510485.py:19: UserWarning: Glyph 65289 (\N{FULLWIDTH RIGHT PARENTHESIS}) missing from font(s) DejaVu Sans.
plt.tight_layout()
/var/folders/fv/xkn6r25n41j9fm98mh1l73hm0000gn/T/ipykernel_49974/1492510485.py:19: UserWarning: Glyph 20013 (\N{CJK UNIFIED IDEOGRAPH-4E2D}) missing from font(s) DejaVu Sans.
plt.tight_layout()
/var/folders/fv/xkn6r25n41j9fm98mh1l73hm0000gn/T/ipykernel_49974/1492510485.py:19: UserWarning: Glyph 20998 (\N{CJK UNIFIED IDEOGRAPH-5206}) missing from font(s) DejaVu Sans.
plt.tight_layout()
/var/folders/fv/xkn6r25n41j9fm98mh1l73hm0000gn/T/ipykernel_49974/1492510485.py:19: UserWarning: Glyph 24930 (\N{CJK UNIFIED IDEOGRAPH-6162}) missing from font(s) DejaVu Sans.
plt.tight_layout()
/var/folders/fv/xkn6r25n41j9fm98mh1l73hm0000gn/T/ipykernel_49974/1492510485.py:19: UserWarning: Glyph 26102 (\N{CJK UNIFIED IDEOGRAPH-65F6}) missing from font(s) DejaVu Sans.
plt.tight_layout()
/Users/sanbu/miniconda3/lib/python3.12/site-packages/IPython/core/pylabtools.py:170: UserWarning: Glyph 24555 (\N{CJK UNIFIED IDEOGRAPH-5FEB}) missing from font(s) DejaVu Sans.
fig.canvas.print_figure(bytes_io, **kw)
/Users/sanbu/miniconda3/lib/python3.12/site-packages/IPython/core/pylabtools.py:170: UserWarning: Glyph 65288 (\N{FULLWIDTH LEFT PARENTHESIS}) missing from font(s) DejaVu Sans.
fig.canvas.print_figure(bytes_io, **kw)
/Users/sanbu/miniconda3/lib/python3.12/site-packages/IPython/core/pylabtools.py:170: UserWarning: Glyph 31186 (\N{CJK UNIFIED IDEOGRAPH-79D2}) missing from font(s) DejaVu Sans.
fig.canvas.print_figure(bytes_io, **kw)
/Users/sanbu/miniconda3/lib/python3.12/site-packages/IPython/core/pylabtools.py:170: UserWarning: Glyph 38024 (\N{CJK UNIFIED IDEOGRAPH-9488}) missing from font(s) DejaVu Sans.
fig.canvas.print_figure(bytes_io, **kw)
/Users/sanbu/miniconda3/lib/python3.12/site-packages/IPython/core/pylabtools.py:170: UserWarning: Glyph 65289 (\N{FULLWIDTH RIGHT PARENTHESIS}) missing from font(s) DejaVu Sans.
fig.canvas.print_figure(bytes_io, **kw)
/Users/sanbu/miniconda3/lib/python3.12/site-packages/IPython/core/pylabtools.py:170: UserWarning: Glyph 20013 (\N{CJK UNIFIED IDEOGRAPH-4E2D}) missing from font(s) DejaVu Sans.
fig.canvas.print_figure(bytes_io, **kw)
/Users/sanbu/miniconda3/lib/python3.12/site-packages/IPython/core/pylabtools.py:170: UserWarning: Glyph 20998 (\N{CJK UNIFIED IDEOGRAPH-5206}) missing from font(s) DejaVu Sans.
fig.canvas.print_figure(bytes_io, **kw)
/Users/sanbu/miniconda3/lib/python3.12/site-packages/IPython/core/pylabtools.py:170: UserWarning: Glyph 24930 (\N{CJK UNIFIED IDEOGRAPH-6162}) missing from font(s) DejaVu Sans.
fig.canvas.print_figure(bytes_io, **kw)
/Users/sanbu/miniconda3/lib/python3.12/site-packages/IPython/core/pylabtools.py:170: UserWarning: Glyph 26102 (\N{CJK UNIFIED IDEOGRAPH-65F6}) missing from font(s) DejaVu Sans.
fig.canvas.print_figure(bytes_io, **kw)

6. 直接外推为什么失败
训练时最长序列是 4096。推理时来了 8192 个 token。最简单的想法:让 RoPE 按正常规则往后数——位置 4097, 4098, …, 8191,正常用 m×θ_i 算旋转角。行不行?
结论是:不行。但理解为什么不行,是掌握所有外推方法的关键。
6.1 快维度和慢维度的差异
从第 5 节我们知道,不同的维度对转速差异极大。这意味着在 4096 个训练位置内,不同维度经历的角度范围不同,它们的"阅历"也不同。下面这张表列出了 d=64 时几个典型维度对在训练窗口内的情况:
| 维度对 | 频率 θ_i | 训练窗内转动总角度 | 折合圈数 | sin/cos 见过的范围 |
|---|---|---|---|---|
| i=0(最快) | 1.0 | 4096 rad | ~652 圈 | 全部 [-1, 1],所有形状 |
| i=8 | ~0.1 | ~410 rad | ~65 圈 | 全部 [-1, 1] |
| i=16 | ~0.01 | ~41 rad | ~6.5 圈 | 全部 [-1, 1] |
| i=24 | ~0.001 | ~4.1 rad | ~0.65 圈 | 约 0.65 个周期的 sin/cos,[-1, 1] 的约 65% |
| i=31(最慢) | ~0.00013 | ~0.55 rad | ~0.09 圈 | 约 0.09 个周期的 sin/cos,[0, 0.52] 区 间 |
核心观察:快维度(i≤16)在 4096 个位置内转了很多整圈,sin 和 cos 在 [-1, 1] 之间的各种值——上升、下降、波峰、波谷、拐点——都反复见过了。它们对任何角度值都有充分的训练。
慢维度(i≥24)则不然。i=31 那对在整个训练窗口内只转了约 31°(0.55 弧度),连一圈的 1/10 都不到。它们只见过 cos 从 1 缓慢降到 0.85、sin 从 0 升到 0.52 这段小区间。任何超出这个区间的角度值,对它们来说都是完全陌生的。
6.2 外推触碰了陌生角度
当推理长度扩展到 8192 时,位置 8191 在最快维度上的角度是 8191 弧度,但由于已经转过很多圈,8191 rad mod 2π 仍然在熟知的范围内——快维度对任何角度值都有经验。
问题出在慢维度。i=31 在位置 8191 的角度是 8191 × 0.00013 ≈ 1.1 弧度(63°)。在 4096 的训练窗口内,这个维度最大只见过 0.55 弧度(31°)。63° 远远超出了训练范围——sin(63°) ≈ 0.89 和 cos(63°) ≈ 0.45 这两个值,模型从未在这个维度上见过。
这就引出了一个关键认识:不是所有维度都有外推问题。只有那些在训练窗口内没走完一圈的慢维度,才会在超长位置遇到陌生角度。 快维度转了很多圈,什么角度都见过,可以直接外推。
这解释了为什么学出来的位置编码(GPT-2)完全不能外推——它的每一维都是一个独立的、只在训练窗口内学过的值,没有"转圈"的概念,所有维度等价于"没走完一圈的慢维度"。也解释了为什么正弦位置编码理论上能外推但实际效果差——虽然所有维度都用 sin/cos,但模型很难学会利用这种连续性的外推能力。而 RoPE 天然暴露了不同维度的频域结构,这恰好给了我们一个可以操作的杠杆。
6.3 用时钟来理解
快维度:一个走了 652 圈的秒针——什么地方都去过,经验丰富。慢维度:一个只在 0° 到 31° 之间来回摆动的时针——只熟悉这个小区间。现在让它指向 63°,它无法理解。
下面画出 i=31 这个最慢维度在训练窗口内外的余弦值,直观感受"角度超标"。
# 直接外推的问题:低频维度在训练窗口外角度超标
import torch
import matplotlib.pyplot as plt
import math
train_len, extrap_len = 4096, 8192
slow_pair = 31
positions_train = torch.arange(train_len).float()
positions_extrap = torch.arange(extrap_len).float()
theta_train = positions_train * freqs[slow_pair]
theta_extrap = positions_extrap * freqs[slow_pair]
fig, ax = plt.subplots(figsize=(10, 4))
ax.plot(positions_extrap.numpy(), theta_extrap.cos().numpy(),
linewidth=1, color='orange', label='cos value at inference')
ax.axvline(x=train_len, color='red', linestyle='--', linewidth=2, alpha=0.7)
ax.fill_between(range(train_len, extrap_len), -1.2, 1.2,
alpha=0.1, color='red', label='unseen angle range')
ax.set_xlabel('Position'); ax.set_ylabel('cos(angle)')
ax.set_title(f'Direct extrapolation to 8192 (low-frequency dim #{slow_pair})\nRight of red line = unseen angle range')
ax.legend(); ax.grid(True, alpha=0.3)
plt.show()
# 训练时角度 0~{train_deg:.0f}°,外推时到了 {extrap_deg:.0f}°,超出 {over:.0f}°
train_deg = math.degrees(theta_train[-1].item())
extrap_deg = math.degrees(theta_extrap[-1].item())
print(f"训练时最大角度: {train_deg:.1f}° → 外推后最大角 度: {extrap_deg:.1f}° → 超出 {extrap_deg - train_deg:.1f}°")

训练时最大角度: 31.3° → 外推后最大角度: 62.6° → 超出 31.3°
7. 核心思想:控制角度范围
第 6 节的结论很明确:外推失败不是因为 RoPE 本身有缺陷,而是因为慢维度的旋转角超出了训练范围。快维度没有问题。
解决思路:想办法让位置 4096~8191 在慢维度上产生的角度,也落在 训练时见过的 [0, 训练最大角] 区间内。对快维度尽量不干预——它们已经见过各种角度,不需要额外处理。
7.1 三种压缩策略的直觉
把模型想象成一个只认识 0 到 4095 号门牌的人。现在要让它识别 4096 到 8191 号,有三种策略:
PI(Position Interpolation):把所有门牌号除以 2。新 4096 号被映射回旧 2048 号,新 8191 号被映射回旧 4095 号——全都在认识的范围里。代价是原来区别清楚的 1 号和 2 号,现在变成了 0.5 号和 1 号,区分能力下降。
NTK-aware:只压缩"不太认识"的慢维度,不碰"已经很熟"的快维度。通过一个参数(base 值)的调整,利用频率公式自身的非线性,自动完成差异化压缩。
YaRN:在 NTK 基础上进一步细化。维度不是简单的"快"和"慢"二分——中间那些"不快不慢"的维度需要平滑过渡。此外,压缩后 Attention 的 softmax 分布会变化,需要温度修正来校准。
7.2 从频率公式看三种方法的差异
所有方法归根结底都在改这个公式中的参数:
-
PI:不碰 θ_i,而是把位置 m 替换为 m/scale。等效于所有 θ_i 等比缩小 scale 倍。所有频率同步变慢。
-
NTK-aware:不碰位置 m,而是把 10000 替换为 。因为分母是 ,指数依赖导致:小 i(快维度)分母变化小 → 频率几乎不变;大 i(慢维度)分母变化大 → 频率大幅降低。一个参数实现了差异化。
-
YaRN:在 NTK 改 base 的基础上,对每个维度 i 计算其波长(走完一圈所需的 token 数)。根据 λ_i 和目标长度的关系,将维度分为三段:λ_i 很短的维度不缩放,λ_i 很长的维度缩放到 scale 倍,中间的维度平滑过渡。温度修正则处理 Attention 分布 sharpness 的变化。
它们之间的关系不是并列的三种"方法",而是一条理解逐步加深的递进路径:PI 发现了"压缩"这件事 → NTK 发现了"只需要压缩慢维度" → YaRN 发现了"压缩需要分段平滑 + 温度校准"。
下面的三节逐一展开每种方法的具体做法和代码。
8. 方法一:Position Interpolation
论文: Meta, 2023 — Extending Context Window via Position Interpolation
想法:把位置编号直接等比例压缩。
目标: 把 4096 窗口扩展到 8192
缩放因子 α = 4096 / 8192 = 0.5
新位置 = 真实位置 × 0.5
真实位置 0 → 给模型的位置 = 0 × 0.5 = 0
真实位置 2048 → 给模型的位置 = 2048 × 0.5 = 1024
真实位置 8192 → 给模型的位置 = 8192 × 0.5 = 4096 ← 刚好落在训练边界!
打个比方:你家门牌号是 1 到 100 号,你只认识 1-50 号。现在来了 51-100 号,你把所有号码除以 2——51 号变成 25.5 号,100 号变成 50 号,全都在你认识的范围内。
代价:所有门牌号都被压缩了。原来能清楚区分 1 号和 2 号,现在 1 号和 2 号变成了 0.5 号和 1 号,差别变小 了——近程分辨力下降。
# PI 的实现:位置 × 缩放因子,再正常算 RoPE
import torch
import matplotlib.pyplot as plt
train_len, target_len = 4096, 8192
alpha = train_len / target_len # 0.5
pair_indices = torch.arange(0, 64, 2).float()
freqs_orig = 1.0 / (10000 ** (pair_indices / 64))
freqs_pi = freqs_orig * alpha
# 原始 RoPE vs PI 压缩后的波形
positions_orig = torch.arange(target_len).float()
angles_orig = positions_orig * freqs_orig[31]
positions_pi = torch.arange(target_len).float() * alpha
angles_pi = positions_pi * freqs_orig[31]
fig, axes = plt.subplots(1, 2, figsize=(14, 4))
axes[0].plot(angles_orig.cos().numpy(), linewidth=1, label='Original RoPE')
axes[0].plot(angles_pi.cos().numpy(), linewidth=1, label=f'PI (×{alpha:.2f})')
axes[0].axvline(x=train_len, color='gray', linestyle='--', alpha=0.5)
axes[0].set_xlabel('Position'); axes[0].set_ylabel('cos(angle)')
axes[0].set_title(f'Low-frequency dim #{31} wave\nPI stretches the wave (half frequency)')
axes[0].legend(); axes[0].grid(True, alpha=0.3)
# 右:所有维度频率被等比例压缩
axes[1].plot(freqs_orig.numpy(), 'o-', markersize=3, label='Original frequency')
axes[1].plot(freqs_pi.numpy(), 's-', markersize=3, label='After PI scaling')
axes[1].set_xlabel('Dimension pair index (0=fast, 31=slow)')
axes[1].set_ylabel('Frequency')
axes[1].set_title('PI scales all dimensions equally\nLocal resolution drops, light tuning needed')
axes[1].legend(); axes[1].grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

9. 方法二:NTK-aware
论文: NTK-Aware Scaled RoPE, bloc97, 2023
PI 的毛病:它对所有维度一视同仁地压缩。但从第 5 节我们知道,不同维度转的速度不一样:
- 快的维度(秒针):在 4096 个位置内已经转了好多圈,各种角度都见过了 → 不需要压缩
- 慢的维度(时针):4096 个位置内连一圈都没走完,没见过的角度多 → 需要压缩
所以 NTK-aware 的想法是:只压慢的,不压快的。
怎么做到?把 base 从 10000 改大。 这是 NTK-aware 最巧妙的地方:
频率公式: freq_i = 1 / base^(2i/d)
base = 10000 → 频率快 → 慢维度也走不完一圈
base = 100000 → 频率变慢 → 慢维度走得更慢 → 在同样位置内角度更小 → 不超出训练范围!
而且:
低 i(快维度): freq ≈ 1 → 改 base 几乎不影响 ← 快的不用调
高 i(慢维度): freq ≈ 1/base → 改 base 影响大 ← 慢的调得多
这恰好实现了「快手不调,慢手多调」!一个参数的改动,自动完成了差异化压缩。
# 演示 NTK:改 base 对不同维度的影响
import torch
import matplotlib.pyplot as plt
base_old, scale = 10000, 2
# NTK 公式:新 base = 旧 base × scale^(d/(d-2))
base_new = base_old * (scale ** (64 / 62))
pair_indices = torch.arange(0, 64, 2).float()
freqs_old = 1.0 / (base_old ** (pair_indices / 64))
freqs_new = 1.0 / (base_new ** (pair_indices / 64))
fig, axes = plt.subplots(1, 2, figsize=(14, 4))
axes[0].plot(freqs_old.numpy(), 'o-', markersize=3, label=f'base={base_old}')
axes[0].plot(freqs_new.numpy(), '^-', markersize=3, label=f'base={base_new:.0f}')
axes[0].set_xlabel('Dimension pair index (0=fast, 31=slow)'); axes[0].set_ylabel('Frequency')
axes[0].set_title('NTK-aware: increase base\nFast dims stay similar, slow dims slow down')
axes[0].legend(); axes[0].grid(True, alpha=0.3)
# 每个维度的压缩比
ratio = freqs_new / freqs_old
axes[1].bar(range(len(ratio)), ratio.numpy())
axes[1].set_xlabel('Dimension pair index'); axes[1].set_ylabel('New frequency / old frequency')
axes[1].set_title('Compression ratio by dimension\nFast dims ~100%, slow dims ~50%')
axes[1].grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
# NTK 只需把 base 从 10000 改到 ~86000,不 改模型结构,大多数情况不需微调

10. 方法三:YaRN
论文: YaRN, 2023
NTK 已经很好了,但 YaRN 发现了一个问题:改 base 之后,中间那些维度(不快不慢的)的 attention 会变得「不够果断」。
什么意思?回想 attention 的 softmax 步骤:
softmax([2, 1, 0.5]) → [0.59, 0.22, 0.13, 0.06] ← 比较「尖锐」,注意力集中
softmax([1, 0.5, 0.25]) → [0.42, 0.26, 0.19, 0.14] ← 比较「平滑」,注意力分散
用「温度」可以调节 softmax 的尖锐程度:
温度低 → softmax 更尖锐 → 注意力更集中 → 适合近程信息
温度高 → softmax 更平滑 → 注意力更分散 → 适合远程信息(反正远距离也不需要精确到哪个 token)
YaRN 的做法:NTK 改 base + 对不同维度组做分段缩放/调节。 它是常见的强基线之一,但不是所有模型和任务上的“唯一最优”。后续还有 LongRoPE、LongRoPE2 这类极长上下文方法。
- 快维度:温度 = 1(不调,保持精确)
- 中间维度:温度平滑过渡
- 慢维度:温度稍高(让远程 attention 更平滑)
# YaRN 的分段策略:根据波长把维度分成三组
import torch
import matplotlib.pyplot as plt
import math
scale, target_len = 4, 16384
pair_indices = torch.arange(0, 64, 2).float()
base_new = 10000 * (scale ** (64 / 62))
freqs_new = 1.0 / (base_new ** (pair_indices / 64))
wavelengths = 2 * math.pi / freqs_new # 走完一圈需要的位置数
# 分段阈值
low_bound = target_len / 1.0 # 波长 > 此值 → 低频(需缩放)
high_bound = target_len / 4.0 # 波长 < 此值 → 高频(不缩放)
# ramping: 从 0(不调)平滑过渡到 1(缩放 scale 倍)
smooth = torch.clamp((wavelengths - high_bound) / (low_bound - high_bound), 0.0, 1.0)
dim_scale = (1 - smooth) * 1.0 + smooth * scale
fig, axes = plt.subplots(1, 2, figsize=(14, 4))
axes[0].bar(range(32), dim_scale.numpy())
axes[0].axhline(y=1.0, color='green', linestyle='--', alpha=0.5, label='No scaling')
axes[0].axhline(y=scale, color='red', linestyle='--', alpha=0.5, label=f'Scale {scale}x')
axes[0].set_xlabel('Dimension pair index (0=fast, 31=slow)'); axes[0].set_ylabel('Scale factor')
axes[0].set_title(f'YaRN: keep early pairs, scale later pairs {scale}x\nSmooth transition in the middle')
axes[0].legend()
axes[1].plot(wavelengths.numpy(), 'o-', markersize=3)
axes[1].axhline(y=high_bound, color='green', linestyle='--', alpha=0.5, label=f'High-frequency threshold ({high_bound:.0f})')
axes[1].axhline(y=low_bound, color='red', linestyle='--', alpha=0.5, label=f'Low-frequency threshold ({low_bound:.0f})')
axes[1].set_xlabel('Dimension pair index'); axes[1].set_ylabel('Wavelength (tokens per cycle)')
axes[1].set_yscale('log'); axes[1].set_title('Wavelength by dimension\nShort=high freq (unchanged), long=low freq (scaled)')
axes[1].legend(); axes[1].grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
# YaRN = NTK 改 base + 分段平滑过渡,是常见强基线之一
