This content originally appeared on Level Up Coding - Medium and was authored by Shubh Mishra
A New Approach to Attention — Differential Transformers | Paper Walkthrough and PyTorch Implementation
Today, we are looking into a very different approach to the transformer architecture.
So far we’ve Implemented the attention mechanism in over half a dozen deep learning architectures, the variations we generally see are standard attention, window attention, convolution attention, etc, but another recent (Oct 24) advancement in the field reeks from the paper titled Differential Transformers [1].
To not waste a good deal of your time this is what the differential transformer achieves in general.

To get you refreshed on the attention scoring we have this:
wei = Q@K/sqrt(head_dim) # Shape (B, T, T) {T : Token Size}
# create an upper triangular mask, because the decoder model only looks at the past tokens
wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
attn_score = softmax(wei) # Shape remains same
out = attn_score@V # (B, T, T) X (B, T, E) -> (B, T, E)In Figure: 1 the first diagram is an experiment that shows how much attention weight/score is given to each token, as our wei matrix simply co-relates each token with another (how much attention does it give to every token) which is then normalized through a softmax function so the for the net sum of each sequence in wei sums up to one.
The figure shows the plots for the query retrieval task which follows the following prompt
<Begning of Sentence> <Paragraphs> <Query>
We see this as an autoregressive task where we have a paragraph in our prompt and ask the model to retrieve some info from the given para available somewhere in the context, thus axis-x on the left plot figure 1 is
<BOS> <Context> <Answer which we’re looking for> <Context> <Query>
The problem with standard attention architecture is that it provides a significant attention score to every token in context even when it is irrelevant, we are referring to this as attention noise, which the Differential Transform aims to mitigate.
Differential Transformer Architecture
The paper demonstrates how standard transformers assign attention to tokens that overshadow the relevant ones. Drawing inspiration from Active Noise Cancellation (ANC) headphones, which operate on the principle of wave interference.
When two noise waves collide, their common phases (analogous to background noise or irrelevant tokens) cancel each other out, leaving only the desired signal. This same principle is applied here to filter out attention noise.
In the direct Sauce, we split the Query and Key vectors into two equal Vectors. [Q (n, d) → Q1 (n, d/2), Q2 (n, d/2)], performs separate attention on each pair and subtracts them. This subtraction is analogous to noise cancellation above, quoting the paper:
Specifically, we partition the query and key vectors into two groups and compute two separate softmax attention maps. Then the result of subtracting these two maps is regarded as attention scores. The differential attention mechanism eliminates attention noise, encouraging models to focus on critical information

The figure above clearly illustrates how this takes place within the attention module and we’ve written the blog just dive into it!
Initially, we have the input X but instead of producing one Query and Key Vector, we generate two!!
Q₁, Q₂ = XW_q ∈ ℝ(n×d/2), where W_q ∈ ℝ(d×d/2)
K₁, K₂ = XW_k ∈ ℝ(n×d/2), where W_k ∈ ℝ(d×d/2)
V = XW_v ∈ ℝ(n×d), where W_v ∈ ℝ(d×d)
Where, n = sequence length, d = head_dim (v) and d/2 will be head_dim for QK pairs (check the implementation below)
Now we do the standard attention operation for each Q&K pair separately and derive our difference by subtracting them, this will cancel out the common attention noise and scores and amplify attention to the relevant context while canceling noise.
A1 = softmax(Q₁K₁ᵀ/√d/2), where A1 ∈ ℝ(n, n)
A2 = softmax(Q₂K₂ᵀ/√d/2), where A1 ∈ ℝ(n, n)
DiffAttn = (A₁ — λA₂)V, where DiffAttn ∈ ℝ(n, n)
Where lambda (λ) is defined as:-

One of the improvements we see in differential transformers is activation stability, the dot product (Q@K.T) can produce very large or small values depending on Q and K’s magnitudes. As it’s known in the softmax function, very large values dominate, pushing most probabilities to near zero for other entries. (e.g. Softmax([100, 4, 5]) -> [1, ~0, ~0]) This could allow extreme values to be allocated to irrelevant tokens, however, the subtraction cancels out shared noise patterns between attn1 and attn2, suppressing the extreme value outliers.
The most exciting part of it is that the DIFF Transformer with 4-bit quantization outperformed [2] a regular Transformer with 6-bit quantization.
What is lambda (λ)?
As our previous analogy to electrical engineering where we cancel out common noises in the signal to amplify the preferable signals (attention scores in our case), the paper introduces a lambda variable which is a learnable parameter to the model.
λ is a learnable damping factor that balances the noise cancellation process (A1−A2) ) by controlling the weight of the second softmax term relative to the first. It ensures that the model doesn’t overly suppress tokens important for context while canceling out irrelevant noise.
By dynamically adjusting λ, the model improves the signal-to-noise ratio in attention scores, focusing on key information (signal) and minimizing distractions (noise), leading to more accurate and context-aware attention mechanisms.

The equation above immitates the MultiHead Attention pseudo code give in figure above. The only addition to the standard multi head here is the scaling factor (1 — lambdainit) factor that is being multiplied after layer normalization LN.
Scaling factor, the authors ensure that the scaling is inversely related to the initial noise suppression strength. This helps balance the contributions of the differential attention and the standard Transformer components and maintain the same gradient flow [3].
Everything else from here is pretty straightforward and reeks of the standard decoder architecture. The paper only proposes the change in the base attention module every other aspect of the decoder transformer model remains the same except taking inspiration from llama for RMSProp and SwiGLU Activations.

Implementing Differential Transformer
The code below is quite different from the official implementation, They have combined the MultiHead and DiffAttn part in a single class, I will not use Group Query Attention, as the main focus is to provide you guys with the key idea. If you are interested in using flash attention or GQA you can check out their code here.
import torch.nn as nn
import torch.nn.functional as f
from apex.normalization import FusedRMSNorm as RMSNorm
import math
class DiffAttn(nn.Module):
def __init__(self, num_heads, embed_dim, depth):
super().__init__()
self.head_dim = int(embed_dim/num_heads)
self.q_linear = nn.Linear(embed_dim, self.head_dim, bias=False)
self.k_linear = nn.Linear(embed_dim, self.head_dim, bias=False)
self.v_linear = nn.Linear(embed_dim, self.head_dim, bias=False)
self.lambda_init = 0.8 - 0.6 * math.exp(-0.3 * depth)
self.lambda_q1 = nn.Parameter(torch.zeros(self.head_dim//2, dtype=torch.float32))
self.lambda_q2 = nn.Parameter(torch.zeros(self.head_dim//2, dtype=torch.float32))
self.lambda_k1 = nn.Parameter(torch.zeros(self.head_dim//2, dtype=torch.float32))
self.lambda_k2 = nn.Parameter(torch.zeros(self.head_dim//2, dtype=torch.float32))
# mean = 0 (default); std = 0.1
nn.init.normal_(self.lambda_q1, std=0.1)
nn.init.normal_(self.lambda_q2, std=0.1)
nn.init.normal_(self.lambda_k1, std=0.1)
nn.init.normal_(self.lambda_k2, std=0.1)
try:
from apex.normalization import FusedRMSNorm
self.ln = FusedRMSNorm(self.head_dim, eps=1e-5, elementwise_affine=True)
except ImportError:
self.ln = RMSNorm(self.head_dim, eps=1e-5)
def forward(self, x):
b, t, d = x.shape # t: token/sequence length
q = self.q_linear(x) # (b, t, C) -> (b, t, d); d : head_dim
k = self.k_linear(x)
v = self.v_linear(x)
# Split q and k into two parts
q1, q2 = torch.chunk(q, 2, dim=-1) # (:,:,d) -> (:,:,d/2)
k1, k2 = torch.chunk(k, 2, dim=-1)
# Compute Attention Scores
attn1 = q1 @ k1.transpose(-2, -1) / math.sqrt(self.head_dim / 2)
attn2 = q2 @ k2.transpose(-2, -1) / math.sqrt(self.head_dim / 2)
# We need to generate a mask as Diff Attn paper trains a decoder only model
attn_mask = torch.triu(torch.zeros([t, t]).fill_(float("-inf")), diagonal=1)
# Compute Saperate scores
a1 = f.softmax(attn1+attn_mask / math.sqrt(self.head_dim / 2), dim=-1)
a2 = f.softmax(attn2+attn_mask / math.sqrt(self.head_dim / 2), dim=-1)
# Compute lambda dynamically
self.lmbda = torch.exp(torch.sum(self.lambda_q1*self.lambda_k1, dim=-1)) \
- torch.exp(torch.sum(self.lambda_q2*self.lambda_k2, dim=-1)) + self.lambda_init
diffattn = (a1 - self.lmbda*a2)@v
attn = (1 - self.lambda_init)*self.ln(diffattn)
return attn
Given our earlier explanation, the code above should be pretty self-exploratory but I’ll still dive a little-
- We receive X (b, t, d) in the forward pass and we project it to query-key pairs projecting dimensions to (b, t, d), we split QK vectors to create our pairs of dim (b, t, d/2) and finally project the value vector to the same dim from (b, t, d) to (b, t, d).
- Get the attention matrix (masked for decoder-only model) for each q-k pair, and subtract them with our re-parameterized lambda strategy as discussed earlier.
- Apply the layer norm and constant (1 — λinit). In the paper, it’s explained that λinit would be the initial value of our lambda (λ) which will remain static throughout each layer (or depth).
- It is empirically found that the relation λinit = 0.8 − 0.6 × exp(−0.3 · (depth − 1)) works best for their implementation, where depth ∈ [1, L] (In our code depth varies from [0, L], and L being the number of decoder layers)
- The lambda parameters are initialized as head_dim/2 (or d/2) as proposed in the paper. As explained, lambda is a learnable parameter, we initialize it with nn.Parameter() about the mean of 0 and std of 0.1.
The implementation then follows the standard MultiHead concatenation.
class MultiHead(nn.Module):
def __init__(self, num_heads, embed_dim, depth):
super().__init__()
self.attn_heads = nn.ModuleList([DiffAttn(num_heads, embed_dim, depth) for _ in range(num_heads)])
self.o_linear = nn.Linear(embed_dim, embed_dim, bias=False)
def forward(self, x):
x = torch.cat([attn_head(x) for attn_head in self.attn_heads], dim=-1)
out = x*self.o_linear(x)
return out
And there it is…
The Differential Transformer is explained along with the implementation. I hope this one was a fun read, you learned something new, or at the least had some revision. With this, I can’t thank you enough for reading so far! If you have any questions, no matter how small, please don’t hesitate to share them in the comments — I’d be happy to help!”
The entire code is available on my GitHub repo, with various other deep learning architectures you might want to look into.
If you found this article helpful or learned something please consider dropping some big Claps. This motivates me to continue writing quality content for you guys.
Thanks for reading!
See you in the next one…
References:
[1] The original ArXiv: https://arxiv.org/pdf/2410.05258
[2] Quantization maps high-precision values (e.g., 32-bit floats) to low-precision representations (e.g., 8-bit integers), The range is finite, so the extreme values (outliers) consume much of the range, leaving less amount resolution for smaller, causing small values to lose precision and an increase in quantization error.
Differential Attention cancels shared extreme values (noise). This reduces outliers, producing smoother activations (attention scores), and enabling more accurate low-bit quantization while preserving smaller values effectively. See detailed results in the paper.
[3] One of the reasons that the constant (1 — λinit) is multiplied after the layer norm is to keep the gradient flow of the differential attention the same as the standard attention.
Supplementary Wisdom
In deep learning architecture development, maintaining a consistent gradient flow is crucial when introducing modifications. Researchers aim to preserve similar weight update mechanisms for two key reasons:
- If the new changes or modifications introduce a different (drastic) gradient flow then if we run into some challenges during training we require developing new training mitigation strategies, as existing methods from previous models become ineffective due to different gradient flows.
- On the other hand, architectures with similar gradient flows can be trained using identical hyperparameters. For instance, the differential transformer leveraged the original transformer’s hyperparameters due to their nearly identical gradient characteristics.
Thus, the authors of the paper ensure that the term using a fixed multiplier (1 — λinit) as the scale of LN(·) aligns the gradient flow of the Differential Transformer with the standard Transformer model.
The paper also provides a scratch derivation of the gradients at the end; you can check it out if you are interested in the calculus behind it.
A New Approach to Attention — Differential Transformers | Paper Walkthrough and PyTorch… was originally published in Level Up Coding on Medium, where people are continuing the conversation by highlighting and responding to this story.
This content originally appeared on Level Up Coding - Medium and was authored by Shubh Mishra
Shubh Mishra | Sciencx (2025-01-28T18:21:20+00:00) A New Approach to Attention — Differential Transformers | Paper Walkthrough and PyTorch…. Retrieved from https://www.scien.cx/2025/01/28/a-new-approach-to-attention-differential-transformers-paper-walkthrough-and-pytorch/
Please log in to upload a file.
There are no updates yet.
Click the Upload button above to add an update.