首页 文章详情

SELF-RAG:通过自我反思进行批判学习如何检索、生成和整理

GonFreecss | 209 2024-06-01 23:24 0 0 0
UniSMS (合一短信)
  • 机构:华盛顿大学、艾伦人工智能研究所、IBM
  • 代码:https://github.com/AkariAsai/self-rag
  • 论文:https://arxiv.org/abs/2310.11511
  • 发布:2023.10

现有问题:LLM的事实不准确性

LLM经常产生幻觉,特别是在长尾情况下,它们的知识变得过时,并且缺乏归因。

检索增强生成是否是万能药?

传统RAG可以无差别地检索和合并一定数量的检索段落,无论检索是否必要或段落是否相关,都可能导致无用的生成。

1. Self-RAG?

自我反思检索增强生成(Self-RAG),本质是微调两个大模型。一个是评估大模型(critic model),另一个是生成大模型(generator model)。微调的内容不是领域知识,而是作为RAG应用所应具备的技能,比如什么时候去检索、生成内容是否有幻觉、如何确保生成内容的真实可靠可用。我分别从模型训练和推理去介绍Self-RAG。

Critic模型训练

学习面对各种各样的query时,1.是否需要检索,2.检索知识是否和Query相关(准确性),3.基于检索生成的内容是否真的来自检索知识还是大模型自己yy的(事实支持性),以及4.检索生成的内容是否真的对用户Query有帮助(有用性)。训练critic模型的目标函数是最大化似然度,033fc31eaddefd37f6dc94cecd5b41b1.webp其中 是数据集, 是自省标记(reflection tokens)。从公式可以看到,根据x和y来计算r的条件生成概率。而自省标记r包括4种:530d242fc2532c5b5ce4fd7b53a437ce.webp粗体的文本表示最理想的自省标记。x、y、d分别表示输入、输出和相关文档段落。

Generator模型训练

有了Critic模型生成自省tokens能力作为基础,进一步构建增强(Augmented)训练数据,下面是构建流程,其中 表示critic模型, 是检索模块, 是相关文档段落。331de40f7d51ed29c13642cdf459a5f2.webp

  • 用Critic模型判断x是否需要检索,并预测 [Retrieve] token值,并把值拼接到x后面;如果是Yes再通过检索模块找出 K 个最相关的文档段落集合 ;
  • 对于每个段落,Critic模型会进一步评估段落和x是否相关,并预测 [IsREL] token值;如果段落是相关的,Critic模型又会进一步评估,段落是否能支持模型的生成,并预测 [IsSUP] token值;最后把这两个token值拼接在检索生成内容y后面
  • 当整个y生成出来后,再预测y的 [IsUSE] token值,并把值拼接到y后,

下面是增强数据的样例c4adaf4ddfb884c0478a674d5f788287.webp备注:文本chunk之间用 <p></p> 包住。

以此,生成整个数据集 ,并基于次数据集进行生成器模型训,目标函数即求x预测【y和r】的条件生成概率的最大对数似然估计54a02fe75d5d1f4b8fa9156344a63162.webp因为纪要预测y,也要预测自省标记r,因此需要将r扩进词表中。

汇总

整个Self-RAG的训练过程伪代码如下:faf33347c4a5cd3e1ea1db87124aa33d.webp

2. 推理流程

064987b510f294899a0496c6380b3e17.webp大致推理流程如上,我们展开描述一下:

  1. 判断是否需要检索时,当 时,再基于 去检索相关知识片段

  2. 如果需要检索,假设检索出的知识片段集合为 ,对于每个 ,

    注意,在每个时间步都用LLM进行并行推理输出 个不同的 候选集,并且记录他们的得分 然后进行都进行Beam Search(设置Beam大小为),如下简图所示,fc4e9c401c368ee1a5a8d7acff4b5bb0.webp最终获取 个最优的候选片段序列 。 分数是 的加权,权重可以认为调整。而对应的自省token的得分也比较简单,看A.3附录即可。

  • 预测判断 和 的相关性:
  • 当前时间步的检索生成:
  • 预测判断当前生成是否满足事实支持性和可用性:
  • 基于三个自省标记( [IsREL]、[IsSUP]、[IsUSE] )的预测结果,对 进行打分

如果不需要检索,

  • 预测 :
  • 评估 分数:

注意:Critic模型是不参与Self-RAG的推理,但它在训练阶段的作用是至关重要的。它确保了Generator模型能够学习到如何生成高质量的输出,并在需要时进行有效的自我评估和批判。

实践

论文也开源了微调模型,可以下载一个GGUF版本,并使用llama.cpp进行推理。先安装,

      
      pip install llama_cpp_python
pip install huggingface-hub

然后下载模型

      
      huggingface-cli download m4r1/selfrag_llama2_7b-GGUF selfrag_llama2_7b.q4_k_m.gguf --local-dir ./model --local-dir-use-symlinks False

给一个简单的运行示例

      
      from llama_cpp import Llama

# 定义模型参数和生成参数
MODEL_KWARGS = {
    "logits_all": True,
    "n_ctx": 2048,
    "n_gpu_layers": 200
}
GENERATE_KWARGS = {
    "temperature": 0.0,
    "top_p": 1.0,
    "max_tokens": 1024,
    "logprobs": 1000
}

# 初始化模型
llm = Llama(model_path="selfrag_llama2_7b.q4_k_m.gguf", **MODEL_KWARGS)

# 格式化Prompt函数
def format_prompt(query, paragraph=None):
    """
    格式化查询为模型所需的prompt格式。
    
    :param query: 输入的问题或指令。
    :param paragraph: 可选的,与查询相关的段落信息,用于检索。
    :return: 格式化后的prompt字符串。
    "
""
    prompt = "### Instruction:\n{0}\n\n### Response:\n".format(query)
    if paragraph:
        prompt += "[Retrieval]<paragraph>{0}</paragraph>".format(paragraph)
    return prompt

# 测试问题
queries = [
    "撰写一首表达对老师的感激之情的短诗",
    "简述一下人工智能在医疗领域的应用"
]

# 测试并打印结果
for query in queries:
    prompt = format_prompt(query)
    result = llm(prompt, **GENERATE_KWARGS)
    
    # 提取并打印生成的文本
    generated_text = result["choices"][0]["text"]
    print("\nResponse:\n{0}".format(generated_text))
    
    # 如果需要,打印详细信息
    # print(result["choices"][0])


good-icon 0
favorite-icon 0
收藏
回复数量: 0
    暂无评论~~
    Ctrl+Enter