Mamba Model
Last updated
Was this helpful?
Last updated
Was this helpful?
Mamba is advanced model of S4, which has selection mechanism and hardware-aware algorithm.
Selection mechanism makes the model to select inputs to ignore or not. This will add content-awareness feature to Mamba model.
The attention in Transformer already has this feature. Attending specific parts of the sequence based on its content can be called as Selection Mechanism
.
To maximize the parallelism in recurrent computation, we use parallel-scan algorithm.
In modern GPUs, main bottleneck for AI computation is Memory bandwidth in HBM. Every computation is done in "fast" SRAM, and datas are stored in "slow" HBM.
Typical GPU bottleneck happens because frequent data copy happens between SRAM and HBM. To deal with this issue, Mamba only temporarily stores hidden state in SRAM. Importantly, if we use the hidden state in addition/multiplication, we replace the hidden state in SRAM.
You might think that intermediate hiddenstate should be used during back propagation, but Mamba suggests that just simply recalculating during backpropagation is enough. This prevents SRAM-HBM copy, and fully utilize GPU.
I made a simple tutorial of Mamba with CIFAR-10 dataset. You can find the code at the following repo:
Concept of Mamba is similar to Transformer, but based on SSM.
I strongly believe Mamba will contribute to unveil the AGI.
I recommend you to read about Mamba.
To add selection mechanism to S4, Mamba model make change based on input. So we generate for every sequence step .
Since we are using different for every time-step, we cannot calculate using convolution kernel. So we use recurrent computation.
To explain parallel scan in Mamba, it first multiply to the input. The arrow means "addition", and in the figure means to multiply with the first element in the box.
[1]
[2]