ball-blog
  • Welcome, I'm ball
  • Machine Learning - Basic
    • Entropy
    • Cross Entropy
    • KL-Divergence
    • Monte-Carlo Method
    • Variational Auto Encoder
    • SVM
    • Adam Optimizer
    • Batch Normalization
    • Tokenizer
    • Rotary Positional Encoding
    • Vector Quantized VAE(VQ-VAE)
    • DALL-E
    • Diffusion Model
    • Memory Layers at Scale
    • Chain-of-Thought
  • Einsum
  • Linear Algebra
    • Linear Transformation
    • Determinant
    • Eigen-Value Decomposition(EVD)
    • Singular-Value Decomposition(SVD)
  • AI Accelerator
    • CachedAttention
    • SGLang
    • CacheBlend
  • Reinforcement Learning
    • Markov
  • Policy-Improvement Algorithm
  • Machine Learning - Transformer
    • Attention is All you need
    • Why do we need a mask in Transformer
    • Linear Transformer
    • kr2en Translator using Tranformer
    • Segment Anything
    • MNIST, CIFAR10 Classifier using ViT
    • Finetuning PaliGemma using LoRA
    • LoRA: Low-Rank Adaptation
  • EGTR: Extracting Graph from Transformer for SGG
  • Machine Learning - Mamba
    • Function Space(Hilbert Space)
    • HIPPO Framework
    • Linear State Space Layer
    • S4(Structures Space for Sequence Model)
    • Parallel Scan Algorithm
    • Mamba Model
  • Computer System
    • Memory Ordering: Release/Acquire 1
    • Memory Ordering: Release/Acquire 2
    • BUDAlloc
    • Lock-free Hash Table
    • Address Sanitizer
  • App development
    • Bluetooth connection in linux
    • I use Bun!
    • Using Tanstack-query in Frontend
    • Deploying AI Service using EC2
  • Problem Solving
    • Bipartite Graph
    • Shortest Path Problem in Graph
    • Diameter of a Tree
  • Scribbles
Powered by GitBook
On this page
  • Problem Define
  • Solution
  • How can we choose important tokens?
  • Pipelining to hide recompute time
  • References

Was this helpful?

Edit on GitHub
  1. AI Accelerator

CacheBlend

Last updated 2 months ago

Was this helpful?

Problem Define

There are many existing inference runtime that utilize KV cache reusing. This KV cache reusing shorten the computation cost and time during prefilling phase. However, existing methods cannot be applied when RAG is used.

RAG

RAG is 'Retreival Augmented Generation'. It adds several prefix chunks to the user sequence, and generate the output sequence. It uses sentence embedding to find appropriate prefix chunks for given user sequence.

If you look at the figure above,

At (a), default LLMs fully re-compute KV cache at prefilling phase.

At (b), it only reuse the KV cache of first chunk.

At (c), it utilize KV cache of all the chunks, but they aren't cross-attentioned. So the output quality becomes really poor.

At (d), paper propose new method that reuse KV cache of all the chunks and also apply cross-attention in negligible overhead.

Solution

The main problem is that we should apply cross-attention between prefix chunk's KV cache. And fully doing cross-attention is just the case (a).

In this paper, it starts with a following idea:

If we want to fuse KV-cache-1 and KV-cache-2, we select few(15%) tokens from chunk-1. Apply cross-attention between selected tokens from chunk-1 and all tokens in chunk-2.

If we could select important tokens from chunk-1, we can apply effective cross-attention that outputs KV cache similar to Full re-compute KV cache.

How can we choose important tokens?

To answer this question, we have to take a look at case (c).

Case (c) outputs poor quality result because it's KV cache has huge gap with case (a)'s KV cache.

As a result, it doesn't properly link the meaning between prefix chunks. If you look at attention matrix in figure below, the cross-attention part between prefix chunks are all 0's.

So the paper came up with the following idea:

We should make the KV-cache similar to full recomputed KV cache. (without fully recomputing)

The term important token means tokens with largest KV deviation.

KV deviation means gap between outputed KV cache and full recomputed KV cache.

Picking tokens that has huge gap between ideal one(full recomputed KV cache), and update via cross-attention will be the solution.

But this doesn't make sense. To calculate gap between ideal one, we should know the ideal one. Eventually doing full recomputation of KV cache. However, the paper suggest a clever thought:

Tokens with the highest KV deviations on one layer are likely to have the highest KV deviations on the next layer.

As adding up KV cache by fusing KV-caches, we can calculate how much KV cache changed. Then we can calculate the 15% HKVD(Highest KV Deviation) tokens and apply cross-attention.

Please read the paper for more information.

Pipelining to hide recompute time

As we reuse KV cache, there is a step of loading KV cache to GPU. If we pipeline the KV cache loading and KV cache recompute, then we can hide the recompute time.

The paper suggest selecting 15% of tokens for cross-attention during fusion will hide the recompute time.

References

[1]

https://arxiv.org/abs/2405.16444
Figure of diverse scenario in prefilling phase. From
Figure of attention map for full KV recompute and full KV reuse. From
https://arxiv.org/abs/2405.16444
https://arxiv.org/abs/2405.16444