Why Cache 32 Heads When One Latent Variable Suffices? A Theory-to-Code Guide to DeepSeek’s MLA for KV-Cache [En/中]

从多头共享到潜变量:DeepSeek的MLA在低秩投影与按需解压中重新定义 KV-Cache

Posted by Yihua Zhang on February 2, 2025

MLA: Redefining KV-Cache Through Low-Rank Projections and On-Demand Decompression

Introduction

As Large Language Models (LLMs) continue to thrive, hardware resources remain a daunting “ceiling”—especially with limited GPU memory (VRAM). The question of how to achieve extended context lengths and faster inference under constrained resources has long been a key focus for both engineering and research communities. Aside from the common techniques of quantization and pruning, there is a growing emphasis on “reducing the KV-Cache footprint at inference time”.

This article first revisits how MHA (Multi-Head Attention), MQA (Multi-Query Attention), and GQA (Grouped-Query Attention) handle or reduce K/V storage, and then centers on the approach introduced by DeepSeek, MLA (Multi-Head Latent Attention). Unlike those earlier methods that work mainly at the level of “K/V sharing,” MLA employs a combination of low-rank projection and on-demand decompression, allowing us to bypass storing multi-head K/V directly. MLA uses “latent vectors” and further leverages a matrix-merging trick so that, during inference, the attention mechanism only requires minimal VRAM usage to function.

It is worth noting that MLA, when deployed in real systems, often needs to accommodate RoPE (Rotary Position Embedding). To keep things clear, we will first explain the core MLA idea (low-rank projection) and then discuss how to integrate RoPE. We hope this structured approach provides insight into the reasoning and nuances behind MLA’s design.

Special Thanks: Part of the inspiration in this article comes from Su Jianlin’s blog post 缓存与效果的极限拉扯:从MHA、MQA、GQA到MLA (in Chinese). We extend our respect to him for his work.


1. Why Reduce the KV-Cache?

1.1 The “Invisible Bottleneck” in Long-Context Inference

During autoregressive generation with a Transformer, each newly generated token references the historical Key/Value ($K, V$) vectors from all previous tokens. These stored Key/Value vectors constitute the KV-Cache during inference. If the sequence length is $L$, with $h$ attention heads, and each head has dimensionality $d_k$ or $d_v$, then the total KV-Cache scales roughly as \(L \times h \times (d_k + d_v),\) growing linearly with $L$. Once the context reaches thousands or even tens of thousands of tokens, KV-Cache can dominate VRAM usage, surpassing even the model’s own activation storage.

1.2 Constraints of VRAM and Bandwidth

When the sequence is very long, fitting all KV-Cache on a single GPU becomes unfeasible. Splitting the cache across multiple GPUs or machines leads to bandwidth bottlenecks—inter-device communication is typically much slower than on-device memory access. Simply put, if we can handle longer contexts on fewer GPUs, we minimize communication overhead and improve throughput. This is precisely why MQA, GQA, and MLA have emerged and keep evolving.


2. MHA → MQA → GQA: Simplifying K/V in Multi-Head Attention

Before introducing MLA, let’s briefly overview conventional Multi-Head Attention (MHA) as well as the sharing-based approaches MQA and GQA, which aim to reduce K/V storage. This context sets the stage for comparing MLA to prior work.

2.1 Multi-Head Attention (MHA) Foundations

2.1.1 Classic Attention Formulas

In a Transformer, a token sequence $\mathbf{x}_1,\dots,\mathbf{x}_l$ is projected into multiple sets of $(Q, K, V)$ for attention computation. For the $s$-th head, assuming hidden dimensionality $d$: \(\mathbf{q}_i^{(s)} = \mathbf{x}_i \mathbf{W}_q^{(s)},\quad \mathbf{k}_i^{(s)} = \mathbf{x}_i \mathbf{W}_k^{(s)},\quad \mathbf{v}_i^{(s)} = \mathbf{x}_i \mathbf{W}_v^{(s)}.\) Under autoregressive decoding, the attention score at step $t$ is often written as \(\alpha_{t,i}^{(s)} = \mathbf{q}_t^{(s)} \,\mathbf{k}_i^{(s)\top}, \quad\text{for } i \le t.\) To speed up inference, we cache the computed $\mathbf{k}_i^{(s)}$ and $\mathbf{v}_i^{(s)}$ in VRAM for later tokens, a storage referred to as the KV-Cache.

2.1.2 Pressures on VRAM

Because MHA typically retains distinct K/V for each head, if $h$ is large, you end up storing $h$ sets of Key/Value, which can quickly blow up VRAM usage. Researchers therefore wondered: can we make multi-head attention share or compress these K/V representations?

2.2 MQA: Extreme K/V Sharing

MQA (Multi-Query Attention) focuses on letting every head share a single K/V pair: \(\mathbf{k}_i = \mathbf{x}_i \mathbf{W}_k,\quad \mathbf{v}_i = \mathbf{x}_i \mathbf{W}_v,\) while each head still retains its own $\mathbf{q}_i^{(s)}$. Then, the KV-Cache is just 1 set of K/V instead of $h$. This slash in VRAM usage can be as large as a factor of $1/h$. Implementations like PaLM or StarCoder have adopted MQA. However, because all heads share the same K/V, certain tasks might see degraded performance unless additional training strategies are used.

2.3 GQA: Grouping Heads

If MQA feels overly aggressive, GQA (Grouped-Query Attention) provides a middle ground: group the $h$ heads into $g$ clusters, each cluster sharing one set of K/V. Hence, the KV-Cache shrinks to $g$ sets (rather than $h$). Examples include LLaMA2-70B and ChatGLM2/3. GQA retains greater variety than MQA but still saves VRAM over standard MHA.

2.4 Comparison of MHA / MQA / GQA

Method KV-Cache Storage K/V Sharing? VRAM Savings Head Diversity
MHA Store $h$ copies Independent per Head Low (Baseline) High
MQA Store 1 copy Fully Shared K/V High Lower
GQA Store $g$ copies Shared by groups Moderate Fairly High

Whether MQA or GQA, they both revolve around the question of “how much to share K/V across heads.” In contrast, MLA rethinks “what we actually store at inference time”: It shifts most of the K/V content into a latent vector, reconstructing it only on demand. Let’s explore how MLA uses low-rank projection and on-demand decompression, initially without worrying about RoPE.


3. The MLA Core: Low-Rank Projections and On-Demand Reconstruction (Without RoPE)

3.1 Key Idea: Switching from “Store Multi-Head K/V” to “Store a Low-Dimensional Latent Vector”

In MLA (Multi-Head Latent Attention), we still project each input into Key and Value at training time, but we introduce a low-rank latent vector $\mathbf{c}_i$. During inference, instead of caching high-dimensional multi-head K/V, we only store this compact $\mathbf{c}_i$, then merge matrices when needed to reconstruct the multi-head K/V. Concretely, ignoring RoPE for the moment, we can represent MLA’s training step like so:

\[\mathbf{c}_i = \mathbf{x}_i \mathbf{W}_c \quad (\text{a low-rank projection, } \mathbf{W}_c \in \mathbb{R}^{d \times d_c}),\]

and for each head $s$ we define projection matrices $\mathbf{W}_{kc}^{(s)}, \mathbf{W}_v^{(s)}$ such that \(\mathbf{k}_i^{(s)} = \mathbf{c}_i \mathbf{W}_{kc}^{(s)}, \quad \mathbf{v}_i^{(s)} = \mathbf{c}_i \mathbf{W}_v^{(s)}.\) As a result, no matter how many heads we have, only training time sees explicit multi-head K/V. At inference, we simply cache the latent vector $\mathbf{c}_i$, reconstructing K/V on-the-fly using matrix combinations. This is how MLA largely decouples the KV-Cache cost from the number of heads.

3.2 On-Demand Decompression: How VRAM is Saved

During inference, whenever we generate token $t$ and need to compute the dot product $\mathbf{q}_t^{(s)} \mathbf{k}_i^{(s)\top}$ for previous tokens $i < t$, a conventional approach would read $\mathbf{k}_i^{(s)}$ from VRAM. MLA, however, merges them:

\[\mathbf{k}_i^{(s)} = \mathbf{c}_i\,\mathbf{W}_{kc}^{(s)}\quad\Longrightarrow\quad\mathbf{q}_t^{(s)} \mathbf{k}_i^{(s)\top}=(\mathbf{x}_t \mathbf{W}_q^{(s)})(\mathbf{c}_i \mathbf{W}_{kc}^{(s)})^\top.\]

Thanks to properties of matrix multiplication, we can further do some merging as:

\[(\mathbf{x}_t \mathbf{W}_q^{(s)}) (\mathbf{c}_i \mathbf{W}_{kc}^{(s)})^\top = \mathbf{x}_t (\mathbf{W}^{(s)}_q \mathbf{W}^{(s)\top}_{kc}) (\mathbf{c}_i)^\top = \mathbf{x}_t \mathbf{W}_{\mathrm{merged}}^{(s)} (\mathbf{c}_i)^\top\]

Hence, we only keep $\mathbf{c}_i$ in VRAM, never storing all the $\mathbf{k}_i^{(s)}$.
If $\mathbf{c}_i$ is $d_c$-dimensional, with $d_c \ll h \times d_k$, then the KV-Cache cost transitions from $h \times d_k$ down to $d_c$.

3.3 How Low-Rank Projection Drastically Reduces Storage

One might ask, how big can the compression ratio get? Suppose $d=4096$, $h=32$, and single-head dimension $d_k=128$. In standard MHA, each token’s Key is $32 \times 128 = 4096$ elements (similarly for Value). MLA might set the latent vector $\mathbf{c}_i$ to 512 elements, cutting VRAM usage from 4096 to 512—an 8× improvement. In more extreme setups, you might see tens or even hundreds of times in compression factor.

Of course, this is the rosy scenario without position encoding. In practice, Transformers commonly use RoPE (Rotary Position Embedding), which modifies how $Q$ and $K$ are projected. So we’ll first clarify MLA’s fundamental low-rank approach before examining how RoPE fits in. Next, we illustrate MLA’s workflow with an “intelligent photo album” analogy, and then return to RoPE afterward.


4. Understanding MLA as a “Low-Rank Thumbnail” in an Intelligent Photo Album

Even if you appreciate MLA’s formulas, you might still find them abstract. Let’s use a more intuitive metaphor: treat each “Token” as a “photograph,” “multi-head attention” as “filters applied to photos,” and “KV Cache” as “album storage.” This analogy shows how MLA achieves compressed storage and on-demand decompression.

4.1 Photo Storage: A Low-Rank Thumbnail

Imagine every time you snap a photo (process a token $\mathbf{x}_i$), instead of saving the “full-resolution image plus all filters,” you only keep a “small-yet-informative thumbnail”—the latent vector $\mathbf{c}_i$. For instance, if the original image resolution is 4096 (like $d=4096$ in MLA), the thumbnail might be 512 in size, achieving roughly 1/8 the original.
Mathematically, \(\mathbf{c}_i = \mathbf{x}_i \mathbf{W}_c, \quad \mathbf{W}_c \in \mathbb{R}^{4096 \times 512}.\) This is akin to “downsampling the photo at capture time,” drastically cutting storage overhead.

4.2 Viewing Photos: Real-Time Decompression

When you “view” a photo with a certain filter—corresponding to attention heads that generate Key/Value—MLA does: \(\mathbf{k}_i^{(s)} = \mathbf{c}_i \mathbf{W}_{kc}^{(s)}, \quad \mathbf{v}_i^{(s)} = \mathbf{c}_i \mathbf{W}_v^{(s)}.\) So no matter how many filters (heads) exist, you only keep the thumbnail ($\mathbf{c}_i$) rather than multiple versions of the same image. At inference, each filter’s parameters can reconstruct the Key or Value from that thumbnail, resulting in massive storage reduction.

4.3 On-Demand Reconstruction: Merging at Calculation Time

In actual inference, a “merge” step is performed right when computing the dot product:

\[\mathbf{q}_t^{(s)} \mathbf{k}_i^{(s)\top} = (\mathbf{x}_t \mathbf{W}_q^{(s)})(\mathbf{c}_i \mathbf{W}_{kc}^{(s)})^\top \approx \mathbf{x}_t \mathbf{W}_{\mathrm{merged}}^{(s)} \mathbf{c}_i^\top.\]

Hence, the filter parameters $\mathbf{W}_{kc}^{(s)}$ do not have to be stored in the “album” for each picture (head). Only the compact thumbnail $\mathbf{c}_i$ remains. This approach omits the repeated Key/Value overhead that standard multi-head attention requires.


5. The RoPE Challenge: Why Add a “Position Sticker”?

Now that we’ve covered MLA’s low-rank approach and on-demand reconstruction, we must address a practical reality. Transformers often rely on RoPE (Rotary Position Embedding) to incorporate positional information into Key/Query. RoPE complicates the straightforward “latent vector only” approach, since each token’s position $i$ introduces a rotation matrix $\mathbf{\mathcal{R}}_i$ that affects the dot product differently.

5.1 RoPE: Timestamp and GPS Coordinates

Returning to our album analogy, RoPE becomes the “timestamp or GPS location” on each photo. Besides the core visual content ($\mathbf{c}_i$), we retain a small sticker that encodes when and where the photo was taken. If we tried to embed the time data directly into the thumbnail, relative distances (time differences) might be lost. Thus, in MLA, a portion of K/Query dimension remains explicitly multiplied by $\mathbf{\mathcal{R}}_i$, preserving relative position even under our low-rank scheme.

5.2 Split Strategy: $\mathbf{c}_i$ + a Small RoPE Dimension

In formal terms, to preserve the rotation property $\mathbf{\mathcal{R}}m \mathbf{\mathcal{R}}_n^\top = \mathbf{\mathcal{R}}{m-n}$, MLA splits each Key (and similarly Query) into two parts:

\[\mathbf{k}_i^{(s)}=\underbrace{\bigl(\mathbf{c}_i \mathbf{W}_{kc}^{(s)}\bigr)}_{\text{compressed portion}}\;\oplus\;\underbrace{\bigl(\mathbf{x}_i \mathbf{W}_{kr}\,\mathbf{\mathcal{R}}_i\bigr)}_{\text{positional portion}},\]

so the KV-Cache only grows by a modest “position dimension,” while the main storage remains the latent vector $\mathbf{c}_i$. This design deftly fuses low-rank projection with rotary embeddings for relative positions.


6. The Comprehensive Benefits of MLA: Storage Innovation, Flexible Retrieval, and Spatial-Temporal Fidelity

After breaking MLA down into core steps, we can see three major upsides:

  • Storage: Traditional multi-head attention must store K/V for every head. MLA instead uses a latent vector $\mathbf{c}_i$ (low-rank) plus a small dimension for RoPE if needed. VRAM can shrink by factors of several times or more.
  • Computation: With on-demand decompression, Key/Value are reconstructed only when necessary—this can be further optimized by merging them into Query. For very long sequences, memory bandwidth is the real bottleneck, so reducing repeated K/V significantly speeds inference.
  • Position: When a model requires relative position (RoPE), MLA can keep a separate “position sticker” dimension. This preserves temporal/spatial information without forcing the entire K space to be stored separately.

7. From an Engineering Angle: Key Considerations

7.1 Balancing VRAM and Speed

If your application involves sequences of thousands or tens of thousands of tokens, MLA’s latent compression helps reduce the KV-Cache drastically and allows a single GPU to process more tokens or bigger batch sizes, thus boosting throughput.

7.2 Tuning RoPE Dimensions

When splitting K into a low-rank zone ($\mathbf{c}_i$) vs. RoPE zone, if the RoPE dimension is too small, extremely long contexts may not get enough positional signal. Conversely, if it’s too large, MLA’s compression advantages diminish. The optimal trade-off typically emerges from empirical experiments.

7.3 Numerical Stability and Precision

Because MLA merges weight matrices at inference time, using BF16/FP16 can introduce small accumulative errors from changing multiply orders. Generally, it’s acceptable. If your application is ultra-sensitive to accuracy, consider higher precision accumulators or partial float32 fallback.


8. Overall Summary and Future Directions

MLA (Multi-Head Latent Attention) is more than merely “low-rank factorization of K/V.” It pushes further by caching only a latent vector $\mathbf{c}_i$ in inference, reconstructing multi-head K/V via on-demand decompression and matrix merging—cutting KV-Cache usage drastically. Then, by applying a split strategy for RoPE, MLA retains relative positional information without forcing the entire K/V to remain explicit.

From an engineering perspective, MLA’s VRAM efficiency for long-context LLM inference is a big advantage, potentially expanding the number of tokens handled on a single card or small cluster. Deciding exactly how to partition the dimension between latent vectors and RoPE, however, depends on your task and model scale. This concept can also extend to other positional encodings (ALiBi, NTK Scaling, etc.) or specialized domains.

Regardless, MLA clearly shows a fresh and powerful path for reducing KV-Cache. We’ll likely see more variants of MLA combined with other attention optimizations, helping large models achieve even more performance under strict hardware constraints.


Appendix: Key Formulas and Their “Photo Album” Analogy

Below are the crucial formulas in MLA, alongside how they match the intelligent photo album metaphor:

  1. Low-Rank Projection

    \[\mathbf{c}_i = \mathbf{x}_i \mathbf{W}_c\] \[\quad\updownarrow\quad\] \[\text{“Store a thumbnail instead of a full photo,” drastically shrinking storage.}\]
  2. Dynamic Decompression

    \[\mathbf{k}_i^{(s)} = \mathbf{c}_i \mathbf{W}_{kc}^{(s)}, \quad \mathbf{v}_i^{(s)} = \mathbf{c}_i \mathbf{W}_v^{(s)}\] \[\quad\updownarrow\quad\] \[\text{“Generate filtered views (Key/Value) from the single thumbnail at runtime.”}\]
  3. On-Demand Reconstruction (Merging Matrices)

    \[\mathbf{q}_t^{(s)} \mathbf{k}_i^{(s)\top}=(\mathbf{x}_t \mathbf{W}_q^{(s)}) (\mathbf{c}_i \mathbf{W}_{kc}^{(s)})^\top\] \[\quad\updownarrow\quad\] \[\text{“Combine the filter math with the viewing step, further reducing storage.”}\]
  4. RoPE Split

    \[\mathbf{k}_i^{(s)} = (\mathbf{c}_i \mathbf{W}_{kc}^{(s)}) \;\oplus\; (\mathbf{x}_i \mathbf{W}_{kr}\mathbf{\mathcal{R}}_i)\] \[\quad\updownarrow\quad\] \[\text{“Keep a separate small label for timestamps or GPS, i.e., relative position.”}\]

Reading through these steps, you can see how MLA seamlessly progresses from low-rank projection to on-demand Key/Value restoration and finally to coexisting with RoPE for position encoding. As its name implies, MLA both retains the potent expressiveness of multi-head attention while delegating the main K/V storage burden to a smaller latent representation—enabling longer contexts under limited GPU memory.


9. A Minimal Working Example of MLA

To give a more hands-on feel for MLA’s core concepts, here is a minimal, runnable code sample of an MLA-based MoE Transformer, illustrating how “latent variables,” “on-demand reconstruction,” and “RoPE integration” show up in an actual code structure. In the ModelArgs data class, the field attn_impl: Literal["naive", "absorb"] = "absorb" governs whether we store classical K,V in a naive style or rely on MLA’s latent caching (absorb). For the complete functionality, see the DeepSeek Official Repo.

Below is the consolidated example:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
import math
from dataclasses import dataclass
from typing import Tuple, Optional, Literal

import torch
import torch.nn.functional as F
from torch import nn

@dataclass
class ModelArgs:
    max_batch_size: int = 8
    max_seq_len: int = 4096 * 4
    vocab_size: int = 102400
    dim: int = 2048
    inter_dim: int = 10944
    moe_inter_dim: int = 1408
    n_layers: int = 2
    n_dense_layers: int = 1
    n_heads: int = 16
    # moe
    n_routed_experts: int = 64
    n_shared_experts: int = 2
    n_activated_experts: int = 6
    n_expert_groups: int = 1
    n_limited_groups: int = 1
    score_func: Literal["softmax", "sigmoid"] = "softmax"
    route_scale: float = 1.
    # mla
    q_lora_rank: int = 0
    kv_lora_rank: int = 512
    qk_nope_head_dim: int = 128
    qk_rope_head_dim: int = 64
    v_head_dim: int = 128
    # rope
    original_seq_len: int = 4096
    rope_theta: float = 10000.0
    rope_factor: float = 40
    beta_fast: int = 32
    beta_slow: int = 1
    mscale: float = 1.
    # kv-cache
    attn_impl: Literal["naive", "absorb"] = "absorb"

@dataclass
class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.dim = dim
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
        return x / rms * self.weight

def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
    return F.linear(x, weight, bias)


class Linear(nn.Module):
    def __init__(self, in_features: int, out_features: int, bias: bool = False):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = nn.Parameter(torch.empty(out_features, in_features))
        nn.init.xavier_normal_(self.weight)
        if bias:
            self.bias = nn.Parameter(torch.empty(out_features))
            nn.init.zeros_(self.bias)
        else:
            self.register_parameter("bias", None)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return linear(x, self.weight, self.bias)


def precompute_freqs_cis(args: ModelArgs) -> torch.Tensor:
    dim = args.qk_rope_head_dim
    seqlen = args.max_seq_len
    beta_fast = args.beta_fast
    beta_slow = args.beta_slow
    base = args.rope_theta
    factor = args.rope_factor

    def find_correction_dim(num_rotations, dim, base, max_seq_len):
        return dim * math.log(max_seq_len / (num_rotations * 2 * math.pi)) / (2 * math.log(base))

    def find_correction_range(low_rot, high_rot, dim, base, max_seq_len):
        low = math.floor(find_correction_dim(low_rot, dim, base, max_seq_len))
        high = math.ceil(find_correction_dim(high_rot, dim, base, max_seq_len))
        return max(low, 0), min(high, dim - 1)

    def linear_ramp_factor(min_val, max_val, dim):
        if min_val == max_val:
            max_val += 0.001
        linear_func = (torch.arange(dim, dtype=torch.float32) - min_val) / (max_val - min_val)
        ramp_func = torch.clamp(linear_func, 0, 1)
        return ramp_func

    freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
    if seqlen > args.original_seq_len:
        low, high = find_correction_range(beta_fast, beta_slow, dim, base, args.original_seq_len)
        smooth = 1 - linear_ramp_factor(low, high, dim // 2)
        freqs = freqs / factor * (1 - smooth) + freqs * smooth

    t = torch.arange(seqlen)
    freqs = torch.outer(t, freqs)
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
    return freqs_cis


def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
    dtype = x.dtype
    x = torch.view_as_complex(x.float().view(*x.shape[:-1], -1, 2))
    freqs_cis = freqs_cis.view(1, x.size(1), 1, x.size(-1))
    y = torch.view_as_real(x * freqs_cis).flatten(3)
    return y.to(dtype)


class MLA(nn.Module):
    """
    Multi-Headed Attention Layer (MLA).
    """
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.dim = args.dim
        self.n_heads = args.n_heads
        self.n_local_heads = self.n_heads
        self.q_lora_rank = args.q_lora_rank
        self.kv_lora_rank = args.kv_lora_rank
        self.qk_nope_head_dim = args.qk_nope_head_dim
        self.qk_rope_head_dim = args.qk_rope_head_dim
        self.qk_head_dim = args.qk_nope_head_dim + args.qk_rope_head_dim
        self.v_head_dim = args.v_head_dim
        self.attn_impl = args.attn_impl

        # Q
        if self.q_lora_rank == 0:
            self.wq = Linear(self.dim, self.n_heads * self.qk_head_dim)
        else:
            self.wq_a = Linear(self.dim, self.q_lora_rank)
            self.q_norm = RMSNorm(self.q_lora_rank)
            self.wq_b = Linear(self.q_lora_rank, self.n_heads * self.qk_head_dim)

        # K,V
        self.wkv_a = Linear(self.dim, self.kv_lora_rank + self.qk_rope_head_dim)
        self.kv_norm = RMSNorm(self.kv_lora_rank)
        self.wkv_b = Linear(
            self.kv_lora_rank, self.n_heads * (self.qk_nope_head_dim + self.v_head_dim)
        )

        # Output
        self.wo = Linear(self.n_heads * self.v_head_dim, self.dim)

        self.softmax_scale = self.qk_head_dim ** -0.5
        if args.max_seq_len > args.original_seq_len:
            mscale = 0.1 * args.mscale * math.log(args.rope_factor) + 1.0
            self.softmax_scale *= mscale * mscale

        # register different buffer based on the choice of "attn_impl"
        if self.attn_impl == "naive":
            self.register_buffer(
                "k_cache",
                torch.zeros(args.max_batch_size, args.max_seq_len, self.n_local_heads, self.qk_head_dim),
                persistent=False
            )
            self.register_buffer(
                "v_cache",
                torch.zeros(args.max_batch_size, args.max_seq_len, self.n_local_heads, self.v_head_dim),
                persistent=False
            )
        else:
            self.register_buffer(
                "kv_cache",
                torch.zeros(args.max_batch_size, args.max_seq_len, self.kv_lora_rank),
                persistent=False
            )
            self.register_buffer(
                "pe_cache",
                torch.zeros(args.max_batch_size, args.max_seq_len, self.qk_rope_head_dim),
                persistent=False
            )


    def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):
        bsz, seqlen, _ = x.size()
        end_pos = start_pos + seqlen

        if self.q_lora_rank == 0:
            q = self.wq(x)
        else:
            q = self.wq_b(self.q_norm(self.wq_a(x)))
        q = q.view(bsz, seqlen, self.n_local_heads, self.qk_head_dim)
        q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
        q_pe = apply_rotary_emb(q_pe, freqs_cis)

        kv = self.wkv_a(x)
        kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
        k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis)

        if self.attn_impl == "naive":
            # naive mode
            q = torch.cat([q_nope, q_pe], dim=-1)
            kv = self.wkv_b(self.kv_norm(kv))  # (bsz, seqlen, n_heads*(qk_nope_head_dim + v_head_dim))
            kv = kv.view(bsz, seqlen, self.n_local_heads, self.qk_nope_head_dim + self.v_head_dim)

            k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
            k = torch.cat([k_nope, k_pe.expand(-1, -1, self.n_local_heads, -1)], dim=-1)

            # write cache
            self.k_cache[:bsz, start_pos:end_pos] = k
            self.v_cache[:bsz, start_pos:end_pos] = v

            scores = torch.einsum("bshd,bthd->bsht", q, self.k_cache[:bsz, :end_pos]) * self.softmax_scale

            if mask is not None:
                scores += mask.unsqueeze(1)

            scores = scores.softmax(dim=-1, dtype=torch.float32).type_as(x)
            out = torch.einsum("bsht,bthd->bshd", scores, self.v_cache[:bsz, :end_pos])

        else:
            # the absorb mode proposed in MLA
            wkv_b = self.wkv_b.weight 
            wkv_b = wkv_b.view(self.n_local_heads, -1, self.kv_lora_rank)

            self.kv_cache[:bsz, start_pos:end_pos] = self.kv_norm(kv)
            self.pe_cache[:bsz, start_pos:end_pos] = k_pe.squeeze(2)

            q_nope = torch.einsum(
                "bshd,hdc->bshc",
                q_nope,
                wkv_b[:, :self.qk_nope_head_dim] 
            )

            scores = (
                torch.einsum("bshc,btc->bsht", q_nope, self.kv_cache[:bsz, :end_pos]) +
                torch.einsum("bshr,btr->bsht", q_pe, self.pe_cache[:bsz, :end_pos])
            ) * self.softmax_scale

            if mask is not None:
                scores += mask.unsqueeze(1)

            scores = scores.softmax(dim=-1, dtype=torch.float32).type_as(x)

            out = torch.einsum("bsht,btc->bshc", scores, self.kv_cache[:bsz, :end_pos])
            out = torch.einsum(
                "bshc,hdc->bshd",
                out,
                wkv_b[:, -self.v_head_dim:]
            )

        out = self.wo(out.flatten(2))
        return out


class MLP(nn.Module):

    def __init__(self, dim: int, inter_dim: int):
        super().__init__()
        self.w1 = Linear(dim, inter_dim)
        self.w2 = Linear(inter_dim, dim)
        self.w3 = Linear(dim, inter_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.w2(F.silu(self.w1(x)) * self.w3(x))


class Expert(nn.Module):
    """
    Expert layer for Mixture-of-Experts (MoE) models.
    """

    def __init__(self, dim: int, inter_dim: int):
        super().__init__()
        self.w1 = Linear(dim, inter_dim)
        self.w2 = Linear(inter_dim, dim)
        self.w3 = Linear(dim, inter_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.w2(F.silu(self.w1(x)) * self.w3(x))


class Gate(nn.Module):
    """
    Gating mechanism for routing inputs in a mixture-of-experts (MoE) model.
    """

    def __init__(self, args: ModelArgs):
        super().__init__()
        self.dim = args.dim
        self.topk = args.n_activated_experts
        self.n_groups = args.n_expert_groups
        self.topk_groups = args.n_limited_groups
        self.score_func = args.score_func
        self.route_scale = args.route_scale
        self.weight = nn.Parameter(torch.empty(args.n_routed_experts, args.dim))
        nn.init.xavier_normal_(self.weight)
        self.bias = nn.Parameter(torch.zeros(args.n_routed_experts)) if self.dim == 7168 else None

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        scores = linear(x, self.weight, self.bias)

        if self.score_func == "softmax":
            scores = scores.softmax(dim=-1, dtype=torch.float32)
        else:
            scores = scores.sigmoid()

        original_scores = scores
        if self.n_groups > 1:
            scores = scores.view(x.size(0), self.n_groups, -1)
            if self.bias is None:
                group_scores = scores.amax(dim=-1)
            else:
                group_scores = scores.topk(2, dim=-1)[0].sum(dim=-1)
            indices_groups = group_scores.topk(self.topk_groups, dim=-1)[1]
            mask = torch.zeros_like(scores[..., 0]).scatter_(1, indices_groups, True)
            scores = (scores * mask.unsqueeze(-1)).flatten(1)

        indices = torch.topk(scores, self.topk, dim=-1)[1]
        weights = original_scores.gather(1, indices)
        if self.score_func == "sigmoid":
            weights /= weights.sum(dim=-1, keepdim=True)
        weights *= self.route_scale
        return weights.type_as(x), indices


class MoE(nn.Module):
    """
    Mixture-of-Experts (MoE) module.
    """

    def __init__(self, args: ModelArgs):
        super().__init__()
        self.dim = args.dim
        self.n_routed_experts = args.n_routed_experts
        self.n_activated_experts = args.n_activated_experts
        self.gate = Gate(args)

        self.experts = nn.ModuleList([
            Expert(args.dim, args.moe_inter_dim) for _ in range(self.n_routed_experts)
        ])
        self.shared_experts = MLP(args.dim, args.n_shared_experts * args.moe_inter_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        shape = x.size()
        x = x.view(-1, self.dim)
        weights, indices = self.gate(x)
        y = torch.zeros_like(x)

        counts = torch.bincount(indices.flatten(), minlength=self.n_routed_experts).tolist()
        for i in range(self.n_routed_experts):
            if counts[i] == 0:
                continue
            expert = self.experts[i]
            idx, top = torch.where(indices == i)
            y[idx] += expert(x[idx]) * weights[idx, top, None]

        z = self.shared_experts(x)
        return (y + z).view(shape)


class Block(nn.Module):
    """
    Transformer block combining attention and feed-forward layers.
    """

    def __init__(self, layer_id: int, args: ModelArgs):
        super().__init__()
        self.attn = MLA(args)
        self.ffn = MLP(args.dim, args.inter_dim) if layer_id < args.n_dense_layers else MoE(args)
        self.attn_norm = RMSNorm(args.dim)
        self.ffn_norm = RMSNorm(args.dim)

    def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor,
                mask: Optional[torch.Tensor]) -> torch.Tensor:
        x = x + self.attn(self.attn_norm(x), start_pos, freqs_cis, mask)
        x = x + self.ffn(self.ffn_norm(x))
        return x


class Transformer(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.max_seq_len = args.max_seq_len
        self.embed = torch.nn.Embedding(args.vocab_size, args.dim)
        self.layers = torch.nn.ModuleList()
        for layer_id in range(args.n_layers):
            self.layers.append(Block(layer_id, args))
        self.norm = RMSNorm(args.dim)
        self.head = Linear(args.dim, args.vocab_size)

        self.register_buffer("freqs_cis", precompute_freqs_cis(args), persistent=False)

    @torch.inference_mode()
    def forward(self, tokens: torch.Tensor, start_pos: int = 0):
        seqlen = tokens.size(1)
        h = self.embed(tokens)
        freqs_cis = self.freqs_cis[start_pos:start_pos + seqlen]

        mask = None
        if seqlen > 1:
            mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device).triu_(1)

        for layer in self.layers:
            h = layer(h, start_pos, freqs_cis, mask)

        h = self.norm(h)[:, -1]
        logits = self.head(h)
        return logits


if __name__ == "__main__":
    torch.manual_seed(0)
    args = ModelArgs()
    x = torch.randint(0, args.vocab_size, (2, 128))
    model = Transformer(args)
    logits = model(x)
    print(logits.shape)  # (batch_size, vocab_size)

引言

在大语言模型繁荣的时代,硬件资源依然是绕不过去的“天花板”——特别是 GPU 显存有限,如何在有限资源下让模型拥有更长的上下文、更快的推理速度,一直是工程与研究领域关注的焦点。除了常见的量化、剪枝,越来越多人也将目光投向 “减少推理时 KV-Cache 占用” 这个方向。

本文将先回顾 MHA(Multi-Head Attention)、MQA(Multi-Query Attention)与 GQA(Grouped-Query Attention)在共享或减少 K/V 方面的思考与取舍,然后聚焦于最新的 MLA(Multi-Head Latent Attention)。不同于前者仅在“共享 K/V”层面做文章,MLA 通过低秩投影与按需解压的组合,让 KV-Cache 不再直接存储多头 K/V,而是改为 “潜变量” 形式;在此基础上,它还能进一步使用“合并矩阵”的技巧,在推理阶段仅依赖极少的显存来完成注意力计算。

需要说明的是,MLA 在实际落地时往往还需兼容 RoPE(Rotary Position Embedding)这一位置编码方案;但为了让读者先理解 MLA 的核心(低秩投影本身),我们会在 介绍完 MLA 的主干思路 后,再探讨 RoPE 与 MLA 的结合方式。希望这样的安排能使你感受到 MLA 设计背后的层次与巧思。

特别感谢:本文借鉴了苏剑林的博客缓存与效果的极限拉扯:从MHA、MQA、GQA到MLA中的一些想法,向苏建林老师致敬。


1. 为什么要减少 KV-Cache?

1.1 长序列推理中显存的“隐形杀手”

自回归 Transformer 每生成一个新 Token,都需要回顾之前所有 Token 的 Key / Value ($K, V$) 做注意力;这部分历史 K/V 在推理阶段被缓存下来,即所谓 KV-Cache。如果序列长度为 $L$、注意力头数为 $h$,单头维度为 $d_k$ 或 $d_v$,则 KV-Cache 的量级约为 \(L \times h \times (d_k + d_v),\) 并随 $L$ 线性增长。当上下文长度轻松达到数千到上万时,KV-Cache 成为显存占用的主要瓶颈,比模型本身的激活值更为庞大。

1.2 显存与带宽的双重约束

当序列过长,“一张卡”往往无法容纳全部 KV-Cache,需要多卡或多机分担。但跨卡和跨机的带宽远低于卡内带宽,导致推理速度显著变慢。简而言之,如果能在单卡或少卡上实现更长的上下文,就能节省通信开销,提升推理吞吐量。这也是近年在 MQA / GQA 以及 MLA 等思路上不断迭代的根源。


2. MHA → MQA → GQA:多头注意力中 K/V 的简化历程

在聚焦 MLA 之前,让我们先快速对多头注意力(MHA)以及后来的共享式注意力(MQA、GQA)做一个小结,为后续对比 MLA 奠定基础。

2.1 多头注意力(MHA)的原点

2.1.1 经典注意力公式

在 Transformer 中,一段序列 $\mathbf{x}_1,\dots,\mathbf{x}_l$ 会映射到多组 $(Q, K, V)$ 来计算注意力。以第 $s$ 个注意力头为例,令隐藏维度为 $d$: \(\mathbf{q}_i^{(s)} = \mathbf{x}_i \,\mathbf{W}_q^{(s)},\quad \mathbf{k}_i^{(s)} = \mathbf{x}_i \,\mathbf{W}_k^{(s)},\quad \mathbf{v}_i^{(s)} = \mathbf{x}_i \,\mathbf{W}_v^{(s)}.\) 此时若进行自回归计算,第 $t$ 步的注意力分数往往写作 \(\alpha_{t,i}^{(s)} = \mathbf{q}_t^{(s)} \,\mathbf{k}_i^{(s)\top}, \quad\text{其中 } i \le t.\) 在推理时,为加速计算,我们把已经算出的 $\mathbf{k}_i^{(s)},\mathbf{v}_i^{(s)}$ 暂存于显存里,供后续 Token 使用,这部分显存缓存就称为 KV-Cache

2.1.2 显存压力

由于 MHA 中每个 Head 都独立存储 K、V,若 Head 数 $h$ 较大,就得存下 $h$ 份相同长度的 Key/Value,显存消耗惊人。为此,研究者开始思考:是不是能让多头注意力在 K/V 层面也做一些共享合并压缩

2.2 MQA:极端共享 K/V

MQA(Multi-Query Attention)在 2019 年提出,核心是让所有 Head 只用一套 K/V: \(\mathbf{k}_i = \mathbf{x}_i \mathbf{W}_k,\quad \mathbf{v}_i = \mathbf{x}_i \mathbf{W}_v,\) 而各 Head 仍可有不同 $\mathbf{q}_i^{(s)}$。这样,KV-Cache 只需存 1 份 K/V,显存占用是原来的 $\tfrac{1}{h}$。在 PaLM、StarCoder 等模型中有使用。不过由于 K、V 过度共享,有些任务上的精度可能下降,需要额外训练技巧弥补。

2.3 GQA:分组共享

如果觉得 MQA 太极端,可以用 GQA(Grouped-Query Attention)折中:把 $h$ 个 Head 划分为 $g$ 组,同一组内共享 K/V;这样 KV-Cache 缩减到原先的 $g/h$,保留一定多样性。LLaMA2-70B、ChatGLM2/3 等是这一路线的代表。

2.4 小对比:MHA / MQA / GQA

方案 KV-Cache 存储 是否共享 K/V 显存省 多头灵活度
MHA 存 $h$ 份 各 Head 独立 低 (基线)
MQA 存 1 份 所有 Head 完全共享
GQA 存 $g$ 份 分组共享 中等 较高

无论是 MQA 还是 GQA,都还处在“是否共享 K/V”这个思路。MLA 则从根本上改变“推理时存什么”:它将大多数 K/V 信息转移到一个潜变量里,并在需要时才“还原”出来。下面让我们首先看看没有 RoPE 干扰时,MLA 如何实现低秩投影与按需解压。


3. MLA 的核心:低秩投影与按需还原(不含 RoPE)

3.1 基本思路:改“存多头 K/V”为“存低维潜变量”

MLA(Multi-Head Latent Attention)里,训练阶段依然会对输入做投影,得到各个 Head 的 Key、Value。但它引入一个低秩潜变量 $\mathbf{c}_i$,在推理时只需要缓存这个维度远小于 $d$ 的向量,就能借助一套“矩阵合并”来临时得到多头的 K、V。如果只考虑不带位置编码(RoPE)的核心公式,MLA 的训练过程可以描述为:

\[\mathbf{c}_i = \mathbf{x}_i \mathbf{W}_c\quad (\text{低秩投影, } \mathbf{W}_c \in \mathbb{R}^{d \times d_c}),\]

并且对每个 Head $s$,定义解压矩阵 $\mathbf{W}_{kc}^{(s)}, \mathbf{W}_v^{(s)}$,使得

\[\mathbf{k}_i^{(s)} = \mathbf{c}_i \mathbf{W}_{kc}^{(s)}, \quad\mathbf{v}_i^{(s)} = \mathbf{c}_i \mathbf{W}_v^{(s)}.\]

这样一来,若 Head 数再多,也只是训练时会显式生成多份 K/V;但在推理时,我们不必缓存所有 $\mathbf{k}_i^{(s)}, \mathbf{v}_i^{(s)}$,而只要保留潜变量 $\mathbf{c}_i$。需要计算时,通过合并矩阵乘法就能“还原”出各头的 K、V。由此可见,MLA 很大程度上摆脱了“多头数 $h$ 与 KV-Cache 成正比”这件事。

3.2 动态解压:显存怎么省?

推理中,每当生成第 $t$ 个 Token,需要与历史 $i < t$ 的 Key 做点积 $\mathbf{q}_t^{(s)} \mathbf{k}_i^{(s)\top}$。传统方法下要在显存中直接拿 $\mathbf{k}_i^{(s)}$,但在 MLA 里我们“合并”了

\[\mathbf{k}_i^{(s)}=\mathbf{c}_i\,\mathbf{W}_{kc}^{(s)}\implies\mathbf{q}_t^{(s)} \mathbf{k}_i^{(s)\top}=(\mathbf{x}_t \mathbf{W}_q^{(s)}) (\mathbf{c}_i \mathbf{W}_{kc}^{(s)})^\top.\]

根据矩阵乘法,我们可以做如下合并:

\[(\mathbf{x}_t \mathbf{W}_q^{(s)}) (\mathbf{c}_i \mathbf{W}_{kc}^{(s)})^\top = \mathbf{x}_t (\mathbf{W}^{(s)}_q \mathbf{W}^{(s)\top}_{kc}) (\mathbf{c}_i)^\top = \mathbf{x}_t \mathbf{W}_{\mathrm{merged}}^{(s)} (\mathbf{c}_i)^\top\]

从而只需要在显存中保留 $\mathbf{c}_i$,而无需分别保存 $\mathbf{k}_i^{(s)}$。显然,如果 $\mathbf{c}_i \in \mathbb{R}^{d_c}$ 且 $d_c \ll h \times d_k$,那么KV-Cache 的占用就从原先按 $h \times d_k$ 线性增长,变成按 $d_c$ 成本增长

3.3 低秩投影如何大幅压缩存储?

有些读者可能会好奇,这种低秩投影压缩率有多大?取一个示例,如果原模型维度 $d=4096$、多头数 $h=32$、单头 $d_k=128$;那么传统 MHA 每个 Token 要存下 $32 \times 128 = 4096$ 维 Key(加上 Value 也是同量级),而 MLA 可以把潜变量 $\mathbf{c}_i$ 设成 512 维,显存需求就从 4096 缩减到 512,几乎减少 8 倍。在某些更极端场景,压缩比例可高达几十甚至数百倍。

当然,这只是一个不带位置编码的美好设想。实际上,Transformer 常用的 RoPE(Rotary Position Embedding)会影响 K、Q 的投影方式,使单纯“潜变量 + 合并矩阵”无法简单替代。为了让读者先理解 MLA 的低秩核心,我们暂不展开 RoPE;接下来会从一个“相册系统”的比喻帮助你理清 MLA 的运作流程,然后再单独讨论 RoPE 问题。


4. 从智能相册系统看 MLA

在说完 MLA 的数学公式后,或许你对“潜变量 $\mathbf{c}_i$ + 投影矩阵”仍然停留在抽象层面。下面让我们用一个生活化的类比:把“Token”比做“拍摄的照片”,把“多头注意力”比做“给照片套滤镜”,而把“KV Cache”看作“相册存储空间”,展示 MLA 如何在这个体系里实现“压缩存储”与“动态解压”。

4.1 拍照存储:低秩投影

想象每次拍照(即处理一个 Token $\mathbf{x}_i$),我们并不把它保存为“多头滤镜后的完整分辨率原图”,而是只保留一个“体积小但保留关键信息”的智能缩略图(潜变量 $\mathbf{c}_i$)。类似地,若原始图像分辨率是 4096,对应 MLA 里 $d=4096$,缩略图可能只有 512,这会带来约 1/8 甚至更高的压缩比例。
在数学上,这就是

\[\mathbf{c}_i =\mathbf{x}_i \,\mathbf{W}_c, \quad \mathbf{W}_c \in \mathbb{R}^{4096 \times 512},\]

扮演了“拍照时做一次低秩投影”的作用。

4.2 浏览照片:实时动态解压

当我们要“浏览”某张照片,并试图把它变成“某种滤镜版本”时,可以理解为要计算特定的 Key/Value(比如复古滤镜、HDR 滤镜,对应不同 Head 的 K、V)。在 MLA 中,这个过程相当于

\[\mathbf{k}_i^{(s)} = \mathbf{c}_i \,\mathbf{W}_{kc}^{(s)}, \quad \mathbf{v}_i^{(s)} = \mathbf{c}_i \,\mathbf{W}_v^{(s)},\]

s也就是把缩略图 $\mathbf{c}_i$ 投影到你想要的滤镜。因此,哪怕有十几种滤镜(多头),我们也不必重复存十几份滤镜版原图,而是只存一张缩略图($\mathbf{c}_i$),在“浏览”时按需解压。这样,就达成了“相册容量极度压缩”的目标。

4.3 动态解压的数学对应:按需还原

在实际推理中,“合并矩阵”的过程就像“把滤镜的计算提前或合并到查看行为”中,用公式来说: \(\mathbf{q}_t^{(s)} \mathbf{k}_i^{(s)\top} = \bigl(\mathbf{x}_t \mathbf{W}_q^{(s)}\bigr) \bigl(\mathbf{c}_i \mathbf{W}_{kc}^{(s)}\bigr)^\top \approx \mathbf{x}_t \,\mathbf{W}_{\mathrm{merged}}^{(s)}\,\mathbf{c}_i^\top.\) 换句话说,滤镜参数 $\mathbf{W}_{kc}^{(s)}$ 不再需要静态保存在相册中,而是只在查看时临时计算。最终,我们的显存里只保留这个“压缩向量” $\mathbf{c}_i$,由此大幅度减少存储成本。

通过这个相册类比,你会发现 MLA 本质上避免了对每个注意力头都存一整份 Key/Value,而是将 K/V 的大部分信息浓缩到一个低秩潜变量里,让“滤镜”(多头投影矩阵)在推理时即时生效。


5. RoPE 的挑战:为何要再加“位置贴纸”?

到这里,MLA 的低秩投影与按需解压思路已相对清晰。不过,在真实的 Transformer 应用中,常见的 RoPE(Rotary Position Embedding)会为 Key、Query 添加一个位置相关的旋转矩阵 $\mathbf{\mathcal{R}}_i$,以便实现相对位置的建模。但这会打破前文的“把所有头的 K 都只存潜变量”做法。简单来说,RoPE 需要显式知道某个 Token 的位置索引 $i$,并且这种位置索引会影响点积公式里的“旋转相减”性质。

5.1 RoPE:拍摄时间与 GPS 坐标

若再回到“相册系统”的比喻,可以把 RoPE 理解成“拍摄时间和 GPS 坐标”:除了基本图像信息($\mathbf{c}_i$)之外,我们还需要单独存一张小贴纸,用来记录照片的拍摄地点或时间。这样,在对不同照片进行对比或排序时,就能根据“时间差 / 距离差”进行更准确的比较。
一旦我们尝试把拍摄时间直接混进“缩略图”里,会让相对位置的计算变得复杂或失效。这就是为什么 MLA 在面对 RoPE 时,往往把 Key/Query 的维度分成两部分:一部分来自共享潜变量,另一部分仍显式乘以 RoPE 矩阵,以保留位置信息。

5.2 分治策略:$\mathbf{c}_i$ + RoPE 小维度

在数学层面,为了保留旋转矩阵 $\mathbf{\mathcal{R}}_i$ 的相减性质,MLA 会把 $\mathbf{k}_i^{(s)}$(以及 $\mathbf{q}_i^{(s)}$)写成: \(\mathbf{k}_i^{(s)}=\underbrace{\bigl(\mathbf{c}_i \,\mathbf{W}_{kc}^{(s)}\bigr)}_{\text{压缩区}} \;\oplus\;\underbrace{\bigl(\mathbf{x}_i \,\mathbf{W}_{kr} \,\mathbf{\mathcal{R}}_i\bigr)}_{\text{位置区}},\) 从而,在 KV-Cache 中只多出一个较小的“位置区”,而主部分的存储成本仍维持在潜变量 $\mathbf{c}_i$ 上,这就巧妙地把“低秩投影”与“RoPE 相对位置”结合起来。


6. MLA 的综合优势:存储革命、灵活查询、时空保真

经过前面两节的拆分介绍,我们可以再次回到一个全局视角,看看 MLA 在“存储革命”“动态解压”和“时空保真”三个层面所带来的提升。

从存储角度看,原本多头注意力在推理阶段必须为每个 Head 保留独立 K/V;而在 MLA 中,大部分信息用低秩向量 $\mathbf{c}_i$ 表示,尺寸极小,显存可以减少数倍甚至数十倍。真正做到了一种“拍照时直接存缩略图”的思路,节省了相册(KV-Cache)空间。

从计算角度看,MLA 借助“按需解压”,只在每次计算点积时根据 $\mathbf{c}_i$ 恢复 Key/Value(或者更进一步地把 Key 合并到 Query 一侧),显存-计算的折中比单纯缩放 Head 更灵活。而且在长序列场景中,带宽消耗往往才是主要瓶颈,所以 MLA 的“KV 缓存减量”常能带来显著推理提速。

从位置角度看,如果我们需要兼容相对位置编码(RoPE),可以让 Key/Query 的一部分维度显式乘上 $\mathbf{\mathcal{R}}_i$,对应“拍摄时间”的小贴纸;主维度仍由 $\mathbf{c}_i$ 提供语义特征。这样保证了时空信息依然得以保真,不会因全部压缩到 $\mathbf{c}_i$ 而丢失。


7. 工程视角:落地 MLA 时需注意的要点

7.1 显存 VS. 推理速度

如果你的应用上下文非常长(数千乃至上万 Token),MLA 的潜变量压缩思路能在不损失过多效果的前提下,显著减少 KV-Cache 占用,从而在单卡上就能跑更多 Token 或更大 batch size,进而提升吞吐量。

7.2 RoPE 维度调参

当把 K 分成 “$\mathbf{c}_i$ 区 + RoPE 区”,RoPE 区若定得过小,超长序列时相对位置信息可能不够充分;过大则会拉高显存占用,抵消 MLA 的收益。工程中往往通过实验来平衡这两者。

7.3 数值误差与精度

MLA 在推理阶段会做矩阵合并,精度较高时没问题,但在 BF16/FP16 下可能出现一些顺序误差积累。通常仍在可接受范围;若应用对精度极端敏感,可考虑混合精度或高精度 accumulators。


8. 整体总结与展望

MLA(Multi-Head Latent Attention)并不是单纯地“把 K/V 做低秩分解”,它更进一步地在推理阶段只缓存一个潜变量 $\mathbf{c}_i$,通过按需解压矩阵合并来还原各头所需的 K/V,大幅减少 KV-Cache。若再配合分治策略来处理 RoPE,那些相对位置信息也不会因为潜变量的介入而丢失。

从工程角度,MLA 带来的显存优势在长上下文推理中极为明显,能够拓展单卡或少卡可以承载的 Token 数量与并发规模。至于 RoPE 的兼容方式以及分治维度怎么选,还需结合任务需求与模型大小来调优。将这一思路推广到其他位置编码(如 ALiBi、NTK Scaling)或其他特殊场景,则是后续可能的研究方向。

无论如何,MLA 无疑已经在“如何减少 KV-Cache”这一问题上,给出了新颖且高效的答案。或许在不久的未来,我们会见到更多结合 MLA 与其他注意力优化方案的模型出现,让大模型的推理在硬件资源有限的条件下继续爆发潜能。


附录:核心公式与对应场景

在此列出 MLA 方案中几个最关键的公式与它们在“智能相册系统”类比中的对应,方便你回顾与对照。

  1. 低秩投影

    \[\mathbf{c}_i = \mathbf{x}_i \mathbf{W}_c\] \[\quad\updownarrow\quad\] \[\text{拍照时直接存缩略图,减少体积 } \propto \frac{d_c}{d}.\]
  2. 动态解压

    \[\mathbf{k}_i^{(s)} = \mathbf{c}_i \mathbf{W}_{kc}^{(s)}, \quad \mathbf{v}_i^{(s)} = \mathbf{c}_i \mathbf{W}_v^{(s)}\] \[\quad\updownarrow\quad\] \[\text{浏览时根据缩略图 + 不同滤镜生成多种版本}\]
  3. 按需还原(推理时合并矩阵)

    \[\mathbf{q}_t^{(s)} \mathbf{k}_i^{(s)\top}=\bigl(\mathbf{x}_t \mathbf{W}_q^{(s)}\bigr) \bigl(\mathbf{c}_i \mathbf{W}_{kc}^{(s)}\bigr)^\top\] \[\quad\updownarrow\quad\] \[\text{把滤镜的计算“合并”到需要时,节省存储}\]
  4. RoPE 分治

    \[\mathbf{k}_i^{(s)} =\bigl(\mathbf{c}_i \mathbf{W}_{kc}^{(s)}\bigr)\;\oplus\;\bigl(\mathbf{x}_i \mathbf{W}_{kr}\mathbf{\mathcal{R}}_i\bigr)\] \[\quad\updownarrow\quad\] \[\text{把拍摄时间或 GPS 坐标这类信息单独存放}\]

通过以上回顾,相信你已对 MLA 从低秩投影按需还原再到与 RoPE 协作的全过程有了更完整的理解。正如它的名字所暗示,MLA 既保留了“多头注意力”的强大表达力,又将主要存储负担隐含在一个更小的潜变量里,使长序列推理跨越了显存障碍。


9. 一个最小化 MLA 实现

为了让读者对 MLA 的核心原理有更多“落地感”,下面我们展示一个最小可运行的 MLA MoE-Transformer 实现,并点到为止地解释其中和“潜变量”“按需解压”“RoPE”密切相关的部分。 在参数类ModelArgs 中,attn_impl: Literal["naive", "absorb"] = "absorb" 参数的选择决定了KV-Cache是使用传统的存储 K、V (naive) 还是存储在 MLA 中定义的潜变量 (absorb)。功能完整代码详见 DeepSeek Official Repo

完整代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
import math
from dataclasses import dataclass
from typing import Tuple, Optional, Literal

import torch
import torch.nn.functional as F
from torch import nn

@dataclass
class ModelArgs:
    max_batch_size: int = 8
    max_seq_len: int = 4096 * 4
    vocab_size: int = 102400
    dim: int = 2048
    inter_dim: int = 10944
    moe_inter_dim: int = 1408
    n_layers: int = 2
    n_dense_layers: int = 1
    n_heads: int = 16
    # moe
    n_routed_experts: int = 64
    n_shared_experts: int = 2
    n_activated_experts: int = 6
    n_expert_groups: int = 1
    n_limited_groups: int = 1
    score_func: Literal["softmax", "sigmoid"] = "softmax"
    route_scale: float = 1.
    # mla
    q_lora_rank: int = 0
    kv_lora_rank: int = 512
    qk_nope_head_dim: int = 128
    qk_rope_head_dim: int = 64
    v_head_dim: int = 128
    # rope
    original_seq_len: int = 4096
    rope_theta: float = 10000.0
    rope_factor: float = 40
    beta_fast: int = 32
    beta_slow: int = 1
    mscale: float = 1.
    # kv-cache
    attn_impl: Literal["naive", "absorb"] = "absorb"

@dataclass
class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.dim = dim
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
        return x / rms * self.weight

def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
    return F.linear(x, weight, bias)


class Linear(nn.Module):
    def __init__(self, in_features: int, out_features: int, bias: bool = False):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = nn.Parameter(torch.empty(out_features, in_features))
        nn.init.xavier_normal_(self.weight)
        if bias:
            self.bias = nn.Parameter(torch.empty(out_features))
            nn.init.zeros_(self.bias)
        else:
            self.register_parameter("bias", None)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return linear(x, self.weight, self.bias)


def precompute_freqs_cis(args: ModelArgs) -> torch.Tensor:
    dim = args.qk_rope_head_dim
    seqlen = args.max_seq_len
    beta_fast = args.beta_fast
    beta_slow = args.beta_slow
    base = args.rope_theta
    factor = args.rope_factor

    def find_correction_dim(num_rotations, dim, base, max_seq_len):
        return dim * math.log(max_seq_len / (num_rotations * 2 * math.pi)) / (2 * math.log(base))

    def find_correction_range(low_rot, high_rot, dim, base, max_seq_len):
        low = math.floor(find_correction_dim(low_rot, dim, base, max_seq_len))
        high = math.ceil(find_correction_dim(high_rot, dim, base, max_seq_len))
        return max(low, 0), min(high, dim - 1)

    def linear_ramp_factor(min_val, max_val, dim):
        if min_val == max_val:
            max_val += 0.001
        linear_func = (torch.arange(dim, dtype=torch.float32) - min_val) / (max_val - min_val)
        ramp_func = torch.clamp(linear_func, 0, 1)
        return ramp_func

    freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
    if seqlen > args.original_seq_len:
        low, high = find_correction_range(beta_fast, beta_slow, dim, base, args.original_seq_len)
        smooth = 1 - linear_ramp_factor(low, high, dim // 2)
        freqs = freqs / factor * (1 - smooth) + freqs * smooth

    t = torch.arange(seqlen)
    freqs = torch.outer(t, freqs)
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
    return freqs_cis


def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
    dtype = x.dtype
    x = torch.view_as_complex(x.float().view(*x.shape[:-1], -1, 2))
    freqs_cis = freqs_cis.view(1, x.size(1), 1, x.size(-1))
    y = torch.view_as_real(x * freqs_cis).flatten(3)
    return y.to(dtype)


class MLA(nn.Module):
    """
    Multi-Headed Attention Layer (MLA).
    """
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.dim = args.dim
        self.n_heads = args.n_heads
        self.n_local_heads = self.n_heads  # 简化,不再做并行分割
        self.q_lora_rank = args.q_lora_rank
        self.kv_lora_rank = args.kv_lora_rank
        self.qk_nope_head_dim = args.qk_nope_head_dim
        self.qk_rope_head_dim = args.qk_rope_head_dim
        self.qk_head_dim = args.qk_nope_head_dim + args.qk_rope_head_dim
        self.v_head_dim = args.v_head_dim
        self.attn_impl = args.attn_impl

        # Q
        if self.q_lora_rank == 0:
            self.wq = Linear(self.dim, self.n_heads * self.qk_head_dim)
        else:
            self.wq_a = Linear(self.dim, self.q_lora_rank)
            self.q_norm = RMSNorm(self.q_lora_rank)
            self.wq_b = Linear(self.q_lora_rank, self.n_heads * self.qk_head_dim)

        # K,V
        self.wkv_a = Linear(self.dim, self.kv_lora_rank + self.qk_rope_head_dim)
        self.kv_norm = RMSNorm(self.kv_lora_rank)
        self.wkv_b = Linear(
            self.kv_lora_rank, self.n_heads * (self.qk_nope_head_dim + self.v_head_dim)
        )

        # Output
        self.wo = Linear(self.n_heads * self.v_head_dim, self.dim)

        self.softmax_scale = self.qk_head_dim ** -0.5
        if args.max_seq_len > args.original_seq_len:
            mscale = 0.1 * args.mscale * math.log(args.rope_factor) + 1.0
            self.softmax_scale *= mscale * mscale

        # 根据 attn_impl 来注册不同缓存
        if self.attn_impl == "naive":
            self.register_buffer(
                "k_cache",
                torch.zeros(args.max_batch_size, args.max_seq_len, self.n_local_heads, self.qk_head_dim),
                persistent=False
            )
            self.register_buffer(
                "v_cache",
                torch.zeros(args.max_batch_size, args.max_seq_len, self.n_local_heads, self.v_head_dim),
                persistent=False
            )
        else:
            self.register_buffer(
                "kv_cache",
                torch.zeros(args.max_batch_size, args.max_seq_len, self.kv_lora_rank),
                persistent=False
            )
            self.register_buffer(
                "pe_cache",
                torch.zeros(args.max_batch_size, args.max_seq_len, self.qk_rope_head_dim),
                persistent=False
            )


    def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):
        bsz, seqlen, _ = x.size()
        end_pos = start_pos + seqlen

        if self.q_lora_rank == 0:
            q = self.wq(x)
        else:
            q = self.wq_b(self.q_norm(self.wq_a(x)))
        q = q.view(bsz, seqlen, self.n_local_heads, self.qk_head_dim)
        q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
        q_pe = apply_rotary_emb(q_pe, freqs_cis)

        kv = self.wkv_a(x)
        kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
        k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis)

        if self.attn_impl == "naive":
            # naive 模式下,直接投影 kv
            q = torch.cat([q_nope, q_pe], dim=-1)
            kv = self.wkv_b(self.kv_norm(kv))  # (bsz, seqlen, n_heads*(qk_nope_head_dim + v_head_dim))
            kv = kv.view(bsz, seqlen, self.n_local_heads, self.qk_nope_head_dim + self.v_head_dim)

            k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
            k = torch.cat([k_nope, k_pe.expand(-1, -1, self.n_local_heads, -1)], dim=-1)

            # 写入 cache
            self.k_cache[:bsz, start_pos:end_pos] = k
            self.v_cache[:bsz, start_pos:end_pos] = v

            scores = torch.einsum("bshd,bthd->bsht", q, self.k_cache[:bsz, :end_pos]) * self.softmax_scale

            if mask is not None:
                scores += mask.unsqueeze(1)

            scores = scores.softmax(dim=-1, dtype=torch.float32).type_as(x)
            out = torch.einsum("bsht,bthd->bshd", scores, self.v_cache[:bsz, :end_pos])

        else:
            # absorb 模式
            # 将 kv_norm(kv) 存进 kv_cache, 同时将 k_pe 存进 pe_cache
            wkv_b = self.wkv_b.weight
            # q_nope 先和 wkv_b 的前 qk_nope_head_dim 行做乘积
            # 并将结果暂存在 q_nope
            # 如果要对 q_nope 做矩阵乘法,需要先转下维度
            wkv_b = wkv_b.view(self.n_local_heads, -1, self.kv_lora_rank)

            # 将 kv_norm(kv) 写进 kv_cache
            self.kv_cache[:bsz, start_pos:end_pos] = self.kv_norm(kv)
            # 将 k_pe 写进 pe_cache
            self.pe_cache[:bsz, start_pos:end_pos] = k_pe.squeeze(2)

            q_nope = torch.einsum(
                "bshd,hdc->bshc",
                q_nope,
                wkv_b[:, :self.qk_nope_head_dim]  # 只取前 qk_nope_head_dim
            )

            scores = (
                torch.einsum("bshc,btc->bsht", q_nope, self.kv_cache[:bsz, :end_pos]) +
                torch.einsum("bshr,btr->bsht", q_pe, self.pe_cache[:bsz, :end_pos])
            ) * self.softmax_scale

            if mask is not None:
                scores += mask.unsqueeze(1)

            scores = scores.softmax(dim=-1, dtype=torch.float32).type_as(x)

            # 计算最终输出
            out = torch.einsum("bsht,btc->bshc", scores, self.kv_cache[:bsz, :end_pos])
            out = torch.einsum(
                "bshc,hdc->bshd",
                out,
                wkv_b[:, -self.v_head_dim:]  # 取 v_head_dim 对应部分
            )

        out = self.wo(out.flatten(2))
        return out


class MLP(nn.Module):

    def __init__(self, dim: int, inter_dim: int):
        super().__init__()
        self.w1 = Linear(dim, inter_dim)
        self.w2 = Linear(inter_dim, dim)
        self.w3 = Linear(dim, inter_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.w2(F.silu(self.w1(x)) * self.w3(x))


class Expert(nn.Module):
    """
    Expert layer for Mixture-of-Experts (MoE) models.
    """

    def __init__(self, dim: int, inter_dim: int):
        super().__init__()
        self.w1 = Linear(dim, inter_dim)
        self.w2 = Linear(inter_dim, dim)
        self.w3 = Linear(dim, inter_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.w2(F.silu(self.w1(x)) * self.w3(x))


class Gate(nn.Module):
    """
    Gating mechanism for routing inputs in a mixture-of-experts (MoE) model.
    """

    def __init__(self, args: ModelArgs):
        super().__init__()
        self.dim = args.dim
        self.topk = args.n_activated_experts
        self.n_groups = args.n_expert_groups
        self.topk_groups = args.n_limited_groups
        self.score_func = args.score_func
        self.route_scale = args.route_scale
        self.weight = nn.Parameter(torch.empty(args.n_routed_experts, args.dim))
        nn.init.xavier_normal_(self.weight)
        self.bias = nn.Parameter(torch.zeros(args.n_routed_experts)) if self.dim == 7168 else None

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        scores = linear(x, self.weight, self.bias)

        if self.score_func == "softmax":
            scores = scores.softmax(dim=-1, dtype=torch.float32)
        else:
            scores = scores.sigmoid()

        original_scores = scores
        if self.n_groups > 1:
            scores = scores.view(x.size(0), self.n_groups, -1)
            if self.bias is None:
                group_scores = scores.amax(dim=-1)
            else:
                group_scores = scores.topk(2, dim=-1)[0].sum(dim=-1)
            indices_groups = group_scores.topk(self.topk_groups, dim=-1)[1]
            mask = torch.zeros_like(scores[..., 0]).scatter_(1, indices_groups, True)
            scores = (scores * mask.unsqueeze(-1)).flatten(1)

        indices = torch.topk(scores, self.topk, dim=-1)[1]
        weights = original_scores.gather(1, indices)
        if self.score_func == "sigmoid":
            weights /= weights.sum(dim=-1, keepdim=True)
        weights *= self.route_scale
        return weights.type_as(x), indices


class MoE(nn.Module):
    """
    Mixture-of-Experts (MoE) module.
    """

    def __init__(self, args: ModelArgs):
        super().__init__()
        self.dim = args.dim
        self.n_routed_experts = args.n_routed_experts
        self.n_activated_experts = args.n_activated_experts
        self.gate = Gate(args)

        self.experts = nn.ModuleList([
            Expert(args.dim, args.moe_inter_dim) for _ in range(self.n_routed_experts)
        ])
        self.shared_experts = MLP(args.dim, args.n_shared_experts * args.moe_inter_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        shape = x.size()
        x = x.view(-1, self.dim)
        weights, indices = self.gate(x)
        y = torch.zeros_like(x)

        counts = torch.bincount(indices.flatten(), minlength=self.n_routed_experts).tolist()
        for i in range(self.n_routed_experts):
            if counts[i] == 0:
                continue
            expert = self.experts[i]
            idx, top = torch.where(indices == i)
            y[idx] += expert(x[idx]) * weights[idx, top, None]

        z = self.shared_experts(x)
        return (y + z).view(shape)


class Block(nn.Module):
    """
    Transformer block combining attention and feed-forward layers.
    """

    def __init__(self, layer_id: int, args: ModelArgs):
        super().__init__()
        self.attn = MLA(args)
        self.ffn = MLP(args.dim, args.inter_dim) if layer_id < args.n_dense_layers else MoE(args)
        self.attn_norm = RMSNorm(args.dim)
        self.ffn_norm = RMSNorm(args.dim)

    def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor,
                mask: Optional[torch.Tensor]) -> torch.Tensor:
        x = x + self.attn(self.attn_norm(x), start_pos, freqs_cis, mask)
        x = x + self.ffn(self.ffn_norm(x))
        return x


class Transformer(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.max_seq_len = args.max_seq_len
        self.embed = torch.nn.Embedding(args.vocab_size, args.dim)
        self.layers = torch.nn.ModuleList()
        for layer_id in range(args.n_layers):
            self.layers.append(Block(layer_id, args))
        self.norm = RMSNorm(args.dim)
        self.head = Linear(args.dim, args.vocab_size)

        self.register_buffer("freqs_cis", precompute_freqs_cis(args), persistent=False)

    @torch.inference_mode()
    def forward(self, tokens: torch.Tensor, start_pos: int = 0):
        seqlen = tokens.size(1)
        h = self.embed(tokens)
        freqs_cis = self.freqs_cis[start_pos:start_pos + seqlen]

        mask = None
        if seqlen > 1:
            mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device).triu_(1)

        for layer in self.layers:
            h = layer(h, start_pos, freqs_cis, mask)

        h = self.norm(h)[:, -1]
        logits = self.head(h)
        return logits


if __name__ == "__main__":
    torch.manual_seed(0)
    args = ModelArgs()
    x = torch.randint(0, args.vocab_size, (2, 128))
    model = Transformer(args)
    logits = model(x)
    print(logits.shape)  # (batch_size, vocab_size)