混合专家模型(MoE)
到目前为止的 GPT,每一层只有一个 FFN,所有 token 共享。参数和计算是绑在一起的——参数翻倍,每个 token 的计算也翻倍。
这一节介绍 MoE。它把一个大 FFN 拆成多个专家,每个 token 只激活少数几个。总参数可以很大,计算量不怎么变。我们从零写一个 MoE 层,重点看路由器怎么选专家、负载怎么均衡。
MoE 的思路:一个 FFN 拆成多个小的专家 FFN,前面加一个路由器决定每个 token 走哪几个。比如 8 个专家,每次只走 2 个——总参数 8 倍,每个 token 的计算 2 倍。参数多,能装的知识多;计算变化不大,推理速度不受影响。
训一段时间后,路由通常会出现某种结构化分工:Mixtral 论文观察到相邻 token 往往分到 相同专家,也有句法相关的路由模式;DeepSeekMoE 则用“细粒度专家 + 共享专家”推动更专门化的分工。但要注意,这不等于每个专家都能被简单命名成“语法专家”或“数字专家”。一些解释性研究提醒,专家更可能学到细粒度语言操作、局部语义模式,或者跨领域都常用的核心能力。参考:Mixtral、DeepSeekMoE、MoE Interpretability、Core Experts。
MoE 还有一个麻烦:路由器可能偷懒,把大部分 token 都发给少数几个专家,其余专家闲着。怎么让负载均衡,是 MoE 训练最核心的问题。
1. 普通 Transformer 的 FFN 层
Dense 模型和 MoE 模型的根本区别在于参数和计算量的关系:
Dense 模型:
所有 token → 同一个 FFN → 输出
参数量 = 每次推理的计算量(全部参数都参与计算)
想让模型更强 → 扩大 FFN → 计算量同步增长
MoE 模型:
每个 token → 路由器选 top-k 个专家 → 加权输出
总参数量 = N × 单个专家参数量(可以很大)
每次计算量 = k × 单个专家参数量(和 Dense 差不多)
参数和计算量不再同步增长
Dense 模型中,想让模型拥有更大的知识容量,就要把 FFN 的维度成倍扩大。但 FFN 扩大后,每个 token 都要经过这个更大的 FFN,推理成本同步增长。MoE 的思路是把知识分散到多个专家里,每个 token 只激活少数几个——总参数可以很大,但每次推理的计算量基本不变。
下面先回顾标准 FFN 的结构,再在这个基础上改造成 MoE。
每个 Transformer Block 里有一个 FFN(前馈网络),结构很简单——两个线性变换,中间夹一个激活函数:
输入 x (d_model=512)
↓
Linear(512 → 2048) ← 升维,给模型更大的「思考空间」
↓
ReLU / GELU ← 非线性,让模型能学复杂模式
↓
Linear(2048 → 512) ← 降维,回到原来的维度
↓
输出
参数量:512 × 2048 + 2048 × 512 ≈ 2M(两个矩阵,各约 1M)。
这个 FFN 是 Transformer 能力的核心来源之一。Attention 负责「从哪些位置获取信息」,FFN 负责「对获取的信息做什么处理」。Attention 告诉你「这个词和哪个词相关」,FFN 告诉你「知道了相关性之后,该把这个词变成什么」。
但现在的设计有一个隐含假设:所有 token 共享同一个 FFN。不管输入是「的」还是「量子力学」,都经过同两个矩阵,做同样的变换。这在小模型里没问题,但模型变大后,用一个 FFN 同时处理所有类型的知识就越来越难——就像用同一套规则处理语法问题和数学问题,规则本身会变得臃肿。
1.1 为什么 MoE 通常替换 FFN,而不是替换 Attention
你可能会问:既然 Transformer Block 里有 Attention 和 FFN 两个大部件,为什么 MoE 通常拿 FFN 开刀,而不是把 Attention 拆成专家?
原因有三个。
第一,FFN 是逐 token 计算的。每个 token 过 FFN 时,本来就像一个独立样本:输入一个向量,输出一个向量。路由器很容易对每个 token 单独决定「找哪几个专家」。
第二,FFN 通常参数很多。Dense FFN 的两层或三层大矩阵占了 Block 里相当多的参数。把它拆成多个专家,能显著增加模型容量;而每个 token 只走 top-k 个专家,计算不会按总专家数线性爆炸。
第三,Attention 负责 token 之间的信息交换。如果把 Attention 也动态路由,系统会更复杂:不同 token 不只要选专家,还要互相看,通信和缓存都会变麻烦。MoE 先替换 FFN,是收益大、改动相对清楚的一步。
所以你可以把 MoE 理解成:保留 Attention 这条信息交换通道,把后面的 FFN 加工车间扩建成多个专家车间。
2. MoE 的核心思想
既然一个 FFN 处理所有 token 负担太重,那换一个思路:把一个大 FFN 拆成 N 个小的专家 FFN,让每个 token 只找其中少数几个专家来处理。
┌─────────────┐
│ 路由器 │ ← 决定每个 token 找哪几个专家
│ (Gate) │
└──┬──┬──┬───┘
│ │ │
┌───────┘ │ └───────┐
↓ ↓ ↓
┌────────┐┌────────┐┌────────┐
│ 专家 1 ││ 专家 2 ││ 专家 3 │ ... (共 8 个)
│ (FFN) ││ (FFN) ││ (FFN) │
└────────┘└────────┘└────────┘
↓ ↓ ↓
└──────────┴──────────┘
加权求和
路由器的输入是 token 的 hidden state(一个 d_model 维的向量),输出是 N 个分数(每个专家一个分)。分数高的专家被选中。
每个 token 只激活 top-k 个专家(通常 k=2),不是全部 8 个:
token "的" → 路由器 → 专家 1, 5 (功能词,通用专家处理)
token "量子" → 路由器 → 专家 3, 7 (物理类知识)
token "hello" → 路由器 → 专家 2, 6 (英文类知识)
效果:
- 总参数量 = N × 一个专家的参数量(8 倍增长)
- 每次推理的计算量 = k × 一个专家的参数量(2 倍增长)
- 参数多但计算少 ← 参数和计算不再同步增长
路由器本身也是一个可训练的参数矩阵——nn.Linear(d_model, num_experts)。它和专家 FFN 一起训练,通过梯度学习如何为每个 token 选择最合适的专家。
import torch
import torch.nn as nn
import torch.nn.functional as F
class MoELayer(nn.Module):
"""
MoE FFN 层
参数:
d_model: 隐藏维度
num_experts: 专家数量
top_k: 每个 token 激活几个专家
"""
def __init__(self, d_model, num_experts=8, top_k=2, expert_dim=None):
super().__init__()
self.num_experts = num_experts
self.top_k = top_k
expert_dim = expert_dim or 4 * d_model
# 路由器:输入 d_model,输出 num_experts 个分数
self.gate = nn.Linear(d_model, num_experts, bias=False)
# N 个专家,每个是一个小 FFN
self.experts = nn.ModuleList([
nn.Sequential(
nn.Linear(d_model, expert_dim),
nn.ReLU(),
nn.Linear(expert_dim, d_model)
)
for _ in range(num_experts)
])
def forward(self, x):
"""
x: [batch, seq_len, d_model]
流程:
1. 路由器给每个 token 打分
2. 选 top-k 个专家
3. 只算这 k 个专家的输出
4. 加权求和
"""
batch_size, seq_len, d_model = x.shape
# Step 1: 路由器打分
gate_logits = self.gate(x) # [batch, seq_len, num_experts]
# Step 2: 选 top-k
top_k_logits, top_k_indices = torch.topk(gate_logits, self.top_k, dim=-1)
top_k_weights = F.softmax(top_k_logits, dim=-1) # 归一化权重
# Step 3 & 4: 对每个 token,算选中专家的输出并加权求和
output = torch.zeros_like(x)
for b in range(batch_size):
for s in range(seq_len):
token = x[b, s] # [d_model]
for k in range(self.top_k):
expert_idx = top_k_indices[b, s, k].item()
weight = top_k_weights[b, s, k]
expert_out = self.experts[expert_idx](token.unsqueeze(0)).squeeze(0)
output[b, s] += weight * expert_out
return output
print("MoE 层定义完成!")
print(f"8 个专家,每个 token 只激活 2 个")
# 演示 MoE 的路由过程
import torch
import torch.nn.functional as F
torch.manual_seed(42)
moe = MoELayer(d_model=16, num_experts=8, top_k=2)
# 模拟 3 个 token
x = torch.randn(1, 3, 16)
# 看路由器给每个 token 的评分
with torch.no_grad():
gate_scores = moe.gate(x).squeeze(0) # [3, 8]
top_k_vals, top_k_idx = torch.topk(gate_scores, 2, dim=-1)
print("=== 路由器为 3 个 token 选择的专家 ===")
print()
for i in range(3):
experts = top_k_idx[i].tolist()
weights = F.softmax(top_k_vals[i], dim=-1).tolist()
print(f"Token {i}: 选中专家 {experts}, 权重 {[f'{w:.2f}' for w in weights]}")
print()
print("每个 token 只激活 2/8 = 25% 的专家")
print("总参数是 8 个专家的和,但计算量只有 2 个专家的量")
2.1 对比:普通 Transformer Block vs MoE Block
要看清 MoE 的结构,最好的办法不是只看公式,而是把模型结构直接打印出来。
普通 decoder block 的主线是:
x → Attention → FFN → output
MoE decoder block 的主线是:
x → Attention → Router → 多个 FFN experts 里选 top-k → output
也就是说,MoE 通常不是替换 Attention,而 是把 Transformer Block 里的 FFN 换成“路由器 + 多个专家 FFN”。
# 用真实 nn.Module 打印:Dense FFN Block vs MoE Block
import torch.nn as nn
class TinyDenseBlock(nn.Module):
"""普通 Transformer decoder block 的骨架:Attention + 单个 FFN"""
def __init__(self, d_model=16, num_heads=2, ffn_dim=64):
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, num_heads, batch_first=True)
self.norm1 = nn.LayerNorm(d_model)
self.ffn = nn.Sequential(
nn.Linear(d_model, ffn_dim),
nn.ReLU(),
nn.Linear(ffn_dim, d_model),
)
self.norm2 = nn.LayerNorm(d_model)
def forward(self, x):
attn_out, _ = self.self_attn(x, x, x, need_weights=False)
x = self.norm1(x + attn_out)
ffn_out = self.ffn(x)
x = self.norm2(x + ffn_out)
return x
class TinyMoEBlock(nn.Module):
"""MoE decoder block 的骨架:Attention + Router + 多个 FFN experts"""
def __init__(self, d_model=16, num_heads=2, num_experts=4, top_k=2, expert_dim=64):
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, num_heads, batch_first=True)
self.norm1 = nn.LayerNorm(d_model)
self.moe = MoELayer(d_model, num_experts=num_experts, top_k=top_k, expert_dim=expert_dim)
self.norm2 = nn.LayerNorm(d_model)
def forward(self, x):
attn_out, _ = self.self_attn(x, x, x, need_weights=False)
x = self.norm1(x + attn_out)
moe_out = self.moe(x)
x = self.norm2(x + moe_out)
return x
dense_block = TinyDenseBlock()
moe_block = TinyMoEBlock(num_experts=4, top_k=2)
print("=== 普通 Transformer Block ===")
print(dense_block)
print()
print("=== MoE Transformer Block ===")
print(moe_block)
上面的打印结果要看两个位置:
- 普通 block 里只有一个
ffn。 - MoE block 里
self_attn后面接的是moe,里面有gate和多个experts。
这就是 MoE 的核心结构证据:Attention 还在,FFN 变成了多个专家。
# trace 一次 forward:看 shape 和路由发生在哪里
import torch
trace_x = torch.randn(1, 3, 16)
print("=== Dense Block Trace ===")
with torch.no_grad():
dense_attn_out, _ = dense_block.self_attn(trace_x, trace_x, trace_x, need_weights=False)
dense_after_attn = dense_block.norm1(trace_x + dense_attn_out)
dense_ffn_out = dense_block.ffn(dense_after_attn)
dense_out = dense_block.norm2(dense_after_attn + dense_ffn_out)
print(f"input: {tuple(trace_x.shape)}")
print(f"attention: {tuple(dense_attn_out.shape)}")
print(f"single FFN: {tuple(dense_ffn_out.shape)}")
print(f"output: {tuple(dense_out.shape)}")
print()
print("=== MoE Block Trace ===")
with torch.no_grad():
moe_attn_out, _ = moe_block.self_attn(trace_x, trace_x, trace_x, need_weights=False)
moe_after_attn = moe_block.norm1(trace_x + moe_attn_out)
gate_logits = moe_block.moe.gate(moe_after_attn)
top_vals, top_idx = torch.topk(gate_logits, moe_block.moe.top_k, dim=-1)
moe_out_raw = moe_block.moe(moe_after_attn)
moe_out = moe_block.norm2(moe_after_attn + moe_out_raw)
print(f"input: {tuple(trace_x.shape)}")
print(f"attention: {tuple(moe_attn_out.shape)}")
print(f"gate logits: {tuple(gate_logits.shape)} # 每个 token 对每个 expert 的分数")
print(f"top-k index: {tuple(top_idx.shape)} # 每个 token 选中的 expert 编号")
print(f"MoE FFN: {tuple(moe_out_raw.shape)}")
print(f"output: {tuple(moe_out.shape)}")
print()
print("每个 token 选中的 experts:")
for token_i, experts in enumerate(top_idx[0].tolist()):
print(f"token {token_i}: experts {experts}")
2.2 打印 HuggingFace 里的真实 MoE 模型
小模型的骨架看懂之后,再看真实工程里的 decoder layer。 这里不下载权重,只用 config 初始化一个很小的 Qwen2-MoE / Mixtral,目的只有一个: 把真实源码里的模块顺序打印出来。
读打印结果时盯住三件事:
self_attn仍然在前面。- 普通 FFN 的位置,变成了
mlp/SparseMoeBlock。 - MoE 里面有
gate和多个experts。
新版 transformers 会把多个 expert 的权重打包成大 tensor,打印时不一定显示成
expert_0, expert_1, ...。所以除了打印 layer,还要打印参数形状:
experts.gate_up_proj 的第 0 维就是 expert 数量。
# 打印真实 HuggingFace MoE decoder layer,并 trace 一 次 router 输出
import inspect
import warnings
import torch
warnings.filterwarnings("ignore", message="IProgress not found.*")
from transformers import Qwen2MoeConfig, Qwen2MoeForCausalLM
from transformers import MixtralConfig, MixtralForCausalLM
qwen_cfg = Qwen2MoeConfig(
vocab_size=128,
hidden_size=32,
intermediate_size=64,
moe_intermediate_size=64,
shared_expert_intermediate_size=64,
num_hidden_layers=1,
num_attention_heads=4,
num_key_value_heads=4,
num_experts=4,
num_experts_per_tok=2,
)
qwen_moe = Qwen2MoeForCausalLM(qwen_cfg)
qwen_layer = qwen_moe.model.layers[0]
print("=== Qwen2-MoE decoder layer ===")
print(qwen_layer)
print()
print("=== Qwen2-MoE MoE 参数形状 ===")
for name, param in qwen_layer.mlp.named_parameters():
if name.startswith("gate") or name.startswith("experts") or name.startswith("shared"):
print(f"{name:32s} {tuple(param.shape)}")
mixtral_cfg = MixtralConfig(
vocab_size=128,
hidden_size=32,
intermediate_size=64,
num_hidden_layers=1,
num_attention_heads=4,
num_key_value_heads=4,
num_local_experts=4,
num_experts_per_tok=2,
)
mixtral_moe = MixtralForCausalLM(mixtral_cfg)
mixtral_layer = mixtral_moe.model.layers[0]
print()
print("=== Mixtral decoder layer ===")
print(mixtral_layer)
print()
print("=== Mixtral MoE 参数形状 ===")
for name, param in mixtral_layer.mlp.named_parameters():
if name.startswith("gate") or name.startswith("experts"):
print(f"{name:32s} {tuple(param.shape)}")
print()
print("=== Qwen2-MoE layer.forward 源码摘录 ===")
source = inspect.getsource(type(qwen_layer).forward).splitlines()
for line in source:
if "self_attn" in line or "mlp" in line or "layernorm" in line:
print(line)
print()
print("=== Qwen2-MoE mlp.forward 源码摘录 ===")
source = inspect.getsource(type(qwen_layer.mlp).forward).splitlines()
for line in source:
if "gate" in line or "expert" in line or "router" in line or "top" in line:
print(line)
input_ids = torch.randint(0, 128, (1, 6))
with torch.no_grad():
qwen_out = qwen_moe(input_ids=input_ids, output_router_logits=True)
router_logits = qwen_out.router_logits[0]
top_vals, top_idx = torch.topk(router_logits, qwen_cfg.num_experts_per_tok, dim=-1)
print()
print("=== Qwen2-MoE router trace ===")
print(f"input_ids: {tuple(input_ids.shape)}")
print(f"logits: {tuple(qwen_out.logits.shape)}")
print(f"router logits: {tuple(router_logits.shape)}")
print(f"top-k experts: {tuple(top_idx.shape)}")
print()
print("前 6 个 token 选中的 experts:")
for token_i, experts in enumerate(top_idx[:6].tolist()):
print(f"token {token_i}: experts {experts}")
3. MoE 的参数 vs 计算量
这是 MoE 最核心的优势。用具体数字感受一下:
假设一个 Dense 模型的 FFN 有 2M 参数(d_model=512,d_ff=2048):
普通 Dense 模型:
FFN 参数: 2M
每个 token 的计算: 2M 次参数运算
参数和计算绑定 → 参数翻倍,计算也翻倍
MoE 模型 (8 专家, top-2):
FFN 总参数: 8 × 2M = 16M ← 参数翻了 8 倍
每个 token 的计算: 2 × 2M = 4M ← 计算只翻了 2 倍
参数和计算解耦
一个 token 在 MoE 中的实际计算路径是这样的:
- 经过路由器:
W_gate @ x,这是一个很小的矩阵乘法(d_model × num_experts) - 经过 top-k 个专家的 FFN:每个专家内部是两次矩阵乘法(升维 + 降维)
- 加权求和:把 k 个专家的输出按路由权重加起来
路由器的计算量远小于 FFN(d_model × num_experts << d_model × d_ff),所以忽略不计。主要计算就是 k 个专家的 FFN 计算。
这就是为什么 Mixtral 8×7B 虽然总参数 47B,但推理速度和 7B 的 Dense 模型差不多——每次推理只激活约 13B 参数(2 个专家的 FFN + 共享的 Attention 层)。
4. MoE 的训练难题:负载均衡
MoE 有一个内生的工程问题:路由器可能偷懒,只把 token 发给少数几个专家。
为什么会这样?路由器的训练目标只有一个——让模型预测下一个 token 的 loss 尽可能低。如果路由器发现「把所有 token 都给专家 3 和专家 5 就能让 loss 很低」,它就没有动力去用其他专家。
坏情况(负载不均):
专家 1: ████████████████████ (被过度使用)
专家 2: ████
专家 3: █
专家 4-8: (几乎空闲,参数白训了)
好情况(负载均衡):
专家 1: ██████
专家 2: ██████
专家 3: ██████
...
专家 8: ██████
负载不均的后果很严重:被过度使用的专家成为瓶颈(计算慢),空闲专家的参数没学到东西(浪费容量),最终模型退化成「只有 2-3 个有效专家的 Dense 模型」。
传统解决方案:辅助 loss
在语言模型的主 loss 之外,加一个额外的 loss 项,鼓励每个专家处理大致相同数量的 token。
# 负载均衡 loss(简化版)
load_balance_loss = 0
for expert_i in range(num_experts):
actual_load = count_tokens_assigned_to(expert_i)
ideal_load = total_tokens * top_k / num_experts
load_balance_loss += (actual_load - ideal_load) ** 2
total_loss = lm_loss + alpha * load_balance_loss
系数 α 需要手动调节。α 太大会让路由器被迫选不合适的专家(影响模型质量),α 太小则负载均衡无效。
改进:无辅助 Loss 的负载均衡(DeepSeek-V3)
辅助 loss 方案虽然有效,但有一个内在矛盾:辅助 loss 和语言模型的主 loss 是竞争关系。辅助 loss 鼓励路由均匀分布,主 loss 鼓励路由把 token 发给最合适的专家。两者之间的平衡需要仔细调节系数 α——α 太大会干扰模型收敛,α 太小则负载不均。
DeepSeek-V3 提出了一种更直接的方案:不给路由加 loss,而是直接调整路由的偏置。给每个专家维护一个偏置项 ,加在路由 logits 上——
正常路由: gate_logits = W_gate @ x
改进后: gate_logits = W_gate @ x + b (b 不参与梯度计算)
的更新规则很简单:统计这一轮每个专家处理了多少 token,负载超过均值的专家偏置减小 γ(让它下次少收到 token),低于均值的偏置增大 γ(让它下次多收到 token)。γ 通常取很小的值(如 0.001),偏置在训练过程中自然收敛到平衡点。
# 每个 training step 结束后执行
expert_loads = count_tokens_per_expert(selected_experts)
mean_load = total_tokens * top_k / num_experts
for i in range(num_experts):
if expert_loads[i] > mean_load:
bias[i] -= gamma # 超载 → 降偏置
else:
bias[i] += gamma # 欠载 → 升偏置