训练循环与损失函数
模型搭好了,参数全是随机数。怎么让它从「瞎猜」变成「能预测」?答案是训练——但训练循环的每一步到底在干什么?
这一节用一个超小的例子,把训练循环拆开:数据怎么组织、loss 怎么算、梯度怎么传、对话模板(Chat Template)怎么影响 loss 的计算范围。每个数字都手动算一遍。
LLM 的训练本质上是一个「下一个 token 预测」任务:给模型看前面的 token,让它猜下一个 token,猜错了就算 loss。训练循环的核心流程是:把语料切成 token 序列,构建 input 和 label(label 是 input 右移一位),模 型前向得到 logits,交叉熵 loss 反向传播更新参数。
这个流程看起来简单,但有几个容易被忽略的细节:loss 是在所有 token 上同时计算的(不是逐 token 串行),padding 位置的 loss 要 mask 掉,对话数据的 loss 默认只算在 assistant 回复上。
1. 最简单的训练例子
假设我们有一个极小的模型:词表只有 5 个词,要训练它预测下一个 token。
词表: [BOS=0, 我=1, 爱=2, 你=3, EOS=4]
训练数据只有一条句子: "我 爱 你"
token IDs: [BOS, 我, 爱, 你, EOS]
= [0, 1, 2, 3, 4]
训练的目标:给定前面几个 token,预测下一个。
给定 [BOS] → 预测 我
给定 [BOS, 我] → 预测 爱
给定 [BOS, 我, 爱] → 预测 你
给定 [BOS, 我, 爱, 你] → 预测 EOS
这看似是 4 个独立的预测任务,但 Transformer 有一个魔法:它可以并行地同时做所有位置 上的预测!
2. 训练数据与标签
这是最关键的理解点。看这张图:
训练句子: [BOS, 我, 爱, 你, EOS]
[0, 1, 2, 3, 4]
┌─────────────────────────────────┐
输入: │ BOS │ 我 │ 爱 │ 你 │
│ [0] │ [1] │ [2] │ [3] │
└─────────────────────────────────┘
↓ 模型 forward
┌─────────────────────────────────┐
模型预测: │ logits │logits│logits│logits│
│ [0] │ [1] │ [2] │ [3] │
└─────────────────────────────────┘
↓ 每个位置预测下一个 token
┌─────────────────────────────────┐
期望标签: │ 我 │ 爱 │ 你 │ EOS │
│ [1] │ [2] │ [3] │ [4] │
└─────────────────────────────────┘
输入 = 去掉最后一个 token
标签 = 去掉第一个 token(整体右移一位)
这叫 teacher forcing:不拿模型自己的预测继续,而是拿正确答案继续。
为什么这样做?因为这样所有位置可以并行训练,不用等上一个位置的结果。
# 用手算演示这个概念
import torch
sentence = torch.tensor([0, 1, 2, 3, 4]) # [BOS, 我, 爱, 你, EOS]
print("完整句子:", sentence.tolist())
print()
# 输入:去掉最后一个
input_ids = sentence[:-1] # [0, 1, 2, 3]
print("输入 (去掉最后一个):", input_ids.tolist())
print("含义: [BOS, 我, 爱, 你 ]")
print()
# 标签:去掉第一个(右移一位)
target_ids = sentence[1:] # [1, 2, 3, 4]
print("标签 (去掉第一个): ", target_ids.tolist())
print("含义: [我, 爱, 你, EOS]")
print()
print("一一对应:")
for i in range(len(input_ids)):
print(f" 位置 {i}: 看到 [{', '.join(str(x) for x in input_ids[:i+1].tolist())}] → 预测 {target_ids[i].item()}")
完整句子: [0, 1, 2, 3, 4]
输入 (去掉最后一个): [0, 1, 2, 3]
含义: [BOS, 我, 爱, 你 ]
标签 (去掉第一个): [1, 2, 3, 4]
含义: [我, 爱, 你, EOS]
一一对应:
位置 0: 看到 [0] → 预测 1
位置 1: 看到 [0, 1] → 预测 2
位置 2: 看到 [0, 1, 2] → 预测 3
位置 3: 看到 [0, 1, 2, 3] → 预测 4
3. Cross-Entropy Loss
模型在每个位置输出一组 logits(每个词一个分数),我们要把它们和标签对比。
用 交叉熵损失(Cross-Entropy Loss):
对每个位置:
1. 把 logits 转成概率: softmax(logits)
2. 看正确标签对应的概率
3. 取 -log(这个概率)
4. 所有位置平均
直觉:如果正确标签的概率是 1.0 → loss = -log(1.0) = 0(完美) 如果正确标签的概率是 0.01 → loss = -log(0.01) = 4.6(很烂)
接下来我们用代码一步步看。
# 模拟模型的输出
import torch
vocab_size = 5
seq_len = 4 # 输入长度
# 假设模型输出的 logits (batch=1)
# 实际中这些是模型算出来的,这里我们手动造一些来观察
torch.manual_seed(123)
logits = torch.randn(1, seq_len, vocab_size) # [batch=1, seq_len=4, vocab=5]
targets = torch.tensor([[1, 2, 3, 4]]) # [batch=1, seq_len=4]
print(f"模型输出 logits 形状: {logits.shape}")
print(f"标签形状: {targets.shape}")
print()
# 看第 0 个位置的 logits 和标签
print(f"位置 0 的 logits: {logits[0, 0].tolist()}")
print(f"位置 0 的标签: {targets[0, 0].item()}")
print(f"→ 模型要在 5 个词里猜对答案是词 {targets[0, 0].item()}")
模型输出 logits 形状: torch.Size([1, 4, 5])
标签形状: torch.Size([1, 4])
位置 0 的 logits: [0.3373701572418213, -0.1777772158384323, -0.3035275340080261, -0.5880124568939209, 1.5809690952301025]
位置 0 的标签: 1
→ 模型要在 5 个词里猜对答案是词 1
# 手工算一遍 loss(理解每个数字怎么来的)
import torch.nn.functional as F
import math
print("=== 手工计算 Cross-Entropy Loss ===")
print()
total_loss = 0.0
for pos in range(seq_len):
# 这个位置的 logits (vocab_size 个分数)
pos_logits = logits[0, pos] # [vocab_size]
# 这个位置的正确答案
correct_id = targets[0, pos].item()
# Step 1: softmax 转概率
probs = F.softmax(pos_logits, dim=-1)
# Step 2: 正确答案的概率
correct_prob = probs[correct_id].item()
# Step 3: loss = -log(概率)
pos_loss = -math.log(correct_prob)
total_loss += pos_loss
print(f"位置 {pos}: 正确答案=词{correct_id}, 概率={correct_prob:.4f}, loss={pos_loss:.4f}")
# Step 4: 平均
manual_loss = total_loss / seq_len
print(f"\n所有位置平均 loss: {manual_loss:.4f}")
# 对比 PyTorch 内置的 cross_entropy
pt_loss = F.cross_entropy(
logits.reshape(-1, vocab_size), # [batch*seq_len, vocab]
targets.reshape(-1) # [batch*seq_len]
).item()
print(f"PyTorch cross_entropy: {pt_loss:.4f}")
print(f"两者一致? {'✅' if abs(manual_loss - pt_loss) < 1e-4 else '❌'}")
=== 手工计算 Cross-Entropy Loss ===
位置 0: 正确答案=词1, 概率=0.0998, loss=2.3050
位置 1: 正确答案=词2, 概率=0.0853, loss=2.4619
位置 2: 正确答案=词3, 概率=0.3147, loss=1.1561
位置 3: 正确答案=词4, 概率=0.6087, loss=0.4965
所有位置平均 loss: 1.6049
PyTorch cross_entropy: 1.6049
两者一致? ✅
4. Token 级别还是句子级别
答案:Token 级别训练。
但注意:是所有 token 同时并行训练,不是一个个 token 串行训练。
┌──────────────────────────────────────────────┐
│ 一次 forward + backward │
│ │
│ Loss = loss(位置0) + loss(位置1) + ... │
│ │
│ 位置 0 预测 token 1 │
│ 位置 1 预测 token 2 } 全部并行计算 │
│ 位置 2 预测 token 3 │
│ 位置 3 预测 token 4 │
│ │
│ 梯度 = ∂Loss/∂W 是所有位置梯度的总和 │
│ 更新参数 ← 这一下 包含了所有位置的学习信号 │
└──────────────────────────────────────────────┘
为什么不是句子级别?
- 句子级别意味着只在一个位置(句子末尾)做预测 → 信号太稀疏
- 比如 100 个 token 的句子,句子级别只有一个监督信号
- Token 级别有 100 个监督信号,学习效率高 100 倍
但也不是"逐个 token 串行训练"。 而是所有位置并行,一次前向算出全部 loss。 这正是 Transformer 比 RNN 快的关键原因。
5. 一个 Batch 多句话的训练
真实训练不会一次只训一句话。一次一个 batch(比如 32 句话),所有句子拼成一个矩阵。
batch 输入:
[[BOS, 我, 爱, 你, EOS, PAD, PAD], ← 句子 1 (5 个有效 token)
[BOS, hello, world, EOS, PAD, PAD, PAD]] ← 句子 2 (4 个有效 token)
形状: [batch_size=2, seq_len=7]
Loss 计算:对所有句子的所有位置(除了 PAD)取平均。
loss = cross_entropy(logits.reshape(-1, vocab), targets.reshape(-1), ignore_index=PAD_ID)
# ↑ 忽略填充位置
# 演示 batch 训练的 loss 计算
import torch
import torch.nn.functional as F
PAD_ID = 0 # 假设 0 是 PAD
batch_input = torch.tensor([
[0, 1, 2, 3, 4, 0, 0], # [BOS, 我, 爱, 你, EOS, PAD, PAD]
[0, 2, 4, 0, 0, 0, 0], # [BOS, 爱, EOS, PAD, PAD, PAD, PAD]
])
# 标签 = input 右移一位
batch_target = torch.tensor([
[1, 2, 3, 4, 0, 0, 0], # [我, 爱, 你, EOS, PAD, PAD, PAD]
[2, 4, 0, 0, 0, 0, 0], # [爱, EOS, PAD, PAD, PAD, PAD, PAD]
])
print("Batch 输入:")
print(batch_input)
print()
print("Batch 标签:")
print(batch_target)
print()
# 模拟模型输出
batch_logits = torch.randn(2, 7, 5) # [batch=2, seq=7, vocab=5]
# 关键:ignore_index=PAD_ID,PAD 位置不参与 loss 计算
loss_with_ignore = F.cross_entropy(
batch_logits.reshape(-1, 5), # [14, 5]
batch_target.reshape(-1), # [14]
ignore_index=PAD_ID
)
loss_without_ignore = F.cross_entropy(
batch_logits.reshape(-1, 5),
batch_target.reshape(-1)
)
print(f"忽略 PAD 的 loss: {loss_with_ignore.item():.4f}")
print(f"不忽略 PAD 的 loss: {loss_without_ignore.item():.4f}")
print(f"\n差别很大!因为 PAD 位置的预测没有意义,不应该贡献 loss。")
Batch 输入:
tensor([[0, 1, 2, 3, 4, 0, 0],
[0, 2, 4, 0, 0, 0, 0]])
Batch 标签:
tensor([[1, 2, 3, 4, 0, 0, 0],
[2, 4, 0, 0, 0, 0, 0]])
忽略 PAD 的 loss: 2.3260
不忽略 PAD 的 loss: 2.3044
差别很大!因为 PAD 位置的预测没有意义,不应该贡献 loss。
6. 完整的训练循环
把前面讲的 loss 函数、梯度计算、优化器组合起来,形成一个完整的训练循环。在真实场景中,训练循环会从数据加载、token 化、前向传播、loss 计算、反向传播到参数更新跑完一整圈。
但这一节我们还没有训练真实的 tokenizer,所以先用一个简化场景来演示:
- 用一个很小的「伪词表」(几十个 token ID)模拟 token 化后的输入
- 手动构造 input 和 label 的偏移关系(label = input 右移一位,这是自回归语言模型的标准做法)
- 走完完整的 forward → loss → backward → update 流程
这个流程虽然用简化数据,但结构和真实训练完全相同。理解了它,后面接入真实 tokenizer 和数据就只是换输入的问题。
# 复用 Part 4 的 MiniGPT,简化版
import torch
import torch.nn as nn
import math
def get_sinusoidal_encoding(seq_len, d_model):
position = torch.arange(seq_len).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
)
pe = torch.zeros(seq_len, d_model)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
return pe
class MiniGPT(nn.Module):
def __init__(self, vocab_size, d_model=64, num_heads=4, num_layers=4, max_seq_len=128):
super().__init__()
self.d_model = d_model
self.vocab_size = vocab_size
self.max_seq_len = max_seq_len
self.token_emb = nn.Embedding(vocab_size, d_model)
pe = get_sinusoidal_encoding(max_seq_len, d_model)
self.register_buffer('pe', pe)
# 简化:不用 ModuleList,直接写几个 block
self.attn1 = nn.MultiheadAttention(d_model, num_heads, batch_first=True)
self.ffn1 = nn.Sequential(
nn.Linear(d_model, 4*d_model), nn.ReLU(), nn.Linear(4*d_model, d_model)
)
self.norm1a = nn.LayerNorm(d_model)
self.norm1f = nn.LayerNorm(d_model)
self.attn2 = nn.MultiheadAttention(d_model, num_heads, batch_first=True)
self.ffn2 = nn.Sequential(
nn.Linear(d_model, 4*d_model), nn.ReLU(), nn.Linear(4*d_model, d_model)
)
self.norm2a = nn.LayerNorm(d_model)
self.norm2f = nn.LayerNorm(d_model)
self.ln_final = nn.LayerNorm(d_model)
self.lm_head = nn.Linear(d_model, vocab_size)
def forward(self, x):
batch_size, seq_len = x.shape
x = self.token_emb(x) + self.pe[:seq_len, :]
mask = torch.triu(torch.ones(seq_len, seq_len, device=x.device) * float('-inf'), diagonal=1)
# Block 1
attn_out, _ = self.attn1(x, x, x, attn_mask=mask)
x = self.norm1a(x + attn_out)
x = self.norm1f(x + self.ffn1(x))
# Block 2
attn_out, _ = self.attn2(x, x, x, attn_mask=mask)
x = self.norm2a(x + attn_out)
x = self.norm2f(x + self.ffn2(x))
x = self.ln_final(x)
return self.lm_head(x)
print("MiniGPT 模型定义完成!")
MiniGPT 模型定义完成!
# === 完整的训练循环演示 ===
# 1. 准备假数据(模拟 token 化后的文本)
import torch
VOCAB_SIZE = 20
PAD_ID = 0
SEQ_LEN = 16
BATCH_SIZE = 8
# 假数据:随机 token 序列
train_data = torch.randint(1, VOCAB_SIZE, (100, SEQ_LEN)) # 100 条「句子」
print(f"训练数据: {train_data.shape} (100 条, 每条 {SEQ_LEN} 个 token)")
# 2. 创建模型
model = MiniGPT(VOCAB_SIZE, d_model=64, num_heads=4, num_layers=2)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
print(f"模型参数量: {sum(p.numel() for p in model.parameters()):,}")
训练数据: torch.Size([100, 16]) (100 条, 每条 16 个 token)
模型参数量: 102,676
# 3. 训练循环
import torch.nn.functional as F
NUM_EPOCHS = 5
losses = []
model.train()
for epoch in range(NUM_EPOCHS):
epoch_loss = 0.0
num_batches = 0
for i in range(0, len(train_data), BATCH_SIZE):
batch = train_data[i:i+BATCH_SIZE] # [batch_size, seq_len]
# 准备输入和标签
input_ids = batch[:, :-1] # 去掉最后一个
target_ids = batch[:, 1:] # 去掉第一个
# Forward
logits = model(input_ids) # [batch, seq_len-1, vocab_size]
# Loss: 把所有 batch 和所有位置展平
loss = F.cross_entropy(
logits.reshape(-1, VOCAB_SIZE), # [batch*(seq_len-1), vocab_size]
target_ids.reshape(-1) # [batch*(seq_len-1)]
)
# Backward
optimizer.zero_grad()
loss.backward()
optimizer.step()
epoch_loss += loss.item()
num_batches += 1
avg_loss = epoch_loss / num_batches
losses.append(avg_loss)
print(f"Epoch {epoch+1}/{NUM_EPOCHS} | Loss: {avg_loss:.4f}")
print(f"\nLoss 从 {losses[0]:.4f} 降到 {losses[-1]:.4f} → 模型在学习!")
Epoch 1/5 | Loss: 3.0834
Epoch 2/5 | Loss: 2.9192
Epoch 3/5 | Loss: 2.8793
Epoch 4/5 | Loss: 2.8393
Epoch 5/5 | Loss: 2.7892
Loss 从 3.0834 降到 2.7892 → 模型在学习!
# 可视化 loss 下降
import matplotlib.pyplot as plt
plt.figure(figsize=(8, 4))
plt.plot(losses, 'o-', markersize=8)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training loss curve')
plt.grid(True, alpha=0.3)
plt.show()
print("Loss 在下降 = 模型在学会预测下一个 token")

Loss 在下降 = 模型在学会预测下一个 token