LLM 蒸馏
一个好老师不只给标准答案,还会解释思路、在学生犯错时及时纠正。大模型和小模型之间的关系也类似——与其让小模型自己摸索,不如让大模型「教」它。
这一节理解蒸馏的三种方法:Logit-based(学输出分布)、Feature-based(学中间表示)、On-Policy(大模型实时批改),并走通从大模型蒸馏到小模型的完整流程。
在 LLM 上下文中,Teacher 是能力更强的模型(可以是闭源 API、大开源模型或同模型的增强版本),Student 是目标部署模型。传统做法是让 Teacher 写出标准答案,Student 照着学习。
但 Student 在背熟 Teacher 的输出之后,一旦独立生成,质量就会下降。原因在于 Student 的训练数据和它自己生成时的数据分布不同,这个差异称为分布偏移(distribution shift)。
最先出现也最经典的方法是 Logit-based 蒸馏——让 Student 不只学答案,还学 Teacher 对每个词的概率判断。
import numpy as np
np.random.seed(42)
1. 蒸馏的本质
普通 SFT(监督微调):
Teacher 输出: "巴黎"
Student 学习: 输入"法国首都是?" → 输出"巴黎"
问题: 只学了答案,没学推理过程
蒸馏:
Teacher 输出: 每个词的概率分布 [巴黎:0.9, 伦敦:0.05, 柏林:0.03, ...]
Student 学习: 不仅输出"巴黎",还要让整个概率分布接近 Teacher
好处: Student 学到了 Teacher 的「判断力」——知道"巴黎"最可能,"伦敦"也有可能但概率低
为什么概率分布比答案更有价值?
Teacher 说「巴黎 90%,伦敦 5%,柏林 3%」比只说「巴黎」多了两条信息:
- 伦敦和柏林也是合理的(只是不太对)——这叫「暗知识」
- 其他几百个城市概率接近 0——明确告诉 Student 哪些是错的
这就是 Hinton 等人在知识蒸馏中 强调的“软目标/暗知识”直觉。参考:Distilling the Knowledge in a Neural Network。
import numpy as np
# 交互演示:硬标签 vs 软标签,用真实概率分布对比
print("=== 硬标签 vs 软标签 ===")
print()
cities = ["巴黎", "伦敦", "柏林", "罗马", "马德里", "东京", "北京", "悉尼"]
teacher_logits = np.array([5.0, 2.0, 1.0, 0.5, 0.1, -3.0, -4.0, -5.0])
# 硬标签 (SFT): one-hot
hard_labels = np.zeros(len(cities))
hard_labels[0] = 1.0
print("问题: 法国的首都是?")
print()
print("硬标签 (SFT):")
for city, prob in zip(cities[:5], hard_labels[:5]):
bar = "█" * int(prob * 40)
print(f" {city}: {prob:.1%} {bar}")
print(" → Student 只知道「巴黎是对的」")
print()
# 软标签 (蒸馏): 概率分布
temperature = 3.0
scaled_logits = teacher_logits / temperature
soft_labels = np.exp(scaled_logits) / np.exp(scaled_logits).sum()
print("软标签 (蒸馏, T=3):")
for city, prob in zip(cities, soft_labels):
bar = "█" * int(prob * 40)
print(f" {city}: {prob:.1%} {bar}")
print(" → Student 学到了:")
print(" 1. 巴黎最对")
print(" 2. 伦敦、柏林也是欧洲首都(相似性知识)")
print(" 3. 东京、北京概率≈0(完全不相关)")
print()
# 量化信息量差异
hard_entropy = -np.sum(hard_labels * np.log(hard_labels + 1e-10))
soft_entropy = -np.sum(soft_labels * np.log(soft_labels + 1e-10))
print(f"硬标签信息熵: {hard_entropy:.2f} bits")
print(f"软标签信息熵: {soft_entropy:.2f} bits")
print(f"→ 软标签包含约 {soft_entropy:.1f} bits 信息,远多于硬标签的 {hard_entropy:.2f} bits!")
2. 方法一:Logit 蒸馏(最经典)
让 Student 的输出概率分布逼近 Teacher 的输出概率分布。
Loss 公式:
其中:
- :Student 和正确答案的交叉熵(保证基本正确)
- :Student 和 Teacher 概率分布的 KL 散度(学习暗知识)
- :温度参数,越大 Teacher 的分布越「软」(暗知识越明显)
- :平衡两个 loss 的权重
温度 T 的作用:
T=1: [0.90, 0.05, 0.03, 0.02] ← 很尖锐,暗知识不明显
T=5: [0.40, 0.25, 0.20, 0.15] ← 软化了,暗知识浮现
T=20: [0.28, 0.26, 0.24, 0.22] ← 太软了,变成均匀分布
T 太大 → 所有词概率差不多 → 没信息量 T 太小 → 和硬标签没区别 → 没暗知识 T=3~10 是常见实验起点,不是固定规则;不同任务、模型和 loss 权重要通过验证集调整。
import numpy as np
# 演示温度对概率分布的影响
print("=== 温度 T 对软标签的影响 ===")
print()
logits = np.array([5.0, 2.0, 1.0, 0.5, 0.1, 0.01, 0.001, 0.0001])
labels = ["巴黎", "伦敦", "柏林", "罗马", "马德里", "维也纳", "布拉格", "华沙"]
for T in [1, 3, 10, 20]:
scaled = logits / T
probs = np.exp(scaled) / np.exp(scaled).sum()
print(f"T={T:2d}: ", end="")
for i in range(5):
bar = "█" * int(probs[i] * 50)
print(f"{labels[i]}:{probs[i]:.3f} {bar} ", end="")
print()
print()
print("T=1: 几乎只有巴黎 → 暗知识被掩盖")
print("T=3: 伦敦、柏林也有一定概率 → 暗知识浮现")
print("T=10: 分布更均匀 → 暗知识丰富但信号变弱")
print("T=20: 几乎均匀 → 信息量太少")