Linear Transformer
Last updated
Was this helpful?
Last updated
Was this helpful?
I got impressed by this fantastic paper. It uses feature mapping in self-attention, and optimize the computation cost to linear
.
Please read the paper! (even though you don't have any questions)
As sequence length grow, the computation and memory cost increase dramatically.
We can rewrite the self-attention calculation as follows:
Linear Transformer is closely related to kernel-trick in SVM.
Then we can rewrite the equation (1) as follows:
Then the self-attention calculation can be written as follows:
Then we can rewrite equation (3) as follows:
The transformer model has high computation and memory cost. Since it has to calculate the similarity for every embedding pair in query and keys, computation cost is . Transformer has to store all the computation results because of softmax computation, which brings memory cost to . ( is sequence length)
Notation means the th row in .
The computation cost would be
You may be curious how came up.
In denominator, it takes to multiply . We are caching result, and use it during calculating numerator.
In numerator, it takes to multiply part. (We are reusing calculation in denominator.
Since calculating denominator and numerator can be done in parallel, overal computation cost
If we assume is non-negative, we can seperate it into two feature mapping.(Mercer's Theorem)
The term and can be reused. is the dimension of .
Then the computation cost for self-attention is
How come up? would take calculation. This term is cached, which means not affecting the overall computation cost.
Calculating the numerator takes calculation cost.
Calculating the denominator takes .
Eventually, the calculation cost is because getting means getting every .
Transformer can be used as autoregressive model by masking the attention computation such that th position can only influence th position s.t. .
If we set as follows:
where and can be computed from and in constant time. This propagating property makes transformer same as recurrent neural network.
[1]
[2]