From SSM to Mamba

Transformer的缺点

Transformer是当前大语言模型的事实上的标准架构。它在训练时,自注意力机制可以非常方便的并行进行,但是它存在一个主要的缺点——在推理时,每次生成下一个token,需要重新计算整个序列的自注意力,时间复杂度为O(L^2)

接下来,先看看以前RNN是如何解决推理慢的问题。RNN通过循环机制,将信息从上一步传递到下一步:

descript

在生成输出时,RNN 只需要考虑之前的隐藏状态和当前的输入。换句话说,RNN 可以快速进行推理,因为它__随序列长度线性缩放。理论上,它__甚至可以有无限的上下文长度

descript

但是,RNN也有缺点——因为RNN只考虑先前的一个状态,随着时间的推移,RNN 往往__会忘记先前信息__。同时,RNN 的这种__顺序性__产生了另一个问题,训练不能并行进行,因为它需要依次执行每个步骤。

那么,能否以某种方式__找到一种像 Transformer 这样并行训练的架构,同时仍然执行随序列长度线性扩展的推理?__

状态空间模型,SSM

状态空间是一种数学模型,传统上是用于控制理论中,通过状态变量对动态系统进行建模。最近被引入到深入学习之中,可以像Transformer、RNN一样建模和处理离散序列信息,同时也可以处理连续信号。

状态空间是什么?

状态空间包含__完整描述系统的最小数量的变量__。它是一种通过定义系统的可能状态,以数学方式表示问题的方法。下面举两个例子:

迷宫网格世界

在迷宫中,“状态空间”是包含所有可能位置(状态)的地图空间。每个点都代表着迷宫中一个唯一的位置,带有一些特定的细节信息,比如离出口有多远。

descript

“状态空间表示”表示您所在的位置(当前状态)、下一步可以去哪里(未来可能的状态)以及哪些变化会将您带到下一个状态(向右或向左)。描述状态的变量(在我们的示例中为 X 和 Y 坐标以及到出口的距离)可以表示为“状态向量”。

descript

语言模型中的嵌入或向量也经常用于描述输入序列的“状态”。例如,当前位置的向量(状态向量)可能看起来有点像这样:

descript

对于神经网络而言,系统的“状态”通常是其隐状态(hidden state)

弹簧-质量-阻尼系统

https://zhuanlan.zhihu.com/p/680846351

descript

状态空间模型是什么?

状态空间模型是__用于描述这些状态表示并根据某些输入预测其下一个状态可能是什么__的模型。

在t时刻,状态空间模型包含:

  • 输入x(t),例如在迷宫世界中的左移,下移
  • 隐状态表示h(t),例如离出口的距离和x/y坐标轴
  • 预测的输出y(t),例如向左移,更快到达出口

原始的状态空间模型使用连续信号而不是离散序列作为输入和输出:

descript

状态空间模型假设动态系统(例如在 3D 空间中移动的物体)可以通过两个方程从其在时间 t 时的状态进行预测。状态方程和输出方程(或者叫观测方程,Observation Equation)

descript

通过求解这两个方程,我们假设可以根据观察到的数据(输入和先前状态),发现系统状态的统计原理。

下面解释直观解释两个核心方程:

状态方程描述了__通过矩阵B输入怎么影响状态__,同时__状态怎么通过矩阵 A发生改变__。

descript

输出方程描述了__状态如何转换为输出__(通过矩阵 C)以及__输入如何影响输出__(通过矩阵 D)。

descript

注意:矩阵 A、B、C 和 D 通常也称为参数,因为它们是可学习的

可视化这两个方程,连续且时间不变的状态空间模型:

descript
descript

一步一步的来理解,这些矩阵如何影响学习过程:

descript

更新后的状态(类似于神经网络的隐藏状态)是一个潜在空间,包含了环境的核心 "知识"。我们将状态h与矩阵 A 相乘,矩阵 A 描述了所有内部状态之间的联系,因为它们代表了系统的潜在动态。

descript

矩阵A应用于状态表示创建之前,并在状态表示更新之后更新。

descript

然后,矩阵C描述状态如何转化为输出,最后,我们可以利用矩阵 D 提供从输入到输出的直接信号,这通常也称为__跳跃连接__(skip-connection)。

descript
descript
descript

通过观察数据和那两个方程,我们便可预测系统的状态。

离散化

通过零阶保持(Zero-order hold)技术,可以从离散信号中创建连续信号:

descript

保持原来值的时间,即步长由一个新的__可学习参数__表示,称为步长descript,代表输入的分辨率(resolution )。现在我们有了连续的输入信号,我们可以生成连续的输出,并且仅__根据输入的时间步长对值进行采样__。

descript

从数学上讲,我们可以按如下方式应用零阶保持:

descript

公式推导过程:

https://zhuanlan.zhihu.com/p/680534665

https://huggingface.co/blog/lbourdois/get-on-the-ssm-train

descript

有了离散化的SSM之后,允许我们以特定的时间步长而不是连续信号来表述问题。正如我们之前在 RNN 中看到的那样,循环方法在这里非常有用。

descript
descript

这种技术同时具备 RNN 的优点和缺点,即推理速度快,训练速度慢

卷积表示

可以用于 SSM 的另一种形式是卷积表示。在经典的图像识别任务中,我们应用过滤器(内核,kernels)来导出聚合特征:

descript

由于我们处理的是文本而不是图像,因此我们需要一维视角:

descript

我们用来表示这个“过滤器”的内核源自 SSM 公式:

descript

推导过程:

https://huggingface.co/blog/lbourdois/get-on-the-ssm-train

descript

让我们探讨一下这个__内核在实践中是如何工__作的。与卷积一样,我们可以使用 SSM 内核来检查每组标记并计算输出:

descript
descript
descript

将 SSM 表示为卷积的一个主要好处是它可以__像卷积神经网络 (CNN) 一样进行并行训练__。然而,由于内核大小固定,它们的推理不如 RNN 那样快速和无限制。

总结一下三种表示,连续表示、循环表示和卷积表示都有不同的优点和缺点:

descript

在训练期间,我们使用可以并行化的卷积表示,在推理期间,我们使用高效的循环表示。这些表示都有一个重要的属性,即__线性时不变性__ (LTI,Linear Time Invariance)。 LTI 规定 SSM 参数 A、B 和 C 对于所有时间步都是固定的。这意味着对于 SSM 生成的每个令牌,矩阵 A、B 和 C 都是相同的,与内容无关的静态表示。在探讨 Mamba 如何解决这个问题之前,让我们先探讨一下这个难题的最后一块:矩阵 A

矩阵A进化

可以说,SSM 公式最重要的方面之一是矩阵 A。正如我们之前在循环表示中看到的,它捕获有关先前状态的信息来构建新状态。

descript

本质上,矩阵 A 产生隐藏状态:

descript

因此,创建好矩阵 A__ 可能是只记住之前的几个标记__和__捕获我们迄今为止看到的每个标记之间的区别__。特别是在循环表示的上下文中,因为它只回顾以前的状态。那么我们怎样才能以保留大内存(上下文大小)的方式创建矩阵A呢?

我们使用HiPPO(__Hi__gh-order __P__olynomial __Pr__ojection __O__perators),HiPPO 尝试将迄今为止看到的所有输入信号压缩为系数向量。

descript

它使用矩阵 A 来建立一种状态表示法,能很好地捕捉最近的标记并衰减较早的标记。其计算公式如下:

descript

假设我们有一个方阵 A,这给我们:

descript

事实证明,使用 HiPPO 构建矩阵 A 比将其初始化为随机矩阵要好得多。因此,与较旧的信号(初始令牌)相比,它可以更准确地重建较新的信号(最近的令牌)。HiPPO 矩阵背后的想法是,它产生一个隐藏状态来记住其历史。从数学上讲,它是通过跟踪勒让德多项式(Legendre polynomial )的系数来实现的,这使得它能够逼近所有以前的历史。然后,HiPPO 被应用于我们之前看到的循环表示和卷积表示,以处理远程依赖性。结果是序列的__结构化状态空间__ (S4,Structured State Space for Sequences),一类可以有效处理长序列的 SSM。

S4由三个部分组成:

  • SSM
  • HiPPO矩阵构建矩阵A
  • 用于创建循环和卷积表示的离散化
descript

代码详细解剖S4:https://srush.github.io/annotated-s4/

选择性状态空间模型,Mamba

Mamba在S4的基础上增加了两点贡献:

  • 选择性扫描算法(selective scan algorithm),允许模型选择和过滤相关信息
descript
descript

Mamba 通过__合并输入的序列长度和批大小来根据输入制作矩阵 B 和 C__,甚至步长 Δ。这意味着对于每个输入标记,我们现在有不同的 B 和 C 矩阵,可以解决内容感知问题。矩阵 A 保持不变,因为我们希望状态本身保持静态,但它的影响方式(通过 B 和 C)是动态的。

descript

由于这些矩阵现在是动态的,因此无法使用卷积表示来计算它们,因为卷积需要固定内核。我们只能使用循环表示,而失去了卷积提供的并行性。

并行扫描算法:通过关联属性,假定我们进行运算的顺序并不重要。因此,我们可以分段计算序列,然后迭代合并:

descript
  • 一种硬件感知算法(hardware-aware algorithm),可通过并行扫描、内核融合和重新计算,高效存储(中间)结果。
descript
descript
descript

通过核融合,把h状态的计算都放在SRAM上,ABC矩阵都放在DRAM上,以减少IO次数。

以下内容被融合到一个内核中:

  • __Δ__和离散化
  • 选择性扫描算法
  • 乘矩阵C

Mamba

descript

参考资料

  1. Mamba: Linear-Time Sequence Modeling with Selective State Spaces
  2. https://maartengrootendorst.substack.com/p/a-visual-guide-to-mamba-and-state
  3. https://huggingface.co/blog/lbourdois/get-on-the-ssm-train
  4. https://zhuanlan.zhihu.com/p/680846351
  5. https://zhuanlan.zhihu.com/p/680534665