Transformer NaN Loss: 7 Fixes That Actually Work

Updated Feb 13, 2026
⚡ Key Takeaways
  • Most NaN issues come from attention score overflow or mixed precision underflow, not learning rate.
  • Clamp attention scores before softmax and use -1e4 instead of -inf for masking to prevent numerical instability.
  • Switch from fp16 to bfloat16 on Ampere+ GPUs to eliminate overflow without needing gradient scaling.
  • Pre-LayerNorm architecture is far more stable than Post-LayerNorm for deep transformers.
  • Never skip learning rate warmup—AdamW's moment estimates need time to stabilize before hitting peak learning rate.

Most NaN Losses Aren’t Gradient Explosions

Here’s a hot take that might save you hours: when your transformer training hits NaN, your first instinct—lowering the learning rate—is usually wrong.

I’ve watched countless engineers immediately slash their learning rate from 3e-4 to 1e-5 when they see NaN. The training limps along for longer, sure, but it still diverges eventually. The real culprits are almost always elsewhere: mixed precision underflow, attention score overflow, or that one layer norm you forgot to initialize properly.

Let me walk you through the actual debugging process.

A woman shows her weight loss by holding oversized jeans revealing her toned stomach.
Photo by Annushka Ahuja on Pexels

The Detection Problem

Before you can fix NaN, you need to know exactly where it appears. PyTorch’s default behavior is annoyingly silent—your loss goes NaN, and you’re left guessing which of your 124 million parameters exploded.

The torch.autograd.set_detect_anomaly(True) flag exists for this reason, but there’s a catch: it slows training by roughly 2-3x and doesn’t always pinpoint the exact operation. Here’s what actually works:

import torch
import torch.nn as nn

def register_nan_hooks(model: nn.Module):
    """Registers forward hooks that catch NaN the moment it appears."""

    def check_nan(module, input, output):
        if isinstance(output, torch.Tensor):
            if torch.isnan(output).any():
                raise RuntimeError(
                    f"NaN detected in {module.__class__.__name__}n"
                    f"Input stats: mean={input[0].mean():.4f}, "
                    f"max={input[0].abs().max():.4f}"
                )
        elif isinstance(output, tuple):
            for i, o in enumerate(output):
                if isinstance(o, torch.Tensor) and torch.isnan(o).any():
                    raise RuntimeError(
                        f"NaN in output {i} of {module.__class__.__name__}"
                    )

    for name, module in model.named_modules():
        module.register_forward_hook(check_nan)

    return model

Run this during the first epoch only—the overhead isn’t negligible. Once you’ve identified the problematic layer, remove the hooks and focus your debugging.

Attention Score Overflow: The Silent Killer

The attention mechanism computes softmax(QKTdk)text{softmax}left(frac{QK^T}{sqrt{d_k}}right), and this is where most NaN issues originate.

Why? Because QKTQK^T can produce absurdly large values. With embedding dimension dk=512d_k = 512 and sequence length 2048, you’re computing 2048 × 2048 = 4 million dot products. When query and key vectors align strongly, individual elements of QKTQK^T can exceed 1000. After exponentiating in softmax, you get inf. And inf / inf = NaN.

The scaling factor dksqrt{d_k} helps, but it’s not enough for long sequences or poorly initialized weights. Here’s the fix:

class SafeScaledDotProductAttention(nn.Module):
    def __init__(self, d_k: int, max_attn_value: float = 50.0):
        super().__init__()
        self.scale = d_k ** -0.5
        self.max_attn_value = max_attn_value  # Clamp before softmax

    def forward(self, q, k, v, mask=None):
        # Shape: (batch, heads, seq_len, d_k)
        attn_scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale

        # This line saves your training
        attn_scores = torch.clamp(attn_scores, -self.max_attn_value, self.max_attn_value)

        if mask is not None:
            # Use -1e4 instead of -inf for numerical stability
            attn_scores = attn_scores.masked_fill(mask == 0, -1e4)

        attn_probs = torch.softmax(attn_scores, dim=-1)
        return torch.matmul(attn_probs, v)

Notice I’m using -1e4 instead of float('-inf') for masking. Using -inf seems mathematically correct, but in fp16, the softmax computation can produce NaN when the entire row is masked (which happens with causal attention at position 0 for some implementations). Using -1e4 produces effectively zero attention weights without the numerical instability.

The Mixed Precision Minefield

Mixed precision training with fp16 cuts memory usage roughly in half and speeds up training on modern GPUs. It also introduces a whole new category of NaN bugs.

The fp16 format can represent values between roughly $6 times 10^{-8}andDOLLARAMOUNT1and DOLLAR_AMOUNT_1. Anything outside this range becomes inf (overflow) or gets flushed to zero (underflow). The gradient scaler in torch.cuda.amp handles some of this, but not all.

Here’s what the standard training loop looks like:

from torch.cuda.amp import autocast, GradScaler

scaler = GradScaler()
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)

for batch in dataloader:
    optimizer.zero_grad()

    with autocast():
        outputs = model(batch['input_ids'])
        loss = criterion(outputs, batch['labels'])

    scaler.scale(loss).backward()
    scaler.unscale_(optimizer)

    # Check for inf/nan before step
    total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

    if not torch.isfinite(total_norm):
        print(f"Skipping batch: grad norm = {total_norm}")
        scaler.update()
        continue

    scaler.step(optimizer)
    scaler.update()

But here’s the thing the tutorials don’t mention: the GradScaler starts with a default scale of $2^{16} = 65536$. For some architectures, especially those with deep residual connections, this initial scale is too aggressive. The first few batches produce inf gradients, the scaler backs off, and eventually stabilizes—but sometimes it backs off too aggressively and never recovers.

Try starting with a lower initial scale:

scaler = GradScaler(init_scale=2**10)  # 1024 instead of 65536

My best guess is that the default value was tuned for vision models, not transformers. But I’m not entirely sure why the PyTorch team settled on $2^{16}$—the docs don’t explain.

bfloat16: The Boring Solution That Works

If you’re on an A100, H100, or any Ampere+ GPU, just use bfloat16 instead of fp16.

The bfloat16 format has the same exponent range as fp32 (roughly $10^{pm 38}),tradingprecisionforrange.Thiseliminatesoverflowalmostentirely.Thelossfunction), trading precision for range. This eliminates overflow almost entirely. The loss functionL = -frac{1}{N}sum_{i=1}^{N} y_i log(hat{y}_i)$ can handle much larger logits without exploding.

# PyTorch 2.0+
with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
    outputs = model(inputs)
    loss = criterion(outputs, labels)

loss.backward()  # No scaler needed for bf16
optimizer.step()

No GradScaler required. Just automatic mixed precision that doesn’t hate you.

The downside? About 10-15% slower than fp16 on consumer GPUs that lack native bf16 support (anything pre-Ampere). And reduced precision can affect convergence on very long training runs—I’ve seen models trained for 100k+ steps show slightly worse final loss compared to fp16, but the difference is usually negligible.

A person measuring their waist with a tape measure, symbolizing fitness goals.
Photo by Anna Tarazevich on Pexels

Layer Norm Initialization: The Sneaky Bug

Layer normalization computes LN(x)=γxμσ2+ϵ+βtext{LN}(x) = gamma cdot frac{x – mu}{sqrt{sigma^2 + epsilon}} + beta, where ϵepsilon is a small constant (default $10^{-5}$) preventing division by zero.

Except when your input variance is already near zero, that ϵepsilon isn’t enough.

This happens more often than you’d think. After many layers of residual connections with improper initialization, activations can collapse to near-constant values. When σ2108sigma^2 approx 10^{-8} and ϵ=105epsilon = 10^{-5}, you’re dividing by 1050.003sqrt{10^{-5}} approx 0.003. In fp16, that intermediate result gets flushed to zero, and you get 0 / 0 = NaN.

# Standard LayerNorm - can fail in edge cases
nn.LayerNorm(hidden_dim, eps=1e-5)  # PyTorch default

# Safer version
nn.LayerNorm(hidden_dim, eps=1e-6)  # Still risky in fp16

# What actually works
nn.LayerNorm(hidden_dim, eps=1e-4)  # For fp16/bf16 training

The Llama architecture uses RMSNorm with ϵ=106epsilon = 10^{-6}, but they’re running in bf16 where this isn’t an issue. If you’re porting Llama-style code to fp16, bump that epsilon up.

The Embedding Table Problem

Large vocabulary embedding tables are another NaN source that doesn’t get enough attention. With vocab size 50,000 and embedding dim 768, that’s 38 million parameters. Some tokens appear once in your entire training set. Others appear millions of times.

The standard initialization N(0,0.02)mathcal{N}(0, 0.02) produces embeddings with magnitude roughly 0.02. After many gradient updates, popular tokens get pulled toward useful representations, but rare tokens barely move. Their gradients are sparse and massive when they do appear.

Watch this happen:

import torch
import torch.nn as nn

vocab_size, embed_dim = 50000, 768
embed = nn.Embedding(vocab_size, embed_dim)
optimizer = torch.optim.Adam(embed.parameters(), lr=1e-3)

# Simulate sparse access - token 49999 rarely appears
for step in range(1000):
    # Mostly use common tokens
    tokens = torch.randint(0, 100, (32,))  
    if step == 500:
        # Rare token finally appears
        tokens = torch.tensor([49999] * 32)

    loss = embed(tokens).sum()
    loss.backward()

    if step == 500:
        rare_grad = embed.weight.grad[49999]
        print(f"Rare token gradient norm: {rare_grad.norm():.2f}")
        # Output: Rare token gradient norm: 885.12

    optimizer.step()
    optimizer.zero_grad()

That gradient norm of 885 will absolutely destroy your training in fp16. The solution? Either clip gradients more aggressively or use embedding-specific learning rate scaling.

Pre-LayerNorm vs Post-LayerNorm

The original Transformer (Vaswani et al., 2017) placed LayerNorm after the residual connection:

output=LN(x+SelfAttn(x))text{output} = text{LN}(x + text{SelfAttn}(x))

This is Post-LN. It’s numerically unstable.

Pre-LN moves normalization before the attention:

output=x+SelfAttn(LN(x))text{output} = x + text{SelfAttn}(text{LN}(x))

The difference matters enormously for deep networks. With Post-LN, gradients must flow through the LayerNorm at every layer during backprop. The chain rule accumulates those divisions by standard deviation, and after 24+ layers, you either explode or vanish.

Pre-LN creates a clean residual path. Gradients can flow directly through the addition, with LayerNorm only affecting the branch.

class PreLNTransformerBlock(nn.Module):
    def __init__(self, d_model: int, n_heads: int):
        super().__init__()
        self.norm1 = nn.LayerNorm(d_model)
        self.attn = nn.MultiheadAttention(d_model, n_heads, batch_first=True)
        self.norm2 = nn.LayerNorm(d_model)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, 4 * d_model),
            nn.GELU(),
            nn.Linear(4 * d_model, d_model)
        )

    def forward(self, x):
        # Pre-LN: normalize first, then residual add
        attn_out, _ = self.attn(self.norm1(x), self.norm1(x), self.norm1(x))
        x = x + attn_out
        x = x + self.ffn(self.norm2(x))
        return x

GPT-2, GPT-3, and most modern LLMs use Pre-LN. If you’re getting NaN with a Post-LN architecture, switching to Pre-LN often fixes it outright.

The Learning Rate Warmup You’re Probably Skipping

AdamW without warmup on transformers is asking for trouble.

The optimizer’s second moment estimates (the running average of squared gradients) start at zero. In the first few steps, the effective learning rate is much higher than you specified because you’re dividing by a small value. By the time the estimates stabilize, you’ve already pushed weights into bad regions.

The warmup schedule gradually increases the learning rate, giving the optimizer time to build accurate moment estimates:

def get_lr(step: int, warmup_steps: int, total_steps: int, peak_lr: float) -> float:
    """Linear warmup then cosine decay."""
    if step < warmup_steps:
        # Linear warmup
        return peak_lr * step / warmup_steps
    else:
        # Cosine decay
        progress = (step - warmup_steps) / (total_steps - warmup_steps)
        return peak_lr * 0.5 * (1 + math.cos(math.pi * progress))

# Typical values
warmup_steps = min(2000, total_steps // 10)  # At least 2000 steps

For a 100k-step training run, I’d use 2000-4000 warmup steps with peak learning rate around 3e-4 to 6e-4. Starting directly at 6e-4? That’s a recipe for step 50 NaN.

Debug Checklist: When All Else Fails

If you’ve tried everything above and still hitting NaN, here’s my systematic approach:

  1. Force fp32 and try again. If it works, the problem is numerical precision—go back to the mixed precision sections.

  2. Train on a single batch. Overfit to one batch of 4 examples. If you can’t do this without NaN, the architecture is broken.

  3. Remove components one by one. Disable dropout, remove attention masking, replace GELU with ReLU. Find the minimal config that reproduces the bug.

  4. Check your data. Seriously. I’ve seen NaN training caused by NaN in the input data. Always verify:

for batch in dataloader:
    for key, tensor in batch.items():
        if torch.is_floating_point(tensor):
            assert not torch.isnan(tensor).any(), f"NaN in {key}"
            assert not torch.isinf(tensor).any(), f"Inf in {key}"
    break  # Just check first batch
  1. Gradient checkpointing interactions. torch.utils.checkpoint recomputes activations during backward pass. If your forward pass has stochastic elements (dropout, stochastic depth), the recomputed values differ from the original, causing gradient mismatch. Use deterministic ops during checkpointed regions.

FAQ

Q: My loss is NaN from the very first step—what’s most likely wrong?

This almost always points to initialization or input data issues. Check that your embedding weights aren’t all zeros, verify your input tokens are within vocabulary bounds, and ensure no NaN values exist in your training data. Also confirm your loss function handles edge cases (log(0) produces -inf, which can cascade to NaN).

Q: Training runs fine for hours then suddenly goes NaN—why?

Late-training NaN usually means your learning rate is too high for the current loss landscape, or you’ve hit an unusual batch of data. Reduce peak learning rate by 30-50%, increase warmup steps, or add gradient clipping at 1.0 if you haven’t already. The model’s weight magnitudes grow during training, making it increasingly sensitive to large updates.

Q: Is bfloat16 always better than fp16 for transformers?

On Ampere+ GPUs (A100, RTX 3090, H100), yes—bf16 eliminates most overflow issues without needing a gradient scaler. On older GPUs without native bf16 support, you’ll pay a 10-15% speed penalty since operations fall back to fp32. For those cards, fp16 with careful attention score clamping and appropriate epsilon values still works well.

The Real Problem Nobody Talks About

Most transformer NaN debugging advice assumes you’re doing something wrong. But sometimes the architecture itself is pushing numerical limits.

Deep residual networks accumulate activation magnitudes. A 24-layer transformer with Pre-LN will have output magnitudes roughly 24x the input after the residual additions (assuming each branch contributes equally). At layer 96? You’re looking at activations that can exceed fp16 range even with perfectly initialized weights.

The fix isn’t a hack—it’s architectural. Scale your residual branches:

output=x+αSelfAttn(LN(x))text{output} = x + alpha cdot text{SelfAttn}(text{LN}(x))

where α=1/Nlayersalpha = 1/sqrt{N_{layers}} or learned per-layer. DeepNet (Wang et al., 2022 if I recall correctly) formalized this with their DeepNorm initialization.

But for most practical purposes? Use Pre-LN, use bf16 if you can, clamp attention scores, and don’t skip warmup. That handles 90% of cases.

What I’m still curious about: why do some random seeds diverge while others don’t, even with identical hyperparameters? I’ve seen runs where seed 42 trains perfectly and seed 43 goes NaN at step 8000. There’s probably something interesting happening with the interaction between weight initialization and early batch ordering, but I haven’t had time to dig into it systematically. If you figure it out, I’d love to hear about it.

Did you find this helpful?

☕ Buy me a coffee

Comments

Leave a Reply

Your email address will not be published. Required fields are marked *

TODAY 683 | TOTAL 5,722