混合专家模型(MoE)
到目前为止的 GPT,每一层只有一个 FFN,所有 token 共享。参数和计算是绑在一起的——参数翻倍,每个 token 的计算也翻倍。
这一节介绍 MoE。它把一个大 FFN 拆成多个专家,每个 token 只激活少数几个。总参数可以很大,计算量不怎么变。我们从零写一个 MoE 层,重点看路由器怎么选专家、负载怎么均衡。
MoE 的思路:一个 FFN 拆成多个小的专家 FFN,前面加一个路由器决定每个 token 走哪几个。比如 8 个专家,每次只走 2 个——总参数 8 倍,每个 token 的计算 2 倍。参数多,能装的知识多;计算变 化不大,推理速度不受影响。
训一段时间后,路由通常会出现某种结构化分工:Mixtral 论文观察到相邻 token 往往分到相同专家,也有句法相关的路由模式;DeepSeekMoE 则用“细粒度专家 + 共享专家”推动更专门化的分工。但要注意,这不等于每个专家都能被简单命名成“语法专家”或“数字专家”。一些解释性研究提醒,专家更可能学到细粒度语言操作、局部语义模式,或者跨领域都常用的核心能力。参考:Mixtral、DeepSeekMoE、MoE Interpretability、Core Experts。
MoE 还有一个麻烦:路由器可能偷懒,把大部分 token 都发给少数几个专家,其余专家闲着。怎么让负载均衡,是 MoE 训练最核心的问题。
1. 普通 Transformer 的 FFN 层
Dense 模型和 MoE 模型的根本区别在于参数和计算量的关系:
Dense 模型:
所有 token → 同一个 FFN → 输出
参数量 = 每次推理的计算量(全部参数都参与计算)
想让模型更强 → 扩大 FFN → 计算量同步增长
MoE 模型:
每个 token → 路由器选 top-k 个专家 → 加权输出
总参数量 = N × 单个专家参数量(可以很大)
每次计算量 = k × 单个专家参数量(和 Dense 差不多)
参数和计算量不再同步增长
Dense 模型中,想让模型拥有更大的知识容量,就要把 FFN 的维度成倍扩大。但 FFN 扩大后,每个 token 都要经过这个更大的 FFN,推理成本同步增长。MoE 的思路是把知识分散到多个专家里,每个 token 只激活少数几个——总参数可以很大,但每次推理的计算量基本不变。
下面先回顾标准 FFN 的结构,再在这个基础上改造成 MoE。
每个 Transformer Block 里有一个 FFN(前馈网络),结构很简单——两个线性变换,中间夹一个激活函数:
输入 x (d_model=512)
↓
Linear(512 → 2048) ← 升维,给模型更大的「思考空间」
↓
ReLU / GELU ← 非线性,让模型能学复杂模式
↓
Linear(2048 → 512) ← 降维,回到原来的维度
↓
输出
参数量:512 × 2048 + 2048 × 512 ≈ 2M(两个矩阵,各约 1M)。
这个 FFN 是 Transformer 能力的核心来源之一。Attention 负责「从哪些位置获取信息」,FFN 负责「对获取的信息做什么处理」。Attention 告诉你「这个词和哪个词相关」,FFN 告诉你「知道了相关性之后,该把这个词变成什么」。
但现在的设计有一个隐含假设:所有 token 共享同一个 FFN。不管输入是「的」还是「量子力学」,都经过同两个矩阵,做同样的变换。这在小模型里没问题,但模型变大后,用一个 FFN 同时处理所有类型的知识就越来越难——就像用同一套规则处理语法问题和数学问题,规则本身会变得臃肿。
1.1 为什么 MoE 通常替换 FFN,而不是替换 Attention
你可能会问:既然 Transformer Block 里有 Attention 和 FFN 两个大部件,为什么 MoE 通常拿 FFN 开刀,而不是把 Attention 拆成专家?
原因有三个。
第一,FFN 是逐 token 计算的。每个 token 过 FFN 时,本来就像一个独立样本:输入一个向量,输出一个向量。路由器很容易对每个 token 单独决定「找哪几个专家」。
第二,FFN 通常参数很多。Dense FFN 的两层或三层大矩阵占了 Block 里相当多的参数。把它拆成多个专家,能显著增加模型容量;而每个 token 只走 top-k 个专家,计算不会按总专家数线性爆炸。
第三,Attention 负责 token 之间的信息交换。如果把 Attention 也动态路由,系统会更复杂:不同 token 不只要选专家,还要互相看,通信和缓存都会变麻烦。MoE 先替换 FFN,是收益大、改动相对清楚的一步。
所以你可以把 MoE 理解成:保留 Attention 这条信息交换通道,把后面的 FFN 加工车间扩建成多个专家车间。
2. MoE 的核心思想
既然一个 FFN 处理所有 token 负担太重,那换一个思路:把一个大 FFN 拆成 N 个小的专家 FFN,让每个 token 只找其中少数几个专家来处理。
┌─────────────┐
│ 路由器 │ ← 决定每个 token 找哪几个专家
│ (Gate) │
└──┬──┬──┬───┘
│ │ │
┌───────┘ │ └───────┐
↓ ↓ ↓
┌────────┐┌────────┐┌────────┐
│ 专家 1 ││ 专家 2 ││ 专家 3 │ ... (共 8 个)
│ (FFN) ││ (FFN) ││ (FFN) │
└────────┘└────────┘└────────┘
↓ ↓ ↓
└──────────┴──────────┘
加权求和
路由器的输入是 token 的 hidden state(一个 d_model 维的向量),输出是 N 个分数(每个专家一个分)。分数高的专家被选中。
每个 token 只激活 top-k 个专家(通常 k=2),不是全部 8 个:
token "的" → 路由器 → 专家 1, 5 (功能词,通用专家处理)
token "量子" → 路由器 → 专家 3, 7 (物理类知识)
token "hello" → 路由器 → 专家 2, 6 (英文类知识)
效果:
- 总参数量 = N × 一个专家的参数量(8 倍增长)
- 每次推理的计算量 = k × 一个专家的参数量(2 倍增长)
- 参数多但计算少 ← 参数和计算不再同步增长
路由器本身也是一个可训练的参数矩阵——nn.Linear(d_model, num_experts)。它和专家 FFN 一起训练,通过梯度学习如何为每个 token 选择最合适的专家。
import torch
import torch.nn as nn
import torch.nn.functional as F
class MoELayer(nn.Module):
"""
MoE FFN 层
参数:
d_model: 隐藏维度
num_experts: 专家数量
top_k: 每个 token 激活几个专家
"""
def __init__(self, d_model, num_experts=8, top_k=2, expert_dim=None):
super().__init__()
self.num_experts = num_experts
self.top_k = top_k
expert_dim = expert_dim or 4 * d_model
# 路由器:输入 d_model,输出 num_experts 个分数
self.gate = nn.Linear(d_model, num_experts, bias=False)
# N 个专家,每个是一个小 FFN
self.experts = nn.ModuleList([
nn.Sequential(
nn.Linear(d_model, expert_dim),
nn.ReLU(),
nn.Linear(expert_dim, d_model)
)
for _ in range(num_experts)
])
def forward(self, x):
"""
x: [batch, seq_len, d_model]
流程:
1. 路由器给每个 token 打分
2. 选 top-k 个专家
3. 只算这 k 个专家的输出
4. 加权求和
"""
batch_size, seq_len, d_model = x.shape
# Step 1: 路由器打分
gate_logits = self.gate(x) # [batch, seq_len, num_experts]
# Step 2: 选 top-k
top_k_logits, top_k_indices = torch.topk(gate_logits, self.top_k, dim=-1)
top_k_weights = F.softmax(top_k_logits, dim=-1) # 归一化权重
# Step 3 & 4: 对每个 token, 算选中专家的输出并加权求和
output = torch.zeros_like(x)
for b in range(batch_size):
for s in range(seq_len):
token = x[b, s] # [d_model]
for k in range(self.top_k):
expert_idx = top_k_indices[b, s, k].item()
weight = top_k_weights[b, s, k]
expert_out = self.experts[expert_idx](token.unsqueeze(0)).squeeze(0)
output[b, s] += weight * expert_out
return output
print("MoE 层定义完成!")
print(f"8 个专家,每个 token 只激活 2 个")
MoE 层定义完成!
8 个专家,每个 token 只激活 2 个
# 演示 MoE 的路由过程
import torch
import torch.nn.functional as F
torch.manual_seed(42)
moe = MoELayer(d_model=16, num_experts=8, top_k=2)
# 模拟 3 个 token
x = torch.randn(1, 3, 16)
# 看路由器给每个 token 的评分
with torch.no_grad():
gate_scores = moe.gate(x).squeeze(0) # [3, 8]
top_k_vals, top_k_idx = torch.topk(gate_scores, 2, dim=-1)
print("=== 路由器为 3 个 token 选择的专家 ===")
print()
for i in range(3):
experts = top_k_idx[i].tolist()
weights = F.softmax(top_k_vals[i], dim=-1).tolist()
print(f"Token {i}: 选中专家 {experts}, 权重 {[f'{w:.2f}' for w in weights]}")
print()
print("每个 token 只激活 2/8 = 25% 的专家")
print("总参数是 8 个专家的和,但计算量只有 2 个专家的量")
2.1 对比:普通 Transformer Block vs MoE Block
要看清 MoE 的结构,最好的办法不是只看公式,而是把模型结构直接打印出来。
普通 decoder block 的主线是:
x → Attention → FFN → output
MoE decoder block 的主线是:
x → Attention → Router → 多个 FFN experts 里选 top-k → output
也就是说,MoE 通常不是替换 Attention,而是把 Transformer Block 里的 FFN 换成“路由器 + 多个专家 FFN”。
# 用真实 nn.Module 打印:Dense FFN Block vs MoE Block
import torch.nn as nn
class TinyDenseBlock(nn.Module):
"""普通 Transformer decoder block 的骨架:Attention + 单个 FFN"""
def __init__(self, d_model=16, num_heads=2, ffn_dim=64):
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, num_heads, batch_first=True)
self.norm1 = nn.LayerNorm(d_model)
self.ffn = nn.Sequential(
nn.Linear(d_model, ffn_dim),
nn.ReLU(),
nn.Linear(ffn_dim, d_model),
)
self.norm2 = nn.LayerNorm(d_model)
def forward(self, x):
attn_out, _ = self.self_attn(x, x, x, need_weights=False)
x = self.norm1(x + attn_out)
ffn_out = self.ffn(x)
x = self.norm2(x + ffn_out)
return x
class TinyMoEBlock(nn.Module):
"""MoE decoder block 的骨架:Attention + Router + 多个 FFN experts"""
def __init__(self, d_model=16, num_heads=2, num_experts=4, top_k=2, expert_dim=64):
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, num_heads, batch_first=True)
self.norm1 = nn.LayerNorm(d_model)
self.moe = MoELayer(d_model, num_experts=num_experts, top_k=top_k, expert_dim=expert_dim)
self.norm2 = nn.LayerNorm(d_model)
def forward(self, x):
attn_out, _ = self.self_attn(x, x, x, need_weights=False)
x = self.norm1(x + attn_out)
moe_out = self.moe(x)
x = self.norm2(x + moe_out)
return x
dense_block = TinyDenseBlock()
moe_block = TinyMoEBlock(num_experts=4, top_k=2)
print("=== 普通 Transformer Block ===")
print(dense_block)
print()
print("=== MoE Transformer Block ===")
print(moe_block)
=== 普通 Transformer Block ===
TinyDenseBlock(
(self_attn): MultiheadAttention(
(out_proj): NonDynamicallyQuantizableLinear(in_features=16, out_features=16, bias=True)
)
(norm1): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
(ffn): Sequential(
(0): Linear(in_features=16, out_features=64, bias=True)
(1): ReLU()
(2): Linear(in_features=64, out_features=16, bias=True)
)
(norm2): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
)
=== MoE Transformer Block ===
TinyMoEBlock(
(self_attn): MultiheadAttention(
(out_proj): NonDynamicallyQuantizableLinear(in_features=16, out_features=16, bias=True)
)
(norm1): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
(moe): MoELayer(
(gate): Linear(in_features=16, out_features=4, bias=False)
(experts): ModuleList(
(0-3): 4 x Sequential(
(0): Linear(in_features=16, out_features=64, bias=True)
(1): ReLU()
(2): Linear(in_features=64, out_features=16, bias=True)
)
)
)
(norm2): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
)
上面的打印结果要看两个位置:
- 普通 block 里只有一个
ffn。 - MoE block 里
self_attn后面接的是moe,里面有gate和多个experts。
这就是 MoE 的核心结构证据:Attention 还在,FFN 变成了多个专家。
# trace 一次 forward:看 shape 和路由发生在哪里
import torch
trace_x = torch.randn(1, 3, 16)
print("=== Dense Block Trace ===")
with torch.no_grad():
dense_attn_out, _ = dense_block.self_attn(trace_x, trace_x, trace_x, need_weights=False)
dense_after_attn = dense_block.norm1(trace_x + dense_attn_out)
dense_ffn_out = dense_block.ffn(dense_after_attn)
dense_out = dense_block.norm2(dense_after_attn + dense_ffn_out)
print(f"input: {tuple(trace_x.shape)}")
print(f"attention: {tuple(dense_attn_out.shape)}")
print(f"single FFN: {tuple(dense_ffn_out.shape)}")
print(f"output: {tuple(dense_out.shape)}")
print()
print("=== MoE Block Trace ===")
with torch.no_grad():
moe_attn_out, _ = moe_block.self_attn(trace_x, trace_x, trace_x, need_weights=False)
moe_after_attn = moe_block.norm1(trace_x + moe_attn_out)
gate_logits = moe_block.moe.gate(moe_after_attn)
top_vals, top_idx = torch.topk(gate_logits, moe_block.moe.top_k, dim=-1)
moe_out_raw = moe_block.moe(moe_after_attn)
moe_out = moe_block.norm2(moe_after_attn + moe_out_raw)
print(f"input: {tuple(trace_x.shape)}")
print(f"attention: {tuple(moe_attn_out.shape)}")
print(f"gate logits: {tuple(gate_logits.shape)} # 每个 token 对每个 expert 的分数")
print(f"top-k index: {tuple(top_idx.shape)} # 每个 token 选中的 expert 编号")
print(f"MoE FFN: {tuple(moe_out_raw.shape)}")
print(f"output: {tuple(moe_out.shape)}")
print()
print("每个 token 选中的 experts:")
for token_i, experts in enumerate(top_idx[0].tolist()):
print(f"token {token_i}: experts {experts}")
=== Dense Block Trace ===
input: (1, 3, 16)
attention: (1, 3, 16)
single FFN: (1, 3, 16)
output: (1, 3, 16)
=== MoE Block Trace ===
input: (1, 3, 16)
attention: (1, 3, 16)
gate logits: (1, 3, 4) # 每个 token 对每个 expert 的分数
top-k index: (1, 3, 2) # 每个 token 选中的 expert 编号
MoE FFN: (1, 3, 16)
output: (1, 3, 16)
每个 token 选中的 experts:
token 0: experts [3, 2]
token 1: experts [1, 3]
token 2: experts [3, 2]
2.2 打印 HuggingFace 里的真实 MoE 模型
小模型的骨架看懂之后,再看真实工程里的 decoder layer。 这里不下载权重,只用 config 初始化一个很小的 Qwen2-MoE / Mixtral,目的只有一个: 把真实源码里的模块顺序打印出来。
读打印结果时盯住三件事:
self_attn仍然在前面。- 普通 FFN 的位置,变成了
mlp/SparseMoeBlock。 - MoE 里面有
gate和多个experts。
新版 transformers 会把多个 expert 的权重打包成大 tensor,打印时不一定显示成
expert_0, expert_1, ...。所以除了打印 layer,还要打印参数形状:
experts.gate_up_proj 的第 0 维就是 expert 数量。
# 打印真实 HuggingFace MoE decoder layer,并 trace 一次 router 输出
import inspect
import warnings
import torch
warnings.filterwarnings("ignore", message="IProgress not found.*")
from transformers import Qwen2MoeConfig, Qwen2MoeForCausalLM
from transformers import MixtralConfig, MixtralForCausalLM
qwen_cfg = Qwen2MoeConfig(
vocab_size=128,
hidden_size=32,
intermediate_size=64,
moe_intermediate_size=64,
shared_expert_intermediate_size=64,
num_hidden_layers=1,
num_attention_heads=4,
num_key_value_heads=4,
num_experts=4,
num_experts_per_tok=2,
)
qwen_moe = Qwen2MoeForCausalLM(qwen_cfg)
qwen_layer = qwen_moe.model.layers[0]
print("=== Qwen2-MoE decoder layer ===")
print(qwen_layer)
print()
print("=== Qwen2-MoE MoE 参数形状 ===")
for name, param in qwen_layer.mlp.named_parameters():
if name.startswith("gate") or name.startswith("experts") or name.startswith("shared"):
print(f"{name:32s} {tuple(param.shape)}")
mixtral_cfg = MixtralConfig(
vocab_size=128,
hidden_size=32,
intermediate_size=64,
num_hidden_layers=1,
num_attention_heads=4,
num_key_value_heads=4,
num_local_experts=4,
num_experts_per_tok=2,
)
mixtral_moe = MixtralForCausalLM(mixtral_cfg)
mixtral_layer = mixtral_moe.model.layers[0]
print()
print("=== Mixtral decoder layer ===")
print(mixtral_layer)
print()
print("=== Mixtral MoE 参数形状 ===")
for name, param in mixtral_layer.mlp.named_parameters():
if name.startswith("gate") or name.startswith("experts"):
print(f"{name:32s} {tuple(param.shape)}")
print()
print("=== Qwen2-MoE layer.forward 源码摘录 ===")
source = inspect.getsource(type(qwen_layer).forward).splitlines()
for line in source:
if "self_attn" in line or "mlp" in line or "layernorm" in line:
print(line)
print()
print("=== Qwen2-MoE mlp.forward 源码摘录 ===")
source = inspect.getsource(type(qwen_layer.mlp).forward).splitlines()
for line in source:
if "gate" in line or "expert" in line or "router" in line or "top" in line:
print(line)
input_ids = torch.randint(0, 128, (1, 6))
with torch.no_grad():
qwen_out = qwen_moe(input_ids=input_ids, output_router_logits=True)
router_logits = qwen_out.router_logits[0]
top_vals, top_idx = torch.topk(router_logits, qwen_cfg.num_experts_per_tok, dim=-1)
print()
print("=== Qwen2-MoE router trace ===")
print(f"input_ids: {tuple(input_ids.shape)}")
print(f"logits: {tuple(qwen_out.logits.shape)}")
print(f"router logits: {tuple(router_logits.shape)}")
print(f"top-k experts: {tuple(top_idx.shape)}")
print()
print("前 6 个 token 选中的 experts:")
for token_i, experts in enumerate(top_idx[:6].tolist()):
print(f"token {token_i}: experts {experts}")
=== Qwen2-MoE decoder layer ===
Qwen2MoeDecoderLayer(
(self_attn): Qwen2MoeAttention(
(q_proj): Linear(in_features=32, out_features=32, bias=True)
(k_proj): Linear(in_features=32, out_features=32, bias=True)
(v_proj): Linear(in_features=32, out_features=32, bias=True)
(o_proj): Linear(in_features=32, out_features=32, bias=False)
)
(mlp): Qwen2MoeSparseMoeBlock(
(gate): Qwen2MoeTopKRouter()
(experts): Qwen2MoeExperts(
(act_fn): SiLUActivation()
)
(shared_expert): Qwen2MoeMLP(
(gate_proj): Linear(in_features=32, out_features=64, bias=False)
(up_proj): Linear(in_features=32, out_features=64, bias=False)
(down_proj): Linear(in_features=64, out_features=32, bias=False)
(act_fn): SiLUActivation()
)
(shared_expert_gate): Linear(in_features=32, out_features=1, bias=False)
)
(input_layernorm): Qwen2MoeRMSNorm((32,), eps=1e-06)
(post_attention_layernorm): Qwen2MoeRMSNorm((32,), eps=1e-06)
)
=== Qwen2-MoE MoE 参数形状 ===
gate.weight (4, 32)
experts.gate_up_proj (4, 128, 32)
experts.down_proj (4, 32, 64)
shared_expert.gate_proj.weight (64, 32)
shared_expert.up_proj.weight (64, 32)
shared_expert.down_proj.weight (32, 64)
shared_expert_gate.weight (1, 32)
=== Mixtral decoder layer ===
MixtralDecoderLayer(
(self_attn): MixtralAttention(
(q_proj): Linear(in_features=32, out_features=32, bias=False)
(k_proj): Linear(in_features=32, out_features=32, bias=False)
(v_proj): Linear(in_features=32, out_features=32, bias=False)
(o_proj): Linear(in_features=32, out_features=32, bias=False)
)
(mlp): MixtralSparseMoeBlock(
(gate): MixtralTopKRouter()
(experts): MixtralExperts(
(act_fn): SiLUActivation()
)
)
(input_layernorm): MixtralRMSNorm((32,), eps=1e-05)
(post_attention_layernorm): MixtralRMSNorm((32,), eps=1e-05)
)
=== Mixtral MoE 参数形状 ===
gate.weight (4, 32)
experts.gate_up_proj (4, 128, 32)
experts.down_proj (4, 32, 64)
=== Qwen2-MoE layer.forward 源码摘录 ===
hidden_states = self.input_layernorm(hidden_states)
hidden_states, _ = self.self_attn(
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
=== Qwen2-MoE mlp.forward 源码摘录 ===
shared_expert_output = self.shared_expert(hidden_states_reshaped)
_, routing_weights, selected_experts = self.gate(hidden_states_reshaped)
expert_output = self.experts(hidden_states_reshaped, selected_experts, routing_weights)
shared_expert_output = F.sigmoid(self.shared_expert_gate(hidden_states_reshaped)) * shared_expert_output
expert_output = expert_output + shared_expert_output
expert_output = expert_output.reshape(batch_size, sequence_length, hidden_dim)
return expert_output
=== Qwen2-MoE router trace ===
input_ids: (1, 6)
logits: (1, 6, 128)
router logits: (6, 4)
top-k experts: (6, 2)
前 6 个 token 选中的 experts:
token 0: experts [2, 0]
token 1: experts [3, 2]
token 2: experts [0, 2]
token 3: experts [2, 3]
token 4: experts [2, 1]
token 5: experts [0, 3]
3. MoE 的参数 vs 计算量
这是 MoE 最核心的优势。用具体数字感受一下:
假设一个 Dense 模型的 FFN 有 2M 参数(d_model=512,d_ff=2048):