SFT/RLHF/RAG/LoRA⚓︎
约 6975 个字 预计阅读时间 23 分钟 总阅读量 次
SFT⚓︎
1. 解决的问题⚓︎
预训练模型(Base Model)是在海量文本上训练的,它的目标是“预测下一个词”。这使得它非常擅长补全文本,但并不直接具备遵循指令或进行对话的能力。
例如,如果你给一个未经 SFT 的预训练模型输入 "请介绍一下爱因斯坦的相对论",它可能并不会直接回答问题,而是续写成 "请介绍一下爱因斯坦的相对论和牛顿力学的区别",因为它在训练数据中见过类似的文章标题。
SFT的核心问题:如何让模型理解“指令-响应” (Instruction-Response) 这种交互模式,使其成为一个有用的助手,而不是一个只会续写的语言模型。
2. 作用与效果⚓︎
- 注入能力:赋予模型遵循指令的能力。
- 格式对齐:让模型的输出格式符合问答、对话的范式。
- 知识引导:在特定领域进行 SFT,可以增强模型在该领域的知识和回答能力。
经过 SFT 后,模型就从 Base Model
变成了 SFT Model
。现在你问它问题,它会尝试直接给出回答。但这个回答的质量可能参差不齐,可能包含事实错误、有害内容、或啰嗦无用的信息。它只是学会了“怎样答”,但不一定“答得好”。
3. 流程⚓︎
SFT 的流程本质上就是标准的目标监督学习。
-
数据准备:这是最关键和成本最高的一步。需要构建一个高质量的
(Prompt, Response)
数据集。- Prompt:是用户输入的指令、问题或对话的上下文。
- Response:是人类专家或高质量模型撰写的、理想的回答。
- 例如:
{"instruction": "写一首关于秋天的诗", "output": "金风送爽,枫叶染红山岗..."}
-
微调训练:
- 使用准备好的数据集对预训练好的 Base Model 进行微调。
- 训练目标是标准的自回归语言模型损失(通常是交叉熵损失)。模型在看到
Prompt
和Response
的一部分后,要能准确地预测出Response
的下一个词。 - 一个重要的细节:在计算损失时,通常会Mask掉(忽略)
Prompt
部分的损失,只计算模型在生成Response
时的损失。因为我们希望模型在给定Prompt
的条件下学会生成Response
,而不是学会生成Prompt
本身。
4. 输入与输出⚓︎
- 输入:
- 一个预训练好的大语言模型 (Base LLM)。
- 一个高质量的指令-响应数据集 (
(Prompt, Response)
Pairs)。
- 输出:
- 一个经过微调的 SFT 模型。这个模型已经具备了初步的指令遵循和对话能力。
RLHF: 基于人类反馈的强化学习⚓︎
1. 解决的问题⚓︎
SFT 模型虽然能回答问题了,但 "好" 的回答是复杂的、多维度的,难以用一个简单的损失函数来定义。什么样是 "好"?
- 更有用 (Helpful):能真正解决用户问题。
- 更诚实 (Honest):不捏造事实,承认自己的局限。
- 更无害 (Harmless):拒绝生成暴力、歧视、危险的内容。
- 更有趣 / 不啰嗦:语言风格更受人类喜欢。
这些标准非常主观,无法像 SFT 那样直接提供“标准答案”。
RLHF的核心问题:如何将这些模糊、主观的人类偏好,量化成一个可优化的信号,从而引导模型生成人类更喜欢的回答。
2. 作用与效果⚓︎
- 行为对齐 (Alignment):使模型的行为和输出与人类的价值观和偏好对齐。这是 RLHF 最核心的作用。
- 提升安全性:显著降低模型生成有害、不当内容的概率。
- 提升有用性:让模型的回答更实际、更贴近用户需求。
经过 RLHF,模型(如 ChatGPT)才真正变得好用、可靠。
3. 流程 (以 PPO 为例)⚓︎
RLHF 的经典流程(来自 InstructGPT/ChatGPT 论文)分为三步:
第 1 步:训练一个奖励模型 (Reward Model, RM)
- 目的:创建一个能够模仿人类偏好的“裁判”。这个裁判模型可以为任何一个
(Prompt, Response)
对打分,分数高低代表人类对这个回答的喜好程度。 - 数据准备:
- 选取一批
Prompt
。 - 用 SFT 模型对每个
Prompt
生成多个不同的回答 (比如 A, B, C, D)。 - 由人工标注者对这些回答进行排序(例如:
D > B > A > C
)。 - 将排序结果转换成成对的比较数据,例如
(Prompt, Chosen_Response, Rejected_Response)
。比如(Prompt, D, B)
,(Prompt, D, A)
,(Prompt, B, C)
等。
- 选取一批
- 训练:
- RM 的结构通常是一个去掉最后分类层的 SFT 模型。它输入一个
(Prompt, Response)
,输出一个标量reward
值。 -
训练目标是让 RM 给
Chosen_Response
的打分高于Rejected_Response
。使用的损失函数通常是Pairwise Ranking Loss:\[ \text{Loss} = -\log(\sigma(r_{\theta}(x, y_c) - r_{\theta}(x, y_r))) \]
其中,\(x\) 是
Prompt
,\(y_c\) 是被选中的回答 (chosen),\(y_r\) 是被拒绝的回答 (rejected),\(r_{\theta}\) 是奖励模型。这个损失函数会驱动模型让 \(r_{\theta}(x, y_c)\) 的值尽可能大于 \(r_{\theta}(x, y_r)\)。
- RM 的结构通常是一个去掉最后分类层的 SFT 模型。它输入一个
第 2 步:使用强化学习 (PPO) 微调 SFT 模型
- 目的:利用 RM 作为奖励信号,通过强化学习算法来优化 SFT 模型本身。
- 核心组件:
- 策略 (Policy):就是我们正在优化的 LLM (从 SFT 模型初始化)。
- 动作空间 (Action Space):词表中的所有词 (Token)。
- 奖励函数 (Reward Function):主要由 RM 提供。
-
流程:
- 从一个数据集中随机取一个
Prompt
\(x\)。 - Policy (LLM) 生成一个回答 \(y\)。
- Reward Model (RM) 对
(x, y)
进行打分,得到奖励值 \(r\)。 -
关键点:KL 散度惩罚。为了防止模型为了迎合 RM 而生成乱七八糟的、不流畅的文本(这种现象称为 "Reward Hacking"),会引入一个惩罚项。该惩罚项计算当前 Policy 模型与原始 SFT 模型的输出分布之间的 KL 散度。这确保了模型在学习人类偏好的同时,不会偏离其从 SFT 阶段学到的语言能力太远。
\[ \text{Final Reward} = r - \beta \cdot D_{KL}(\pi_{\text{RLHF}} || \pi_{\text{SFT}}) \]
-
使用 PPO (Proximal Policy Optimization) 算法根据这个
Final Reward
更新 Policy (LLM) 的参数。PPO 是一种稳定高效的强化学习算法。
- 从一个数据集中随机取一个
关于 DPO (Direct Preference Optimization):
PPO 流程复杂,训练不稳定。DPO 是一种更新、更简单的方法,它巧妙地证明了可以跳过显式训练奖励模型和强化学习这两个步骤。
- 核心思想:直接利用在第1步中收集的偏好数据
(Prompt, Chosen, Rejected)
,通过一个特殊的损失函数来直接优化 LLM。 - 效果:这个损失函数等价于最大化奖励并施加 KL 散度约束,但它是在一个简单的分类任务框架下完成的。它直接让模型增加生成
Chosen
回答的概率,同时降低生成Rejected
回答的概率。 - 优点:更简单、更稳定、效果通常不亚于甚至优于 PPO。目前正在成为 RLHF 的主流方案。
4. 输入与输出⚓︎
- 输入:
- 一个 SFT 模型。
- 一个人类偏好数据集 (
(Prompt, Chosen_Response, Rejected_Response)
Pairs)。
- 输出:
- 一个经过对齐的、最终的、强大的对话模型(例如 GPT-4)。
总结对比⚓︎
方面 | 监督微调 (SFT) | RLHF (PPO/DPO) | RAG (检索增强生成) |
---|---|---|---|
核心目标 | 学会遵循指令 (Instruction Following) | 学会人类偏好 (Human Preference Alignment) | 为模型提供外部知识,解决知识局限和幻觉。 |
解决的问题 | 模型不会按指令格式回答问题 | 模型回答的质量不高、不安全 | 为大模型提供知识库,提高数据准确率 |
核心方法 | 监督学习 (Supervised Learning) | 强化学习 / 直接偏好优化 | - |
数据格式 | (指令, 理想回答) |
(指令, 胜出回答, 失败回答) |
- |
训练信号 | 预测标准答案的交叉熵损失 | 奖励模型的分数 或 偏好对的直接优化 | - |
输出模型特点 | 能用,但不一定好用 | 更好用、更安全、更可靠 | 信息更及时、准确、外部 |
RAG: 检索增强生成⚓︎
1. 解决的问题⚓︎
RAG 主要为了解决大语言模型(LLM)固有的两大缺陷:
-
知识局限性 (Knowledge Limitation):
- 知识截止 (Knowledge Cutoff):LLM 的知识被冻结在它训练数据截止的那个时间点。它不知道之后发生的新闻、事件或研究。例如,一个2022年训练的模型不知道2023年发布的新手机。
- 领域知识缺失 (Domain-specific Knowledge Gap):通用 LLM 没有学习过特定企业内部的私有文档、技术手册、项目数据库等。你问它关于你公司的项目,它一无所知。
-
幻觉 (Hallucination):
- 当 LLM 缺乏相关知识时,它不会简单地说“我不知道”,而是倾向于根据其学到的语言模式“编造”一个看似合理但实际上是错误的答案。这在需要事实准确性的场景中是致命的。
RAG 通过一种 “开卷考试” 的方式来解决这些问题。它不要求模型“背诵”所有知识(这是通过参数学习的,即微调),而是让模型在回答问题时,先去一个外部知识库里“查资料”,然后根据查到的资料来组织和生成答案。
具体而言,效果有:
- 提高事实准确性:答案基于检索到的真实文档,大大减少了模型幻觉。
- 增强知识时效性:只需更新外部知识库(这比重新训练模型便宜和快捷得多),模型就能获取最新信息。
- 支持私有领域知识:可以轻松接入公司内部知识库、数据库,打造企业专属的问答机器人。
- 提供可解释性和可追溯性:可以告知用户答案是基于哪些源文档生成的,增加了答案的可信度。
2. 一般流程⚓︎
RAG 的流程可以分为两个阶段:数据索引(离线) 和 检索生成(在线)。
- 阶段一:数据索引 (Indexing / Offline Process)
-
这个阶段是预处理阶段,为“开卷考试”准备“参考书”。
- 加载数据 (Load):从各种来源(如 PDF, HTML, Word, 数据库)加载原始文档。
- 切分文档 (Split):将长文档切分成更小的、语义完整的文本块 (Chunks)。为什么要切分?
- LLM 的上下文窗口长度有限。
- 更小的文本块在检索时更聚焦,能提高检索精度。
- 创建向量嵌入 (Embed):使用一个嵌入模型 (Embedding Model)(如
m3e-base
,bge-large-zh
, OpenAI'stext-embedding-3-small
等)将每个文本块转换成一个高维向量。这个向量可以被认为是文本块在语义空间中的“坐标”。 - 存储索引 (Store):将文本块及其对应的向量索引存储到向量数据库 (Vector Database) 中(如
FAISS
,Chroma
,Pinecone
,Milvus
)。向量数据库专门为高效的向量相似度搜索做了优化。
- 阶段二:检索生成 (Retrieval & Generation / Online Process)
-
这个阶段是用户与系统交互的实时过程。
- 用户提问 (Query):用户输入一个问题或指令。
- 嵌入问题 (Embed Query):使用与索引阶段相同的嵌入模型将用户的问题也转换成一个向量。
- 检索文档 (Retrieve):在向量数据库中,计算问题向量与所有文本块向量之间的相似度(常用余弦相似度),找出与问题语义最相关的 Top-K 个文本块。
- 增强提示 (Augment Prompt):将检索到的这 Top-K 个文本块作为上下文,与用户的原始问题一起,组合成一个新的、信息丰富的提示 (Prompt)。这个提示通常会遵循一个模板,例如: > "请根据以下已知信息来回答用户的问题。如果信息不足,请说你不知道。 > > 已知信息: > [这里是检索到的文本块1] > [这里是检索到的文本块2] > ... > > 用户问题: > [这里是用户的原始问题]"
- 生成答案 (Generate):将这个增强后的提示喂给 LLM,LLM 会基于提供的上下文信息生成最终的、更准确的答案。
3. 用到的核心方法⚓︎
- 信息检索 (Information Retrieval):
- 稠密检索 (Dense Retrieval):即基于 Embedding 向量的语义相似度搜索。这是目前 RAG 的主流方法。
- 稀疏检索 (Sparse Retrieval):传统的基于关键词匹配的检索,如 TF-IDF 或 BM25。
- 混合搜索 (Hybrid Search):结合稠密检索和稀疏检索,既能捕捉语义相似性,又能精确匹配关键词,效果通常更好。
- 提示工程 (Prompt Engineering):如何将检索到的上下文和原始问题有效地组织成一个提示,以引导 LLM 最好地利用这些信息,这是一项关键技术。
SFT/RLHF/RAG的辨析⚓︎
方面 | 监督微调 (SFT) | RLHF (PPO/DPO) | RAG (检索增强生成) |
---|---|---|---|
核心目标 | 学会遵循指令 (Instruction Following) | 学会人类偏好 (Human Preference Alignment) | 为模型提供外部知识,解决知识局限和幻觉。 |
解决的问题 | 模型不会按指令格式回答问题 | 模型回答的质量不高、不安全 | 知识过时、专业领域知识缺失、容易产生幻觉。 |
核心方法 | 监督学习 (Supervised Learning) | 强化学习 / 直接偏好优化 | 信息检索 + 提示工程 (Information Retrieval + Prompt Engineering) |
数据格式 | (指令, 理想回答) |
(指令, 胜出回答, 失败回答) |
外部知识库 (文档、网页等,被处理成文本块-向量 索引) |
训练信号 | 预测标准答案的交叉熵损失 | 奖励模型的分数 或 偏好对的直接优化 | 无 (通常是推理时技术,不涉及LLM训练)¹ |
输出模型特点 | 能用,但不一定好用 | 更好用、更安全、更可靠 | 答案可溯源、知识可更新、有效减少幻觉 |
RAG 和 SFT 的区别⚓︎
这是非常关键的区别,它们解决的问题和所处的环节完全不同。
特性 | RAG (检索增强生成) | SFT (监督微调) |
---|---|---|
核心目标 | 为模型提供外部知识,解决知识局限和幻觉。 | 教会模型一种行为模式或风格,如遵循特定指令、以特定角色对话。 |
发生环节 | 推理时 (Inference Time),实时发生。 | 训练时 (Training Time),一次性的学习过程。 |
知识来源 | 外部、显式的数据库,易于更新。 | 内部、隐式地存储在模型权重中。 |
更新知识 | 简单、快速、低成本。只需更新向量数据库中的文档。 | 复杂、慢、高成本。需要重新准备数据集并重新微调模型。 |
解决幻觉 | 非常有效,强制模型基于事实回答,且答案可溯源。 | 有限效果。可以教模型不要乱说,但无法根除它在知识盲区内的幻觉。 |
应用场景 | 知识库问答、客服机器人、文档分析、新闻摘要。 | 角色扮演、特定格式输出、指令遵循能力的通用提升。 |
好比 | 开卷考试(提供参考书) | 考前辅导(训练答题技巧) |
- RAG 和 SFT 不是互斥的,而是互补的。
- 一个最佳实践是:先用 SFT 微调一个模型,让它更好地理解你的指令格式、输出风格(比如“请用中文、专业的语气回答”)。然后,在这个 SFT 后的模型基础上,再搭建 RAG 应用,为它提供实时的、私有的知识。
- 如果你想让模型学习一种新的“能力”或“行为”,选择 SFT。
- 如果你想让模型掌握一片新的“知识”,选择 RAG。
在当今的大模型落地应用中,RAG 因其低成本、高效率和高可靠性,已经成为了绝对的主流方案。
LoRA: 低秩自适应⚓︎
当然,LoRA 是大模型工程领域的一项明星技术,很好地解决了大模型训练和部署中的实际痛点。
1. LoRA 是什么技术?⚓︎
LoRA,全称是 Low-Rank Adaptation,即低秩自适应。它是一种参数高效微调 (Parameter-Efficient Fine-Tuning, PEFT) 技术。
核心思想:在对大模型进行微调时,我们不改变原始模型的任何权重(保持其冻结),而是在模型的关键部分(通常是 Transformer 中的 Attention
层的 QKV
线性投影层)旁边,注入两个小型的、可训练的“旁路”矩阵。我们只训练这两个小矩阵,从而用极少的参数来“适配”大模型以完成新任务。
2. LoRA 用在大模型工程的哪个阶段?⚓︎
LoRA 并不是一个独立的“阶段”,而是一种可以应用于任何需要微调模型的阶段的技术。
最常见的应用场景是:
- 监督微调 (SFT) 阶段:当我们需要让一个基础模型 (Base Model) 学会遵循指令时,可以使用 LoRA 来进行 SFT。这样,我们不需要为每个 SFT 任务都保存一个完整的模型副本,只需要保存一个几十兆大小的 LoRA 适配器即可。
- 个人/特定领域微调:用户想让模型模仿某种写作风格,或者让模型掌握某个特定领域的知识(如医疗、法律),可以使用 LoRA 在自己的数据上进行微调,成本极低。
- RLHF 阶段:理论上也可以用 LoRA 来优化 Actor 模型,但实现会更复杂一些。
总而言之,只要你想对一个大模型进行微调,但又受限于计算资源(显存)和存储成本,LoRA 就是首选方案。
3. 具体流程如何?⚓︎
假设我们要在 Transformer 中的一个线性层(其权重为 \(W_0\))上应用 LoRA。\(W_0\) 是一个 \(d \times k\) 的大矩阵。
-
冻结原始权重:在训练开始前,将原始权重矩阵 \(W_0\) 设置为不可训练 (frozen)。
-
注入可训练的低秩矩阵:
- 创建两个新的、小的矩阵:矩阵
A
(维度为 \(d \times r\))和矩阵B
(维度为 \(r \times k\))。 - 这里的 \(r\) 就是“秩 (rank)”,是一个远小于 \(d\) 和 \(k\) 的超参数(比如 8, 16, 64)。
- 矩阵
A
通常用随机高斯分布初始化,矩阵B
初始化为全零。B 初始化为零至关重要,因为它保证了在训练刚开始时,\(W_0 + \alpha(BA) = W_0\),LoRA 模块对原始模型没有影响,从而保证了训练的稳定性。
- 创建两个新的、小的矩阵:矩阵
-
前向传播:当输入为 \(x\) 时,该层的计算方式变为: $$ h = W_0 x + \Delta W x = W_0 x + (\alpha \cdot B A) x $$
- \(W_0x\) 是原始模型的计算路径。
- \(BAx\) 是新增的 LoRA 旁路计算路径。
- \(\alpha\) 是一个缩放系数(类似学习率),用于调整旁路计算结果的影响力。
-
反向传播:在训练过程中,只有矩阵
A
和B
的参数会被梯度更新,而巨大的 \(W_0\) 始终保持不变。 -
推理与部署:
- 训练完成后,我们得到了训练好的
A
和B
。 - 为了不增加推理时的计算量,我们可以将 LoRA 模块“合并”回原始权重中: $$ W_{\text{trained}} = W_0 + \alpha \cdot BA $$
- 这样,在部署时,我们使用的就是一个和原始模型结构完全相同的模型,其权重为 \(W_{\text{trained}}\),没有任何额外的计算延迟。你只需要保存
A
和B
这两个小矩阵,在需要时执行一次合并即可。
- 训练完成后,我们得到了训练好的
4. LoRA vs 全参数微调 (Full FT) 优劣对比⚓︎
特性 | LoRA(低秩自适应) | 全参数微调 (Full Fine-Tuning) |
---|---|---|
显存占用 | 低。仅需存储A、B矩阵的梯度和优化器状态。 | 极高。需要存储整个模型所有参数的梯度和优化器状态。 |
训练速度 | 快。需要计算的梯度量大大减少。 | 慢。反向传播需要计算所有参数的梯度。 |
存储成本 | 极低。每个任务只需保存几十MB的LoRA权重(A,B)。 | 极高。每个任务都需要保存一个完整的模型副本(几十GB)。 |
任务切换 | 灵活、高效。可以动态加载/卸载不同的LoRA适配器,实现一个模型服务多个任务。 | 笨重、昂贵。切换任务需要加载一个全新的、巨大的模型。 |
模型性能 | 通常与Full FT相当,有时甚至更好。因为它像一个正则化器,防止了对预训练知识的灾难性遗忘。 | 理论上有更高的性能上限,但也更容易在小数据集上过拟合或灾难性遗忘。 |
5. 它的理论依据是什么?⚓︎
LoRA 的成功背后有一个非常重要的理论假设,来自微软的研究:
预训练的大语言模型具有很低的“内在维度” (intrinsic dimension),即低秩假设。
这个假设意味着,虽然大模型有数十亿个参数,但当我们将它适配到一个新的下游任务时,真正需要被改变的那些参数(权重的更新量 \(\Delta W\)),本质上是低秩的。也就是说,这个巨大的权重更新矩阵 \(\Delta W\) 可以被分解并用两个小得多的矩阵 \(B\) 和 \(A\) 来高效地近似,即 \(\Delta W \approx BA\)。
LoRA 正是基于这个洞察。它不去直接学习那个巨大的、难以训练的 \(\Delta W\),而是去学习它的低秩分解形式 \(B\) 和 \(A\)。这不仅极大地减少了需要训练的参数量,还隐式地加入了一种正则化,限制了模型适配的自由度,使其更专注于学习任务相关的核心知识,而不过度改变其强大的预训练能力,从而有效防止了“灾难性遗忘”。
好的,我们再次深入 LoRA 的细节。这几个问题问得非常好,直接关系到 LoRA 的实现原理和效率。
6. LoRA 一般作用在神经网络的什么结构上?⚓︎
理论上,LoRA 可以作用于任何包含权重矩阵的神经网络层。但在 Transformer 架构中,研究和实践都表明,将 LoRA 应用于注意力层 (Attention Layers) 的效果是最好、最高效的。
具体来说,LoRA 主要作用于 Transformer Encoder 或 Decoder 块中的以下几个部分:
- 查询 (Query,
Q
) 的线性投影层:\(W_q\) - 键 (Key,
K
) 的线性投影层:\(W_k\) - 值 (Value,
V
) 的线性投影层:\(W_v\) - 注意力输出 (Output) 的线性投影层: \(W_o\)
- 前馈网络 (Feed-Forward Network, FFN) 中的线性层(通常有两个,
up_proj
和down_proj
)。
在实践中,最常见和最有效的做法是只对 Attention 层中的 Q
和 V
投影矩阵应用 LoRA。也有研究表明,将 LoRA 应用于所有的 Q, K, V, O
矩阵,甚至包括 FFN,可以带来微小的性能提升,但这会增加可训练参数量。因此,仅作用于 Q
和 V
是一个性价比极高的选择。
7. 微调哪些参数?⚓︎
记住 LoRA 的核心:只微调新增的、低秩的适配器参数,而冻结所有原始模型的参数。
具体到实现上,当我们对一个原始权重矩阵 \(W_0\) 应用 LoRA 时:
- 冻结的参数:巨大的原始权重矩阵 \(W_0\)。
- 微调的参数:我们新注入的两个小矩阵,矩阵
A
和矩阵B
。
在训练过程中,模型的所有梯度计算和参数更新都只发生在这些 A
和 B
矩阵上。这正是 LoRA 能够做到参数高效微调的根本原因。
8. 对于 Attention 层,为什么对 Q 和 V 微调,不对 K 微调?⚓︎
这背后没有一个绝对的数学定论,更多是基于经验观察和一些直观的解释。一种被广泛接受的理解是,这与 Q
, K
, V
在自注意力机制中扮演的不同角色有关。
我们回顾一下自注意力的计算:\(\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V\)
- \(Q \cdot K^T\) (查询与键的点积):这一步计算的是注意力分数 (Attention Scores)。它决定了在生成当前
token
的表示时,应该给予其他token
多少“关注度”。Q
(Query) 代表了当前token
为了获取信息而主动发出的“查询信号”。K
(Key) 代表了序列中每个token
能够被“检索”的“内容标签”。
- \(\text{Scores} \cdot V\) (与值相乘):这一步是根据计算出的注意力分数,对
token
的内容进行加权求和。V
(Value) 代表了每个token
实际包含的“内容信息”。
直观解释:
-
V
(Value) 为什么要微调?Value
是最终信息内容的直接来源。在新的下游任务中,我们需要模型学会提取和组合出与任务相关的新内容。比如,在法律文书微调中,我们希望Value
能够携带法律术语的特定含义。因此,微调V
可以直接改变模型能够生成和理解的内容,这是非常必要的。 -
Q
(Query) 为什么要微调?Query
决定了模型去“关注哪里”。新的任务意味着新的信息关注模式。比如,在摘要任务中,Query
需要学会更多地关注句子的开头和结尾;在代码生成任务中,Query
需要学会关注变量定义和函数调用。微调Q
使得模型可以根据新任务的需求,调整其信息检索的模式。 -
K
(Key) 为什么可以不微调?Key
扮演的角色是“被检索的内容标签”。预训练好的大模型已经在海量数据上学会了一套非常稳定和泛化的“知识索引系统”。也就是说,模型已经知道如何为每个token
打上一个合适的、能代表其内容的Key
。 在微调时,我们主要想做的是用新的方式去查询和组合这些已有的知识,而不是去改变这些知识本身的“索引标签”。保持K
不变,相当于保留了原始模型强大的泛化知识库结构。我们只需要调整Q
(改变查询方式)和V
(改变提取的内容),就能有效地适应新任务。
9. LoRA 微调所需的参数量梳理⚓︎
我们可以来算一笔账,这样感受更直观。
假设我们要对一个 Transformer 模型中的某个线性层应用 LoRA,这个层的原始权重矩阵 \(W_0\) 的维度是 \(d \times k\)。
- 原始参数量 (Full FT):\(d \times k\)
- LoRA 参数量:
- 我们注入一个矩阵 \(A\)(维度为 \(d \times r\))和一个矩阵 \(B\)(维度为 \(r \times k\))。
- \(r\) 是 LoRA 的秩 (rank),是一个远小于 \(d\) 和 \(k\) 的超参数。
- LoRA 新增的参数量 = (参数量 of A) + (参数量 of B) = \((d \times r) + (r \times k) = r \cdot (d+k)\)
举例说明:
假设是一个典型的 Llama-7B
模型,其隐藏层维度 \(d_{\text{model}} = 4096\)。
我们考虑自注意力中的 \(W_q\) 矩阵,它的维度也是 \(4096 \times 4096\)。
-
全参数微调这一个矩阵所需的参数: \(4096 \times 4096 = 16,777,216\) (约 16.7M 个参数)
-
使用 LoRA 微调这个矩阵 (假设 rank \(r=8\)) 所需的参数:
- 矩阵 A 维度: \(4096 \times 8\)
- 矩阵 B 维度: \(8 \times 4096\)
- 总参数量: \((4096 \times 8) + (8 \times 4096) = 32768 + 32768 = 65,536\) (约 65.5K 个参数)
对比一下: \(16,777,216\) vs \(65,536\)
LoRA 需要训练的参数量仅为全参数微调的 \(65536 / 16777216 \approx 0.39\%\)!
通过用两个低秩矩阵来近似权重的更新,它将需要训练的参数量降低了几个数量级,从而实现了在消费级硬件上微调大模型的目标。