LLM 蒸馏
一个好老师不只给标准答案,还会解释思路、在学生犯错时及时纠正。大模型和小模型之间的关系也类似——与其让小模型自己摸索,不如让大模型「教」它。
这一节理解蒸馏的三种方法:Logit-based(学输出分布)、Feature-based(学中间表示)、On-Policy(大模型实时批改),并走通从大模型蒸馏到小模型的完整流程。
在 LLM 上下文中,Teacher 是能力更强的模型(可以是闭源 API、大开源模型或同模型的增强版本),Student 是目标部署模型。传统做法是让 Teacher 写出标准答案,Student 照着学习。
但 Student 在背熟 Teacher 的输出之后,一旦独立生成,质量就会下降。原因在于 Student 的训练数据和它自己生成时的数据分布不同,这个差异称为分布偏移(distribution shift)。
最先出现也最经典的方法是 Logit-based 蒸馏——让 Student 不只学答案,还学 Teacher 对每个词的概率判断。
import numpy as np
np.random.seed(42)
1. 蒸馏的本质
普通 SFT(监督微调):
Teacher 输出: "巴 黎"
Student 学习: 输入"法国首都是?" → 输出"巴黎"
问题: 只学了答案,没学推理过程
蒸馏:
Teacher 输出: 每个词的概率分布 [巴黎:0.9, 伦敦:0.05, 柏林:0.03, ...]
Student 学习: 不仅输出"巴黎",还要让整个概率分布接近 Teacher
好处: Student 学到了 Teacher 的「判断力」——知道"巴黎"最可能,"伦敦"也有可能但概率低
为什么概率分布比答案更有价值?
Teacher 说「巴黎 90%,伦敦 5%,柏林 3%」比只说「巴黎」多了两条信息:
- 伦敦和柏林也是合理的(只是不太对)——这叫「暗知识」
- 其他几百个城市概率接近 0——明确告诉 Student 哪些是错的
这就是 Hinton 等人在知识蒸馏中强调的“软目标/暗知识”直觉。参考:Distilling the Knowledge in a Neural Network。
import numpy as np
# 交互演示:硬标签 vs 软标签,用真实概率分布对比
print("=== 硬标签 vs 软标签 ===")
print()
cities = ["巴黎", "伦敦", "柏林", "罗马", "马德里", "东京", "北京", "悉尼"]
teacher_logits = np.array([5.0, 2.0, 1.0, 0.5, 0.1, -3.0, -4.0, -5.0])
# 硬标签 (SFT): one-hot
hard_labels = np.zeros(len(cities))
hard_labels[0] = 1.0
print("问题: 法国的首都是?")
print()
print("硬标签 (SFT):")
for city, prob in zip(cities[:5], hard_labels[:5]):
bar = "█" * int(prob * 40)
print(f" {city}: {prob:.1%} {bar}")
print(" → Student 只知道「巴黎是对的」")
print()
# 软标签 (蒸馏): 概率分布
temperature = 3.0
scaled_logits = teacher_logits / temperature
soft_labels = np.exp(scaled_logits) / np.exp(scaled_logits).sum()
print("软标签 (蒸馏, T=3):")
for city, prob in zip(cities, soft_labels):
bar = "█" * int(prob * 40)
print(f" {city}: {prob:.1%} {bar}")
print(" → Student 学到了:")
print(" 1. 巴黎最对")
print(" 2. 伦敦、柏林也是欧洲首都(相似性知识)")
print(" 3. 东京、北京概率≈0(完全不相关)")
print()
# 量化信息量差异
hard_entropy = -np.sum(hard_labels * np.log(hard_labels + 1e-10))
soft_entropy = -np.sum(soft_labels * np.log(soft_labels + 1e-10))
print(f"硬标签信息熵: {hard_entropy:.2f} bits")
print(f"软标签信息熵: {soft_entropy:.2f} bits")
print(f"→ 软标签包含约 {soft_entropy:.1f} bits 信息,远多于硬标签的 {hard_entropy:.2f} bits!")
=== 硬标签 vs 软标签 ===
问题: 法国的首都是?
硬标签 (SFT):
巴黎: 100.0% ████████████████████████████████████████
伦敦: 0.0%
柏林: 0.0%
罗马: 0.0%
马德里: 0.0%
→ Student 只知道「巴黎是对的」
软标签 (蒸馏, T=3):
巴黎: 45.4% ██████████████████
伦敦: 16.7% ██████
柏林: 12.0% ████
罗马: 10.1% ████
马德里: 8.9% ███
东京: 3.2% █
北京: 2.3%
悉尼: 1.6%
→ Student 学到了:
1. 巴黎最对
2. 伦敦、柏林也是欧洲首都(相似性知识)
3. 东京、北京概率≈0(完全不相关)
硬标签信息熵: -0.00 bits
软标签信息熵: 1.62 bits
→ 软标签包含约 1.6 bits 信息,远多于硬标签的 -0.00 bits!
2. 方法一:Logit 蒸馏(最经典)
让 Student 的输出概率分布逼近 Teacher 的输出概率分布。
Loss 公式:
其中:
- :Student 和正确答案的交叉熵(保证基本正确)
- :Student 和 Teacher 概率分布的 KL 散度(学习暗知识)
- :温度参数,越大 Teacher 的分布越「软」(暗知识越明显)
- :平衡两个 loss 的权重
温度 T 的作用:
T=1: [0.90, 0.05, 0.03, 0.02] ← 很尖锐,暗知识不明显
T=5: [0.40, 0.25, 0.20, 0.15] ← 软化了,暗知识浮现
T=20: [0.28, 0.26, 0.24, 0.22] ← 太软了,变成均匀分布
T 太大 → 所有词概率差不多 → 没信息量 T 太小 → 和硬标签没区别 → 没暗知识 T=3~10 是常见实验起点 ,不是固定规则;不同任务、模型和 loss 权重要通过验证集调整。
import numpy as np
# 演示温度对概率分布的影响
print("=== 温度 T 对软标签的影响 ===")
print()
logits = np.array([5.0, 2.0, 1.0, 0.5, 0.1, 0.01, 0.001, 0.0001])
labels = ["巴黎", "伦敦", "柏林", "罗马", "马德里", "维也纳", "布拉格", "华沙"]
for T in [1, 3, 10, 20]:
scaled = logits / T
probs = np.exp(scaled) / np.exp(scaled).sum()
print(f"T={T:2d}: ", end="")
for i in range(5):
bar = "█" * int(probs[i] * 50)
print(f"{labels[i]}:{probs[i]:.3f} {bar} ", end="")
print()
print()
print("T=1: 几乎只有巴黎 → 暗知识被掩盖")
print("T=3: 伦敦、柏林也有一定概率 → 暗知识浮现")
print("T=10: 分布更均匀 → 暗知识丰富但信号变弱")
print("T=20: 几乎均匀 → 信息量太少")
=== 温度 T 对软标签的影响 ===
T= 1: 巴黎:0.903 █████████████████████████████████████████████ 伦敦:0.045 ██ 柏林:0.017 罗马:0.010 马德里:0.007
T= 3: 巴黎:0.382 ███████████████████ 伦敦:0.141 ███████ 柏林:0.101 █████ 罗马:0.085 ████ 马德里:0.075 ███
T=10: 巴黎:0.182 █████████ 伦敦:0.135 ██████ 柏林:0.122 ██████ 罗马:0.116 █████ 马德里:0.112 █████
T=20: 巴黎:0.152 ███████ 伦敦:0.130 ██████ 柏林:0.124 ██████ 罗马:0.121 ██████ 马德里:0.119 █████
T=1: 几乎只有巴黎 → 暗知识被掩盖
T=3: 伦敦、柏林也有一定概率 → 暗知识浮现
T=10: 分布更均匀 → 暗知识丰富但信号变弱
T=20: 几乎均匀 → 信息量太少
3. 方法二:数据蒸馏(最容易落地)
直接做 full-vocab logit 蒸馏时,Student 和 Teacher 最好共享 tokenizer 或能建立可靠的跨 tokenizer 对齐;闭源 API 通常拿不到完整 logits,所以 LLM 场景里更常见的是数据蒸馏或只用 top-k / sampled-token 近似信号。
数据蒸馏绕过了这个问题:让 Teacher 生成训练数据,Student 在这些数据上做 SFT。
Step 1: 收集 prompts(从你的业务场景中)
["写一首关于春天的诗", "解释量子力学", "翻译: Hello → 中文", ...]
Step 2: Teacher 为每个 prompt 生成高质量回答
GPT-4: "春天来了,万物复苏..."
GPT-4: "量子力学是研究微观粒子..."
Step 3: 用 (prompt, teacher_answer) 对训练 Student
Student 做标准 SFT,学习模仿 Teacher 的输出风格和质量
优点:不要求词表相同,任何 Teacher 可以教任何 Student。 缺点:只学到了「答案长什么样」,没学到「概率分布中的暗知识」。
数据蒸馏的进阶技巧:
- 多轮对话蒸馏:Teacher 生成多轮对话,Student 学会对话节奏
- CoT 蒸馏:Teacher 生成带推理过程的答案,Student 学会推理
- 拒绝采样:Teacher 生成多个答案,只保留最好的给 Student 学
# 模拟数据蒸馏流程
print("=== 数据蒸馏流程模拟 ===")
print()
prompts = [
"解释什么是机器学习",
"写一首关于秋天的五言诗",
"Python 中 list 和 tuple 的区别",
]
# 模拟 Teacher (GPT-4) 生成
teacher_responses = [
"机器学习是人工智能的一 个分支,让计算机从数据中学习模式,而不需要显式编程。",
"秋风扫落叶,霜降百花残。独坐寒窗下,思君衣可单。",
"list 是可变的(可以增删改),tuple 是不可变的(创建后不能修改)。list 用 [],tuple 用 ()。",
]
print("生成训练数据:")
for i, (prompt, response) in enumerate(zip(prompts, teacher_responses)):
print(f"\n--- 样本 {i+1} ---")
print(f"User: {prompt}")
print(f"Assistant: {response}")
print()
print(f"共生成 {len(prompts)} 条训练数据")
print("Student 在这些数据上做 SFT,学习模仿 Teacher 的风格。")
print()
print("实际项目需要多少数据取决于任务宽度、teacher 质量和 student 基座能力;几千条可以跑通概念,几万到更多样本才更可能形成稳定效果。")
=== 数据蒸馏流程模拟 ===
生成训练数据:
--- 样本 1 ---
User: 解释什么是机器学习
Assistant: 机器学习是人工智能的一个分支,让计算机从数据中学习模式,而不需要显式编程。
--- 样本 2 ---
User: 写一首关于秋天的五言诗
Assistant: 秋风扫落叶,霜降百花残。独坐寒窗下,思君衣可单。
--- 样本 3 ---
User: Python 中 list 和 tuple 的区别
Assistant: list 是可变的(可以增删改),tuple 是不可变的(创建后不能修改)。list 用 [],tuple 用 ()。
共生成 3 条训练数据
Student 在这些数据上做 SFT,学习模仿 Teacher 的风格。
实际项目需要多少数据取决于任务宽度、teacher 质量和 student 基座能力;几千条可以跑通概念,几万到更多样本才更可能形成稳定效果。
4. 方法三:特征蒸馏(进阶)
不仅学输出分布,还学中间层的表示。
Teacher (GPT-4, 96层):
Layer 1 → Layer 2 → ... → Layer 48 → ... → Layer 96 → Output
↑
Student (7B, 32层): | 让 Student 第 16 层的输出
Layer 1 → Layer 2 → ... → Layer 16 → ... → Layer 32 → Output 逼近 Teacher 第 48 层
为什么有效? 中间层包含了「怎么理解这句话」的信息,比最终输出更丰富。
为什么少用?
- 需要访问 Teacher 的内部表示(闭源模型不行)
- Teacher 和 Student 的维度不同,需要投影矩阵对齐
- 计算量大,显存消耗高
在闭源 teacher + 开源 student 的 LLM 场景里,数据蒸馏最容易落地;特征蒸馏更多用于能访问 teacher hidden states 的白盒设置,视觉模型和小型语言模型 中都能看到类似思路。
5. 实战:蒸馏 7B 模型
前面讲了三种蒸馏方法的原理,现在把它们串成一次完整的蒸馏流程。整个过程分为四步:
- 准备训练数据:收集一批高质量的 prompt,覆盖目标领域——数学推理、代码生成、或者通用对话
- Teacher 生成:用强 teacher 模型为每个 prompt 生成回答;具体模型和价格会随时间变化,保存为 (prompt, teacher_answer) 对
- Student 训练:如果能拿到 teacher logits,可用 KL 做 logit 蒸馏;如果只能拿到文本回答,就用 SFT 做数据蒸馏
- 评估对比:用评测 benchmark 对比 Student 在蒸馏前后的分数变化
下面每一步都给出可执行的代码。即使没有真正的 GPT-4 API key,也可以用本地的 MiniGPT 来模拟 Teacher 和 Student 的角色,完整跑通流程。
print("=== 实战:GPT-4 → 7B 蒸馏流程 ===")
print()
steps = [
("Step 1: 选基座模型", [
"推荐: Qwen2.5-7B / Llama-3-8B / Mistral-7B",
"要求: 基座模型本身有一定能力(不能太差)",
"选 Instruct 版本(已经会遵循指令)",
]),
("Step 2: 收集 prompts", [
"来源 1: 你的业务数据(用户真实问题)",
"来源 2: 开源数据集(OpenHermes, ShareGPT, WildChat)",
"来源 3: 自建——用另一个 LLM 生成多样化 prompt",
"数量: 几千条可跑通概念;生产效果通常需要更多高质量、多样化样本,并通过验证集决定是否继续扩充",
]),
("Step 3: Teacher 生成回答", [
"用 GPT-4 API 为每个 prompt 生成回答",
"system prompt: '你是一个有帮助的助手,请详细、准确地回答。'",
"temperature=0.7(保留一定多样性)",
"成本估算要按当日 API 价格、输入/输出 token、重试率和过滤率计算;这里只能作为估算方法,不写固定美元数",
]),
("Step 4: 数据清洗", [
"去掉太短的回答(<20 token)",
"去掉包含 '作为 AI' 等拒绝回答的",
"去掉格式错乱的",
"去重(相似度 > 0.9 的只保留一条)",
]),
("Step 5: SFT 训练", [
"