如何将 Llama 变成 Mamba?研究人员推出新工作并提高推理速度

aixo 2024-09-11 02:50:28
大模型 2024-09-11 02:50:28

先来看一张其乐融融的图片(一眼AI):

llama怎么读_llama是什么意思_Llama

右边的小羊驼代表Llama,而左边的蛇(Mamba)也是我们的老熟人了。

至于到底能不能其乐融融,咱就不管了,之所以有此场景,是因为Mamba方面又搞出了有意思的研究:

——如何把Llama变成Mamba?

llama怎么读_llama是什么意思_Llama

近日,来自康奈尔、普林斯顿等机构的研究人员推出了上面这篇工作,将Llama这样的大型提炼成了Mamba模型,

并且成功在Mamba架构上应用了带有硬件感知的推测解码算法,提高了整个模型的推理速度。

为什么要把Llama变成Mamba?

因为从头开始训练一个大模型太贵了。

Mamba也火了这么长时间了,相关的研究每天都有,但自己训练大尺寸Mamba模型的却很少。

目前比较有名的是AI21的Jamba(进化到了1.5版本,最大398B,MoE),以及的 模型(8B)。

llama是什么意思_llama怎么读_Llama

不过世界上有那么多成功的大模型,而知识就包含在这些模型参数里。

如果能够锁住知识,同时把微调成Mamba,不就解决问题了?

Llama_llama怎么读_llama是什么意思

在本文中,研究人员结合渐进式蒸馏、监督微调(SFT)和定向偏好优化(DPO)等方法达成了这一目标。

光是变大还不够,

在性能匹配的前提下,速度也要够快才行。

Mamba凭借固定的推理开销,在长序列中的优势明显,但这边也是有推理加速方案的,比如推测解码。

而由于Mamba本身的结构特性,不能直接应用这种方案,所以作者设计了全新的算法,并结合硬件的性质来实现基于Mamba的推测解码。

llama是什么意思_llama怎么读_Llama

最终,研究人员将-7B、Llama-3 8B提炼为了线性RNN模型(混合Mamba和),且性能与蒸馏之前的标准模型相当。

llama怎么读_llama是什么意思_Llama

整个训练过程只使用了20B的token,效果却能够与使用1.2T个token从头开始训练的Mamba 7B模型,以及使用3.5T个token训练的 模型相媲美。

llama怎么读_Llama_llama是什么意思

从 到 Mamba

在介绍Mamba 2的时候我们讲过,线性RNN(或SSM)跟线性注意力是一回事。

所以可以根据x,B,C与V,K,Q的对应关系直接复用注意力中的投影矩阵。

llama怎么读_llama是什么意思_Llama

额外的参数包括SSM需要的A矩阵和Δt(由x投影得到),这就完成了基本的参数初始化。

之后就是SSM的运算过程,再通过投影和累加得到输出。

模型架构和训练

下图给出了模型的架构,因为的知识存在于MLP层,所以冻结这部分参数。

除了用线性RNN层(Mamba)替换掉注意力头,还有一些组件需要处理,比如跨头共享键和值的分组查询注意力(GQA)。

Llama_llama怎么读_llama是什么意思

知识蒸馏( ,KD)是一种常用的压缩技术,用来训练模仿较大模型()行为的较小网络()。

llama怎么读_llama是什么意思_Llama

根据经验,这里采用逐步替换层的策略,先是每2层进行蒸馏,然后每4层继续蒸馏......

监督微调

有两种常见的蒸馏方法。一种方法是使用word-level的KL散度,此时训练模型去匹配模型输出的完整概率分布。

第二种方法是序列级知识蒸馏(SeqKD),直接使用模型的输出作为 truth来训练模型(也称为伪标签)。

llama怎么读_llama是什么意思_Llama

这里θ是模型的可训练参数,α和β分别控制序列和词的loss项的权重。

偏好优化

LLM指令调优的第二阶段是使其符合用户偏好。这个阶段,使用一组期望的偏好对来改进模型的输出。

优化的目标是使奖励模型最大化,同时保持产生的输出接近参考模型。

通常,参考模型使用上一步监督微调后的模型。这里因为是蒸馏,直接可以用模型:

llama是什么意思_llama怎么读_Llama

偏好模型的奖励函数定义取决于所使用的方法,本文采用直接偏好优化(DPO),通过直接梯度更新有效地到达优化目标。

Llama_llama怎么读_llama是什么意思

DPO表明,对于给定的提示x ,如果我们能够获得和两种输出,就可以将这个优化问题重新表述为:

llama怎么读_llama是什么意思_Llama

这种优化可以在序列级别上执行,让模型和模型一起对和输出进行评分,然后反向传播给模型。

推测解码

经过上面的一套小连招,模型转换就搞定了,下面开始想办法应用那边的推测解码。

推测解码( )可以简单理解为下面这张图。

llama是什么意思_llama怎么读_Llama

做推理的时候,除了要处理不断变长的KV cache之外,计算效率也是个问题。

因为显卡的设计是计算高于访存的,具体到计算单元就是做矩阵乘法。

而推理的时候每次只能进入一个词向量,显卡的很多计算就被浪费了。

Llama_llama是什么意思_llama怎么读

推测解码给出的解决方案是,使用一个小模型做生成,然后拿显卡多余的计算做验证。

小模型跑得快,可以一口气生成很多输出向量,但是可能效果差一点。这时候用大模型作为验证,一次计算之前生成的很多个向量。

所以小模型串行跑得快,大模型可以并行计算跑得也快,遇到验证不通过的就直接回滚,整体上提高了推理的速度。

llama怎么读_Llama_llama是什么意思

可以方便地回滚,因为KV cache跟时间是一一对应的,但Mamba这边只有一个当前的中间状态ht,你总不能把所有中间状态都存起来吧。

为了解决这个问题,研究人员设计了下面的算法:

llama怎么读_Llama_llama是什么意思

简单来说就是每次使用小模型(draft model)生成一组输出,然后大模型( model)验证这一组输出,根据验证匹配的位置来更新需要保存的中间状态。