ball-blog
  • Welcome, I'm ball
  • Machine Learning - Basic
    • Entropy
    • Cross Entropy
    • KL-Divergence
    • Monte-Carlo Method
    • Variational Auto Encoder
    • SVM
    • Adam Optimizer
    • Batch Normalization
    • Tokenizer
    • Rotary Positional Encoding
    • Vector Quantized VAE(VQ-VAE)
    • DALL-E
    • Diffusion Model
    • Memory Layers at Scale
    • Chain-of-Thought
  • Einsum
  • Linear Algebra
    • Linear Transformation
    • Determinant
    • Eigen-Value Decomposition(EVD)
    • Singular-Value Decomposition(SVD)
  • AI Accelerator
    • CachedAttention
    • SGLang
    • CacheBlend
  • Reinforcement Learning
    • Markov
  • Policy-Improvement Algorithm
  • Machine Learning - Transformer
    • Attention is All you need
    • Why do we need a mask in Transformer
    • Linear Transformer
    • kr2en Translator using Tranformer
    • Segment Anything
    • MNIST, CIFAR10 Classifier using ViT
    • Finetuning PaliGemma using LoRA
    • LoRA: Low-Rank Adaptation
  • EGTR: Extracting Graph from Transformer for SGG
  • Machine Learning - Mamba
    • Function Space(Hilbert Space)
    • HIPPO Framework
    • Linear State Space Layer
    • S4(Structures Space for Sequence Model)
    • Parallel Scan Algorithm
    • Mamba Model
  • Computer System
    • Memory Ordering: Release/Acquire 1
    • Memory Ordering: Release/Acquire 2
    • BUDAlloc
    • Lock-free Hash Table
    • Address Sanitizer
  • App development
    • Bluetooth connection in linux
    • I use Bun!
    • Using Tanstack-query in Frontend
    • Deploying AI Service using EC2
  • Problem Solving
    • Bipartite Graph
    • Shortest Path Problem in Graph
    • Diameter of a Tree
  • Scribbles
Powered by GitBook
On this page
  • Summary
  • Problem with Transformer Model
  • Using feature-mapping(kernel function)
  • Transformer is RNN in Causal Masking
  • References

Was this helpful?

Edit on GitHub
  1. Machine Learning - Transformer

Linear Transformer

Last updated 3 months ago

Was this helpful?

Summary

I got impressed by this fantastic paper. It uses feature mapping in self-attention, and optimize the computation cost to linear.

Please read the paper! (even though you don't have any questions)

Problem with Transformer Model

As sequence length grow, the computation and memory cost increase dramatically.

We can rewrite the self-attention calculation as follows:

Using feature-mapping(kernel function)

Linear Transformer is closely related to kernel-trick in SVM.

Then we can rewrite the equation (1) as follows:

Transformer is RNN in Causal Masking

Then the self-attention calculation can be written as follows:

Then we can rewrite equation (3) as follows:

References

The transformer model has high computation and memory cost. Since it has to calculate the similarity for every embedding pair in query and keys, computation cost is O(N2⋅?)O(N^2 \cdot ?)O(N2⋅?). Transformer has to store all the computation results because of softmax computation, which brings memory cost to O(N2⋅?)O(N^2 \cdot ?)O(N2⋅?). (NNN is sequence length)

Q∈RN×D,K∈RN×D,V∈RN×MQ \in \mathbb{R}^{N\times D}, K \in \mathbb{R}^{N\times D}, V \in \mathbb{R}^{N\times M}Q∈RN×D,K∈RN×D,V∈RN×M

Notation Vi,Qi,KiV_i, Q_i, K_iVi​,Qi​,Ki​ means the iiith row in V,Q,KV, Q, KV,Q,K.

Vi′=∑jNsim(Qi,Kj)⋅Vj∑jNsim(Qi,Kj)...(1)V_i' = \frac{\sum_{j}^{N}sim(Q_i, K_j) \cdot V_j}{\sum_{j}^{N}sim(Q_i, K_j)} ...(1)Vi′​=∑jN​sim(Qi​,Kj​)∑jN​sim(Qi​,Kj​)⋅Vj​​...(1)

The computation cost would be O(N2⋅max(D,M))O(N^2 \cdot max(D, M))O(N2⋅max(D,M))

You may be curious how O(N2⋅max(D,M))O(N^2 \cdot max(D, M))O(N2⋅max(D,M)) came up.

In denominator, it takes O(D)O(D)O(D) to multiply sim(⋅)sim(\cdot)sim(⋅). We are caching sim(⋅)sim(\cdot)sim(⋅) result, and use it during calculating numerator.

In numerator, it takes O(M)O(M)O(M) to multiply sim(⋅)⋅Vjsim(\cdot) \cdot V_jsim(⋅)⋅Vj​ part. (We are reusing calculation in denominator.

Since calculating denominator and numerator can be done in parallel, overal computation cost O(N2⋅max(D,M))O(N^2 \cdot max(D, M))O(N2⋅max(D,M))

sim(Qi,Kj)=ϕ(Qi)⋅ϕ(Kj)sim(Q_i, K_j) = \phi(Q_i) \cdot \phi(K_j)sim(Qi​,Kj​)=ϕ(Qi​)⋅ϕ(Kj​)

If we assume sim(x,y)sim(x, y)sim(x,y) is non-negative, we can seperate it into two feature mapping.(Mercer's Theorem)

Vi′=∑jNϕ(Qi)Tϕ(Kj)Vj∑jNϕ(Qi)Tϕ(Kj) =ϕ(Qi)T∑jNϕ(Kj)Vjϕ(Qi)T∑jNϕ(Kj)...(2)V_i'=\frac{\sum_j^N \phi(Q_i)^T\phi(K_j)V_j}{\sum_j^N\phi(Q_i)^T\phi(K_j)} \\~\\ = \frac{\phi(Q_i)^T\sum_j^N\phi(K_j)V_j}{\phi(Q_i)^T\sum_j^N\phi(K_j)} ... (2)Vi′​=∑jN​ϕ(Qi​)Tϕ(Kj​)∑jN​ϕ(Qi​)Tϕ(Kj​)Vj​​ =ϕ(Qi​)T∑jN​ϕ(Kj​)ϕ(Qi​)T∑jN​ϕ(Kj​)Vj​​...(2)

The term ∑jNϕ(Kj)Vj\sum_j^N\phi(K_j)V_j∑jN​ϕ(Kj​)Vj​ and ∑jNϕ(Kj)\sum_j^N\phi(K_j)∑jN​ϕ(Kj​) can be reused. CCC is the dimension of ϕ(⋅)\phi(\cdot)ϕ(⋅).

Then the computation cost for self-attention is O(NCM)O(NCM)O(NCM)

How O(NCM)O(NCM)O(NCM) come up? ∑jNϕ(Kj)Vj\sum_j^N\phi(K_j)V_j∑jN​ϕ(Kj​)Vj​ would take O(NCM)O(NCM)O(NCM) calculation. This term is cached, which means not affecting the overall computation cost.

Calculating the numerator takes O(CM)O(CM)O(CM) calculation cost.

Calculating the denominator takes O(C)O(C)O(C).

Eventually, the calculation cost is O(NCM)O(NCM)O(NCM) because getting VVV means getting every Vi′,i=0...N−1V_i', i=0...N-1Vi′​,i=0...N−1.

Transformer can be used as autoregressive model by masking the attention computation such that iiith position can only influence jjjth position s.t. j≤ij \le ij≤i.

Vi′=∑jisim(Qi,Kj)Vj∑jisim(Qi,Kj) =ϕ(Qi)T∑jiϕ(Kj)Vjϕ(Qi)T∑jiϕ(Kj)...(3)V_i'=\frac{\sum_j^isim(Q_i, K_j)V_j}{\sum_j^isim(Q_i, K_j)} \\~\\ = \frac{\phi(Q_i)^T\sum_j^i\phi(K_j)V_j}{\phi(Q_i)^T\sum_j^i\phi(K_j)} ...(3)Vi′​=∑ji​sim(Qi​,Kj​)∑ji​sim(Qi​,Kj​)Vj​​ =ϕ(Qi​)T∑ji​ϕ(Kj​)ϕ(Qi​)T∑ji​ϕ(Kj​)Vj​​...(3)

If we set Si,ZiS_i, Z_iSi​,Zi​ as follows:

Si=∑jiϕ(Kj)VjS_i=\sum_j^i\phi(K_j)V_jSi​=j∑i​ϕ(Kj​)Vj​
Zi=∑jiϕ(Kj)Z_i=\sum_j^i\phi(K_j)Zi​=j∑i​ϕ(Kj​)
Vi′=ϕ(Qi)TSiϕ(Qi)TZiV_i'=\frac{\phi(Q_i)^TS_i}{\phi(Q_i)^TZ_i}Vi′​=ϕ(Qi​)TZi​ϕ(Qi​)TSi​​

where SiS_iSi​ and ZiZ_iZi​ can be computed from Si−1S_{i-1}Si−1​ and Zi−1Z_{i-1}Zi−1​ in constant time. This propagating property makes transformer same as recurrent neural network.

[1]

[2]

https://arxiv.org/abs/2006.16236
https://app.gitbook.com/o/rQcI92Wl1ZpMF9SYbg9M/s/41WpIJevPqDKpTFpP3Zp/~/changes/71/machine-learning/svm
LogoTransformers are RNNs: Fast Autoregressive Transformers with...arXiv.org