CachedAttention
Last updated
Was this helpful?
Last updated
Was this helpful?
To understand this paper, you should have full-understanding of Attention.
This paper propose a KV caching system to reduce the computation during LLM inference.
The paper distinguishes LLM inferencing into two phases:
Prefilling: Making KV caches for previous prompts + Generate the next token
Decoding: Making KV cache for generated token + Generate the next token
The inference engine first do the prefilling-phase. And then iterate decoding-phase until it outputs EOF token or reaches maximum generation length.
Most of realworld tasks handle multi-turn conversations. And this is the part that problem happens.
For every future turn(Turn 2, Turn 3, ...), it should do the prefilling-phase again. Which is just a duplicate computation. As the conversation length gets bigger and bigger, recomputation cost in prefilling-phase takes 99% of the inference computation.
The paper suggests a solution for the problem, CachedAttention
Cache the previous KV values, and use it in the future turn.
Overlap cache saving/loading operation with Transformer operations.
Design hierarchical KV cache placement and positional encoding decoupled KV cache scheme.
The purpose of decoupling positional encoding is due to token truncation. If the conversation gets longer and longer, the token sequence overflow the maximum context window. As a result, it truncates token sequence. If the KV cache contains positional encoding, then all of the KV cache should be invalidated and recomputed from the beginning. Since this happens frequently, CachedAttention decouples the position encoding from KV cache to reuse it even at truncation scenario.
RE means Recomputation(baseline)
CA means CachedAttention
As you can see, CachedAttention significantly improves inference performance.
Please read the paper
[1]