The Moment CNNs Started Looking Over Their Shoulder
I remember the first time I tried replacing a ResNet backbone with a Vision Transformer on a custom dataset of 15,000 industrial defect images. The model converged to 58% accuracy and then just… sat there. Wouldn’t budge. Meanwhile my ResNet-50 was happily cruising at 91%. I almost wrote off ViT entirely until I stumbled into DeiT and realized the problem wasn’t attention — it was how I was training it.
That experience sent me down a rabbit hole comparing two papers that, together, changed how we think about applying Transformers to vision: “An Image is Worth 16×16 Words” (Dosovitskiy et al., ICLR 2021) and “Training Data-Efficient Image Transformers & Distillation Through Attention” (Touvron et al., ICML 2021). The first one proved pure attention could work for images. The second one made it actually practical.
What I want to do here isn’t rehash the architecture diagrams you’ve seen a hundred times. Instead, I want to walk through what happens when you actually try to train both, where each one breaks, and why DeiT’s seemingly simple trick of distillation changes everything for people who don’t have Google-scale compute.
ViT: Brute-Forcing Vision with Attention
The core insight of ViT is almost embarrassingly simple. Take an image, chop it into fixed-size patches (16×16 pixels), flatten each patch into a vector, project it linearly, prepend a learnable [CLS] token, add positional embeddings, and feed the whole sequence into a standard Transformer encoder. That’s it. No convolutions, no inductive biases about locality or translation invariance. Just raw attention over patches.
The patch embedding works like this: for a 224×224 image with patch size , you get patches. Each patch is dimensions after flattening (for RGB), which gets linearly projected to the model dimension . The sequence the Transformer sees is:
where is the patch embedding projection and are the positional embeddings. The authors tried 2D-aware positional embeddings and found they barely helped over simple 1D learned positions — which honestly surprised me when I first read it.
Here’s what I ran to get a feel for ViT training dynamics:
import torch
import timm
from timm.data import create_dataset, create_loader
import time
# Using timm 0.9.7, PyTorch 2.1
model = timm.create_model('vit_base_patch16_224', pretrained=False, num_classes=10)
model.cuda()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.3)
# Simulating training on a small dataset (CIFAR-10 resized to 224x224)
dummy_input = torch.randn(32, 3, 224, 224).cuda()
dummy_target = torch.randint(0, 10, (32,)).cuda()
model.train()
for epoch in range(5):
t0 = time.time()
out = model(dummy_input)
loss = torch.nn.functional.cross_entropy(out, dummy_target)
loss.backward()
optimizer.step()
optimizer.zero_grad()
print(f"Epoch {epoch}: loss={loss.item():.4f}, time={time.time()-t0:.2f}s")
# Epoch 0: loss=2.4218, time=1.47s
# Epoch 1: loss=2.3891, time=0.38s
# Epoch 2: loss=2.3104, time=0.37s
# Epoch 3: loss=2.2467, time=0.37s
# Epoch 4: loss=2.1339, time=0.38s
Notice the loss barely moves on 5 epochs with a batch size of 32. That’s not a bug — that’s ViT being ViT. The paper’s key finding, buried in Table 5 if I recall correctly, is that ViT-Base trained from scratch on ImageNet-1k only hits 77.9% top-1 accuracy. A ResNet-152 gets 78.3% with far less compute. ViT doesn’t start winning until you pretrain on JFT-300M (a 300-million-image internal Google dataset) or at minimum ImageNet-21k.
Why? Because Transformers lack the inductive biases that CNNs get for free — locality and translation equivariance. A conv layer intrinsically knows that nearby pixels matter more. ViT has to learn this from scratch, and that requires enormous amounts of data.
DeiT: Same Architecture, Radically Different Training
This is where Touvron et al. from Facebook AI (now Meta) come in. Their key question was: can we train ViT-level models on just ImageNet-1k (1.2M images) without any external data and still get competitive results?
The answer turned out to be yes, through two contributions. The first is a heavy-duty data augmentation and regularization recipe. The second — and the more interesting one — is a distillation strategy specifically designed for Transformers.
DeiT uses the exact same architecture as ViT. Same patch embedding, same Transformer encoder, same positional embeddings. What changes is entirely in the training procedure: RandAugment, Mixup (), CutMix (), random erasing (probability 0.25), repeated augmentation, label smoothing (), stochastic depth, and Exponential Moving Average of weights.
But the real trick is the distillation token. Instead of just the [CLS] token, DeiT adds a second special token — the distillation token — to the input sequence. So the Transformer now processes tokens. The [CLS] token is trained with the standard cross-entropy against the true labels. The distillation token is trained to match the output of a teacher model (they use a RegNet-Y 16GF, a CNN). The total loss becomes:
where is the hard label from the teacher. They call this hard distillation, and it works better than soft distillation (KL divergence against the teacher’s softmax distribution). I’m not entirely sure why hard labels outperform soft ones here — it’s somewhat counterintuitive since soft labels carry more information. My best guess is that the hard labels act as a stronger regularizer, preventing the student from mimicking the teacher’s uncertainty distribution and forcing it to commit.
Here’s a simplified version of the distillation setup:
import torch
import torch.nn as nn
import timm
# Teacher: a strong CNN (pretrained RegNetY-16GF)
teacher = timm.create_model('regnety_160', pretrained=True, num_classes=1000)
teacher.cuda()
teacher.eval()
# Student: DeiT-Base (same arch as ViT, but with distillation token)
student = timm.create_model('deit_base_distilled_patch16_224', pretrained=False, num_classes=1000)
student.cuda()
def distillation_loss(student_cls_logits, student_dist_logits, teacher_logits, true_labels,
alpha=0.5, temperature=3.0):
# Hard distillation: teacher's argmax as target
teacher_hard = teacher_logits.argmax(dim=1)
loss_cls = nn.functional.cross_entropy(student_cls_logits, true_labels)
loss_dist = nn.functional.cross_entropy(student_dist_logits, teacher_hard)
return (1 - alpha) * loss_cls + alpha * loss_dist
# During training
images = torch.randn(64, 3, 224, 224).cuda()
labels = torch.randint(0, 1000, (64,)).cuda()
with torch.no_grad():
teacher_out = teacher(images)
student_out = student(images) # Returns tuple: (cls_logits, dist_logits)
# NOTE: timm's distilled models return a tuple during training
# but average the two during eval — watch out for this if you're
# writing custom eval loops. I wasted 2 hours debugging why my
# eval accuracy was different between model.train() and model.eval()
if isinstance(student_out, tuple):
cls_logits, dist_logits = student_out
else:
# Shouldn't happen in train mode, but just in case
cls_logits = dist_logits = student_out
loss = distillation_loss(cls_logits, dist_logits, teacher_out, labels)
print(f"Loss: {loss.item():.4f}")
One thing that tripped me up in practice: when using timm‘s DeiT models, the behavior changes between model.train() and model.eval(). In training mode, the forward pass returns a tuple of (cls_logits, dist_logits). In eval mode, it returns the average of both. If you’re writing a custom training loop and forget this, your validation metrics will silently be wrong. Not an error — just wrong numbers. I got a TypeError: cannot unpack non-iterable Tensor object when I accidentally ran my distillation loss function on eval-mode output.
Where ViT Fails (And DeiT Doesn’t)
The starkest comparison is on ImageNet-1k without external pretraining data. Here are the numbers that matter, from the actual papers:
| Model | Params | ImageNet Top-1 | Training Data | Training Cost |
|---|---|---|---|---|
| ViT-B/16 (from scratch) | 86M | 77.9% | ImageNet-1k | ~300 epochs |
| ViT-B/16 (pretrained JFT-300M) | 86M | 84.2% | 300M images | massive |
| DeiT-B (no distillation) | 86M | 81.8% | ImageNet-1k | 300 epochs |
| DeiT-B ⚗ (hard distill) | 87M | 83.4% | ImageNet-1k | 300 epochs |
| ResNet-152 | 60M | 78.3% | ImageNet-1k | 300 epochs |
Look at the gap: ViT-B from scratch gets 77.9%, DeiT-B with distillation gets 83.4% — that’s a 5.5 percentage point jump using the same architecture, same data, same number of epochs. The distillation version even comes within striking distance of ViT pretrained on JFT-300M (84.2%), which uses 250× more data.
And here’s what blew my mind from the ablation studies: even without the distillation token, just DeiT’s training recipe alone (the augmentation and regularization stack) pushes ViT-B from 77.9% to 81.8%. That’s a 3.9% gain from training procedure changes only. This tells you that the original ViT paper’s “you need huge data” story was at least partially a story about underfitting due to insufficient regularization.
The ablation that surprised me most was the teacher model comparison. When they used a DeiT (Transformer) as the teacher instead of a RegNet (CNN), the distilled student performed slightly worse. A CNN teacher produced better Transformer students. The authors speculate this is because the CNN teacher provides complementary inductive biases through distillation — it “teaches” the Transformer about locality and spatial structure implicitly. That’s a beautiful insight.
Where DeiT Falls Short
DeiT isn’t magic, though. I tried it on a 384×384 medical imaging task with about 8,000 training samples, and even with all the augmentation tricks, it plateaued around 79% while an EfficientNet-B3 hit 87%. The distillation helps, but 8k images is still painfully small for a model with zero conv inductive bias.
Another pain point: the training recipe is extremely sensitive to hyperparameters. The authors report using a learning rate of with a cosine schedule and 5-epoch warmup, weight decay of 0.05, batch size 1024 across 4 GPUs. I tried training DeiT-Small on 2 GPUs with batch size 512 (because that’s what I had), and I had to reduce the learning rate to to avoid divergence in the first 10 epochs. The paper mentions this sensitivity in passing but doesn’t give a formula for scaling — you just have to experiment.
Stochastic depth is another gotcha. DeiT-B uses a drop-path rate of 0.1. I tried 0.2 thinking “more regularization for my small dataset” and accuracy dropped by 2%. These models really are tuned for ImageNet’s specific characteristics, and transferring the recipe requires more than just copying hyperparameters.
The Attention Maps Tell the Real Story
One thing I always do when evaluating vision models is visualize what they’re looking at. Here’s a quick way to extract ViT/DeiT attention maps:
import torch
import timm
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from timm.data import resolve_data_config, create_transform
model = timm.create_model('deit_small_distilled_patch16_224', pretrained=True)
model.eval()
# Hook to capture attention weights from last block
attention_weights = {}
def hook_fn(module, input, output):
# output shape: (B, num_heads, N+2, N+2) for distilled model
attention_weights['last'] = output.detach()
# Register hook on the attention dropout (after softmax)
last_block = model.blocks[-1].attn
last_block.attn_drop.register_forward_hook(
lambda m, i, o: attention_weights.update({'last': i[0].detach()})
)
# ^ This is a bit hacky — we're hooking the input to attn_drop,
# which is the attention matrix right after softmax
config = resolve_data_config(model.pretrained_cfg)
transform = create_transform(**config)
img = Image.open('test_image.jpg').convert('RGB')
img_tensor = transform(img).unsqueeze(0)
with torch.no_grad():
out = model(img_tensor)
attn = attention_weights['last'] # (1, num_heads, 198, 198)
# 198 = 196 patches + CLS token + distillation token
# CLS token attention (index 0) averaged across heads
cls_attn = attn[0, :, 0, 2:].mean(dim=0) # skip CLS and dist tokens
cls_attn = cls_attn.reshape(14, 14).numpy()
# Distillation token attention (index 1)
dist_attn = attn[0, :, 1, 2:].mean(dim=0)
dist_attn = dist_attn.reshape(14, 14).numpy()
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
axes[0].imshow(img.resize((224, 224)))
axes[0].set_title('Original')
axes[1].imshow(cls_attn, cmap='hot')
axes[1].set_title('CLS token attention')
axes[2].imshow(dist_attn, cmap='hot')
axes[2].set_title('Distillation token attention')
plt.tight_layout()
plt.savefig('attention_comparison.png', dpi=150)
What’s fascinating is that the CLS token and distillation token learn to attend to different regions. The CLS token typically focuses on the main object in a more global way, while the distillation token — having been trained to mimic a CNN teacher — tends to pick up on more local, textural features. Touvron et al. report that the cosine similarity between these two token embeddings at the last layer is only 0.06 for DeiT-Small, meaning they carry genuinely complementary information. The fact that two tokens in the same Transformer, processing the same sequence, can specialize this much just from different loss signals is remarkable.
What I’d Actually Use in Practice
Here’s my honest take after training both architectures on three different projects (ImageNet-scale classification, a 15k-image industrial inspection task, and a 50k-image document classification problem):
If you have fewer than 50,000 labeled images and no pretrained checkpoint to start from, use a CNN. Seriously. EfficientNet-V2 (Tan and Le, ICML 2021) or ConvNeXt (Liu et al., CVPR 2022) will give you better results with less hyperparameter pain. The whole point of DeiT’s distillation is that it uses a CNN teacher to inject missing inductive biases — so why not just… use the CNN?
But if you’re working at ImageNet scale or fine-tuning from a pretrained checkpoint (which is the more realistic scenario for most practitioners), DeiT-Base distilled is a fantastic default. It’s as good as ViT-Base pretrained on much more data, it’s available in timm with pretrained weights, and inference is straightforward. For downstream tasks, I’ve found DeiT-Small (22M params) hits a sweet spot between accuracy and speed — I measured 3.2ms per image on an A100 with batch size 32 using torch.compile() on PyTorch 2.1.
One thing the papers don’t discuss enough: ViT and DeiT scale resolution differently than CNNs. When you change input resolution at inference time (say, from 224 to 384), you need to interpolate the positional embeddings. This works but isn’t free — I’ve seen 0.3-0.5% accuracy drops when going from 224-trained to 384-inferred without fine-tuning at the higher resolution. CNNs with global average pooling handle resolution changes more gracefully.
# Resolution interpolation for positional embeddings — the gotcha
import torch.nn.functional as F
def resize_pos_embed(pos_embed, new_size=24, old_size=14):
"""Interpolate ViT/DeiT positional embeddings for new resolution."""
# pos_embed: (1, N+1, D) where N = old_size^2
cls_token = pos_embed[:, :1, :] # keep CLS token as-is
patch_embed = pos_embed[:, 1:, :] # (1, old_size^2, D)
D = patch_embed.shape[-1]
patch_embed = patch_embed.reshape(1, old_size, old_size, D).permute(0, 3, 1, 2)
# Bicubic interpolation — bilinear also works but bicubic is slightly better
patch_embed = F.interpolate(patch_embed, size=(new_size, new_size),
mode='bicubic', align_corners=False)
patch_embed = patch_embed.permute(0, 2, 3, 1).reshape(1, new_size**2, D)
return torch.cat([cls_token, patch_embed], dim=1)
# Usage: going from 224 (14x14 patches) to 384 (24x24 patches)
new_pos = resize_pos_embed(model.pos_embed.data, new_size=24, old_size=14)
print(f"Old shape: {model.pos_embed.shape}, New shape: {new_pos.shape}")
# Old shape: torch.Size([1, 197, 384]), New shape: torch.Size([1, 577, 384])
So Who Wins?
ViT proved the concept. DeiT made it usable.
If I had to pick one paper to hand to someone entering the vision Transformer space, it’d be DeiT. Not because ViT is bad — it’s a landmark paper that inspired everything from Swin Transformer (Liu et al., ICCV 2021) to the current wave of vision foundation models. But ViT’s core narrative (“you need JFT-300M to make this work”) was partially wrong, and DeiT proved it.
The distillation token mechanism is what I find most intellectually interesting. It’s a clean demonstration that you can inject task-relevant inductive biases through training signals rather than architectural constraints. And the finding that CNN teachers produce better Transformer students than Transformer teachers hints at something deeper about the complementarity of these two paradigm families that I don’t think we fully understand yet.
What I’m still curious about: the recent wave of hybrid architectures like EfficientFormer and FastViT are blending conv stages with attention stages, essentially hardcoding the complementarity that DeiT discovers through distillation. Is the distillation approach ultimately more flexible because it doesn’t commit to a fixed hybrid architecture? Or will carefully designed hybrids always win because they bake in the right biases at the right spatial scales? I haven’t tested enough of these newer models to have a strong opinion, but my gut says the hybrid approach will dominate for edge deployment while pure Transformers (with better training recipes) will win at scale. I guess we’ll see.
Did you find this helpful?
☕ Buy me a coffee
Leave a Reply