- 机构:华盛顿大学、艾伦人工智能研究所、IBM
- 代码:https://github.com/AkariAsai/self-rag
- 论文:https://arxiv.org/abs/2310.11511
- 发布:2023.10
现有问题:LLM的事实不准确性
LLM经常产生幻觉,特别是在长尾情况下,它们的知识变得过时,并且缺乏归因。
检索增强生成是否是万能药?
传统RAG可以无差别地检索和合并一定数量的检索段落,无论检索是否必要或段落是否相关,都可能导致无用的生成。
1. Self-RAG?
自我反思检索增强生成(Self-RAG),本质是微调两个大模型。一个是评估大模型(critic model),另一个是生成大模型(generator model)。微调的内容不是领域知识,而是作为RAG应用所应具备的技能,比如什么时候去检索、生成内容是否有幻觉、如何确保生成内容的真实可靠可用。我分别从模型训练和推理去介绍Self-RAG。
Critic模型训练
学习面对各种各样的query时,1.是否需要检索,2.检索知识是否和Query相关(准确性),3.基于检索生成的内容是否真的来自检索知识还是大模型自己yy的(事实支持性),以及4.检索生成的内容是否真的对用户Query有帮助(有用性)。训练critic模型的目标函数是最大化似然度,其中 是数据集, 是自省标记(reflection tokens)。从公式可以看到,根据x和y来计算r的条件生成概率。而自省标记r包括4种:粗体的文本表示最理想的自省标记。x、y、d分别表示输入、输出和相关文档段落。
Generator模型训练
有了Critic模型生成自省tokens能力作为基础,进一步构建增强(Augmented)训练数据,下面是构建流程,其中 表示critic模型, 是检索模块, 是相关文档段落。
- 用Critic模型判断x是否需要检索,并预测 [Retrieve] token值,并把值拼接到x后面;如果是Yes再通过检索模块找出 K 个最相关的文档段落集合 ;
- 对于每个段落,Critic模型会进一步评估段落和x是否相关,并预测 [IsREL] token值;如果段落是相关的,Critic模型又会进一步评估,段落是否能支持模型的生成,并预测 [IsSUP] token值;最后把这两个token值拼接在检索生成内容y后面
- 当整个y生成出来后,再预测y的 [IsUSE] token值,并把值拼接到y后,
下面是增强数据的样例备注:文本chunk之间用 <p></p> 包住。
以此,生成整个数据集 ,并基于次数据集进行生成器模型训,目标函数即求x预测【y和r】的条件生成概率的最大对数似然估计因为纪要预测y,也要预测自省标记r,因此需要将r扩进词表中。
汇总
整个Self-RAG的训练过程伪代码如下:
2. 推理流程
大致推理流程如上,我们展开描述一下:
-
判断是否需要检索时,当 时,再基于 去检索相关知识片段
-
如果需要检索,假设检索出的知识片段集合为 ,对于每个 ,
注意,在每个时间步都用LLM进行并行推理输出 个不同的 候选集,并且记录他们的得分 然后进行都进行Beam Search(设置Beam大小为),如下简图所示,最终获取 个最优的候选片段序列 。 分数是 的加权,权重可以认为调整。而对应的自省token的得分也比较简单,看A.3附录即可。
- 预测判断 和 的相关性:
- 当前时间步的检索生成:
- 预测判断当前生成是否满足事实支持性和可用性:
- 基于三个自省标记( [IsREL]、[IsSUP]、[IsUSE] )的预测结果,对 进行打分
如果不需要检索,
- 预测 :
- 评估 分数:
注意:Critic模型是不参与Self-RAG的推理,但它在训练阶段的作用是至关重要的。它确保了Generator模型能够学习到如何生成高质量的输出,并在需要时进行有效的自我评估和批判。
实践
论文也开源了微调模型,可以下载一个GGUF版本,并使用llama.cpp进行推理。先安装,
pip install llama_cpp_python
pip install huggingface-hub
然后下载模型
huggingface-cli download m4r1/selfrag_llama2_7b-GGUF selfrag_llama2_7b.q4_k_m.gguf --local-dir ./model --local-dir-use-symlinks False
给一个简单的运行示例
from llama_cpp import Llama
# 定义模型参数和生成参数
MODEL_KWARGS = {
"logits_all": True,
"n_ctx": 2048,
"n_gpu_layers": 200
}
GENERATE_KWARGS = {
"temperature": 0.0,
"top_p": 1.0,
"max_tokens": 1024,
"logprobs": 1000
}
# 初始化模型
llm = Llama(model_path="selfrag_llama2_7b.q4_k_m.gguf", **MODEL_KWARGS)
# 格式化Prompt函数
def format_prompt(query, paragraph=None):
"""
格式化查询为模型所需的prompt格式。
:param query: 输入的问题或指令。
:param paragraph: 可选的,与查询相关的段落信息,用于检索。
:return: 格式化后的prompt字符串。
"""
prompt = "### Instruction:\n{0}\n\n### Response:\n".format(query)
if paragraph:
prompt += "[Retrieval]<paragraph>{0}</paragraph>".format(paragraph)
return prompt
# 测试问题
queries = [
"撰写一首表达对老师的感激之情的短诗",
"简述一下人工智能在医疗领域的应用"
]
# 测试并打印结果
for query in queries:
prompt = format_prompt(query)
result = llm(prompt, **GENERATE_KWARGS)
# 提取并打印生成的文本
generated_text = result["choices"][0]["text"]
print("\nResponse:\n{0}".format(generated_text))
# 如果需要,打印详细信息
# print(result["choices"][0])