ball-blog
  • Welcome, I'm ball
  • Machine Learning - Basic
    • Entropy
    • Cross Entropy
    • KL-Divergence
    • Monte-Carlo Method
    • Variational Auto Encoder
    • SVM
    • Adam Optimizer
    • Batch Normalization
    • Tokenizer
    • Rotary Positional Encoding
    • Vector Quantized VAE(VQ-VAE)
    • DALL-E
    • Diffusion Model
    • Memory Layers at Scale
    • Chain-of-Thought
  • Einsum
  • Linear Algebra
    • Linear Transformation
    • Determinant
    • Eigen-Value Decomposition(EVD)
    • Singular-Value Decomposition(SVD)
  • AI Accelerator
    • CachedAttention
    • SGLang
    • CacheBlend
  • Reinforcement Learning
    • Markov
  • Policy-Improvement Algorithm
  • Machine Learning - Transformer
    • Attention is All you need
    • Why do we need a mask in Transformer
    • Linear Transformer
    • kr2en Translator using Tranformer
    • Segment Anything
    • MNIST, CIFAR10 Classifier using ViT
    • Finetuning PaliGemma using LoRA
    • LoRA: Low-Rank Adaptation
  • EGTR: Extracting Graph from Transformer for SGG
  • Machine Learning - Mamba
    • Function Space(Hilbert Space)
    • HIPPO Framework
    • Linear State Space Layer
    • S4(Structures Space for Sequence Model)
    • Parallel Scan Algorithm
    • Mamba Model
  • Computer System
    • Memory Ordering: Release/Acquire 1
    • Memory Ordering: Release/Acquire 2
    • BUDAlloc
    • Lock-free Hash Table
    • Address Sanitizer
  • App development
    • Bluetooth connection in linux
    • I use Bun!
    • Using Tanstack-query in Frontend
    • Deploying AI Service using EC2
  • Problem Solving
    • Bipartite Graph
    • Shortest Path Problem in Graph
    • Diameter of a Tree
  • Scribbles
Powered by GitBook
On this page
  • Problem with conventional tensor contraction codes
  • Einsum
  • Why we use einsum?
  • Hands-on Einsum
  • References

Was this helpful?

Edit on GitHub

Einsum

Last updated 17 days ago

Was this helpful?

Please read the post for better understanding of einsum.

Problem with conventional tensor contraction codes

If you've ever wrote a pytorch code to build a Neural Network, then you have experienced tensor contraction code. Tensor contraction is a fancy term meaning combine tensors and build a tensor

Dot product, Cross product, Matrix multiplication, Matrix element-wise multiplication etc. All kinds of operations we know are actually a subset of tensor contraction.

In this post, I will give you powerful tool to express tensor contraction in a single line(without any dirty unsqueeze, transpose, axis swap)

Einsum

Let's say we have two tensors P∈Rp1×p2×...pn,Q∈Rq1×q2×...qmP \in \mathbb{R}^{p_1 \times p_2 \times ... p_n} , Q \in \mathbb{R}^{q_1 \times q_2 \times ... q_m}P∈Rp1​×p2​×...pn​,Q∈Rq1​×q2​×...qm​

Notation P(i1,i2,...,in)P(i_1, i_2, ..., i_n)P(i1​,i2​,...,in​) means element of index (i1,i2,...,in)(i_1, i_2, ..., i_n)(i1​,i2​,...,in​)

Then we can express einsum as following:

∑a∑bP(i1,i2,...,in−1,a) Q(b,j2,j3,...,jm)=i1i2...in−1a, bj2j3...jm→i1i2...in−1j2j3...jm\sum_a \sum_b P(i_1, i_2, ..., i_{n-1}, a) \ Q(b, j_2, j_3, ..., j_m) \\ = i_1i_2...i_{n-1}a, \ bj_2j_3...j_m \rightarrow i_1i_2...i_{n-1}j_2j_3...j_ma∑​b∑​P(i1​,i2​,...,in−1​,a) Q(b,j2​,j3​,...,jm​)=i1​i2​...in−1​a, bj2​j3​...jm​→i1​i2​...in−1​j2​j3​...jm​

Why we use einsum?

Einsum may be more optimizable because it compress various operations into compact expression. This gives compiler more opportunity to optimize.

Also, it is beutiful!

Hands-on Einsum

You can try out some examples!

References

[1]

Einsum is All you need
https://rockt.ai/2018/04/30/einsum