Absorption of $\mathbf{W}^{UV}$ into $\mathbf{W}^{Q}$
Step 1: Original MLA equations for KV compression
- $\mathbf{c}_t^{KV} = \mathbf{W}^{DKV} \mathbf{h}_t$ (Compressed KV latent vector)
- $\mathbf{k}_t = \mathbf{W}^{UK} \mathbf{c}_t^{KV}$ (Key projection)
- $\mathbf{v}_t = \mathbf{W}^{UV} \mathbf{c}_t^{KV}$ (Value projection)
- $\mathbf{q}_t = \mathbf{W}^Q \mathbf{x}_t$ (Query projection, after applying RoPE)
Step 2: Attention score computation with absorbed key weights
The attention score between query at position t and key at position m:
- $\text{score}_{t,m} = \frac{\mathbf{q}_t^T\mathbf{k}_m}{\sqrt{d_h}} = \frac{\mathbf{q}_t^T(\mathbf{W}^{UK} \mathbf{c}_m^{KV})}{\sqrt{d_h}} = \frac{(\mathbf{W}^Q \mathbf{x}_t)^T\mathbf{W}^{UK} \mathbf{c}_m^{KV}}{\sqrt{d_h}} = \frac{\mathbf{x}_t^T(\mathbf{W}^Q)^T\mathbf{W}^{UK} \mathbf{c}_m^{KV}}{\sqrt{d_h}}$
Defining $\mathbf{W}^{Q'} = (\mathbf{W}^Q)^T\mathbf{W}^{UK}$, we get:
- $\text{score}_{t,m} = \frac{\mathbf{x}_t^T\mathbf{W}^{Q'} \mathbf{c}_m^{KV}}{\sqrt{d_h}}$
Absorption of $\mathbf{W}^{UV}$ into $\mathbf{W}^{o}$
Step 1: Output computation with attention weights
- $\mathbf{o}t = \sum{m=1}^{t} \alpha_{t,m}\mathbf{v}m$ where $\alpha{t,m} = \text{Softmax}m(\text{score}_{t,m})$
- $\mathbf{o}t = \sum{m=1}^{t} \alpha_{t,m}(\mathbf{W}^{UV} \mathbf{c}_m^{KV})$
- $\mathbf{o}t = \mathbf{W}^{UV} \sum{m=1}^{t} \alpha_{t,m} \mathbf{c}_m^{KV}$ (Taking $\mathbf{W}^{UV}$ outside the sum)
Step 2: Final output with absorbed value weights
Defining $\mathbf{o}t' = \sum{m=1}^{t} \alpha_{t,m} \mathbf{c}_m^{KV}$:
- $\mathbf{o}_t = \mathbf{W}^{UV} \mathbf{o}_t'$
- $\mathbf{u}_t = \mathbf{W}^O\mathbf{o}_t = \mathbf{W}^O\mathbf{W}^{UV}\mathbf{o}_t'$
Defining $\mathbf{W}^{O'} = \mathbf{W}^O\mathbf{W}^{UV}$: