跳转至

SFT/RLHF/RAG/LoRA⚓︎

约 6975 个字 预计阅读时间 23 分钟 总阅读量

SFT⚓︎

1. 解决的问题⚓︎

预训练模型(Base Model)是在海量文本上训练的,它的目标是“预测下一个词”。这使得它非常擅长补全文本,但并不直接具备遵循指令或进行对话的能力。

例如,如果你给一个未经 SFT 的预训练模型输入 "请介绍一下爱因斯坦的相对论",它可能并不会直接回答问题,而是续写成 "请介绍一下爱因斯坦的相对论和牛顿力学的区别",因为它在训练数据中见过类似的文章标题。

SFT的核心问题:如何让模型理解“指令-响应” (Instruction-Response) 这种交互模式,使其成为一个有用的助手,而不是一个只会续写的语言模型。

2. 作用与效果⚓︎

  • 注入能力:赋予模型遵循指令的能力。
  • 格式对齐:让模型的输出格式符合问答、对话的范式。
  • 知识引导:在特定领域进行 SFT,可以增强模型在该领域的知识和回答能力。

经过 SFT 后,模型就从 Base Model 变成了 SFT Model。现在你问它问题,它会尝试直接给出回答。但这个回答的质量可能参差不齐,可能包含事实错误、有害内容、或啰嗦无用的信息。它只是学会了“怎样答”,但不一定“答得好”。

3. 流程⚓︎

SFT 的流程本质上就是标准的目标监督学习

  1. 数据准备:这是最关键和成本最高的一步。需要构建一个高质量的 (Prompt, Response) 数据集。

    • Prompt:是用户输入的指令、问题或对话的上下文。
    • Response:是人类专家或高质量模型撰写的、理想的回答。
    • 例如:{"instruction": "写一首关于秋天的诗", "output": "金风送爽,枫叶染红山岗..."}
  2. 微调训练

    • 使用准备好的数据集对预训练好的 Base Model 进行微调。
    • 训练目标是标准的自回归语言模型损失(通常是交叉熵损失)。模型在看到 PromptResponse 的一部分后,要能准确地预测出 Response 的下一个词。
    • 一个重要的细节:在计算损失时,通常会Mask掉(忽略)Prompt部分的损失,只计算模型在生成 Response 时的损失。因为我们希望模型在给定 Prompt 的条件下学会生成 Response,而不是学会生成 Prompt 本身。

4. 输入与输出⚓︎

  • 输入:
    1. 一个预训练好的大语言模型 (Base LLM)。
    2. 一个高质量的指令-响应数据集 ((Prompt, Response) Pairs)。
  • 输出:
    1. 一个经过微调的 SFT 模型。这个模型已经具备了初步的指令遵循和对话能力。

RLHF: 基于人类反馈的强化学习⚓︎

1. 解决的问题⚓︎

SFT 模型虽然能回答问题了,但 "好" 的回答是复杂的、多维度的,难以用一个简单的损失函数来定义。什么样是 "好"?

  • 更有用 (Helpful):能真正解决用户问题。
  • 更诚实 (Honest):不捏造事实,承认自己的局限。
  • 更无害 (Harmless):拒绝生成暴力、歧视、危险的内容。
  • 更有趣 / 不啰嗦:语言风格更受人类喜欢。

这些标准非常主观,无法像 SFT 那样直接提供“标准答案”。

RLHF的核心问题如何将这些模糊、主观的人类偏好,量化成一个可优化的信号,从而引导模型生成人类更喜欢的回答

2. 作用与效果⚓︎

  • 行为对齐 (Alignment):使模型的行为和输出与人类的价值观和偏好对齐。这是 RLHF 最核心的作用。
  • 提升安全性:显著降低模型生成有害、不当内容的概率。
  • 提升有用性:让模型的回答更实际、更贴近用户需求。

经过 RLHF,模型(如 ChatGPT)才真正变得好用、可靠。

3. 流程 (以 PPO 为例)⚓︎

RLHF 的经典流程(来自 InstructGPT/ChatGPT 论文)分为三步:

第 1 步:训练一个奖励模型 (Reward Model, RM)

  • 目的:创建一个能够模仿人类偏好的“裁判”。这个裁判模型可以为任何一个 (Prompt, Response) 对打分,分数高低代表人类对这个回答的喜好程度。
  • 数据准备
    1. 选取一批 Prompt
    2. 用 SFT 模型对每个 Prompt 生成多个不同的回答 (比如 A, B, C, D)。
    3. 由人工标注者对这些回答进行排序(例如:D > B > A > C)。
    4. 将排序结果转换成成对的比较数据,例如 (Prompt, Chosen_Response, Rejected_Response)。比如 (Prompt, D, B), (Prompt, D, A), (Prompt, B, C) 等。
  • 训练
    • RM 的结构通常是一个去掉最后分类层的 SFT 模型。它输入一个 (Prompt, Response),输出一个标量 reward 值。
    • 训练目标是让 RM 给 Chosen_Response 的打分高于 Rejected_Response。使用的损失函数通常是Pairwise Ranking Loss:

      \[
      \text{Loss} = -\log(\sigma(r_{\theta}(x, y_c) - r_{\theta}(x, y_r)))
      \]

      其中,\(x\)Prompt\(y_c\) 是被选中的回答 (chosen),\(y_r\) 是被拒绝的回答 (rejected),\(r_{\theta}\) 是奖励模型。这个损失函数会驱动模型让 \(r_{\theta}(x, y_c)\) 的值尽可能大于 \(r_{\theta}(x, y_r)\)

第 2 步:使用强化学习 (PPO) 微调 SFT 模型

  • 目的:利用 RM 作为奖励信号,通过强化学习算法来优化 SFT 模型本身。
  • 核心组件
    • 策略 (Policy):就是我们正在优化的 LLM (从 SFT 模型初始化)。
    • 动作空间 (Action Space):词表中的所有词 (Token)。
    • 奖励函数 (Reward Function):主要由 RM 提供。
  • 流程

    1. 从一个数据集中随机取一个 Prompt \(x\)
    2. Policy (LLM) 生成一个回答 \(y\)
    3. Reward Model (RM)(x, y) 进行打分,得到奖励值 \(r\)
    4. 关键点:KL 散度惩罚。为了防止模型为了迎合 RM 而生成乱七八糟的、不流畅的文本(这种现象称为 "Reward Hacking"),会引入一个惩罚项。该惩罚项计算当前 Policy 模型与原始 SFT 模型的输出分布之间的 KL 散度。这确保了模型在学习人类偏好的同时,不会偏离其从 SFT 阶段学到的语言能力太远。

      \[
      \text{Final Reward} = r - \beta \cdot D_{KL}(\pi_{\text{RLHF}} || \pi_{\text{SFT}})
      \]
    5. 使用 PPO (Proximal Policy Optimization) 算法根据这个 Final Reward 更新 Policy (LLM) 的参数。PPO 是一种稳定高效的强化学习算法。

关于 DPO (Direct Preference Optimization):

PPO 流程复杂,训练不稳定。DPO 是一种更新、更简单的方法,它巧妙地证明了可以跳过显式训练奖励模型和强化学习这两个步骤

  • 核心思想:直接利用在第1步中收集的偏好数据 (Prompt, Chosen, Rejected),通过一个特殊的损失函数来直接优化 LLM。
  • 效果:这个损失函数等价于最大化奖励并施加 KL 散度约束,但它是在一个简单的分类任务框架下完成的。它直接让模型增加生成 Chosen 回答的概率,同时降低生成 Rejected 回答的概率。
  • 优点:更简单、更稳定、效果通常不亚于甚至优于 PPO。目前正在成为 RLHF 的主流方案。

4. 输入与输出⚓︎

  • 输入:
    1. 一个 SFT 模型。
    2. 一个人类偏好数据集 ((Prompt, Chosen_Response, Rejected_Response) Pairs)。
  • 输出:
    1. 一个经过对齐的、最终的、强大的对话模型(例如 GPT-4)。

总结对比⚓︎

方面 监督微调 (SFT) RLHF (PPO/DPO) RAG (检索增强生成)
核心目标 学会遵循指令 (Instruction Following) 学会人类偏好 (Human Preference Alignment) 为模型提供外部知识,解决知识局限和幻觉。
解决的问题 模型不会按指令格式回答问题 模型回答的质量不高、不安全 为大模型提供知识库,提高数据准确率
核心方法 监督学习 (Supervised Learning) 强化学习 / 直接偏好优化 -
数据格式 (指令, 理想回答) (指令, 胜出回答, 失败回答) -
训练信号 预测标准答案的交叉熵损失 奖励模型的分数 或 偏好对的直接优化 -
输出模型特点 能用,但不一定好用 更好用、更安全、更可靠 信息更及时、准确、外部

RAG: 检索增强生成⚓︎

1. 解决的问题⚓︎

RAG 主要为了解决大语言模型(LLM)固有的两大缺陷:

  1. 知识局限性 (Knowledge Limitation):

    • 知识截止 (Knowledge Cutoff):LLM 的知识被冻结在它训练数据截止的那个时间点。它不知道之后发生的新闻、事件或研究。例如,一个2022年训练的模型不知道2023年发布的新手机。
    • 领域知识缺失 (Domain-specific Knowledge Gap):通用 LLM 没有学习过特定企业内部的私有文档、技术手册、项目数据库等。你问它关于你公司的项目,它一无所知。
  2. 幻觉 (Hallucination):

    • 当 LLM 缺乏相关知识时,它不会简单地说“我不知道”,而是倾向于根据其学到的语言模式“编造”一个看似合理但实际上是错误的答案。这在需要事实准确性的场景中是致命的。

RAG 通过一种 “开卷考试” 的方式来解决这些问题。它不要求模型“背诵”所有知识(这是通过参数学习的,即微调),而是让模型在回答问题时,先去一个外部知识库里“查资料”,然后根据查到的资料来组织和生成答案。

具体而言,效果有:

  • 提高事实准确性:答案基于检索到的真实文档,大大减少了模型幻觉。
  • 增强知识时效性:只需更新外部知识库(这比重新训练模型便宜和快捷得多),模型就能获取最新信息。
  • 支持私有领域知识:可以轻松接入公司内部知识库、数据库,打造企业专属的问答机器人。
  • 提供可解释性和可追溯性:可以告知用户答案是基于哪些源文档生成的,增加了答案的可信度。

2. 一般流程⚓︎

RAG 的流程可以分为两个阶段:数据索引(离线)检索生成(在线)

阶段一:数据索引 (Indexing / Offline Process)

这个阶段是预处理阶段,为“开卷考试”准备“参考书”。

  1. 加载数据 (Load):从各种来源(如 PDF, HTML, Word, 数据库)加载原始文档。
  2. 切分文档 (Split):将长文档切分成更小的、语义完整的文本块 (Chunks)。为什么要切分?
    • LLM 的上下文窗口长度有限。
    • 更小的文本块在检索时更聚焦,能提高检索精度。
  3. 创建向量嵌入 (Embed):使用一个嵌入模型 (Embedding Model)(如 m3e-base, bge-large-zh, OpenAI's text-embedding-3-small 等)将每个文本块转换成一个高维向量。这个向量可以被认为是文本块在语义空间中的“坐标”。
  4. 存储索引 (Store):将文本块及其对应的向量索引存储到向量数据库 (Vector Database) 中(如 FAISS, Chroma, Pinecone, Milvus)。向量数据库专门为高效的向量相似度搜索做了优化。
阶段二:检索生成 (Retrieval & Generation / Online Process)

这个阶段是用户与系统交互的实时过程。

  1. 用户提问 (Query):用户输入一个问题或指令。
  2. 嵌入问题 (Embed Query):使用与索引阶段相同的嵌入模型将用户的问题也转换成一个向量。
  3. 检索文档 (Retrieve):在向量数据库中,计算问题向量与所有文本块向量之间的相似度(常用余弦相似度),找出与问题语义最相关的 Top-K 个文本块。
  4. 增强提示 (Augment Prompt):将检索到的这 Top-K 个文本块作为上下文,与用户的原始问题一起,组合成一个新的、信息丰富的提示 (Prompt)。这个提示通常会遵循一个模板,例如: > "请根据以下已知信息来回答用户的问题。如果信息不足,请说你不知道。 > > 已知信息: > [这里是检索到的文本块1] > [这里是检索到的文本块2] > ... > > 用户问题: > [这里是用户的原始问题]"
  5. 生成答案 (Generate):将这个增强后的提示喂给 LLM,LLM 会基于提供的上下文信息生成最终的、更准确的答案。

3. 用到的核心方法⚓︎

  • 信息检索 (Information Retrieval)
    • 稠密检索 (Dense Retrieval):即基于 Embedding 向量的语义相似度搜索。这是目前 RAG 的主流方法。
    • 稀疏检索 (Sparse Retrieval):传统的基于关键词匹配的检索,如 TF-IDF 或 BM25。
    • 混合搜索 (Hybrid Search):结合稠密检索和稀疏检索,既能捕捉语义相似性,又能精确匹配关键词,效果通常更好。
  • 提示工程 (Prompt Engineering):如何将检索到的上下文和原始问题有效地组织成一个提示,以引导 LLM 最好地利用这些信息,这是一项关键技术。

SFT/RLHF/RAG的辨析⚓︎

方面 监督微调 (SFT) RLHF (PPO/DPO) RAG (检索增强生成)
核心目标 学会遵循指令 (Instruction Following) 学会人类偏好 (Human Preference Alignment) 为模型提供外部知识,解决知识局限和幻觉。
解决的问题 模型不会按指令格式回答问题 模型回答的质量不高、不安全 知识过时、专业领域知识缺失、容易产生幻觉。
核心方法 监督学习 (Supervised Learning) 强化学习 / 直接偏好优化 信息检索 + 提示工程 (Information Retrieval + Prompt Engineering)
数据格式 (指令, 理想回答) (指令, 胜出回答, 失败回答) 外部知识库 (文档、网页等,被处理成文本块-向量索引)
训练信号 预测标准答案的交叉熵损失 奖励模型的分数 或 偏好对的直接优化 无 (通常是推理时技术,不涉及LLM训练)¹
输出模型特点 能用,但不一定好用 更好用、更安全、更可靠 答案可溯源、知识可更新、有效减少幻觉

RAG 和 SFT 的区别⚓︎

这是非常关键的区别,它们解决的问题和所处的环节完全不同。

特性 RAG (检索增强生成) SFT (监督微调)
核心目标 为模型提供外部知识,解决知识局限和幻觉。 教会模型一种行为模式或风格,如遵循特定指令、以特定角色对话。
发生环节 推理时 (Inference Time),实时发生。 训练时 (Training Time),一次性的学习过程。
知识来源 外部、显式的数据库,易于更新。 内部、隐式地存储在模型权重中。
更新知识 简单、快速、低成本。只需更新向量数据库中的文档。 复杂、慢、高成本。需要重新准备数据集并重新微调模型。
解决幻觉 非常有效,强制模型基于事实回答,且答案可溯源。 有限效果。可以教模型不要乱说,但无法根除它在知识盲区内的幻觉。
应用场景 知识库问答、客服机器人、文档分析、新闻摘要。 角色扮演、特定格式输出、指令遵循能力的通用提升。
好比 开卷考试(提供参考书) 考前辅导(训练答题技巧)
  • RAG 和 SFT 不是互斥的,而是互补的
  • 一个最佳实践是:先用 SFT 微调一个模型,让它更好地理解你的指令格式、输出风格(比如“请用中文、专业的语气回答”)。然后,在这个 SFT 后的模型基础上,再搭建 RAG 应用,为它提供实时的、私有的知识。
  • 如果你想让模型学习一种新的“能力”或“行为”,选择 SFT。
  • 如果你想让模型掌握一片新的“知识”,选择 RAG。

在当今的大模型落地应用中,RAG 因其低成本、高效率和高可靠性,已经成为了绝对的主流方案。


LoRA: 低秩自适应⚓︎

当然,LoRA 是大模型工程领域的一项明星技术,很好地解决了大模型训练和部署中的实际痛点。

1. LoRA 是什么技术?⚓︎

LoRA,全称是 Low-Rank Adaptation,即低秩自适应。它是一种参数高效微调 (Parameter-Efficient Fine-Tuning, PEFT) 技术。

核心思想:在对大模型进行微调时,我们不改变原始模型的任何权重(保持其冻结),而是在模型的关键部分(通常是 Transformer 中的 Attention 层的 QKV 线性投影层)旁边,注入两个小型的、可训练的“旁路”矩阵我们只训练这两个小矩阵,从而用极少的参数来“适配”大模型以完成新任务

2. LoRA 用在大模型工程的哪个阶段?⚓︎

LoRA 并不是一个独立的“阶段”,而是一种可以应用于任何需要微调模型的阶段技术

最常见的应用场景是:

  • 监督微调 (SFT) 阶段:当我们需要让一个基础模型 (Base Model) 学会遵循指令时,可以使用 LoRA 来进行 SFT。这样,我们不需要为每个 SFT 任务都保存一个完整的模型副本,只需要保存一个几十兆大小的 LoRA 适配器即可。
  • 个人/特定领域微调:用户想让模型模仿某种写作风格,或者让模型掌握某个特定领域的知识(如医疗、法律),可以使用 LoRA 在自己的数据上进行微调,成本极低。
  • RLHF 阶段:理论上也可以用 LoRA 来优化 Actor 模型,但实现会更复杂一些。

总而言之,只要你想对一个大模型进行微调,但又受限于计算资源(显存)和存储成本,LoRA 就是首选方案。

3. 具体流程如何?⚓︎

假设我们要在 Transformer 中的一个线性层(其权重为 \(W_0\))上应用 LoRA。\(W_0\) 是一个 \(d \times k\) 的大矩阵。

  1. 冻结原始权重:在训练开始前,将原始权重矩阵 \(W_0\) 设置为不可训练 (frozen)。

  2. 注入可训练的低秩矩阵

    • 创建两个新的、小的矩阵:矩阵 A(维度为 \(d \times r\))和矩阵 B(维度为 \(r \times k\))。
    • 这里的 \(r\) 就是“秩 (rank)”,是一个远小于 \(d\)\(k\) 的超参数(比如 8, 16, 64)。
    • 矩阵 A 通常用随机高斯分布初始化,矩阵 B 初始化为全零。B 初始化为零至关重要,因为它保证了在训练刚开始时,\(W_0 + \alpha(BA) = W_0\),LoRA 模块对原始模型没有影响,从而保证了训练的稳定性。
  3. 前向传播:当输入为 \(x\) 时,该层的计算方式变为: $$ h = W_0 x + \Delta W x = W_0 x + (\alpha \cdot B A) x $$

    • \(W_0x\) 是原始模型的计算路径。
    • \(BAx\) 是新增的 LoRA 旁路计算路径。
    • \(\alpha\) 是一个缩放系数(类似学习率),用于调整旁路计算结果的影响力。
  4. 反向传播:在训练过程中,只有矩阵 AB 的参数会被梯度更新,而巨大的 \(W_0\) 始终保持不变。

  5. 推理与部署

    • 训练完成后,我们得到了训练好的 AB
    • 为了不增加推理时的计算量,我们可以将 LoRA 模块“合并”回原始权重中: $$ W_{\text{trained}} = W_0 + \alpha \cdot BA $$
    • 这样,在部署时,我们使用的就是一个和原始模型结构完全相同的模型,其权重为 \(W_{\text{trained}}\),没有任何额外的计算延迟。你只需要保存 AB 这两个小矩阵,在需要时执行一次合并即可。

4. LoRA vs 全参数微调 (Full FT) 优劣对比⚓︎

特性 LoRA(低秩自适应) 全参数微调 (Full Fine-Tuning)
显存占用 。仅需存储A、B矩阵的梯度和优化器状态。 极高。需要存储整个模型所有参数的梯度和优化器状态。
训练速度 。需要计算的梯度量大大减少。 。反向传播需要计算所有参数的梯度。
存储成本 极低。每个任务只需保存几十MB的LoRA权重(A,B)。 极高。每个任务都需要保存一个完整的模型副本(几十GB)。
任务切换 灵活、高效。可以动态加载/卸载不同的LoRA适配器,实现一个模型服务多个任务。 笨重、昂贵。切换任务需要加载一个全新的、巨大的模型。
模型性能 通常与Full FT相当,有时甚至更好。因为它像一个正则化器,防止了对预训练知识的灾难性遗忘。 理论上有更高的性能上限,但也更容易在小数据集上过拟合或灾难性遗忘

5. 它的理论依据是什么?⚓︎

LoRA 的成功背后有一个非常重要的理论假设,来自微软的研究:

预训练的大语言模型具有很低的“内在维度” (intrinsic dimension),即低秩假设。

这个假设意味着,虽然大模型有数十亿个参数,但当我们将它适配到一个新的下游任务时,真正需要被改变的那些参数(权重的更新量 \(\Delta W\)),本质上是低秩的。也就是说,这个巨大的权重更新矩阵 \(\Delta W\) 可以被分解并用两个小得多的矩阵 \(B\)\(A\) 来高效地近似,即 \(\Delta W \approx BA\)

LoRA 正是基于这个洞察。它不去直接学习那个巨大的、难以训练的 \(\Delta W\),而是去学习它的低秩分解形式 \(B\)\(A\)。这不仅极大地减少了需要训练的参数量,还隐式地加入了一种正则化,限制了模型适配的自由度,使其更专注于学习任务相关的核心知识,而不过度改变其强大的预训练能力,从而有效防止了“灾难性遗忘”。

好的,我们再次深入 LoRA 的细节。这几个问题问得非常好,直接关系到 LoRA 的实现原理和效率。

6. LoRA 一般作用在神经网络的什么结构上?⚓︎

理论上,LoRA 可以作用于任何包含权重矩阵的神经网络层。但在 Transformer 架构中,研究和实践都表明,将 LoRA 应用于注意力层 (Attention Layers) 的效果是最好、最高效的。

具体来说,LoRA 主要作用于 Transformer Encoder 或 Decoder 块中的以下几个部分:

  1. 查询 (Query, Q) 的线性投影层\(W_q\)
  2. 键 (Key, K) 的线性投影层\(W_k\)
  3. 值 (Value, V) 的线性投影层\(W_v\)
  4. 注意力输出 (Output) 的线性投影层: \(W_o\)
  5. 前馈网络 (Feed-Forward Network, FFN) 中的线性层(通常有两个,up_projdown_proj)。

在实践中,最常见和最有效的做法是只对 Attention 层中的 QV 投影矩阵应用 LoRA。也有研究表明,将 LoRA 应用于所有的 Q, K, V, O 矩阵,甚至包括 FFN,可以带来微小的性能提升,但这会增加可训练参数量。因此,仅作用于 QV 是一个性价比极高的选择。

7. 微调哪些参数?⚓︎

记住 LoRA 的核心:只微调新增的、低秩的适配器参数,而冻结所有原始模型的参数。

具体到实现上,当我们对一个原始权重矩阵 \(W_0\) 应用 LoRA 时:

  • 冻结的参数:巨大的原始权重矩阵 \(W_0\)
  • 微调的参数:我们新注入的两个小矩阵,矩阵 A 和矩阵 B

在训练过程中,模型的所有梯度计算和参数更新都只发生在这些 AB 矩阵上。这正是 LoRA 能够做到参数高效微调的根本原因。

8. 对于 Attention 层,为什么对 Q 和 V 微调,不对 K 微调?⚓︎

这背后没有一个绝对的数学定论,更多是基于经验观察和一些直观的解释。一种被广泛接受的理解是,这与 Q, K, V 在自注意力机制中扮演的不同角色有关。

我们回顾一下自注意力的计算:\(\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V\)

  • \(Q \cdot K^T\) (查询与键的点积):这一步计算的是注意力分数 (Attention Scores)。它决定了在生成当前 token 的表示时,应该给予其他 token 多少“关注度”。
    • Q (Query) 代表了当前 token 为了获取信息而主动发出的“查询信号”。
    • K (Key) 代表了序列中每个 token 能够被“检索”的“内容标签”。
  • \(\text{Scores} \cdot V\) (与值相乘):这一步是根据计算出的注意力分数,对 token 的内容进行加权求和
    • V (Value) 代表了每个 token 实际包含的“内容信息”。

直观解释:

  1. V (Value) 为什么要微调? Value 是最终信息内容的直接来源。在新的下游任务中,我们需要模型学会提取和组合出与任务相关的新内容。比如,在法律文书微调中,我们希望 Value 能够携带法律术语的特定含义。因此,微调 V 可以直接改变模型能够生成和理解的内容,这是非常必要的。

  2. Q (Query) 为什么要微调? Query 决定了模型去“关注哪里”。新的任务意味着新的信息关注模式。比如,在摘要任务中,Query 需要学会更多地关注句子的开头和结尾;在代码生成任务中,Query 需要学会关注变量定义和函数调用。微调 Q 使得模型可以根据新任务的需求,调整其信息检索的模式。

  3. K (Key) 为什么可以不微调? Key 扮演的角色是“被检索的内容标签”。预训练好的大模型已经在海量数据上学会了一套非常稳定和泛化的“知识索引系统”。也就是说,模型已经知道如何为每个 token 打上一个合适的、能代表其内容的 Key。 在微调时,我们主要想做的是用新的方式去查询和组合这些已有的知识,而不是去改变这些知识本身的“索引标签”。保持 K 不变,相当于保留了原始模型强大的泛化知识库结构。我们只需要调整 Q(改变查询方式)和 V(改变提取的内容),就能有效地适应新任务。

9. LoRA 微调所需的参数量梳理⚓︎

我们可以来算一笔账,这样感受更直观。

假设我们要对一个 Transformer 模型中的某个线性层应用 LoRA,这个层的原始权重矩阵 \(W_0\) 的维度是 \(d \times k\)

  • 原始参数量 (Full FT)\(d \times k\)
  • LoRA 参数量
    • 我们注入一个矩阵 \(A\)(维度为 \(d \times r\))和一个矩阵 \(B\)(维度为 \(r \times k\))。
    • \(r\) 是 LoRA 的秩 (rank),是一个远小于 \(d\)\(k\) 的超参数。
    • LoRA 新增的参数量 = (参数量 of A) + (参数量 of B) = \((d \times r) + (r \times k) = r \cdot (d+k)\)

举例说明:

假设是一个典型的 Llama-7B 模型,其隐藏层维度 \(d_{\text{model}} = 4096\)。 我们考虑自注意力中的 \(W_q\) 矩阵,它的维度也是 \(4096 \times 4096\)

  • 全参数微调这一个矩阵所需的参数\(4096 \times 4096 = 16,777,216\) (约 16.7M 个参数)

  • 使用 LoRA 微调这个矩阵 (假设 rank \(r=8\)) 所需的参数

    • 矩阵 A 维度: \(4096 \times 8\)
    • 矩阵 B 维度: \(8 \times 4096\)
    • 总参数量: \((4096 \times 8) + (8 \times 4096) = 32768 + 32768 = 65,536\) (约 65.5K 个参数)

对比一下\(16,777,216\) vs \(65,536\)

LoRA 需要训练的参数量仅为全参数微调的 \(65536 / 16777216 \approx 0.39\%\)

通过用两个低秩矩阵来近似权重的更新,它将需要训练的参数量降低了几个数量级,从而实现了在消费级硬件上微调大模型的目标。