跳到主要内容

DAPO 源码解析笔记

算法来源:字节跳动 2025 年论文《DAPO: Decoupled Clip and Dynamic Sampling Policy Optimization》

实验配置:Qwen2.5-1.5B-Instruct + 单卡 48G GPU + GSM8K 中文数据集 + 300 步训练(约 60 分钟)


什么是 DAPO?

DAPO 是对 DeepSeek 使用的 GRPO 算法 的改进版,专门针对 CoT(Chain-of-Thought)长文本推理训练 优化。

核心思想:用 Group 内多个回答的对比,代替 PPO 中的 Critic 网络估计优势,省掉了单独的价值函数,同时在 GRPO 基础上做了三处关键改进:

改进点GRPODAPO
① Clip 范围对称 [1-ε, 1+ε]解耦 [1-0.2, 1+0.28],上限更大
② 无效样本处理Dynamic Sampling,优势全零则跳过
③ Loss 粒度序列级(token 被平均)Token 级(直接 sum/总 token 数)

整体训练流水线

训练是 off-policy 模式:先采样一批经验,然后用这批经验复用训练 num_iterations 次。


数据集:GSM8KDataset

class GSM8KDataset(Dataset):
def __getitem__(self, index):
return {
"prompt": sample["question_zh-cn"], # 中文数学题
"answer": sample["answer_only"] # 纯数字答案,如 "72"
}
  • 数据集:GSM8K 中文版(约 8500 道小学数学题)
  • 每个样本只有题目和答案,格式简单

核心数据结构:Samples(一个 Group)

@dataclass
class Samples:
prompt_response_ids # [num_gen, seq_len] 完整 prompt+response 的 token id
response_ids # [num_gen, resp_len] 仅 response 部分
attention_mask # [num_gen, seq_len] 非 pad 位置为 1
action_mask # [num_gen, resp_len] 非 eos/pad 的 response token 为 1
num_actions # response 最大长度
response_length # [num_gen] 每个回答的实际有效 token 数量

Group 的概念:对同一道题生成 num_generations=4 个不同回答,构成一个 Group,用于组内对比计算优势。


超参数:DapoArguments

参数说明
num_generations4Group 大小,每道题生成 4 个回答
clip_eps_high0.28Clip-Higher 上限(比标准 PPO 的 0.2 更大)
clip_eps_low0.2Clip-Higher 下限
beta0.0KL 惩罚系数,为 0 不用参考模型(省显存)
gradient_accumulation_steps2梯度累积,模拟更大 batch
num_iterations1每批经验复用训练次数
max_prompt_length256输入最大长度
max_generate_length128输出最大长度
lr1e-6学习率

generate_samples:批量采样

def generate_samples(self, inputs):

关键设计点:

Left Padding

tokenizer.padding_side = "left"

生成时 batch 内 prompt 长度不同,左 pad 确保所有序列右端对齐,使 response 部分在每行同一偏移处开始。

批量生成 4 个回答

inputs_enc = self.tokenizer([input_text] * self.args.num_generations, ...)
prompt_response_ids = self.model.generate(**inputs_enc, temperature=0.9, top_p=1, top_k=50)

temperature=0.9 引入随机性,使 4 个回答各不相同,构成有效的组内对比。

action_mask 的构建

action_mask = (
response_ids.ne(self.tokenizer.eos_token_id)
& response_ids.ne(self.tokenizer.pad_token_id)
).to(dtype=torch.long)

过滤掉 eos 和 pad,只保留 "真正生成的内容",后续 loss 计算只在这些位置计算。


generate_experiences:计算经验数据

def generate_experiences(self, inputs):

流程图:

奖励计算:

# rewards_per_func: [num_funcs, num_generations]
rewards = rewards_per_func * torch.tensor(reward_weights).unsqueeze(1)
rewards = rewards.sum(dim=0) # → [num_generations]

组内归一化(优势估计):

advantages = (rewards - rewards.mean()) / (rewards.std() + 1e-8)

用组内均值和标准差归一化,避免奖励尺度对梯度大小的影响,这是 GRPO/DAPO 的核心思想。

【DAPO 创新 ②】Dynamic Sampling:

nonzero_num = advantages.count_nonzero().item()
if nonzero_num == 0:
print("组内优势为0, 跳过")
continue

组内 4 个回答的奖励完全一样(全对或全错)→ 优势全为 0 → 梯度贡献为零 → 直接跳过,避免浪费算力。

注意:源码只过滤了 nonzero_num == 0,论文中还过滤了 nonzero_num == len(advantages)(全不同),这是复现版的一个小偏差。


compute_loss:DAPO 核心算法

def compute_loss(self, model, inputs):

完整公式:

其中:

  • :重要性采样比
  • :组内归一化优势(序列级标量,广播到 token 级)
  • :总有效 token 数

【DAPO 创新 】Clip-Higher 实现:

coef_1 = torch.exp(action_log_probs - old_action_log_probs)  # r(θ)
coef_2 = torch.clamp(coef_1,
1 - self.args.clip_eps_low, # 下限 0.8
1 + self.args.clip_eps_high) # 上限 1.28

对比标准 PPO 的对称裁剪 [0.8, 1.2],DAPO 上限更大(1.28),允许更大幅度的正向更新,增强探索能力。

per_token_loss1 = coef_1 * advantages.unsqueeze(1)   # 未裁剪
per_token_loss2 = coef_2 * advantages.unsqueeze(1) # 裁剪后
per_token_loss = -torch.min(per_token_loss1, per_token_loss2) # 取悲观值

【DAPO 创新 】Token-Level Loss 实现:

# reshape: [batch*gen, seq] → [batch, gen, seq]
per_token_loss = per_token_loss.view(-1, self.args.num_generations, num_actions)
action_mask = action_mask.view(-1, self.args.num_generations, num_actions)

# 所有 token 求和 / 总有效 token 数
loss = per_token_loss.sum(-1).sum(-1) / action_mask.sum(-1).sum(-1)
loss = loss.mean()

对比 GRPO(先对序列内 token 平均,再对序列平均),DAPO 直接按 token 数归一化,长序列贡献比例更大,更适合鼓励长 CoT 推理。

KL 惩罚(beta = 0 时跳过):

log_ratio = ref_action_log_probs - action_log_probs
k3 = log_ratio.exp() - 1 - log_ratio # k3 KL estimator,数值更稳定
per_token_loss = per_token_loss + self.args.beta * k3

get_action_log_probs:计算 log 概率

def get_action_log_probs(self, model, input_ids, attention_mask, num_actions):
logits = model(input_ids, attention_mask=attention_mask).logits
# logits[:, :-1, :] 预测 input_ids[:, 1:](causal LM 的错位技巧)
log_probs = F.log_softmax(logits[:, :-1, :], dim=-1)
# gather 出实际 token 的 log prob
log_probs_labels = log_probs.gather(dim=-1, index=input_ids[:, 1:].unsqueeze(-1))
# 只取 response 部分(最后 num_actions 个 token)
action_log_probs = log_probs_labels.squeeze(-1)[:, -num_actions:]
return action_log_probs # [batch * num_gen, num_actions]

错位技巧:Causal LM 的 logits[i] 预测 token[i+1],所以用 logits[:, :-1] 对应 input_ids[:, 1:],再 gather 出实际 token 的概率。


奖励函数设计(reward_func.py)

奖励层次总览

奖励函数权重性质作用
correctness_reward2.0稀疏、硬最核心,答对才给高分
digit_reward0.5稠密缓解稀疏,鼓励输出数字
hard_format_reward0.5稀疏、硬严格 XML 格式匹配
mark_reward1.0稠密、软逐标签奖励,引导学习格式

设计哲学:用软奖励(digit_rewardmark_reward)填充稀疏奖励(correctness_rewardhard_format_reward)的空白,使模型在训练初期也有方向感。

答案提取

def extract_answer(text):
# 从 <answer>...</answer> 中提取最后一个数字
answer_regex = r"<answer>(.*?)<\/answer>"
num_regex = r"\d+\.\d+|\d+/\d+|\d+" # 支持小数、分数、整数
nums = re.findall(num_regex, answer_content)
return nums[-1] # 取最后一个数字(通常是最终结果)

格式奖励的两层设计

def hard_format_reward(prompts, responses, answers):
# 严格正则匹配完整 XML 结构,全对才给 0.5
pattern = r"^<think>\n.*?\n</think>\n<answer>\n.*?\n</answer>\n$"
...

def mark_reward(prompts, responses, answers):
# 逐标签计分,每个标签 0.125 分,最高 0.5 分
reward = 0
if text.count("<think>\n") == 1: reward += 0.125
if text.count("</think>\n") == 1: reward += 0.125
if text.count("<answer>\n") == 1: reward += 0.125
if text.count("</answer>\n") == 1: reward += 0.125

hard_format_reward 是终态目标,mark_reward 是过渡引导,两者配合形成 格式学习的渐进课程


训练稳定性设计

梯度累积

loss = loss / self.args.gradient_accumulation_steps
loss.backward()
if (step + 1) % self.args.gradient_accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()

单卡显存有限,通过梯度累积模拟更大的 effective batch size。

Buffer 机制

Dynamic Sampling 会过滤掉优势全零的 Group,导致有效 batch 数量不稳定。Buffer 机制积攒够 batch_size 个有效样本再进行参数更新,保证每次更新的数据量一致。

beta = 0,不用参考模型

beta = 0.0  # 关闭 KL 惩罚,不需要参考模型

节省约一倍显存。论文中 DAPO 也建议 β = 0,因为 Clip-Higher 已经能控制策略漂移。


与 GRPO 的全面对比

维度GRPODAPO
Clip 范围对称 [1-ε, 1+ε]解耦 [1-ε_low, 1+ε_high],上限更大
无效样本无过滤优势全零时跳过(Dynamic Sampling)
Loss 粒度序列级(token 被平均)Token 级(直接 sum 归一化)
KL 惩罚通常 β > 0,需参考模型β = 0,不需参考模型,省显存
适用场景通用 RL 训练长 CoT 推理训练
探索能力一般更强(上限更宽松)
长序列支持一般更好(token 级 loss 不稀释长序列)

运行方式

# 安装依赖
pip install transformers datasets torch

# 准备模型和数据
# ./Qwen2.5-1.5B-Instruct (模型目录)
# ./gsm8k_chinese (数据集目录)

# 开始训练
python train.py

训练输出:

  • ./output/training_losses.txt:每步 loss 记录
  • ./output/accuracy_losses.txt:每 10 步准确率记录
  • ./output/checkpoint_N/:每 100 步保存一次

关键代码索引

功能文件位置
数据集加载train.pyGSM8KDataset
超参数配置train.pyDapoArguments
Group 采样train.pygenerate_samples()
经验生成+Dynamic Samplingtrain.pygenerate_experiences()
Clip-Higher + Token-Level Losstrain.pycompute_loss()
Log prob 计算train.pyget_action_log_probs()
答案提取reward_func.pyextract_answer()
四个奖励函数reward_func.pycorrectness/digit/hard_format/mark_reward()

笔记整理自带注释源码,源文件:train_annotated.py + reward_func_annotated.py


附录:完整源码

reward_func_annotated.py

import re

# ============================================================
# 奖励函数模块
# 设计原则:多层次、渐进式奖励,解决稀疏奖励问题
#
# 奖励层次(从硬到软):
# 1. correctness_reward(权重2.0):最核心,答对才给高分,但稀疏
# 2. digit_reward(权重0.5):答案是数字就给分,缓解稀疏
# 3. hard_format_reward(权重0.5):严格格式匹配,引导输出规范
# 4. mark_reward(权重1.0):标签级别的软奖励,逐步引导学习格式
# ============================================================


def extract_answer(text):
"""
从模型输出中提取 <answer>...</answer> 标签内的数字答案。

提取逻辑:
1. 正则匹配 <answer>...</answer> 内容(DOTALL 匹配多行)
2. 在内容中找所有数字(支持小数、分数、整数)
3. 取最后一个数字(通常是最终计算结果)

返回:
数字字符串,如 "72";无法提取时返回 ""
"""
answer_regex = r"<answer>(.*?)<\/answer>"
answer_match = re.search(answer_regex, text, re.DOTALL)
if not answer_match:
return "" # 没有 answer 标签

answer_content = answer_match.group(1)
if not answer_content:
return "" # 标签内容为空

# 支持三种数字格式:小数(3.14)、分数(3/4)、整数(42)
num_regex = r"\d+\.\d+|\d+/\d+|\d+"
nums = re.findall(num_regex, answer_content)
if len(nums) == 0:
return ""

# 取最后一个数字:数学题通常最后给出最终答案
return nums[-1].strip()


def mark_num(text):
"""
逐标签检查 XML 格式的软奖励(每个标签 0.125 分,满分 0.5)。
"""
reward = 0
if text.count("<think>\n") == 1: reward += 0.125
if text.count("</think>\n") == 1: reward += 0.125
if text.count("<answer>\n") == 1: reward += 0.125
if text.count("</answer>\n") == 1: reward += 0.125
return reward


def correctness_reward(prompts, responses, answers):
"""正确性奖励(权重 2.0):答对得 2.0,否则 0.0"""
extracted_responses = [extract_answer(r) for r in responses]
rewards = []
for response, ans in zip(extracted_responses, answers):
if response == str(ans.item()):
rewards.append(2.0)
else:
rewards.append(0.0)
return rewards


def digit_reward(prompts, responses, answers):
"""数字奖励(权重 0.5):answer 中是数字得 0.5,否则 0.0"""
extracted_responses = [extract_answer(r) for r in responses]
return [0.5 if response.isdigit() else 0.0 for response in extracted_responses]


def hard_format_reward(prompts, responses, answers):
"""严格格式奖励(权重 0.5):完整 XML 结构匹配得 0.5,否则 0.0"""
pattern = r"^<think>\n.*?\n</think>\n<answer>\n.*?\n</answer>\n$"
matches = [re.match(pattern, response, re.DOTALL) for response in responses]
return [0.5 if match else 0.0 for match in matches]


def mark_reward(prompts, responses, answers):
"""标记奖励(权重 1.0):逐标签软奖励,0 ~ 0.5"""
return [mark_num(response) for response in responses]

train_annotated.py

# ============================================================
# DAPO (Decoupled Clip and Dynamic Sampling Policy Optimization)
# 字节跳动提出的 RL 训练算法,针对 CoT 长文本推理优化
# 核心改进:① Clip-Higher ② Dynamic Sampling ③ Token-Level Loss
# ============================================================

from transformers import (
AutoModelForCausalLM,
AutoModel,
AutoModelForSequenceClassification,
AutoTokenizer,
PreTrainedModel,
)
from dataclasses import dataclass
from typing import Optional, Union, Tuple
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from typing import Callable, Dict, List, Optional, Tuple, Union, Any
from copy import deepcopy
from datasets import load_dataset
from reward_func import *
import os


# ============================================================
# 1. 数据集
# ============================================================

class GSM8KDataset(Dataset):
"""
GSM8K 中文版数学数据集封装。
每个样本包含:
- prompt: 中文数学题目(question_zh-cn)
- answer: 标准数字答案(answer_only)
"""

def __init__(self, data_path, tokenizer, split: str = "train", test_size: int = 100):
self.tokenizer = tokenizer
data = load_dataset(data_path)
self.data = data[split]

def __len__(self):
return len(self.data)

def __getitem__(self, index):
sample = self.data[index]
answer = sample["answer_only"]
prompt = sample["question_zh-cn"]
return {"prompt": prompt, "answer": answer}


# ============================================================
# 2. 数据结构:一个 Group 的采样结果
# ============================================================

@dataclass
class Samples:
"""
存储一道题的一个 Group 的生成结果。
Group:对同一个 prompt 生成 num_generations 个不同回答。

字段说明:
prompt_response_ids : [num_generations, seq_len]
response_ids : [num_generations, resp_len]
attention_mask : [num_generations, seq_len]
action_mask : [num_generations, resp_len]
num_actions : response 的最大长度
response_length : [num_generations] 每个回答的实际有效 token 数量
"""
prompt_response_ids: torch.Tensor
response_ids: torch.Tensor
prompt: Any
answer: Any
attention_mask: Optional[torch.LongTensor]
action_mask: Optional[torch.BoolTensor]
num_actions: Union[int, torch.Tensor]
response_length: int


# ============================================================
# 3. 超参数配置
# ============================================================

class DapoArguments:
output_dir = "./output"
device = "cuda" if torch.cuda.is_available() else "cpu"
lr = 0.000001
save_steps = 100
epoch = 3
num_generations = 4 # Group 大小
max_prompt_length = 256
max_generate_length = 128
reward_weights: List[float] = None
beta = 0.0 # KL 系数,0 = 不用参考模型
clip_eps_high = 0.28 # Clip-Higher 上限
clip_eps_low = 0.2 # Clip-Higher 下限
gradient_accumulation_steps = 2
num_iterations = 1
batch_size = 1


# ============================================================
# 4. DAPO Trainer
# ============================================================

class DapoTrainer:
def __init__(
self,
model=None,
reward_funcs: Union[List[str], List[Callable]] = None,
args=None,
train_dataset: Optional[Union[Dataset]] = None,
eval_dataset: Optional[Union[Dataset]] = None,
tokenizer=None,
reward_tokenizers=None,
):
self.args = args

if isinstance(model, str):
model = AutoModelForCausalLM.from_pretrained(model)
self.model = model.to(self.args.device)

# 参考模型(beta=0 时不使用,节省显存)
self.ref_model = None
if self.args.beta != 0.0:
self.ref_model = deepcopy(model)
self.ref_model.eval()

if isinstance(tokenizer, str):
tokenizer = AutoTokenizer.from_pretrained(tokenizer)
self.tokenizer = self.get_tokenizer(tokenizer)

if isinstance(reward_funcs, str):
reward_funcs = [reward_funcs]

for i, reward_func in enumerate(reward_funcs):
if isinstance(reward_func, str):
reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained(
reward_func, num_labels=1
).to(self.args.device)

self.reward_funcs = reward_funcs

if reward_tokenizers is None:
reward_tokenizers = [None] * len(reward_funcs)
elif isinstance(reward_tokenizers, str):
reward_tokenizers = [reward_tokenizers]
else:
if len(reward_tokenizers) != len(reward_funcs):
raise ValueError("reward_tokenizers 数量必须与 reward_funcs 一致")

for i, (reward_tokenizer, reward_func) in enumerate(
zip(reward_tokenizers, reward_funcs)
):
if isinstance(reward_func, PreTrainedModel):
if reward_tokenizer is None:
reward_tokenizer = AutoTokenizer.from_pretrained(
reward_func.config._name_or_path
)
if reward_tokenizer.pad_token_id is None:
reward_tokenizer.pad_token = reward_tokenizer.eos_token
reward_func.config.pad_token_id = reward_tokenizer.pad_token_id
reward_tokenizers[i] = reward_tokenizer

self.reward_tokenizers = reward_tokenizers
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.args.lr)
self.train_dataset = train_dataset
self.eval_dataset = eval_dataset
self.input_buffer = [None] * self.args.gradient_accumulation_steps
self.update_steps = 0

def get_tokenizer(self, tokenizer):
tokenizer.padding_side = "left" # left padding 用于生成
return tokenizer

# ---- 4.1 采样 ----

def generate_samples(self, inputs):
"""对每道题生成 num_generations 个回答,构成一个 Group"""
samples_list = []
self.model.eval()

prompts = [prompt for prompt in inputs["prompt"]]
answers = [None] * len(prompts)
if "answer" in inputs:
answers = [answer for answer in inputs["answer"]]

max_length = self.args.max_generate_length + self.args.max_prompt_length

for prompt, answer in zip(prompts, answers):
input_text = self.tokenizer.apply_chat_template(
[
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": prompt},
],
add_generation_prompt=True,
tokenize=False,
)

inputs_enc = self.tokenizer(
[input_text] * self.args.num_generations,
padding="max_length",
max_length=self.args.max_prompt_length,
truncation=True,
return_tensors="pt",
)
prompt_ids = inputs_enc["input_ids"]

with torch.no_grad():
prompt_response_ids = self.model.generate(
**inputs_enc.to(self.args.device),
max_new_tokens=self.args.max_generate_length,
temperature=0.9,
top_p=1,
top_k=50,
)

if prompt_response_ids.size(1) >= max_length:
prompt_response_ids = prompt_response_ids[:, :max_length]
else:
pad_len = max_length - prompt_response_ids.size(1)
prompt_response_ids = torch.cat(
[
prompt_response_ids,
torch.full(
(prompt_response_ids.size(0), pad_len),
fill_value=self.tokenizer.pad_token_id,
device=prompt_response_ids.device,
),
],
dim=1,
)

attention_mask = prompt_response_ids.ne(self.tokenizer.pad_token_id).to(
dtype=torch.long
)
response_ids = prompt_response_ids[:, prompt_ids.size(1):]
action_mask = (
response_ids.ne(self.tokenizer.eos_token_id)
& response_ids.ne(self.tokenizer.pad_token_id)
).to(dtype=torch.long)

samples = Samples(
prompt_response_ids=prompt_response_ids,
response_ids=response_ids,
prompt=prompt,
answer=answer,
attention_mask=attention_mask,
action_mask=action_mask,
num_actions=action_mask.size(1),
response_length=action_mask.float().sum(dim=-1),
)
samples_list.append(samples)

return samples_list

# ---- 4.2 生成经验 ----

def generate_experiences(self, inputs):
"""采样 + 奖励 + 优势计算 + Dynamic Sampling"""
self.model.eval()
samples_list = self.generate_samples(inputs)

batch_prompt_response_ids = []
batch_attention_mask = []
batch_action_mask = []
batch_advantages = []
batch_old_action_log_probs = []
batch_ref_action_log_probs = []

for samples in samples_list:
prompt_response_ids = samples.prompt_response_ids
response_ids = samples.response_ids
answer = samples.answer
attention_mask = samples.attention_mask
action_mask = samples.action_mask
num_actions = samples.num_actions
prompt = samples.prompt

with torch.no_grad():
rewards_per_func = torch.zeros(
len(self.reward_funcs),
self.args.num_generations,
device=self.args.device,
)

response_texts = self.tokenizer.batch_decode(
response_ids, skip_special_tokens=True
)
prompt_texts = [prompt] * len(response_texts)
prompt_response_texts = [
p + r for p, r in zip(prompt_texts, response_texts)
]

for i, (reward_func, reward_tokenizer) in enumerate(
zip(self.reward_funcs, self.reward_tokenizers)
):
if isinstance(reward_func, PreTrainedModel):
with torch.inference_mode():
reward_model_inputs = reward_tokenizer(
prompt_response_texts,
return_tensors="pt",
padding=True,
)
rewards_per_func[i] = reward_func(
**reward_model_inputs.to(self.args.device)
).logits.squeeze(-1)
else:
answers = [answer] * len(prompt_texts)
output_reward_func = reward_func(
prompts=prompt_texts,
responses=response_texts,
answers=answers,
)
output_reward_func = [
r if r is not None else torch.nan
for r in output_reward_func
]
rewards_per_func[i] = torch.tensor(
output_reward_func,
dtype=torch.float32,
device=self.args.device,
)

if not self.args.reward_weights:
self.args.reward_weights = [1.0] * len(self.reward_funcs)
if len(self.args.reward_weights) != len(self.reward_funcs):
raise ValueError("reward_weights 数量必须与 reward_funcs 一致")

rewards = rewards_per_func * torch.tensor(
self.args.reward_weights,
dtype=torch.float32,
device=rewards_per_func.device,
).unsqueeze(1)
rewards = rewards.sum(dim=0)

# 组内归一化
mean_group_rewards = rewards.mean()
std_group_rewards = rewards.std()
advantages = (rewards - mean_group_rewards) / (
std_group_rewards + 1e-8
)

# 【DAPO 创新②】Dynamic Sampling:优势全零则跳过
nonzero_num = advantages.count_nonzero().item()
if nonzero_num == 0:
print(f"组内优势为0, 跳过")
continue

print(f"rewards: {rewards}")
batch_advantages.append(advantages)

old_action_log_probs = self.get_action_log_probs(
self.model, prompt_response_ids, attention_mask, num_actions
)
batch_old_action_log_probs.append(old_action_log_probs)

if self.ref_model:
ref_action_log_probs = self.get_action_log_probs(
self.ref_model,
prompt_response_ids,
attention_mask,
num_actions,
)
batch_ref_action_log_probs.append(ref_action_log_probs)

batch_prompt_response_ids.append(prompt_response_ids)
batch_attention_mask.append(attention_mask)
batch_action_mask.append(action_mask)

return {
"prompt_response_ids": batch_prompt_response_ids,
"attention_mask": batch_attention_mask,
"action_mask": batch_action_mask,
"old_action_log_probs": batch_old_action_log_probs,
"ref_action_log_probs": batch_ref_action_log_probs if self.ref_model else None,
"advantages": batch_advantages,
}

# ---- 4.3 DAPO Loss(核心)----

def compute_loss(self, model, inputs):
"""
DAPO Policy Gradient Loss
loss = -min(r(θ)*A, clip(r(θ), 1-ε_low, 1+ε_high)*A)
Token-Level:所有 token 求和 / 总有效 token 数
"""
prompt_response_ids = inputs["prompt_response_ids"]
attention_mask = inputs["attention_mask"]
action_mask = inputs["action_mask"]
num_actions = action_mask.size(1)

action_log_probs = self.get_action_log_probs(
model, prompt_response_ids, attention_mask, num_actions
)

if self.args.beta != 0.0:
ref_action_log_probs = inputs["ref_action_log_probs"]
log_ratio = ref_action_log_probs - action_log_probs
log_ratio = log_ratio * action_mask
k3 = log_ratio.exp() - 1 - log_ratio # k3 KL estimator

advantages = inputs["advantages"]

old_action_log_probs = (
inputs["old_action_log_probs"]
if self.args.num_iterations > 1
else action_log_probs.detach()
)

# 重要性采样比
coef_1 = torch.exp(action_log_probs - old_action_log_probs)

# 【DAPO 创新①】Clip-Higher:非对称裁剪
coef_2 = torch.clamp(coef_1, 1 - self.args.clip_eps_low, 1 + self.args.clip_eps_high)

per_token_loss1 = coef_1 * advantages.unsqueeze(1)
per_token_loss2 = coef_2 * advantages.unsqueeze(1)
per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
per_token_loss = per_token_loss * action_mask

if self.args.beta != 0.0:
per_token_loss = per_token_loss + self.args.beta * k3

# 【DAPO 创新③】Token-Level Loss
per_token_loss = per_token_loss.view(-1, self.args.num_generations, num_actions)
action_mask = action_mask.view(-1, self.args.num_generations, num_actions)
loss = per_token_loss.sum(-1).sum(-1) / action_mask.sum(-1).sum(-1)
loss = loss.mean()

return loss

# ---- 4.4 计算 log prob ----

def get_action_log_probs(self, model, input_ids, attention_mask, num_actions):
"""计算 response 部分每个 token 的 log prob"""
output = model(input_ids, attention_mask=attention_mask)
logits = output.logits
log_probs = F.log_softmax(logits[:, :-1, :], dim=-1)
log_probs_labels = log_probs.gather(
dim=-1, index=input_ids[:, 1:].unsqueeze(-1)
)
action_log_probs = log_probs_labels.squeeze(-1)[:, -num_actions:]
return action_log_probs

# ---- 4.5 单步训练 ----

def train_step(self, model, inputs, optimizer, step):
model.train()
loss = self.compute_loss(model, inputs)
loss = loss / self.args.gradient_accumulation_steps
loss.backward()

if (step + 1) % self.args.gradient_accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()
print(f"step: {self.update_steps}/{self.global_steps} dapo_loss: {loss.item():.8f}")
loss_file_path = os.path.join(self.args.output_dir, "training_losses.txt")
os.makedirs(self.args.output_dir, exist_ok=True)
with open(loss_file_path, "a", encoding="utf-8") as f:
f.write(f"{self.update_steps},{loss.item():.8f}\n")

torch.cuda.empty_cache()

# ---- 4.6 主训练循环 ----

def train(self):
print(f"\n第 {self.update_steps} 步: === 开始评估模型 ===")
accuracy = self.evaluate(num_samples=100, batch_size=20)
print(f"第 {self.update_steps} 步: 模型准确率: {accuracy:.2f}")
accuracy_file_path = os.path.join(self.args.output_dir, "accuracy_losses.txt")
os.makedirs(self.args.output_dir, exist_ok=True)
with open(accuracy_file_path, "a", encoding="utf-8") as f:
f.write(f"{self.update_steps},{accuracy:.2f}\n")

self.global_steps = (
self.args.num_iterations
* self.args.epoch
* len(self.train_dataset)
// (self.args.batch_size * self.args.gradient_accumulation_steps)
)

for _ in range(self.args.epoch):
dataloader = DataLoader(
self.train_dataset, batch_size=self.args.batch_size, shuffle=True
)

buffer = {
"prompt_response_ids": [],
"attention_mask": [],
"action_mask": [],
"old_action_log_probs": [],
"ref_action_log_probs": [],
"advantages": [],
}

idx = 0
for batch in dataloader:
inputs = self.generate_experiences(batch)

buffer["prompt_response_ids"] += inputs["prompt_response_ids"]
buffer["attention_mask"] += inputs["attention_mask"]
buffer["action_mask"] += inputs["action_mask"]
buffer["old_action_log_probs"] += inputs["old_action_log_probs"]
if self.ref_model is not None:
buffer["ref_action_log_probs"] += inputs["ref_action_log_probs"]
else:
buffer["ref_action_log_probs"] = None
buffer["advantages"] += inputs["advantages"]

if len(buffer["prompt_response_ids"]) < self.args.batch_size:
continue

if self.ref_model is not None:
inputs = {k: v[: self.args.batch_size] for k, v in buffer.items()}
inputs = {k: torch.cat(v, dim=0) for k, v in inputs.items()}
buffer = {k: v[self.args.batch_size :] for k, v in buffer.items()}
else:
inputs = {
k: v[: self.args.batch_size]
for k, v in buffer.items()
if k != "ref_action_log_probs"
}
inputs = {k: torch.cat(v, dim=0) for k, v in inputs.items()}
inputs["ref_action_log_probs"] = None
buffer = {
k: v[self.args.batch_size :]
for k, v in buffer.items()
if k != "ref_action_log_probs"
}
buffer["ref_action_log_probs"] = None

self.input_buffer[idx % self.args.gradient_accumulation_steps] = inputs

if (idx + 1) % self.args.gradient_accumulation_steps == 0:
for _ in range(self.args.num_iterations):
for step, inputs in enumerate(self.input_buffer):
self.train_step(self.model, inputs, self.optimizer, step)

self.update_steps += 1

if self.update_steps % 10 == 0:
print(f"\n第 {self.update_steps} 步: === 开始评估模型 ===")
accuracy = self.evaluate(num_samples=100, batch_size=25)
print(f"第 {self.update_steps} 步: 模型准确率: {accuracy:.2f}")
accuracy_file_path = os.path.join(
self.args.output_dir, "accuracy_losses.txt"
)
os.makedirs(self.args.output_dir, exist_ok=True)
with open(accuracy_file_path, "a", encoding="utf-8") as f:
f.write(f"{self.update_steps},{accuracy:.2f}\n")

if self.update_steps % self.args.save_steps == 0:
self.model.save_pretrained(
self.args.output_dir + f"/checkpoint_{self.update_steps}"
)
self.tokenizer.save_pretrained(
self.args.output_dir + f"/checkpoint_{self.update_steps}"
)

idx += 1
del inputs

# ---- 4.7 评估 ----

def evaluate(self, num_samples=100, batch_size=20):
if len(self.eval_dataset) > num_samples:
indices = torch.randperm(len(self.eval_dataset))[:num_samples].tolist()
eval_subset = torch.utils.data.Subset(self.eval_dataset, indices)
else:
eval_subset = self.eval_dataset
num_samples = len(self.eval_dataset)

self.model.eval()
correct_count = 0
total_count = 0

with torch.no_grad():
dataloader = DataLoader(eval_subset, batch_size=batch_size, shuffle=False)

for i, batch in enumerate(dataloader):
current_batch_size = len(batch["prompt"])
batch_start_idx = i * batch_size

if batch_start_idx >= num_samples:
break

prompts = batch["prompt"]
answers = batch["answer"]

input_texts = []
for prompt in prompts:
input_text = self.tokenizer.apply_chat_template(
[
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": prompt},
],
add_generation_prompt=True,
tokenize=False,
)
input_texts.append(input_text)

inputs = self.tokenizer(
input_texts,
padding="max_length",
max_length=self.args.max_prompt_length,
truncation=True,
return_tensors="pt",
).to(self.args.device)

prompt_response_ids = self.model.generate(
**inputs,
max_new_tokens=self.args.max_generate_length,
temperature=0.9,
top_p=1,
top_k=50,
)

response_texts = []
for j in range(current_batch_size):
response_ids = prompt_response_ids[j, len(inputs[j].ids):] # NOTE: bug here
response_text = self.tokenizer.decode(
response_ids, skip_special_tokens=True
)
response_texts.append(response_text)

from reward_func import extract_answer

for j in range(current_batch_size):
predicted_answer = extract_answer(response_texts[j])
pred_normalized = str(predicted_answer).strip()
true_normalized = str(answers[j].item()).strip()
if pred_normalized == true_normalized:
correct_count += 1
total_count += 1

if total_count >= num_samples:
break

accuracy = correct_count / total_count if total_count > 0 else 0.0
self.model.train()
return accuracy

def save_model(self):
self.model.save_pretrained(self.args.output_dir)
self.tokenizer.save_pretrained(self.args.output_dir)


# ============================================================
# 5. 入口
# ============================================================

if __name__ == "__main__":

SYSTEM_PROMPT = """
按照如下格式回答问题:
<think>
你的思考过程
</think>
<answer>
你的回答
</answer>
"""

args = DapoArguments()

tokenizer = AutoTokenizer.from_pretrained("./Qwen2.5-1.5B-Instruct")
model = AutoModelForCausalLM.from_pretrained("./Qwen2.5-1.5B-Instruct")

prompts_dataset = GSM8KDataset("./gsm8k_chinese", tokenizer, split="train")
test_dataset = GSM8KDataset("./gsm8k_chinese", tokenizer, split="test")

trainer = DapoTrainer(
model=model,
reward_funcs=[correctness_reward, digit_reward, hard_format_reward, mark_reward],
args=args,
train_dataset=prompts_dataset,
eval_dataset=test_dataset,
tokenizer=tokenizer,
)
trainer.train()
trainer.save_model()