RAG:检索增强生成
问 LLM「昨天的新闻」或者「公司内部文档里写了什么」,它要么说不知道,要么编造一个看似合理的答案。训练完成后知识就固定了——这是 LLM 的根本限制。
这一节从零实现 RAG(Retrieval-Augmented Generation),让 LLM 在生成时能「查阅资料」,回答参数记忆之外、但检索资料中存在的问题。参考:RAG paper。
RAG 解决的核心问题是:模型参数里的知识不会自动更新,也不会天然访问你的私有文档。训练完成后,参数固定;如果不接外部检索或工具,模型只能依赖上下文和参数记忆回答。
RAG 的思路是在回答前先检索相关文档,把检索到的内容拼入 prompt,然后基于这些文档生成回答。常见 RAG 推理流程不修改模型参数,而是改变输入上下文;它更像“检索增强的上下文注入”。有些系统还会结合微调、reranker 训练或反馈学习,所以不要把 RAG 理解成永远无训练。
完整的 RAG 主流程通常包括:文档切分、向量化、索引、检索、重排/过滤和生成。
1. RAG 的基本思路
RAG 的核心想法很简单:让模型「开卷考试」。
传统 LLM:
用户提问 → 模型凭记忆回答 → 可能答错
RAG:
用户提问 → 检索相关文档 → 把文档塞进 prompt → 模型看着文档回答
完整的 RAG 主流程可以先理解为五步:
- 文档切分(Chunking):把长文档切成小段
- 向量化(Embedding):把每段文本变成一个向量
- 索引(Indexing):把向量存起来,方便快速查找
- 检索(Retrieval):用户提问时,找到最相关的文档段落
- 生成(Generation):把检索到的内容拼入 prompt,让模型回答
2. 文档切分(Chunking)
为什么要切分?因为:
- 模型的上下文窗口有限:常见模型从几 K 到数十万 token 不等,部分商业模型已到 1M 级,但仍不适合把所有资料无脑塞进去。参考:OpenAI GPT-4.1 1M context、GPT-4.1 发布说明
- 检索精度需要「粒度合适」的文本块——太长会混入无关内容,太短会丢失上下文
最简单的切分方式是按固定字符数切:
# 模拟一段文档
doc = """Transformer 是一种基于自注意力机制的神经网络架构,由 Vaswani 等人在 2017 年提出。
它的 核心创新是完全摒弃了循环结构(RNN),只依赖注意力机制来建模序列中的依赖关系。
Transformer 由 Encoder 和 Decoder 两部分组成。原始论文用它做机器翻译,Encoder 负责理解源语言,Decoder 负责生成目标语言。
GPT 系列只使用了 Decoder 部分,采用自回归方式逐个生成 token。
BERT 只使用了 Encoder 部分,通过掩码语言模型(MLM)进行预训练。
现代大语言模型(如 GPT-4、LLaMA、Qwen)都基于 Transformer 架构,但在原始设计上做了很多改进。
主要改进包括:LayerNorm 改为 RMSNorm、ReLU 激活函数改为 SwiGLU、位置编码从正弦编码改为 RoPE 等。
训练一个大语言模型需要大量的计算资源和数据。以 LLaMA-65B 为例,训练使用了 1.4T token 的数据,在 2048 张 A100 GPU 上训练了约 21 天。
模型参数量越大,需要的训练数据也越多。Chinchilla 定律指出, 对于给定的计算预算,模型参数量和训练数据量应该同步增长。
推理时,大语言模型采用自回归方式逐个生成 token。朴素实现会在每一步重复处理历史 token;实际部署通常使用 KV Cache 复用历史 K/V,避免完整重算,但新 token 仍要和历史 cache 做 attention。
这就是为什么推理速度是 LLM 部署的核心挑战之一。KV Cache、FlashAttention、量化等技术都是用来加速推理的。KV Cache 的作用可以参考 HuggingFace 文档:https://huggingface.co/docs/transformers/main/en/kv_cache。
"""
def chunk_fixed(text, chunk_size=100, overlap=20):
"""固定长度切分,带重叠"""
chunks = []
start = 0
while start < len(text):
end = start + chunk_size
chunk = text[start:end]
# 尝试在句号处断开
last_period = chunk.rfind('。')
if last_period > chunk_size // 2:
chunk = text[start:start + last_period + 1]
end = start + last_period + 1
chunks.append(chunk.strip())
start = end - overlap
return [c for c in chunks if len(c) > 10]
chunks = chunk_fixed(doc, chunk_size=120, overlap=20)
print(f"文档总长: {len(doc)} 字符")
print(f"切分后: {len(chunks)} 个片段\n")
for i, chunk in enumerate(chunks):
print(f"Chunk {i}: [{len(chunk)}字] {chunk[:60]}...")
3. 向量化(Embedding)
切分后,每个 chunk 需要变成一个向量,这样就能用数学方法衡量「两段文本有多相似」。
实际中会用专门的 Embedding 模型(如 BGE、text-embedding-3-small)。这里用随机向量模拟,重点是理解流程。
# 模拟 Embedding:用随机向量代表每个 chunk
# 实际中会用模型(如 sentence-transformers)生成有语义的向量
import numpy as np
import random
def fake_embedding(text, dim=64):
"""模拟文本向量化:基于字符频率生成伪向量"""
vec = np.zeros(dim)
for i, ch in enumerate(text[:dim*3]):
vec[i % dim] += ord(ch) * 0.01
# 加一点随机性模拟模型输出
vec += np.random.randn(dim) * 0.1
return vec / np.linalg.norm(vec) # 归一化
# 为每个 chunk 生成向量
chunk_vectors = np.array([fake_embedding(c) for c in chunks])
print(f"{len(chunks)} 个 chunk,每个向量化为 {chunk_vectors.shape[1]} 维")
print(f"向量范数(归一化后应接近 1.0): {np.linalg.norm(chunk_vectors[0]):.4f}")
4. 向量检索
有了向量后,检索就变成了“找相近的向量”。余弦相似度是常见距离之一,实际系统也可能用 dot product、L2 或向量数据库的近似最近邻配置:
向量归一化后,余弦相似度就是点积。
# 实现向量检索
import numpy as np
def cosine_similarity(a, B):
"""查询向量 a 与矩阵 B 中每行的余弦相似度"""
# 假设已归一化,直接点积
return B @ a
def retrieve(query, chunks, chunk_vectors, top_k=2):
"""检索与 query 最相关的 top_k 个 chunk"""
query_vec = fake_embedding(query)
scores = cosine_similarity(query_vec, chunk_vectors)
top_indices = np.argsort(scores)[-top_k:][::-1]
results = []
for idx in top_indices:
results.append({
"chunk": chunks[idx],
"score": float(scores[idx]),
"index": int(idx)
})
return results
# 测试检索
query = "Transformer 的推理速度怎么优化?"
results = retrieve(query, chunks, chunk_vectors, top_k=2)
print(f"查询: {query}\n")
print("检索结果:")
for r in results:
print(f" [score={r['score']:.4f}] Chunk {r['index']}: {r['chunk'][:70]}...")
5. 拼接 Prompt 并生成
检索到相关文档后,把它拼入 prompt。一个典型的 RAG prompt 结构:
[System Prompt]
请根据以下参考资料回答用户的问题。如果参考资料中没有相关信息,请说「我不知道」。
[参考资料]
{检索到的文档段落}
[用户问题]
{用户的问题}
# 构建 RAG prompt
def build_rag_prompt(query, retrieved_chunks):
"""构造 RAG 的完整 prompt"""
context = "\n\n".join([
f"[{i+1}] {r['chunk']}"
for i, r in enumerate(retrieved_chunks)
])
prompt = f"""请根据以下参考资料回答用户的问题。如果参考资料中没有相关信息,请说「根据已有资料无法回答」。
参考资料:
{context}
用户问题:{query}"""
return prompt
prompt = build_rag_prompt(query, results)
print(f"RAG prompt 长度: {len(prompt)} 字符\n")
print(prompt)