结构化状态空间模型可视化解析
原文towardsdatascience.com/structured-state-space-models-visually-explained-86cfe2757386https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/a3551757a56352eabdce39ea6b99502a.png图片由 Sascha Kirch 提供。这是我的新多部分系列的第二部分 迈向图像、视频和时间序列的 Mamba 状态空间模型。状态空间模型几十年来在许多工程学科中广为人知现在正在深度学习中崭露头角。在我们迈向 Mamba 选择性状态空间模型及其在研究中的最新成就的过程中理解状态空间模型至关重要。而且正如工程中经常发生的那样正是细节使得理论概念在实践中的应用成为可能。除了状态空间模型之外我们还必须讨论如何将它们应用于序列数据如何处理长距离依赖以及如何通过利用某些矩阵结构来有效地训练它们。结构化状态空间模型为 Mamba 构建了理论基础。然而它们与系统理论和高级代数之间的联系可能是采用这一新框架的障碍之一。因此让我们将其分解确保我们理解关键概念并通过可视化来揭示这一新旧理论的一些启示。即使你最终没有使用状态空间模型了解一些技巧比如为什么我们需要加快矩阵乘法以及我们如何利用矩阵的某些结构来实现这一点也肯定会提升你作为工程师或开发者的技能。图像、视频和时间序列的 Mamba 状态空间模型概述第一部分结束后的内容状态空间模型概述 2.1 状态空间模型的连续时间表示 2.2 状态空间模型的离散化 2.3 SSM 的递归和卷积表示启用长距离关系 3.1 状态大小很重要 3.2 HiPPO 框架对状态矩阵施加结构以提高效率 4.1 什么是结构化矩阵以及为什么它们很重要 4.2 整合一切 – S4结构化状态空间序列模型总结 5.1 结构化状态空间模型的局限性 5.2 继续阅读第三部分进一步阅读与资源1. 第一部分结束后的内容迈向图像、视频和时间序列的 Mamba 状态空间模型在第一部分我们回顾了循环神经网络RNNs和 Transformer 的优缺点以说明为什么我们需要一个新的模型架构。我们说过 RNNs 在推理时间快但在训练时慢而 Transformer 在训练时快但在推理时慢。我们想要的是一个在训练和推理时都快的模型同时与 Transformer 的性能具有竞争力。2. 状态空间模型概述Mamba 通过学习其各种矩阵在深度学习中构建了使用状态空间模型的想法。因此在探讨“结构化”部分之前让我们简要介绍一下状态空间模型。状态空间模型可以用连续时间表示来处理连续信号或者可以离散化以处理离散数据序列。这是我们主要感兴趣离散化状态空间模型因为它像循环神经网络RNNs和 Transformer 一样一个离散 SSM 处理数据序列例如文本标记或模拟时间信号的样本。2.1 状态空间模型的连续时间表示连续时间状态空间模型描述了输入信号通过具有状态的系统传播与产生的输出信号之间的关系。https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/637b0df4c587809caed2c024f6409c51.png图 1状态空间模型的高级功能。图片由 Sascha Kirch 提供。这是一个从系统理论中借用的想法。输出取决于输入和系统的当前状态而当前状态取决于先前状态和输入。这种关系可以通过两个简单的方程有效地表达其中A、B、C和D是矩阵稍后我们将看到这些矩阵将被学习x(t)是输入信号y(t)是输出信号而h(t)和h’(t)分别是当前状态和更新后的状态。https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/69ee5e889495623f4c4a775ee1db27fe.png方程式 1连续时间 SSM 的状态和输出方程。矩阵D将输入x(t)转换并映射到输出y(t)这通常包含在经典 SSM 的输出方程中。然而当将 SSM 应用于深度学习时我们移除了这种转换并将其建模为一个简单的跳跃连接从而简化了 SSM。我们可以将这些方程表示在一个类似于这样的框图中https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/d7be8818c5ab9d1cabb2f4c49d8c4c1e.png图 2连续时间状态空间模型的框图。图片由 Sascha Kirch 提供。这是 SSM 的连续时间表示。然而有两个困难找到模型状态h(t)的解析解具有挑战性由于我们在计算机上工作我们通常处理离散信号而不是连续信号。2.2 状态空间模型的离散化因此我们实际上是在将一个函数x(t)映射到另一个函数y(t)而不是将一个函数映射到另一个函数。这意味着我们需要一个可以处理离散信号的离散化状态空间模型。https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/9d3e2bccb55815656e84af05a9b9a71e.png图 3离散状态空间模型的高级功能。图片由 Sascha Kirch 提供。要做到这一点我们需要对矩阵A和B进行离散化。这导致了我们状态空间模型的离散化表示现在包括一个离散化的状态h[k]。https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/7e4d8ba1d7f13ea4950e09dfd9f1e3ce.png方程 2离散化 SSM 的状态和输出方程。为了离散化我们的状态空间模型我们使用了一种在系统理论中经常使用的技巧我们使用一个零阶保持ZOH模块将我们的离散信号x[k]转换为连续时间信号x(t)然后将其输入到连续的 SSM 中。SSM 将生成一个连续输出y(t)然后我们对其进行采样以获得离散输出序列y[k]https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/6ee465a4db89f6566cd55487791ce9c3.png图 4离散化状态空间模型的框图。图片由 Sascha Kirch 提供。为了离散化状态空间模型我们必须对矩阵A和B进行离散化。最初大多数基于深度学习的状态空间模型如LSSL 和 S4使用了双线性离散化而后来的一些如DSS 和 Mamba则使用了零阶保持ZOH离散化。目前我们关心的是我们需要进行离散化并且可能存在多种方式进行离散化。https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/1b3f17bec6b7a0735b8e73ab3b9206ac.png方程 3零阶保持ZOH离散化与双线性离散化的比较。A和B是 SSM 的连续时间矩阵I是单位矩阵Δ 是表示输入分辨率的步长在深度学习设置中是学习的。注意离散化使我们能够使用 SSM 的时间连续公式并将其应用于计算机中使用的离散信号。这意味着我们变得对分辨率不变因为我们学习了底层的时间信号。查看 S4ND 论文以获取更多关于该内容的阅读材料。Sascha Kirch 的论文解读2.3 SSM 的循环和卷积表示你有没有注意到状态方程中的循环性因为h[k]依赖于h[k-1]事实上从这种意义上讲状态空间模型非常类似于 RNNhttps://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/3e31a46e1712a05f9b7899d399c3cecb.png图 5离散状态空间模型的循环表示。图片由Sascha Kirch提供。矩阵A、B和C是时间不变的。这意味着每个输入和每个状态都使用相同的参数进行转换无论时间步长如何。线性时不变系统或简称为 LTI 系统具有一个非常整洁的特性它们可以表示为卷积https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/796037dd74ccf7776a8d82b190282883.png方程 4离散状态空间模型的卷积核。核由矩阵C和离散矩阵A和B构建。请注意核的长度L与输入序列相同。核与输入x卷积以获得输出y。为了在输出端保持与输入端相同的序列长度通常需要对输入进行填充。https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/9f2125122ea66dda6ea1f21584dffd73.png图 6离散状态空间模型的卷积表示示例。图片由Sascha Kirch提供。这真是太棒了因为我们知道卷积是可并行的而且 GPU 和其他硬件加速器都经过优化可以快速计算它们。为了总结本节利用线性时不变系统的特性我们最终得到了三种不同的表示 SSM 的方法https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/b8e552f8bacb5943542562fde1f398d4.png图 7状态空间模型的不同表示。图片由Sascha Kirch提供。SSM 的循环和卷积表示是相同的因此我们可以决定使用其中之一。在训练期间我们会选择卷积表示因为我们可以通过训练数据提前访问整个序列x[k]和标签y[k]。在推理期间我们会选择循环表示因为我们通常只能访问过去的样本例如自回归文本生成。注意由于我们的下一个输出只依赖于前一个状态和当前输入推理的缩放与序列长度成线性关系。现在我们有一个具有高效训练和高效推理的模型。太棒了是 LSSL线性状态空间层论文 展示了我们可以训练一个 SSM并强调了不同表示之间的联系。每当 Sascha Kirch 发布内容时都收到一封电子邮件 _ 每当 Sascha Kirch 发布内容时都收到一封电子邮件 想了解更多关于深度学习或只是保持最新状态…medium.com3. 启用长距离关系矩阵A负责从上一个状态创建下一个状态。系统的状态是考虑过去状态和新输入的多次更新后的结果。可以说状态包含了整个输入历史的信息。不难想象对于非常长的序列可能很难准确地记住所有输入。那么我们如何处理长距离依赖3.1 状态大小很重要让我们先谈谈状态它的压缩以及推理的效率。https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/df87c9ded69107403e70630f9cb81647.png图 8RNN、SSM、Transformer 和期望模型的效率与性能对比。图片由 Sascha Kirch 提供。RNN 和 Transformer 都可以从理论上处理无限长的数据序列有时人们说这些模型是无界的。RNN通过将整个历史压缩成一个低维状态来实现这一点。这对于推理来说非常高效因为现在下一个状态只依赖于上一个状态和当前输入。RNN 的性能取决于模型学习进行这种压缩的好坏。此外对于长序列尤其是由于梯度消失即使对于使用门控机制如 GRUs 和 LSTMs的更高级 RNN 来说这也是一项不容易完成的任务。Transformer的另一方面则完全不压缩历史。它通过其注意力机制考虑整个历史如果不是像自回归生成中那样掩码则包括未来。这允许 Transformer 并行训练但同时也随着序列长度的增加推理时的内存和计算需求激增。❓ 所以再次问我们如何处理具有高度压缩状态的长距离依赖同时在训练过程中避免困难并实现高性能一种解决方案我们使用 HiPPO高阶多项式投影算子框架初始化我们的矩阵。3.2 HiPPO 框架主要思想是将状态方程中的矩阵A和B参数化以便我们可以在任何给定点重建历史即从时间t0到tnow的输入信号。具体来说如果状态表示为某些正交多项式的某些系数例如勒让德多项式我们只需评估多项式即可重建直到当前时间的输入信号。例如我们有一个输入函数x(t)。在任意给定的时间步ti我们想要找到一个重建g(ti)它近似于tti之前的所有t的x(t)。近似g(ti)完全由可以插入勒让德多项式以重建g(ti)的系数c(ti)描述。https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/1f55c8ff10a1d43d02b9afca96445caf.png图 9HiPPO 框架中输入函数 x(t) 的近似 g(t)。图片由 Sascha Kirch 提供。这一点很重要我们不仅可以访问输入的压缩表示我们还可以在任何给定时间重建输入。尽管如此它仍然是一个具有线性复杂度所有优点的循环模型。这些系数随时间演变因为对于每个ti我们都有不同的系数c(ti)。那么我们如何得到这些系数呢实际上c(t)可以定义为常微分方程ODEhttps://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/0e4df073d832ed4b5803855d36302fc9.png方程 5HiPPO 框架系数的常微分方程ODE。通过一些数学技巧例如假设x和g的正交性以及选择不同时间步长的加权µ(t)这个问题就被完全定义了。我们可以求解常微分方程ODE并得到满足我们加权选择的A(t)和B(t)然后在我们状态空间模型SSM中使用它们。已提供了不同的 HiPPO 矩阵作为不同过去样本加权和不同多项式基的解决方案。LegT具有滑动窗口的勒让德多项式LegS具有恒定加权的勒让德多项式。LagT具有指数衰减的拉格朗日多项式LegS 是后续工作中使用最突出的一个。以下是矩阵A的解决方案https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/1354cb00cf6b97a2a60bdcc2541087ec.png方程 6N4 的 HiPPO 矩阵定义和示例。一个简单的 Python 实现可能看起来像这样defget_hippo_A(N:int)-np.ndarray:Anp.zeros((N,N))forninrange(N):forkinrange(N):ifnk:A[n,k](2*n1)**0.5*(2*k1)**0.5elifnk:A[n,k](n1)returnA之前提到的 LSSL 显示如果用 HiPPO 矩阵初始化可以训练一个 SSM并且它在所展示的实验中甚至超过了 Transformer 的性能。然而其昂贵的计算和内存需求使得 LSSL 作为通用序列模型不可行。直到 S4结构化状态空间序列模型的引入这个问题才得到解决我们将在下一章中讨论。4. 对状态矩阵施加结构以提高效率现在是时候讨论结构化状态空间模型中的结构是什么以及为什么它很重要了。我们首先尝试获得一些关于结构是什么以及为什么它很重要的非常基本的直觉然后我们将更详细地研究 S4 模型。4.1 结构化矩阵是什么为什么它们很重要让我们考虑一个标准的NxN元素方阵。我们将关注两个方面我们需要多少内存来表示这个矩阵我们能有多快地乘以两个矩阵这里的“快”可以由几个因素决定比如能够并行化计算、利用冗余以及最初处理的值更少。我们需要多少内存让我们从存储矩阵的内存需求开始。我们可以有不同的矩阵类型。有时它们已经以某种形式存在有时我们可以利用线性代数的规则将它们转换成某种形式。让我们看看一些常见的矩阵类型https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/e7985566a1d7138200fcb0f8cb882d96.png图 10几种常见矩阵类型需要存储的值数量。图片由 Sascha Kirch 提供。我们看到在最坏的情况下对于一个密集矩阵我们需要在内存中存储N²个值。不仅如此我们还必须从 CPU 加载N²个值到 GPU然后在 GPU 上应用某些矩阵运算。注意我们在 CPU 端即在 RAM 或磁盘上需要存储的值数量与我们在 GPU 上用于计算的内存占用之间存在差异。为了使这一点更清晰让我们看看低秩分解。一些矩阵可以被分解为 2 个NxR的低秩矩阵其中R是这些矩阵的秩。我们现在只需要为每个矩阵存储NR个值因此总共需要2NR个值。这也意味着我们只需要从 CPU 加载2NR个值而不是N²个值到 GPU。但是如果我们想用原始矩阵进行计算我们首先需要在 GPU 内部重建原始矩阵有些人称之为实体化它们这意味着我们仍然需要在 GPU 缓存中存储N²个值。另一方面像对角矩阵或稀疏矩阵这样的其他表示可以减少两个方面磁盘上存储的值以及需要在 GPU 上执行矩阵运算所需的值的数量。需要多少计算量现在让我们考虑一个标准的矩阵乘法CAB其中A和B要么是密集矩阵要么是对角矩阵而C是矩阵乘法的结果https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/6641a0d6c9b6097989854956ae8a78cd.png图 11不同矩阵类型矩阵乘法的操作数。图片由 Sascha Kirch 提供。对于A和B都是密集矩阵的情况使用我们所有人都知道的矩阵乘法标准算法我们需要N³次乘法和(N-1)N²次加法。在A是密集矩阵而B是对角矩阵的情况下这减少到N²次乘法和0次加法。在A和B都是对角矩阵的情况下我们只需要N次乘法即对角线元素的逐元素乘法。虽然这在硬件资源方面是可取的但请注意对角矩阵可能不如密集矩阵表达性强因为它有更少的值。像所有事情一样这是一个权衡。计算的并行化如果你仔细观察密集矩阵乘法你可能会注意到计算输出C的单行涉及到B的所有值但只有A的单行。天真地你可以在 GPU 的不同核心上并行计算C的每一行来加速计算。这意味着每个核心只需要N²次乘法和(N-1)N次加法而不是单核心情况下的N³次乘法和(N-1)N²次加法因为我们现在有一个向量-矩阵乘法而不是矩阵-矩阵乘法。简而言之为了总结以NxN的方阵为例通过尝试利用矩阵的某些特性我们希望将它们带入一种结构化的形式这样我们就不需要存储和加载超过N²个值并且对于矩阵乘法我们需要的乘法次数不超过N³加法次数不超过(N-1)N²。当然还有许多其他我们可以利用的矩阵结构它们各自都有优缺点。在整个关于 Mamba 状态空间模型的系列中我们将遇到许多这样的结构它们都服务于同一个目的在保持最佳性能的同时减少资源需求。4.2 将所有内容整合 – S4结构化状态空间序列模型S4 通过对涉及的矩阵施加一定的结构旨在使 SSMs 更高效。通过这样做它可以应用已研究的代数规则来降低涉及计算的复杂性并节省内存特别是对于卷积视图的核K的重复计算。https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/2c807f740d7cd24a0e45bfea09281978.png图 12S4 结构化状态序列模型。图片由 Sascha Kirch 提供。与 LSSL 相比S4 已经显示出 30 倍的速度提升和 400 倍的内存使用减少同时性能也得到了提升。S4 论文涉及许多高级线性代数在这篇博客文章中我将主要跳过。如果你想要深入了解数学和代码实现我推荐你查看 The Annotated S4。回想一下为了计算 SSM 的卷积表示中的卷积核K需要对离散化方阵A进行重复的矩阵乘法。记住核的长度L与输入序列x相同并且随着核中每个元素的加入矩阵A的幂次会增加。https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/796037dd74ccf7776a8d82b190282883.png方程 6离散化状态空间模型的卷积核。为了明确起见长度为L101的序列将会有矩阵的100次幂以及9998……以此类推。当然我们可以重用以前的结果但这仍然需要大量的计算。S4 的主要目标是通过在矩阵A上施加某种结构来减少计算核K所需的资源数量。理想情况下矩阵A将是一个对角矩阵因为正如我们在图 11 中之前看到的两个对角矩阵相乘会产生另一个对角矩阵只需要N次乘法。由于 HiPPO 矩阵不能以数值稳定的方式对角化如 S4 论文中所述接下来最好的办法是将它分解为正加低秩NPLR形式然后可以通过共轭进一步简化为对角加低秩DPLR形式。请注意我们希望保留 HiPPO 矩阵不将其转换为另一个可对角化的矩阵因为我们需要它来处理长距离依赖关系。https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/59956aef50be463c0b68fc80c0273e6f.png图 13比较矩阵 A 的不同表示及其计算卷积核 K 的复杂度。图片由 Sascha Kirch 提供。当矩阵A处于 DPLR 形式时我们可以应用许多复杂的技巧。总之核是在频域中计算的而 DPLR 形式允许将矩阵重新排列成柯西矩阵对于这种矩阵存在许多硬件优化的算法。此外使用 Woodburry 标识符可以将A的多个幂次的计算转换为单个逆矩阵。最后应用快速傅里叶变换FFT以获得最终的卷积滤波器K。因此之前是A, B, C的函数的SSM(A,B,C)现在变成了SSM(**Λ-*PP,B,C)其中Λ,P, B和C是 S4 的可学习参数。如果你对原始实现感兴趣这里提供了dplr()函数和nplr()函数它们位于S4 仓库中。5. 总结5.1 摘要状态空间模型SSMs是序列模型。SSMs 可以被描述为循环模型或卷积模型由于它们只依赖于当前输入和前一个状态它们在循环表示中快速推断并且因为这种表示线性扩展所以它们是快速的。由于它们的卷积表示可以并行化因此它们训练得很快。由于它们是 LTI 系统任何时间步的每个输入都被同等对待这意味着通过相同的矩阵进行转换。结构化状态空间序列模型即 S4在其矩阵上使用一定的结构使 SSM 更快允许它们在计算和内存需求上不那么严格。感谢我想要感谢你我故事书的读者感谢你们的持续支持为我点赞并关注我以便不错过我的任何最新文章。正是你们激励我继续写作深入研究复杂话题并通过写作成为更好的工程师/研究人员。你们太棒了5.2 结构化状态空间模型Structured SSMs的局限性然而有一个大问题由于矩阵A、B和C是不变的意味着对所有输入都是相同的因此 SSM 无法对不同时间步的不同输入进行推理。5.3 继续阅读第三部分Mamba选择性状态空间模型在第三部分中我们将发现 Mamba 如何通过向模型架构添加选择性来解决结构化状态空间模型在输入序列中无法区分样本的问题。❓但我们得到了什么代价是什么6. 进一步阅读与资源Mamba 状态空间模型用于图像、视频和时间序列Sascha Kirch 的论文解读博客文章Mamba 和状态空间模型的视觉指南S4 的注释HiPPO具有最优多项式投影的循环记忆论文S4使用结构化状态空间高效建模长序列 by A. Gu et.al.2021 年 10 月 31 日S5简化状态空间层用于序列建模 by J. Smith et. al.2022 年 8 月 9 日HiPPO: 具有最优多项式投影的循环记忆 by A. Gu et.al., 17 Aug. 2020LSSL: 结合循环、卷积和连续时间模型与线性状态空间层 by A. Gu et. al., 26 Oct. 2021DSS: 对角状态空间与结构化状态空间一样有效 by A. Gupta et. al., 27 Mar. 2022S4D: 对对角状态空间模型的参数化和初始化 by A. Gu et. al., 23 Jun. 2022S4ND: 使用状态空间将图像和视频建模为多维信号 by E. Nguyen et. al., 12 Oct. 2020代码S4 代码仓库