Weave-Head Attention: Taming Gradient Norm Spikes

A simple idea I prototyped back in graduate school. Over Labor Day weekend I finally had some personal time to write it up and share. Huge thanks to TPU Research Cloud (TRC) and Google Cloud research credits for supporting the experiments.

TL;DR


Motivation

During large-scale LLM training, gradient norm spikes often appear and destabilize optimization—training jitters, convergence slows, and final quality suffers (see e.g. OPT[1], OpenLLaMA[2], PaLM[3], OLMo[4]). This issue becomes more severe as models scale up, making it a key challenge of scaling. Here's a concrete snapshot from training OpenLLaMA: both the 3B and 7B runs exhibit many gradient norm spikes.

OpenLLaMA: side-by-side loss and gradient norm trajectories (3B vs 7B)
OpenLLaMA training snapshot: loss vs gradient norm (3B and 7B). Gradient norm spikes are frequent, especially for 7B.

Stabilization tricks exist, such as gradient norm clipping, qk normalization, warmup, data de-duplication, removing low-quality data, etc. But at large scale they may not fully eliminate spikes. This raised a question: perhaps information flow across heads is too weak. What if we let heads attend to each other?

Introducing Weave-Head Attention

Weave-Head augments standard Multi-Head Attention (MHA) with bidirectional cross-head attention within each token. Intuitively, before a head commits to a causal distribution over the past, it first consults other heads at the same token to aggregate their evidence. Concretely, Weave-Head performs:

Why this helps (intuition)

Near-Zero Overhead

Summary. Extra FLOPs ≈ \(\frac{H}{6D + S}\); memory ≈ 0 (continue the online softmax to fuse cross-head terms). Here \(H\)=heads, \(D\)=model dim, \(S\)=sequence length. In typical regimes this is ≈ 0.1-0.5% FLOPs.

Example

\(D{=}4096,\,H{=}32,\,S{=}2048\): Weave adds \(\frac{H}{6D + S} = \frac{32}{6 \times 4096 + 2048} \approx 0.12\%\) extra FLOPs. Memory-wise, Weave-Head uses zero extra memory: compute causal attention and continue the online softmax with the cross-head terms.

Show derivation

Let \(B\) = batch, \(S\) = sequence length, \(D\) = model dim, \(H\) = num of heads, \(K\) = head dim. Count a multiply-add as 2 FLOPs.

Baseline (per layer, forward)

  • Projections \(Q,K,V\): \(3\) GEMMs \(\Rightarrow 6BSD^2\)
  • Scores \(QK^\top\): \(2BHS^2K = 2BS^2D\)
  • Apply weights to \(V\): \(2BS^2D\)
  • Output projection \(O\): \(2BSD^2\)

Attention total: \(8BSD^2 + 4BS^2D\). MLP (expansion 4): \(16BSD^2\).

Baseline FLOPs: \(\;24\,B S D^2 + 4\,B S^2 D\).

Weave-Head (additional, per layer, forward)

At each token, cross-head attention computes:

  • Cross-head scores: \(2\,B S H^2 K = 2\,B S H D\)
  • Weights-values product: \(2\,B S H^2 K = 2\,B S H D\)

Weave overhead: \(\approx 4\,B S H D\).

Overhead ratio (vs. baseline): \(\displaystyle \frac{4\,B S H D}{24\,B S D^2 + 4\,B S^2 D} = \frac{H}{6D + S}\).

Experiments

Setting. We train LLMs at 1.5x Chinchilla optimal [6] (≈30 tokens per parameter). We compare Baseline and Weave-Head across model sizes. The plots report training loss, gradient norm (average + per layer), and validation bits per byte (bpb).

The model architecture follows a LLaMA-2 style design (dense; one MLP and one self-attention per layer). We trained three model sizes: 1B, 4B, and 7B.

Experimental setup (click to expand)
Global batch size 512 sequences x 2048 tokens (per step)
Sequence length 2048
Learning rate 2e-4, linear warmup for 200 steps and cosine decay to 10%
Hardware Google Cloud TPUv4-512
Sharding FSDP + SP (ring)
Normalization Pre RMS norm, QK RMS norm
Training data Mixture of fineweb, starcoder, wikipedia, arxiv and books. Maximum one epoch, no repeating
Validation data Hold-out split of the training data
Tokenizer LLaMA 3 tokenizer
Model configurations (click to expand)
Model Layers Heads Head dim
1B 20 16 128
4B 30 28 128
7B 34 32 128

We now show results at 4B and 7B side-by-side from four angles: gradient norm, per-layer gradient norm, training loss, and validation bits per byte (bpb).

Mean gradient norm. Spikes are sharply reduced with Weave-Head.

4B gradient norm line plot
7B gradient norm line plot
Gradient norm (mean) over training. Left: 4B; Right: 7B.

Per-layer view. Baseline shows bursty multi-layer spikes; Weave-Head shows none.

4B per layer gradient norm plot
7B per layer gradient norm plot
Per-layer gradient norm by step. Each point represents one layer's gradient norm. Baseline shows bursty events where many layers spike together; Weave-Head shows no such bursts. Left: 4B; Right: 7B.

Training loss. Loss decreases faster and to a lower value with Weave-Head.

4B training loss
7B training loss
Training loss over steps. Stability translates into faster descent and a lower final value, with a clearer margin at 7B. Left: 4B; Right: 7B.

Bits per byte. Validation bpb is consistently lower with Weave-Head.

4B validation bits-per-byte
7B validation bits-per-byte
Validation bits-per-byte (↓ is better). Weave-Head is consistently lower for both sizes, indicating better compression and generalization. Left: 4B; Right: 7B.
Key takeaway. Weave-Head substantially reduces the frequency and amplitude of gradient norm spikes in both 4B and 7B models, alongside improved training curves and lower validation bpb. Overall, gains grow with model size, matching the trend that gradient norm spikes worsen with scale. At 1B, effects are neutral.
Small scale (1B): minimal effect (click to expand)
Observation. At 1B, gradient norm spikes are rarer; Weave-Head shows limited benefits. Training/validation are effectively neutral.
1B per layer gradient norm plot
1B average gradient norm plot
1B gradient norm: per layer (left) and mean (right).

Discussion

Jax Code

This code implements Weave-Head Attention in Jax. In this case, XLA provides adequate fusion of head-to-head and causal attention. Nevertheless, to approach hardware-limited performance, a kernel in Pallas/CUDA is recommended.

def weave_head_attention(q_BHTK: jax.Array, k_BHTK: jax.Array, v_BHTK: jax.Array) -> jax.Array:
    B, H, T, K = q_BHTK.shape
    assert k_BHTK.shape == (B, H, T, K) and v_BHTK.shape == (B, H, T, K)

    # same-token cross-head (xh)
    q_bHK = q_BHTK.transpose(0, 2, 1, 3).reshape(B * T, H, K)
    k_bHK = k_BHTK.transpose(0, 2, 1, 3).reshape(B * T, H, K)
    v_bHK = v_BHTK.transpose(0, 2, 1, 3).reshape(B * T, H, K)

    scale = 1.0 / jnp.sqrt(K)
    logits_xh = jnp.einsum("bHK,bQK->bHQ", q_bHK, k_bHK) * scale
    m_xh = jnp.max(logits_xh, axis=-1, keepdims=True)
    w_xh = jnp.exp(logits_xh - m_xh)
    den_xh = jnp.maximum(jnp.sum(w_xh, axis=-1), 1e-9)
    num_xh_bHK = jnp.einsum("bHQ,bQK->bHK", w_xh, v_bHK)

    m_xh = m_xh.reshape(B, T, H, 1)
    den_xh = den_xh.reshape(B, T, H)
    num_xh = num_xh_bHK.reshape(B, T, H, K)

    # causal across tokens (per head)
    tri_TS = jnp.tril(jnp.ones((T, T), dtype=bool))
    logits_causal = jnp.einsum("BHTK,BHSK->BHTS", q_BHTK, k_BHTK) * scale
    logits_causal = jnp.where(tri_TS[None, None, :, :], logits_causal, -jnp.inf)

    m_causal = jnp.max(logits_causal, axis=-1, keepdims=True)
    w_causal = jnp.exp(logits_causal - m_causal)
    den_causal = jnp.maximum(jnp.sum(w_causal, axis=-1), 1e-9)
    num_causal = jnp.einsum("BHTS,BHSK->BHTK", w_causal, v_BHTK)

    m_causal = m_causal.transpose(0, 2, 1, 3)
    den_causal = den_causal.transpose(0, 2, 1)
    num_causal = num_causal.transpose(0, 2, 1, 3)

    # fuse via online softmax
    m_final = jnp.maximum(m_xh, m_causal)
    num_final = num_xh * jnp.exp(m_xh - m_final) + num_causal * jnp.exp(m_causal - m_final)
    den_final = (
        den_xh * jnp.exp(m_xh.squeeze(-1) - m_final.squeeze(-1)) +
        den_causal * jnp.exp(m_causal.squeeze(-1) - m_final.squeeze(-1))
    )
    den_final = jnp.maximum(den_final, 1e-9)

    return (num_final / den_final[..., None]).transpose(0, 2, 1, 3)

Acknowledgments

Thanks again to TPU Research Cloud (TRC) and Google Cloud research credits for making these experiments possible.

References

  1. OPT — Zhang et al., “OPT: Open Pre-trained Transformer Language Models,” (2022). arXiv:2205.01068
  2. OpenLLaMA — OpenLM Research, “OpenLLaMA: An Open Reproduction of LLaMA” (2023). Project
  3. PaLM — Chowdhery et al., “PaLM: Scaling Language Modeling with Pathways” (2022). arXiv:2204.02311
  4. OLMo 2 — AI2, “OLMo 2 Technical Report” (2025). Project
  5. Talking-Heads Attention — Shazeer et al., “Talking-Heads Attention” (2020). arXiv:2003.02436
  6. Chinchilla — Hoffmann et al., “Training Compute-Optimal Large Language Models” (2022). arXiv:2203.15556

Citing this blog post

If this blog is useful and you'd like to cite it:

@misc{liu2025weave,
      title        = {Weave-Head Attention: Taming Gradient Norm Spikes},
      author       = {Hao Liu},
      year         = {2025},
      howpublished = {\url{https://haoliu.ai/blog/weave-head.html}}
    }

Blog

Home