跳到主要内容

RAG:检索增强生成

问 LLM「昨天的新闻」或者「公司内部文档里写了什么」,它要么说不知道,要么编造一个看似合理的答案。训练完成后知识就固定了——这是 LLM 的根本限制。

这一节从零实现 RAG(Retrieval-Augmented Generation),让 LLM 在生成时能「查阅资料」,回答参数记忆之外、但检索资料中存在的问题。参考:RAG paper

RAG 解决的核心问题是:模型参数里的知识不会自动更新,也不会天然访问你的私有文档。训练完成后,参数固定;如果不接外部检索或工具,模型只能依赖上下文和参数记忆回答。

RAG 的思路是在回答前先检索相关文档,把检索到的内容拼入 prompt,然后基于这些文档生成回答。常见 RAG 推理流程不修改模型参数,而是改变输入上下文;它更像“检索增强的上下文注入”。有些系统还会结合微调、reranker 训练或反馈学习,所以不要把 RAG 理解成永远无训练。

完整的 RAG 主流程通常包括:文档切分、向量化、索引、检索、重排/过滤和生成。

1. RAG 的基本思路

RAG 的核心想法很简单:让模型「开卷考试」

传统 LLM:
用户提问 → 模型凭记忆回答 → 可能答错

RAG:
用户提问 → 检索相关文档 → 把文档塞进 prompt → 模型看着文档回答

完整的 RAG 主流程可以先理解为五步:

  1. 文档切分(Chunking):把长文档切成小段
  2. 向量化(Embedding):把每段文本变成一个向量
  3. 索引(Indexing):把向量存起来,方便快速查找
  4. 检索(Retrieval):用户提问时,找到最相关的文档段落
  5. 生成(Generation):把检索到的内容拼入 prompt,让模型回答

2. 文档切分(Chunking)

为什么要切分?因为:

  • 模型的上下文窗口有限:常见模型从几 K 到数十万 token 不等,部分商业模型已到 1M 级,但仍不适合把所有资料无脑塞进去。参考:OpenAI GPT-4.1 1M contextGPT-4.1 发布说明
  • 检索精度需要「粒度合适」的文本块——太长会混入无关内容,太短会丢失上下文

最简单的切分方式是按固定字符数切:

# 模拟一段文档
doc = """Transformer 是一种基于自注意力机制的神经网络架构,由 Vaswani 等人在 2017 年提出。
它的核心创新是完全摒弃了循环结构(RNN),只依赖注意力机制来建模序列中的依赖关系。
Transformer 由 Encoder 和 Decoder 两部分组成。原始论文用它做机器翻译,Encoder 负责理解源语言,Decoder 负责生成目标语言。
GPT 系列只使用了 Decoder 部分,采用自回归方式逐个生成 token。
BERT 只使用了 Encoder 部分,通过掩码语言模型(MLM)进行预训练。
现代大语言模型(如 GPT-4、LLaMA、Qwen)都基于 Transformer 架构,但在原始设计上做了很多改进。
主要改进包括:LayerNorm 改为 RMSNorm、ReLU 激活函数改为 SwiGLU、位置编码从正弦编码改为 RoPE 等。
训练一个大语言模型需要大量的计算资源和数据。以 LLaMA-65B 为例,训练使用了 1.4T token 的数据,在 2048 张 A100 GPU 上训练了约 21 天。
模型参数量越大,需要的训练数据也越多。Chinchilla 定律指出,对于给定的计算预算,模型参数量和训练数据量应该同步增长。
推理时,大语言模型采用自回归方式逐个生成 token。朴素实现会在每一步重复处理历史 token;实际部署通常使用 KV Cache 复用历史 K/V,避免完整重算,但新 token 仍要和历史 cache 做 attention。
这就是为什么推理速度是 LLM 部署的核心挑战之一。KV Cache、FlashAttention、量化等技术都是用来加速推理的。KV Cache 的作用可以参考 HuggingFace 文档:https://huggingface.co/docs/transformers/main/en/kv_cache。
"""

def chunk_fixed(text, chunk_size=100, overlap=20):
"""固定长度切分,带重叠"""
chunks = []
start = 0
while start < len(text):
end = start + chunk_size
chunk = text[start:end]
# 尝试在句号处断开
last_period = chunk.rfind('。')
if last_period > chunk_size // 2:
chunk = text[start:start + last_period + 1]
end = start + last_period + 1
chunks.append(chunk.strip())
start = end - overlap
return [c for c in chunks if len(c) > 10]

chunks = chunk_fixed(doc, chunk_size=120, overlap=20)
print(f"文档总长: {len(doc)} 字符")
print(f"切分后: {len(chunks)} 个片段\n")
for i, chunk in enumerate(chunks):
print(f"Chunk {i}: [{len(chunk)}字] {chunk[:60]}...")

3. 向量化(Embedding)

切分后,每个 chunk 需要变成一个向量,这样就能用数学方法衡量「两段文本有多相似」。

实际中会用专门的 Embedding 模型(如 BGE、text-embedding-3-small)。这里用随机向量模拟,重点是理解流程。

# 模拟 Embedding:用随机向量代表每个 chunk
# 实际中会用模型(如 sentence-transformers)生成有语义的向量

import numpy as np
import random

def fake_embedding(text, dim=64):
"""模拟文本向量化:基于字符频率生成伪向量"""
vec = np.zeros(dim)
for i, ch in enumerate(text[:dim*3]):
vec[i % dim] += ord(ch) * 0.01
# 加一点随机性模拟模型输出
vec += np.random.randn(dim) * 0.1
return vec / np.linalg.norm(vec) # 归一化

# 为每个 chunk 生成向量
chunk_vectors = np.array([fake_embedding(c) for c in chunks])

print(f"{len(chunks)} 个 chunk,每个向量化为 {chunk_vectors.shape[1]} 维")
print(f"向量范数(归一化后应接近 1.0): {np.linalg.norm(chunk_vectors[0]):.4f}")

4. 向量检索

有了向量后,检索就变成了“找相近的向量”。余弦相似度是常见距离之一,实际系统也可能用 dot product、L2 或向量数据库的近似最近邻配置:

向量归一化后,余弦相似度就是点积。

# 实现向量检索

import numpy as np

def cosine_similarity(a, B):
"""查询向量 a 与矩阵 B 中每行的余弦相似度"""
# 假设已归一化,直接点积
return B @ a

def retrieve(query, chunks, chunk_vectors, top_k=2):
"""检索与 query 最相关的 top_k 个 chunk"""
query_vec = fake_embedding(query)
scores = cosine_similarity(query_vec, chunk_vectors)
top_indices = np.argsort(scores)[-top_k:][::-1]
results = []
for idx in top_indices:
results.append({
"chunk": chunks[idx],
"score": float(scores[idx]),
"index": int(idx)
})
return results

# 测试检索
query = "Transformer 的推理速度怎么优化?"
results = retrieve(query, chunks, chunk_vectors, top_k=2)

print(f"查询: {query}\n")
print("检索结果:")
for r in results:
print(f" [score={r['score']:.4f}] Chunk {r['index']}: {r['chunk'][:70]}...")

5. 拼接 Prompt 并生成

检索到相关文档后,把它拼入 prompt。一个典型的 RAG prompt 结构:

[System Prompt]
请根据以下参考资料回答用户的问题。如果参考资料中没有相关信息,请说「我不知道」。

[参考资料]
{检索到的文档段落}

[用户问题]
{用户的问题}
# 构建 RAG prompt

def build_rag_prompt(query, retrieved_chunks):
"""构造 RAG 的完整 prompt"""
context = "\n\n".join([
f"[{i+1}] {r['chunk']}"
for i, r in enumerate(retrieved_chunks)
])

prompt = f"""请根据以下参考资料回答用户的问题。如果参考资料中没有相关信息,请说「根据已有资料无法回答」。

参考资料:
{context}

用户问题:{query}"""
return prompt

prompt = build_rag_prompt(query, results)
print(f"RAG prompt 长度: {len(prompt)} 字符\n")
print(prompt)

6. 评估检索质量

RAG 的效果很大程度取决于检索质量。如果检索到的内容不相关,模型再强也回答不好。

两个常用指标:

  • 召回率(Recall@K):相关文档中有多少被检索到了
  • MRR(Mean Reciprocal Rank):第一个相关文档排第几位
# 模拟评估检索质量

# 假设有 5 个查询,每个有标注的「相关 chunk」
test_queries = [
{"query": "谁提出了 Transformer?", "relevant": [0]},
{"query": "GPT 用了 Transformer 的哪部分?", "relevant": [1, 2]},
{"query": "LLaMA 训练用了多少 GPU?", "relevant": [5]},
{"query": "什么是 Chinchilla 定律?", "relevant": [6]},
{"query": "推理加速有哪些方法?", "relevant": [8, 9]},
]

def evaluate_retrieval(test_queries, chunks, chunk_vectors, top_k=3):
"""评估检索的召回率和 MRR"""
recall_sum = 0
mrr_sum = 0

for tq in test_queries:
results = retrieve(tq['query'], chunks, chunk_vectors, top_k=top_k)
retrieved_ids = set(r['index'] for r in results)
relevant_ids = set(tq['relevant'])

# 召回率
if len(relevant_ids) > 0:
recall = len(retrieved_ids & relevant_ids) / len(relevant_ids)
else:
recall = 0
recall_sum += recall

# MRR:第一个相关结果排在第几位
for rank, r in enumerate(results, 1):
if r['index'] in relevant_ids:
mrr_sum += 1.0 / rank
break

n = len(test_queries)
return recall_sum / n, mrr_sum / n

recall, mrr = evaluate_retrieval(test_queries, chunks, chunk_vectors, top_k=3)
print(f"Recall@3: {recall:.2f}")
print(f"MRR: {mrr:.2f}")
print(f"\n注意:这里用的是随机向量(fake_embedding),所以检索质量不高")
print(f"实际中用真实 Embedding、BM25、多路召回或 reranker 后,Recall@K 往往会明显高于 fake_embedding;具体数值必须用自己的数据集测。")

7. 进阶优化

基础的「切分 → 向量化 → 检索 → 生成」能跑通 RAG 主流程;真实生产效果通常还需要围绕召回、重排、引用、权限和评测做优化:

技巧解决什么问题怎么做
Re-ranker初步检索可能不准,用更强(但更慢)的模型重排先用向量/BM25 检索较多候选,再用 cross-encoder 或 reranker 排到更少候选
HyDE用户的提问很短,和文档的表达方式不匹配让模型先写一个「假答案」,用假答案去检索
多路召回向量检索可能漏掉关键词匹配很重要的场景同时用向量检索 + 关键词检索(BM25),合并结果
Metadata 过滤文档有结构化属性(日期、分类)检索前先按元数据过滤,缩小范围
上下文窗口压缩检索到的内容太多,塞不进上下文窗口用压缩模型、规则或 reranker 只保留和问题相关的证据;注意压缩也可能丢信息

Re-ranker 的直觉

向量检索(快但粗):
问:「LLaMA 训练成本」→ 返回 20 个可能相关的 chunk

Re-ranker(慢但精):
把问题和每个 chunk 拼在一起,过一遍交叉编码器
→ 重新排序,取 top-5 最相关的

这就好比先用搜索引擎快速找到一批候选网页,再仔细阅读标题和摘要选出最相关的几篇。

# 模拟 Re-ranker 的效果

# 第一阶段:向量检索 top-5(模拟粗筛)
query = "大模型训练需要多少计算资源?"
stage1_results = retrieve(query, chunks, chunk_vectors, top_k=5)

print("=== 第一阶段:向量检索 Top-5 ===")
for r in stage1_results:
print(f" [{r['score']:.4f}] {r['chunk'][:50]}...")

# 第二阶段:模拟 Re-ranker(给更相关的打更高分)
def fake_rerank(query, results):
"""模拟 Re-ranker:基于关键词匹配重排"""
query_keywords = set(query.replace('?', '').replace('?', ''))
for r in results:
keyword_overlap = len(query_keywords & set(r['chunk']))
r['rerank_score'] = r['score'] + keyword_overlap * 0.1
results.sort(key=lambda x: x['rerank_score'], reverse=True)
return results

stage2_results = fake_rerank(query, stage1_results[:3])

print(f"\n=== 第二阶段:Re-rank 后 Top-3 ===")
for r in stage2_results:
print(f" [rerank={r['rerank_score']:.4f}] {r['chunk'][:50]}...")

print(f"\n关键观察:Re-ranker 可能改变排序,让最相关的结果排到前面")

8. 工业界的 RAG 架构

实际生产中的 RAG 系统比上面的演示复杂得多,但核心流程相同:

                    ┌──────────────┐
│ 用户提问 │
└──────┬───────┘

┌──────▼───────┐
│ Query 改写 │ ← 可选:让问题更容易检索
└──────┬───────┘

┌────────────┼────────────┐
│ │ │
┌─────▼────┐ ┌────▼────┐ ┌────▼────┐
│ 向量检索 │ │ BM25 │ │ 知识图谱 │ ← 多路召回
└─────┬────┘ └────┬────┘ └────┬────┘
│ │ │
└────────────┼────────────┘

┌──────▼───────┐
│ Re-ranker │ ← 精排
└──────┬───────┘

┌──────▼───────┐
│ LLM 生成 │ ← 带检索结果的 prompt
└──────────────┘

常用工具:

组件常用工具
向量数据库Chroma、FAISS、Milvus、Pinecone、Qdrant
Embedding 模型BGE、E5、text-embedding-3-small、Cohere Embed
Re-rankerbge-reranker、Cohere Rerank
框架LangChain、LlamaIndex

小结

  • RAG 的核心思路是让 LLM「开卷考试」:先检索相关文档,再基于文档生成回答
  • 完整流程:文档切分 → 向量化 → 索引 → 检索 → 拼接 prompt → 生成
  • Chunking 需要平衡粒度:太长会混入无关内容,太短会丢失上下文
  • 检索质量是 RAG 效果的关键:用 Recall@K 和 MRR 评估
  • 进阶优化包括 Re-ranker、HyDE、多路召回、Metadata 过滤
  • 常见 RAG 推理不改变模型参数,而是把检索证据放进输入;完整系统也可能训练 embedding、reranker 或做模型微调

作业

作业 1:实现一个按段落切分的 chunk 函数(以空行或换行符分割),对比它和固定长度切分的效果差异。

小提示:先按 \n\n 分割,再对过长的段落做二次切分。

作业 2:修改 retrieve 函数,让它支持 top_k 和相似度阈值双重过滤(只返回 score > threshold 的结果)。

小提示:在 argsort 后加一个阈值判断,过滤掉低分结果。

作业 3:实现 BM25 检索(基于词频),和向量检索的结果做对比。

小提示:BM25 的核心是对每个词计算 IDF(逆文档频率),然后根据词频和文档长度计算相关性分数。