XF-Blog
ProjectMachine LearningdevelopmentAbout
MACHINE LEARNING PAPER NOTE
[Paper Note] Titans: Learning to Memorize at Test Time
[Paper Note] Titans: Learning to Memorize at Test Time

https://arxiv.org/abs/2501.00663

Attention mechanisms and recurrent models each have their strengths and weaknesses:

This work aims to combine the best of both worlds by using recurrent attention to capture long-term memory and attention to model short-term memory.

Current approaches often lack some key components of human learning, such as:

Human Learning Mechanisms

Why Linear Transformers Outperform RNNs:

This leads to five critical questions:

  1. What structure should memory take?
  2. What is a proper memory update mechanism?
  3. What is a good memory retrieval process?
  4. How can different memory components be integrated?
  5. Is a deep memory module needed to effectively store long-term memory?

We introduce a system that can efficiently and effectively learn to memorize at test time, addressing the above questions and combining the strengths of attention and recurrent models.

Background

Long-Term Memory

Typically, memory is seen as a double-edged sword. If a model memorizes the training data too well, it may suffer from poor generalization and struggle with out-of-distribution data.

We need a model that can dynamically memorize or forget information at test time. Therefore, the model is learning a function that is capable of memorization, an online meta-model that learns how to memorize/forget the data at test time.

How to Train a Memory Module M\mathcal{M}?

The memory module M\mathcal{M} needs to learn how to update its memory.

Without Momentum:

Mt=Mt1θt(Mt1;xt)Surprise\mathcal{M}_{t} = \mathcal{M}_{t-1} - \theta_t \underbrace{\nabla \ell(\mathcal{M}_{t-1}; x_{t})}_{\text{Surprise}}

With Momentum:

Mt=Mt1+St,St=ηtSt1Past Surpriseθt(Mt1;xt)Momentary Surprise.\mathcal{M}_{t} = \mathcal{M}_{t-1} + S_{t}, \\ S_{t} = \eta_t \underbrace{S_{t-1}}_{\text{Past Surprise}} - \theta_t \: \underbrace{\nabla \ell\left(\mathcal{M}_{t-1}; x_{t}\right)}_{\text{Momentary Surprise}}.

Objective Function for M\mathcal{M}

To formalize test-time learning as a gradient descent optimization task, we design an objective function based on key-value memory:

The objective function is defined as:

(Mt1;xt)=Mt1(kt)vt22\ell(\mathcal{M}_{t-1}; x_t) = \left\| \mathcal{M}_{t-1}\left(\mathbf{k}_t\right) - \mathbf{v}_t \right\|_2^2

Where the keys and values are computed as:

kt=xtWK,vt=xtWV\mathbf{k}_t = x_t W_K, \qquad \mathbf{v}_t = x_t W_V

loss=Mt1(kt)vt22loss = \left\| \mathcal{M}_{t-1}\left(\mathbf{k}_t\right) - \mathbf{v}_t \right\|_2^2

Forgetting Mechanism

To prevent memory overload, we introduce a decay factor:

Mt=(1αt)Mt1+St\mathcal{M}_{t} = (1 - \alpha_t) \mathcal{M}_{t-1} + S_{t}

Memory Architecture Design

Memory Retrieval Process

Parallelizing Long-Term Memory Training

As discussed above, the design of our long-term memory module is equivalent to training a meta model by optimizing associative memory loss function (Mt1;xt)=Mt1(kt)vt22\ell(\mathcal{M}_{t-1}; x_t) = \left\| \mathcal{M}_{t-1}\left(\mathbf{k}_t\right) - \mathbf{v}_t \right\|_2^2 using gradient descent with momentum and weight decay.

Chunk-Based Parallel Processing

Our long-term memory module’s training can be viewed as optimizing an associative memory loss function through gradient descent with momentum and weight decay. To enable efficient parallel computation during test-time:

Persistent Memory

Persistent memory differs from long-term memory in that it contains task-relevant context that is input-independent. It is learned during training, fixed, and then concatenated with the input at front during inference.

There are three key motivations for using persistent memory:

Integrating Memory Module into Transformers

1. Memory as Context (MAC)

In the MAC approach, memory is used as input to the Transformer’s attention mechanism:

2. Memory as Gate (MAG)

The MAG approach uses the memory module as a gating mechanism for the attention output:

3. Memory as Layer Before Attention (MAL)

In the MAL setting, the memory module acts as a processing layer before the attention mechanism:

Limitations of MAL:

Results

Want to know these conclusions:

Training Configuration

Performance

Long-Range Dependencies

This is the result of BABILong Benchmark after fine-tune

Memory Module Depth

Ablation Study

Related Work

My Thought

Why Use M\mathcal{M} When M(kt)\mathcal{M}(k_t) Aims to Approximate vtv_t?

At first glance, it might seem redundant to use M\mathcal{M} since its output M(kt)\mathcal{M}(k_t) is trained to approximate vt=xtWVv_t = x_t W_V. Why not just use vtv_t directly instead of M\mathcal{M}'s output?

  • While M(kt)\mathcal{M}(k_t) is trained to approximate vtv_t, it does not simply replicate vtv_t. Instead, it combines vtv_t with historical memory stored in M\mathcal{M}.
  • The attention mechanism uses M(kt)\mathcal{M}(k_t) as input, not vtv_t directly. This ensures that attention is computed based on a representation that incorporates both current and historical context.