ball-blog
  • Welcome, I'm ball
  • Machine Learning - Basic
    • Entropy
    • Cross Entropy
    • KL-Divergence
    • Monte-Carlo Method
    • Variational Auto Encoder
    • SVM
    • Adam Optimizer
    • Batch Normalization
    • Tokenizer
    • Rotary Positional Encoding
    • Vector Quantized VAE(VQ-VAE)
    • DALL-E
    • Diffusion Model
    • Memory Layers at Scale
    • Chain-of-Thought
  • Einsum
  • Linear Algebra
    • Linear Transformation
    • Determinant
    • Eigen-Value Decomposition(EVD)
    • Singular-Value Decomposition(SVD)
  • AI Accelerator
    • CachedAttention
    • SGLang
    • CacheBlend
  • Reinforcement Learning
    • Markov
  • Policy-Improvement Algorithm
  • Machine Learning - Transformer
    • Attention is All you need
    • Why do we need a mask in Transformer
    • Linear Transformer
    • kr2en Translator using Tranformer
    • Segment Anything
    • MNIST, CIFAR10 Classifier using ViT
    • Finetuning PaliGemma using LoRA
    • LoRA: Low-Rank Adaptation
  • EGTR: Extracting Graph from Transformer for SGG
  • Machine Learning - Mamba
    • Function Space(Hilbert Space)
    • HIPPO Framework
    • Linear State Space Layer
    • S4(Structures Space for Sequence Model)
    • Parallel Scan Algorithm
    • Mamba Model
  • Computer System
    • Memory Ordering: Release/Acquire 1
    • Memory Ordering: Release/Acquire 2
    • BUDAlloc
    • Lock-free Hash Table
    • Address Sanitizer
  • App development
    • Bluetooth connection in linux
    • I use Bun!
    • Using Tanstack-query in Frontend
    • Deploying AI Service using EC2
  • Problem Solving
    • Bipartite Graph
    • Shortest Path Problem in Graph
    • Diameter of a Tree
  • Scribbles
Powered by GitBook
On this page
  • Prerequisites
  • What makes S4 special?
  • Applying HIPPO framework
  • Using truncated generating function and IFFT to generate
  • How to get generating function (z)
  • Wrapup
  • References

Was this helpful?

Edit on GitHub
  1. Machine Learning - Mamba

S4(Structures Space for Sequence Model)

Last updated 3 months ago

Was this helpful?

Prerequisites

To fully understand S4, I recommend you some reading materials.

Please read it before reading the paper or blog post .

  1. FFT -

  2. Generating Function -

  3. LSSL -

What makes S4 special?

  1. Addressing long-range dependencies using HIPPO framework

  2. Instead of directly calculating power of Aˉ\bar{A}Aˉ , S4 use truncated generating function and IFFT(Inverse Fast Fourier Transform) to generate convolution kernel Kˉ\bar{K}Kˉ.

  3. When AAA matrix is DPLR, then by using Woodbury Identity, we can compute generating function K^L\hat{K}_LK^L​ using efficient Cauchy dot-product.

I will go through all the mathematical parts.

Applying HIPPO framework

Previous SSMs struggled because hidden state xxx couldn't store the past history. So they made a HIPPO framework, SSM that models past history using orthonormal basis.

In HIPPO framework, there are 4 matrix A,B,C,DA, B, C, DA,B,C,D. But the most important matrix is AAA(which is called HIPPO matrix). So they bring the HIPPO matrix to S4.

Using truncated generating function and IFFT to generate Kˉ\bar{K}Kˉ

Let's think of Roots of Unity Filter.

Converting into inverse-form

Since

Assume A is DPLR(Diagonal Plus Low-Rank), and apply Woodbury identity

Woodbury identity converts inverse of DPLR into simpler form:

Background: Cauchy Matrix

Cauchy matrix is defined as follows:

Background: Cauchy kernel(Cauchy dot-product)

You don't have to understand why this form can be computed efficiently.

For simplicity, We are going to write Cauchy dot-product as

Wrapup

This math-journey is everthing for S4. It was exciting to study FFT, Generating function. I hope this post helps you guide to the ultimate goal, Mamba.

References

Computing Kˉ\bar{K}Kˉ takes huge computation resource.

Kˉ=(Cˉ∗Aˉ0Bˉ,Cˉ∗Aˉ1Bˉ,...,Cˉ∗AˉL−1Bˉ)\bar{K}=(\bar{C}^*\bar{A}^0\bar{B}, \bar{C}^*\bar{A}^1\bar{B}, ..., \bar{C}^*\bar{A}^{L-1}\bar{B})Kˉ=(Cˉ∗Aˉ0Bˉ,Cˉ∗Aˉ1Bˉ,...,Cˉ∗AˉL−1Bˉ)

From Kˉ\bar{K}Kˉ, let's create a generating function K^L \hat{K}_LK^L​

K^L(z;Aˉ,Bˉ,Cˉ)=∑k=0L−1Cˉ∗AˉkBˉzk\hat{K}_L(z;\bar{A}, \bar{B}, \bar{C})=\sum_{k=0}^{L-1}\bar{C}^*\bar{A}^k\bar{B}z^kK^L​(z;Aˉ,Bˉ,Cˉ)=k=0∑L−1​Cˉ∗AˉkBˉzk
ωj=exp(−i2πjkL),  j=0,1,2,...,L−1\omega_j=exp(-i 2\pi\frac{jk}{L}), \ \ j=0, 1, 2, ..., L-1ωj​=exp(−i2πLjk​),  j=0,1,2,...,L−1

If we subsitute zzz to ωj\omega_jωj​, then we get the following equation.

K^L(ωj)=∑k=0L−1(Cˉ∗AˉkBˉ)⋅ωjk =∑k=0L−1Cˉ∗AˉkBˉ⋅exp(−i2πjkL)\hat{K}_L(\omega_j)=\sum_{k=0}^{L-1}(\bar{C}^*\bar{A}^k\bar{B}) \cdot \omega_j^k \\~\\ =\sum_{k=0}^{L-1}\bar{C}^*\bar{A}^k\bar{B}\cdot exp(-i2\pi\frac{jk}{L})K^L​(ωj​)=k=0∑L−1​(Cˉ∗AˉkBˉ)⋅ωjk​ =k=0∑L−1​Cˉ∗AˉkBˉ⋅exp(−i2πLjk​)

This is exactly the same as DFT(Discrete Fourier Transform). Think jjj as frequency, and kkk as time.

This means that if we can get the generating function K^L(z)\hat{K}_L(z)K^L​(z), we can easily calculate the convolution kernel Kˉ\bar{K}Kˉ using IFFT.

How to get generating function K^L\hat{K}_LK^L​(z)

You may think that we need all CˉAˉkBˉ\bar{C}\bar{A}^k\bar{B}CˉAˉkBˉ terms to get generating function K^L(z)\hat{K}_L(z)K^L​(z).

Well, actually we don't. Let's look at some tricks to get K^L\hat{K}_LK^L​ with low computation cost.

K^L(z;Aˉ,Bˉ,Cˉ)=∑k=0L−1Cˉ∗AˉkBˉzk =Cˉ∗(I−AˉLzL)(I−Aˉz)−1Bˉ =C~∗(I−Aˉz)−1Bˉ\hat{K}_L(z;\bar{A}, \bar{B}, \bar{C})=\sum_{k=0}^{L-1}\bar{C}^*\bar{A}^k\bar{B}z^k \\~\\ = \bar{C}^*(I-\bar{A}^Lz^L)(I-\bar{A}z)^{-1}\bar{B} \\~\\ = \tilde{C}^*(I-\bar{A}z)^{-1}\bar{B}K^L​(z;Aˉ,Bˉ,Cˉ)=k=0∑L−1​Cˉ∗AˉkBˉzk =Cˉ∗(I−AˉLzL)(I−Aˉz)−1Bˉ =C~∗(I−Aˉz)−1Bˉ

We are going to use z from exp(−i2πkL):k∈[L]exp(-i2\pi\frac{k}{L}):k\in[L]exp(−i2πLk​):k∈[L]. So zLz^LzL is always 1

Convert Aˉ,Bˉ\bar{A}, \bar{B}Aˉ,Bˉ into A,BA, BA,B

Previously, we discretized A,BA, BA,B into Aˉ,Bˉ\bar{A}, \bar{B}Aˉ,Bˉ. But we are doing the reverse.

Aˉ=(I−Δ2A)−1(I+Δ2A)\bar{A}=(I-\frac{\Delta}{2}A)^{-1}(I+\frac{\Delta}{2}A)Aˉ=(I−2Δ​A)−1(I+2Δ​A)
Bˉ=(I−Δ2A)−1ΔB\bar{B}=(I-\frac{\Delta}{2}A)^{-1}\Delta BBˉ=(I−2Δ​A)−1ΔB

We can convert K^L(z)\hat{K}_L(z)K^L​(z) as following:

K^L=C~∗(I−Aˉz)−1Bˉ =C~∗(I−(I−Δ2A)−1(I+Δ2A)z)−1(I−Δ2A)−1)ΔB =C~∗((I−Δ2A)−1(I−Δ2A)−(I−Δ2A)−1(I+Δ2A)z)−1(I−Δ2A)−1ΔB =C~∗[I(1−z)−Δ2A(1+z)]−1ΔB =21+zC~∗[2(1−z)Δ(1+z)I−A]−1B\hat{K}_L=\tilde{C}^*(I-\bar{A}z)^{-1}\bar{B} \\~\\ = \tilde{C}^*(I-(I-\frac{\Delta}{2}A)^{-1}(I+\frac{\Delta}{2}A)z)^{-1}(I-\frac{\Delta}{2}A)^{-1})\Delta B \\~\\ = \tilde{C}^*((I-\frac{\Delta}{2}A)^{-1}(I-\frac{\Delta}{2}A)-(I-\frac{\Delta}{2}A)^{-1}(I+\frac{\Delta}{2}A)z)^{-1}(I-\frac{\Delta}{2}A)^{-1}\Delta B \\~\\ = \tilde{C}^*[I(1-z)-\frac{\Delta}{2}A(1+z)]^{-1}\Delta B \\~\\ = \frac{2}{1+z}\tilde{C}^*[\frac{2(1-z)}{\Delta(1+z)}I-A]^{-1}BK^L​=C~∗(I−Aˉz)−1Bˉ =C~∗(I−(I−2Δ​A)−1(I+2Δ​A)z)−1(I−2Δ​A)−1)ΔB =C~∗((I−2Δ​A)−1(I−2Δ​A)−(I−2Δ​A)−1(I+2Δ​A)z)−1(I−2Δ​A)−1ΔB =C~∗[I(1−z)−2Δ​A(1+z)]−1ΔB =1+z2​C~∗[Δ(1+z)2(1−z)​I−A]−1B

If we assume AAA is DPLR, we can write AAA as following: Λ\LambdaΛ is a diagonal matrix Λ∈CN×N\Lambda \in \mathbb{C}^{N\times N}Λ∈CN×N. P,Q∈CN×1P, Q \in \mathbb{C}^{N \times 1}P,Q∈CN×1

A=Λ+PQ∗A = \Lambda +PQ^*A=Λ+PQ∗
(Λ+PQ∗)−1=Λ−1−Λ−1P(1+Q∗Λ−1P)−1Q∗Λ−1(\Lambda+PQ^*)^{-1}=\Lambda^{-1}-\Lambda^{-1}P(1+Q^*\Lambda^{-1}P)^{-1}Q^*\Lambda^{-1}(Λ+PQ∗)−1=Λ−1−Λ−1P(1+Q∗Λ−1P)−1Q∗Λ−1

By applying it, we get the following equation for K^L(z)\hat{K}_L(z)K^L​(z): R(z)R(z)R(z) is also a diagonal matrix!

K^L=21+z[C~∗R(z)−1B−C~∗R(z)−1P(1+Q∗R(z)−1P)−1Q∗R(z)−1B] where R(z)=2(1−z)Δ(1+z)I−Λ\hat{K}_L=\frac{2}{1+z}[\tilde{C}^*R(z)^{-1}B-\tilde{C}^*R(z)^{-1}P(1+Q^*R(z)^{-1}P)^{-1}Q^*R(z)^{-1}B] \\~\\ where \ R(z)=\frac{2(1-z)}{\Delta(1+z)}I-\LambdaK^L​=1+z2​[C~∗R(z)−1B−C~∗R(z)−1P(1+Q∗R(z)−1P)−1Q∗R(z)−1B] where R(z)=Δ(1+z)2(1−z)​I−Λ

With elements Ω=(wi)∈CM\Omega=(w_i)\in\mathbb{C}^MΩ=(wi​)∈CM and Λ=(λj)∈CN\Lambda=(\lambda_j)\in \mathbb{C}^NΛ=(λj​)∈CN,

M∈CM×N=M(Ω,Λ)=(Mij)i∈[M], j∈[N] Mij=1ωi−λjM\in\mathbb{C}^{M\times N}=M(\Omega, \Lambda)=(M_{ij})_{i\in[M],\ j\in[N]} \\~\\ M_{ij}=\frac{1}{\omega_i - \lambda_j}M∈CM×N=M(Ω,Λ)=(Mij​)i∈[M], j∈[N]​ Mij​=ωi​−λj​1​

Cauchy kernel is a efficient way to compute following form: A∈CM×1, B∈CM×N, C∈CN×1A \in \mathbb{C}^{M\times 1}, \ B \in \mathbb{C}^{M\times N}, \ C \in \mathbb{C}^{N \times 1}A∈CM×1, B∈CM×N, C∈CN×1. BBB is Cauchy matrix.

ATBCA^TBCATBC
ATBC=kΩ,Λ(A,C)A^TBC=k_{\Omega, \Lambda}(A, C)ATBC=kΩ,Λ​(A,C)

Applying Cauchy dot-product to the K^L\hat{K}_LK^L​ calculation

Since R(z)−1R(z)^{-1}R(z)−1 is a Cauchy-Matrix, we can apply Cauchy dot-product!

K^L=21+z[C~∗R(z)−1B−C~∗R(z)−1P(1+Q∗R(z)−1P)−1Q∗R(z)−1B] =c(z)[kz,Λ(C~,B)−kz,Λ(C~,P)(1+kz,Λ(Q,P))−1kz,Λ(Q,B)]\hat{K}_L=\frac{2}{1+z}[\tilde{C}^*R(z)^{-1}B-\tilde{C}^*R(z)^{-1}P(1+Q^*R(z)^{-1}P)^{-1}Q^*R(z)^{-1}B] \\~\\ = c(z)[k_{z, \Lambda}(\tilde{C}, B)-k_{z, \Lambda}(\tilde{C}, P)(1+k_{z, \Lambda}(Q, P))^{-1}k_{z, \Lambda}(Q, B)]K^L​=1+z2​[C~∗R(z)−1B−C~∗R(z)−1P(1+Q∗R(z)−1P)−1Q∗R(z)−1B] =c(z)[kz,Λ​(C~,B)−kz,Λ​(C~,P)(1+kz,Λ​(Q,P))−1kz,Λ​(Q,B)]

Since we can calculate the generating function K^L\hat{K}_LK^L​ with low computation cost, we don't need huge computation to generate the convolution kernel Kˉ\bar{K}Kˉ.

[1]

[2]

https://srush.github.io/annotated-s4/
https://arxiv.org/abs/2111.00396
The Annotated S4
https://www.youtube.com/watch?v=spUNpyF58BY&ab_channel=3Blue1Brown
https://www.youtube.com/watch?v=bOXCLR3Wric&ab_channel=3Blue1Brown
https://app.gitbook.com/o/rQcI92Wl1ZpMF9SYbg9M/s/nHC3k6mrFavPwYubW6hY/~/changes/16/machine-learning-mamba/linear-state-space-layer/~/overview