Stick-Breaking自注意力的优势在于它会自动关注最近的相关字符,而不需要引入额外的位置信息,比如position embedding和relative position bias。
给定一个包含 t 个时间步的输入向量序列 x1, x2, …, xt,每个输入被投影到一系列key向量 k1, k2, …, kt 和一系列value向量 v1, v2, …, vt。为了计算时间 t 的注意力,输入 x_t 被投影到一个查询向量 q_t = W_q x_t,其中 W_q 是查询投影矩阵。对于所有之前的步骤和当前步骤 i ≤ t,计算时间步 i 的键与时间步 t 的查询匹配的概率:
需要注意的是,这个查询概率使用了sigmoid激活函数,所以没有归一化。接下来通过stick-breaking过程来对查询概率进去归一化:
这样,注意力就会自动分配给离t时刻最近,且具有较大查询概率的时刻。使得自注意力机制在没有额外的位置信息的情况下,也能对于相对位置进行有效的建模。最终,自注意力模块的输出是由注意力权重对历史的value向量进行加和并且投影得到:
ModuleFormer中的模块控制
预训练中的负载均衡
为了避免SMoE反复使用相同的模块并浪费其他模块的额外容量,一般采用负载平衡损失函数来调节每个专家的使用频率。与之前的SMoE模型 不同,团队希望最大化输入字符和模块之间的互信息(MI):
为了简化起见,假设在批次X中的令牌分布是均匀的,因此p(x) = 1/X。在去除所有常数成分后,可以简化互信息损失(公式6)为p(m)的熵与p(m | x)的条件熵之间的差异。
在上述内容中,p(m) = sum_x(g(m|x)p(x)),其中p(x)是批处理中每个字符的概率,H(m)是模块分布的边际熵,H(m | x)是模块在给定输入字符x的条件下的熵,|X |是输入字符的数量。对于长度为T的batch大小为B的小批量,字符的数量是|X | = BT,字符的概率是p(x) = 1/|X |。
直观地说,互信息损失最大化了模块的概率分布的边际熵,并最小化了给定输入x的模块条件分布的商。它平衡了整个batch中每个专家的负载(最大化H(m)),同时也鼓励每个输入x将其路由概率集中在较少的模块上(最小化H(m | x))。
微调中的负载集中
尽管团队希望在预训练期间最大限度地利用每个专家的能力,但在微调期间希望将少量的模块专注于下游任务。这样可以移除未使用的模块并减少微调后模型的参数数量。为了将负载集中在较少的模块上,团队引入了一个新的负载集中损失函数来最小化模块的边际熵:
这样可以鼓励模型使用更少的模块来处理下游任务。在微调后,可以计算在训练或验证集上使用的模块频率f_m。f_m代表了模块m对于这个任务的重要性,可以通过移除f_m小于某个特定阈值的专家来轻松实现模型剪枝。
用新的模块来学习新的知识
对于模块化模型来说,插入新模块是一种直接且参数高效的方法,可以在不对整个模型进行微调的情况下学习新知识。当向每一层插入N_new个随机初始化的模块时,还需要扩展路由器(方程2中的A)中的模块嵌入层A,使其包含一个形状为(N_new,D_rtr)的新矩阵A’。因此,
新的路由函数可以写成:
由于在微调期间其他的模块参数被冻结,因此使用新模块进行持续学习可以在很大程度上避免灾难性遗忘问题。
然而,灾难性遗忘仍然可能影响路由函数。当新模块在一个新领域进行训练时,如果路由函数错误地将来自旧领域的输入路由到新专家,模型可能会遭受灾难性遗忘。
为了避免这种情况,团队对路由函数进行了正则化以避免灾难性遗忘,并提出了两种训练策略:
1)全面微调路由,公式9中A和B使用预训练参数进行初始化,而A’则是随机初始化的。这个策略是为了训练数据中同时包含新旧数据的情况设计。
2)只训练A’,这个策略是为了连续学习(lifelong learning)的情况而设计的,不使用以前训练过的数据。由于这种情况可能导致新的模块使用频率过高,从而带来灾难性遗忘。团队引入了正则项来限制A’的范数:
与被指出存在缺陷的传统连续学习正则化方法(如衰减或L2损失)不同,路由正则化不限制专家的能力,而只限制对新专家的使用趋势。
评估
基于ModuleFormer,研究者在Pile数据集上预训练了三个不同体积和计算量的ModuleFormer Language Model(MoLM)语言模型:
基础性能评估
团队使用Language Model Evaluation Harness来评估零样本、少样本和语言建模任务中的语言模型。
对于零样本和少样本任务,目标是在给定上下文的基础上从一组给定选项中选择最合适的完成部分。最终选择在给定上下文下具有最高可能性的完成部分。
对于语言建模,在Wikitext数据集上进行测试。目标是最小化下一个标记预测的困惑度。
对于代码生成,在HumanEval数据集上评估模型。HumanEval包含164个手写的Python编程问题。模型需要根据任务描述提示完成一个函数,以便能够通过所有提供的测试案例。
表2和表3显示了MoLM和基准语言模型在常识推理、闭卷问答和代码生成基准上的性能。
总体而言,MoLM-4B-K2模型的性能与大约13亿参数的稠密模型相当,MoLM-4B-K4和MoLM-8B-K2模型的性能与大约27亿参数的稠密模型相当。
由于其稀疏计算结构,MoLM处理每个字符的激活参数仅(等同于计算量)相当于同等性能稠密模型的约25%。因此,它减少了50%的延迟,同时具有较低的内存使用峰值,并在GPU内存完全占用时将吞吐量提高了2倍。
通过增加模块学习新语言
在本节中,我们测试了模型学习新语言的能力。主要研究两种实验设置:连续联合预训练(continual joint pre-training)和连续终身预训练(continual lifelong pre-training)。
它们的区别在于是否有英文文本的存在。对于这两种设置,我们通过在CC-100语料库上进行语言模型任务,不断地对ModuleFormer和GPT-Neo进行预训练。为了评估质量,我们采用了由XGLM和mGPT引入的0-shot方法的mLAMA基准测试。
持续联合预训练:在这部分中,我们对联合训练的模型进行持续预训练。具体而言,我们混合了英语和一种新语言来构建一个新的训练语料库,并保持嵌入层可训练。联合训练[Caruana, 1997]是一种众所周知的多任务学习方法,展示了对旧任务和新任务的熟练掌握。然而,它经常在不同任务之间产生负面干扰。
表4显示了持续训练模型获得的结果。表格揭示了以下发现:
1)我们观察到稀疏模型在Fully Tuned的情况下经历较少干扰,最终得到了最好的的性能;2)ModuleFormer通过增加模块(Insert New Expert)的能力,比之前的LoRA方法展示出了更好的少量参数(Parameter Efficient)调优的能力。这些结果表明,稀疏架构带来了更强的抗遗忘能力。
持续终身预训练:对于这个实验设定,模型仅在新语言文本上进行训练。Abraham和Robins [2005] 提出了稳定性-可塑性困境,这解释了模型面临的一个困难挑战:1)模型应具有较高的可塑性以学习新语言,2)模型必须具有出色的稳定性,考虑到在众多的训练迭代中不会接触到任何英语标记。
表5显示了LoRA基准和我们的方法在不同的路由正则化损失权重下的结果。我们的ModuleFormer借助路由正则化损失表现出了强大的平衡稳定性和可塑性的能力。
当我们通过增加损失权重来限制新专家的使用时,模型获得了稳定性,但可塑性下降。相比之下,使用LoRA对GPT-Neo进行微调在稳定性和可塑性方面都落后。
相比于1.33亿可训练参数的高秩LoRA,低秩LoRA(减少训练参数到2400万)和基本正则化都无法改善稳定性。
微调和压缩模型
在本节中,我们展示了ModuleFormer中的模块可以被快速移除,以创建一个在尺寸上更小但性能不受损的任务专用模型。
我们首先从GitHub-code-clean数据集中创建了一个包含150亿个字符的子集,该子集只包含Python代码。然后,我们使用负载集中损失函数(权重为0.001)对MoLM-4B-K2模型在该数据集上进行精调。
在精调之后,我们在从精调数据集中随机抽样的小型评估集上,计算每个专家的激活频率,然后通过将每层除以层内最大频率来进行归一化。之后,我们设定一个阈值τ,并修剪了所有归一化频率低于该阈值的模块。
我们在HumanEval数据集上测试了我们修剪后的MoLM-4B-K2模型。
图2a说明了pass@k指标与剩余参数比例之间的相关性。图2b展示了剩余参数比例与阈值之间的关联。我们观察到:
1)修剪不必要的模块对结果影响不大。我们可以修剪40%至50%的参数而不牺牲性能。相反,适当的修剪(33%)使精调后的模型在任务上表现更好。
2)模块分布存在显著差异,大约有一半的模块的激活频率低于最常使用的专家的0.3%。这个结果显示了负载集中损失函数的有效性。
总结
在这篇论文中,我们提出了一种新的模块化架构ModuleFormer,以及与之相关的模块操作方法。
ModuleFormer包括几个新组件:新的Stickbreaking注意力机制、新的互信息负载平衡损失函数用于预训练,以及新的负载集中损失函数用于微调。
基于ModuleFormer,我们预训练了一个新的语言模型MoLM。我们的实验结果显示了MoLM的相对于稠密LLM展现出了一些新的能力:
1)它在更低的延迟(50%)和更小的内存占用下实现了与密集LLM相同的性能;从而提高了吞吐量超过2倍;
2)在对整个模型进行微调以适应新领域后,它对灾难性遗忘的鲁棒性较强,并且也可以轻松扩展以学习新的语言和知识;
3)它可以在下游任务上进行微调,以使一部分模块专注于任务,并且未被任务使用的模块可以被修剪而不影响性能。
论文地址:https://arxiv.org/abs/2306.04640
— 完 —
量子位 QbitAI · 头条号签约
关键词: