XF-Blog
ProjectMachine LearningdevelopmentAbout
MACHINE LEARNING PAPER NOTE
[Paper Note] Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention

https://arxiv.org/abs/2006.16236

Our Approach

The purpose of softmax is to identify the tokens most relevant to the current token.

Linear Attention

What can replace softmax?

Example of sim(Q,K)\text{sim}(Q, K): sim(Q,K)=QTK\text{sim}(Q, K) = Q^T K

We can write a generalized attention equation for any similarity function as follows:

Vi=j=1Nsim(Qi,Kj)Vjj=1Nsim(Qi,Kj)\newcommand{\similarity}[1]{\text{sim}\left(#1\right)} V'_i = \frac{\sum_{j=1}^N \similarity{Q_i, K_j} V_j} {\sum_{j=1}^N \similarity{Q_i, K_j}}

Define a kernel function ϕ:RdRC\phi: \mathbb{R}^d \rightarrow \mathbb{R}^C that maps to a C-dimensional feature space.

Vi=j=1Nϕ(Qi)Tϕ(Kj)Vjj=1Nϕ(Qi)Tϕ(Kj)\newcommand{\fe}[1]{\phi\left(#1\right)} V'_i = \frac{\sum_{j=1}^N \fe{Q_i}^T \fe{K_j} V_j} {\sum_{j=1}^N \fe{Q_i}^T \fe{K_j}}

To simplify this expression, let’s temporarily ignore the ϕ\phi function and directly use QiTKjQ_i^T K_j.

Vi=j=1NQiTKjVjj=1NQiTKj V'_i=\frac{\sum_{j=1}^N Q_i^T K_j V_j} {\sum_{j=1}^N Q_i^T K_j}

Then, it can be rewritten as:

Vi=QiT(j=1NKjVj)QiT(j=1NKj) V'_i = \frac{Q_i^T \left(\sum_{j=1}^N K_j V_j\right)} {Q_i^T \left(\sum_{j=1}^N K_j\right)}

This is possible because j=1Ncxjyj=cj=1Nxjyj\sum_{j=1}^N c x_j y_j = c \sum_{j=1}^N x_j y_j, and QiQ_i is a constant in this context.

By expressing the equation in this form, we avoid repeatedly calculating j=1NKjVj\sum_{j=1}^N K_j V_j and j=1NKj\sum_{j=1}^N K_j.

Note that QiQ_i cannot be canceled out from the numerator and denominator because QiQ_i is a vector, not a scalar.

We define ϕ(x)=elu(x)+1\phi(x)=\text{elu}(x)+1 as the kernel function.

ELU (Exponential Linear Unit) is defined as:

ELU(x)={xif x0α(ex1)if x<0\text{ELU}(x) = \begin{cases} x & \text{if } x \geq 0 \\ \alpha (e^x - 1) & \text{if } x < 0 \end{cases}

where α\alpha is a hyperparameter, typically set to 1.

An Interesting Question

Doesn’t linear attention still require reading the entire model to generate the next token? If the bottleneck is mainly memory bandwidth, then it doesn’t offer much advantage.

Related Work