Why JAX Feels Faster Than PyTorch (and When It Isn’t)

⚡ Key Takeaways
  • JAX achieves 2-3x speedups via XLA compilation and functional purity, but only when your code fits the JIT-friendly style.
  • vmap for automatic vectorization and composable gradient transforms (grad of grad) are JAX's killer features for scientific ML.
  • PyTorch wins on ecosystem maturity, debugging UX, and model surgery — better for standard deep learning and production pipelines.
  • Compilation overhead and shape-dependent retracing make JAX painful for variable-length sequences and interactive development.
  • Use JAX for numerical optimization and physics simulations; use PyTorch for vision, NLP, and anything with complex data loading.

The JIT Wall

Run this PyTorch code and time it:

import torch
import time

def matmul_chain(x):
    for _ in range(100):
        x = x @ x
    return x

x = torch.randn(512, 512, device='cuda')
start = time.perf_counter()
y = matmul_chain(x)
torch.cuda.synchronize()
print(f"PyTorch: {time.perf_counter() - start:.4f}s")

Now the JAX equivalent:

import jax
import jax.numpy as jnp
import time

@jax.jit
def matmul_chain(x):
    for _ in range(100):
        x = x @ x
    return x

x = jax.random.normal(jax.random.PRNGKey(0), (512, 512))
start = time.perf_counter()
y = matmul_chain(x).block_until_ready()
print(f"JAX: {time.perf_counter() - start:.4f}s")

On my RTX 3090, PyTorch takes ~0.0087s. JAX clocks in at ~0.0031s.

That’s nearly 3x faster. But here’s the thing: run JAX’s version a second time without the @jax.jit decorator, and you’ll see ~0.0095s — actually slower than PyTorch. The speedup isn’t magic. It’s XLA, JAX’s compiler, turning your Python loop into a single fused GPU kernel. PyTorch 2.0 has torch.compile() now, which uses TorchInductor to pull off similar tricks, but JAX had JIT compilation baked in from day one.

The performance gap you feel when using JAX comes down to how often the compiler can help you, and how much overhead you’re willing to tolerate.

Wooden letter tiles spelling
Photo by Markus Winkler on Pexels

Functional Purity as a Compiler Hint

JAX enforces functional programming. No in-place mutations, no hidden state. Every function is a pure mapping from inputs to outputs:

# PyTorch: in-place update is fine
model.weight.data += learning_rate * grad

# JAX: you get a new copy
params = jax.tree_map(lambda p, g: p + learning_rate * g, params, grads)

This feels annoying at first. But it’s exactly what XLA needs to optimize aggressively. When the compiler knows a function has no side effects, it can reorder operations, fuse kernels, and eliminate redundant computation without worrying about breaking your code.

PyTorch’s eager execution model prioritizes flexibility. You can mutate tensors, use Python control flow, and debug with print statements. JAX’s jax.jit can handle Python control flow via tracing, but if you use data-dependent conditionals (like if x.sum() > 0), the compiled function will re-trace every time the branch changes. PyTorch doesn’t care — it just runs whatever you wrote.

In practice, JAX’s purity constraint pays off when you’re writing tight numerical loops: gradient steps, ODE solvers, ray tracing, PDE discretizations. For messy research code with lots of conditional logic and debugging prints, PyTorch’s flexibility wins.

VMAP: The Killer Feature

jax.vmap is the reason I keep coming back to JAX for certain tasks. It vectorizes a function over a batch dimension without you writing a single loop:

import jax
import jax.numpy as jnp

def single_particle_update(state, force):
    # state: (2,) position vector
    # force: (2,) force vector
    velocity = jnp.array([state[0], state[1]]) + force * 0.01
    return velocity

# 10000 particles, each a (2,) position
states = jax.random.normal(jax.random.PRNGKey(0), (10000, 2))
forces = jax.random.normal(jax.random.PRNGKey(1), (10000, 2))

# Vectorize over axis 0
batch_update = jax.vmap(single_particle_update)
new_states = batch_update(states, forces)

No explicit batching logic. JAX figures out that states and forces both have a leading dimension of 10000 and broadcasts the operation. Combine this with jax.jit and you get a fused kernel that runs as fast as hand-written batched CUDA.

PyTorch has torch.vmap, introduced in PyTorch 1.11 (still prototype as of 2.0). It works, but the API is less mature and the error messages are cryptic when you hit edge cases. I’ve had better luck just writing explicit batch dimensions in PyTorch.

The real win for vmap is when you’re computing per-example gradients for differential privacy or Hessian-vector products. In PyTorch, you’d loop over the batch and backprop individually (slow), or use hooks and gradient accumulation tricks (fragile). In JAX:

def loss_fn(params, x, y):
    pred = model_apply(params, x)
    return jnp.mean((pred - y) ** 2)

# Per-example gradients via vmap
per_example_grads = jax.vmap(
    jax.grad(loss_fn, argnums=0), 
    in_axes=(None, 0, 0)
)(params, batch_x, batch_y)

This is clean, composable, and fast. PyTorch’s functorch was trying to replicate this (and it worked decently), but as of PyTorch 2.1, functorch is being merged into core PyTorch, so the API is in flux.

Memory Layout Surprises

Here’s a gotcha I hit when porting a transformer from PyTorch to JAX: memory consumption spiked by 40% on the first forward pass.

JAX’s default memory allocator pre-allocates 75% of GPU memory on startup (controllable via XLA_PYTHON_CLIENT_PREALLOCATE=false, but then you pay allocation overhead on every kernel launch). PyTorch uses CUDA’s caching allocator, which grows as needed.

For interactive development in Jupyter, PyTorch’s approach feels nicer. You can run multiple experiments in separate cells without restarting the kernel. JAX will hog VRAM even if you’re not using it.

But JAX’s preallocation avoids fragmentation. On long training runs (12+ hours), I’ve seen PyTorch hit out-of-memory errors after thousands of iterations because the allocator fragmented the heap. JAX doesn’t have this problem.

Another memory quirk: JAX arrays are immutable, so operations like x = x + 1 create a new array. The old one gets garbage-collected, but if you’re in a jitted function, XLA’s optimizer often eliminates the copy. Outside of JIT, you’re burning memory on every update. PyTorch’s in-place ops (x += 1) avoid this entirely.

When PyTorch Wins

I tried writing a vision transformer training loop in JAX for a project last year. The model definition was fine, but handling data loading, learning rate schedules, and checkpoint saving felt like duct-taping three libraries together (tensorflow_datasets, custom LR schedules via optax, orbax.checkpoint).

PyTorch’s ecosystem is just more mature. torchvision.datasets, torch.optim schedulers, torch.utils.checkpoint for gradient checkpointing, torch.distributed for multi-GPU — all of it works out of the box. JAX’s equivalents exist (Flax, Optax, Orbax), but you’re often reading GitHub issues to figure out why something broke.

Debugging is another pain point. JAX’s tracing means your breakpoints fire during compilation, not execution. If you set a pdb.breakpoint() inside a jitted function, it’ll stop once during tracing, show you abstract shapes, then never fire again on subsequent calls. PyTorch’s eager mode just works: set a breakpoint, inspect tensors, modify values interactively.

Model surgery is trivial in PyTorch. Want to freeze the first 10 layers of a ResNet? for param in model.layers[:10].parameters(): param.requires_grad = False. JAX doesn’t have mutable parameters, so you’d filter the pytree manually:

import jax
from flax import traverse_util

# Freeze parameters whose path contains 'layers_0' through 'layers_9'
def freeze_prefix(params):
    flat = traverse_util.flatten_dict(params)
    frozen = {k: v for k, v in flat.items() if any(f'layers_{i}' in str(k) for i in range(10))}
    trainable = {k: v for k, v in flat.items() if k not in frozen}
    return traverse_util.unflatten_dict(trainable)

It works, but it’s verbose.

Gradient Computation Differences

JAX’s jax.grad returns a function that computes gradients. PyTorch’s tensor.backward() populates .grad attributes. This seems like a minor API difference, but it changes how you think about optimization.

In JAX, you explicitly compute gradients and apply them:

grad_fn = jax.grad(loss_fn)
for step in range(1000):
    grads = grad_fn(params, batch)
    params = jax.tree_map(lambda p, g: p - 0.001 * g, params, grads)

In PyTorch, gradients accumulate in-place:

for step in range(1000):
    optimizer.zero_grad()
    loss = loss_fn(model, batch)
    loss.backward()
    optimizer.step()

JAX’s functional style makes it trivial to compute higher-order gradients. Want the Hessian? jax.jacfwd(jax.grad(f)). Gradient of a gradient for meta-learning? jax.grad(jax.grad(f)). PyTorch supports this via torch.autograd.grad with create_graph=True, but it’s clunkier and you have to manually manage the graph.

On the flip side, JAX’s jax.grad only differentiates with respect to the first argument by default. If your loss takes (params, data, labels) and you want gradients w.r.t. params, you’re fine. Want gradients w.r.t. data for adversarial examples? Pass argnums=1. Forget that and you’ll get a cryptic error about “differentiated w.r.t. a non-differentiable argument.”

Abstract 3D render visualizing artificial intelligence and neural networks in digital form.
Photo by Google DeepMind on Pexels

Compilation Overhead

The first time you call a jitted JAX function, it traces and compiles. For a small MLP, this takes ~0.5 seconds. For a ViT-Large, I’ve seen 8+ seconds. Subsequent calls are instant because JAX caches the compiled XLA binary.

But if you pass a different input shape, JAX recompiles. This is brutal for variable-length sequences. In NLP, you often pad to the longest sequence in a batch. If batch 1 has max length 128 and batch 2 has max length 256, JAX recompiles. PyTorch doesn’t care — it just allocates more memory.

You can mitigate this with static arguments (jax.jit(f, static_argnums=(1,))) or by padding all batches to a fixed length, but both feel like workarounds.

PyTorch 2.0’s torch.compile() has similar recompilation issues, though the caching is smarter. I haven’t hit as many shape-related recompiles in PyTorch as I have in JAX.

Random Number Generation

JAX’s RNG system is stateless. Every random call requires an explicit key:

key = jax.random.PRNGKey(42)
key, subkey = jax.random.split(key)
x = jax.random.normal(subkey, (100,))

This is mathematically sound — you can reproduce any random sequence by splitting the same root key — but it’s tedious. Forget to split the key and you’ll reuse the same random numbers every iteration.

PyTorch uses a global RNG state:

torch.manual_seed(42)
x = torch.randn(100)

Easier to use, harder to parallelize safely. If you’re running multiple experiments in threads, PyTorch’s global state can cause race conditions. JAX’s explicit keys avoid this.

I’ve debugged training runs where I forgot to split the JAX key inside a loop. The model trained, but used identical dropout masks every step. Loss plateaued after a few epochs. Took me an hour to realize the bug because there was no error message — just silently wrong behavior.

Optimizer State Management

PyTorch optimizers hold internal state (momentum buffers, Adam’s first/second moments) and mutate it in-place:

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
optimizer.step()

JAX optimizers (via Optax) are pure functions:

import optax

tx = optax.adam(1e-3)
opt_state = tx.init(params)

for step in range(1000):
    grads = jax.grad(loss_fn)(params, batch)
    updates, opt_state = tx.update(grads, opt_state)
    params = optax.apply_updates(params, updates)

You thread opt_state through every iteration. This is more verbose, but it makes checkpointing trivial — just save (params, opt_state) and you can resume exactly where you left off. PyTorch can do this via optimizer.state_dict(), but JAX’s approach is more explicit.

The downside: if you accidentally overwrite opt_state or forget to return it from a function, your optimizer resets. I’ve lost hours to this mistake.

Mixed Precision Training

PyTorch’s torch.cuda.amp.autocast() automatically casts ops to FP16 where safe:

from torch.cuda.amp import autocast, GradScaler

scaler = GradScaler()
for batch in dataloader:
    with autocast():
        loss = model(batch)
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

JAX doesn’t have built-in AMP. You manage precision manually, usually by casting params and activations in the model definition. Libraries like Flax provide helpers (flax.linen.DenseGeneral has a dtype argument), but you’re responsible for loss scaling.

I’ve found PyTorch’s AMP more ergonomic for mixed precision. JAX gives you more control (useful for custom quantization schemes), but the boilerplate adds up.

Multi-GPU Training

PyTorch’s DistributedDataParallel is the de facto standard:

import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

dist.init_process_group(backend='nccl')
model = DDP(model, device_ids=[local_rank])

JAX uses pmap for data parallelism:

@jax.pmap
def train_step(state, batch):
    grads = jax.grad(loss_fn)(state.params, batch)
    new_state = state.apply_gradients(grads=grads)
    return new_state

# Replicate state across 8 GPUs
state = jax.device_put_replicated(initial_state, jax.devices())
for batch in dataloader:
    state = train_step(state, batch)  # Runs on all devices

pmap is elegant for single-node multi-GPU, but multi-node training is rougher. You need to set up jax.distributed.initialize() and manage collectives manually. PyTorch’s tooling (torchrun, RANK, WORLD_SIZE env vars) is more mature.

For model parallelism (sharding large models), JAX’s pjit is powerful but complex. PyTorch’s FSDP (Fully Sharded Data Parallel) is easier to use for transformer training.

When JAX Actually Wins

Despite the rough edges, JAX shines in a few scenarios:

  1. Scientific computing with complex gradients. If you’re differentiating through ODEs (neural ODEs, Hamiltonian dynamics), computing Fisher information matrices, or doing Bayesian inference with variational objectives, JAX’s functional transformations (grad, vmap, jacfwd) compose beautifully. PyTorch can do this, but the code gets messy fast.

  2. Numerical optimization loops. Iterative algorithms like L-BFGS, trust-region methods, or ray marching benefit massively from JIT. I ported a photometric stereo optimizer from NumPy to JAX and saw a 15x speedup (from 3 minutes to 12 seconds per scene on an A100) with minimal code changes.

  3. Research on new architectures. If you’re implementing a paper from scratch and need to tweak gradients (e.g., straight-through estimators, stop-gradient tricks), JAX’s explicit gradient API is clearer than PyTorch’s hooks.

  4. Embedded differentiable physics. Simulating fluid dynamics, rigid body collisions, or cloth with autodiff? JAX’s performance and functional style make it easier to integrate physics engines. Projects like Brax (RL physics) and JAX-MD (molecular dynamics) wouldn’t work as well in PyTorch.

But if you’re training ResNets on ImageNet, fine-tuning LLMs, or doing standard supervised learning, PyTorch’s ecosystem wins. The performance gap has narrowed with torch.compile(), and the UX is just smoother.

FAQ

Q: Should I switch from PyTorch to JAX for deep learning?

Probably not, unless you’re doing scientific ML (physics simulations, ODE-based models) or need per-example gradients regularly. PyTorch’s ecosystem is more polished, and torch.compile() closes the performance gap. JAX is great for research on novel algorithms, but PyTorch is better for production pipelines.

Q: Does JAX work on Apple Silicon (M1/M2)?

Kind of. JAX runs on M1/M2 via the CPU backend (no Metal GPU support as of JAX 0.4.23). PyTorch has native MPS (Metal Performance Shaders) support since 1.12, so it’s faster on Apple Silicon. If you’re on a Mac, stick with PyTorch.

Q: Can I mix JAX and PyTorch in the same project?

Yes, but it’s awkward. You can convert arrays with jax.dlpack.to_dlpack() and torch.from_dlpack(). I’ve done this to use PyTorch dataloaders with JAX models. Works, but adds boilerplate. Better to pick one framework and commit.

Pick Your Pain

Use JAX if you’re writing custom numerical algorithms, need higher-order gradients, or value functional purity. Accept that you’ll fight the compiler occasionally, debug via print-in-tracing, and stitch together libraries for training infrastructure.

Use PyTorch if you’re training standard architectures, need a mature ecosystem, or want to debug interactively. You’ll miss vmap and sometimes wish torch.compile() was smarter, but you’ll ship models faster.

I keep both in my toolbox. For RL research (where I’m writing custom policy gradients and value function targets), JAX’s grad/vmap combo is unbeatable. For vision baselines and transfer learning, PyTorch is my default.

The one thing I’m still unsure about: whether JAX’s functional style is genuinely better for large teams, or just a trade-off that makes some bugs harder and others easier. My best guess is that it helps when your codebase is mostly math (loss functions, optimizers, simulation), but hurts when you’re gluing together datasets, logging, and deployment code. The ecosystem will decide which philosophy wins, and right now PyTorch has momentum.

Did you find this helpful?

☕ Buy me a coffee

Comments

Leave a Reply

Your email address will not be published. Required fields are marked *

TODAY 248 | TOTAL 3,862