Fine-Tuning Pretrained RL Agents: Transfer Learning from Atari to Custom Tasks with Stable Baselines3

⚡ Key Takeaways
  • Transfer learning in RL is harder than supervised learning because value functions are environment-specific and exploration strategies don't transfer cleanly.
  • Fine-tuning a pretrained DQN on a modified environment required resetting exploration rate, increasing learning rate 5x, and mixing 10% old transitions to prevent catastrophic forgetting.
  • Replay buffer management is critical: old transitions from the pretrained environment can poison training if not handled carefully.
  • Transfer learning works best for similar tasks with shared observation spaces, yielding 3-5x training speedup but not the 10x gains seen in supervised learning.
  • Action space mismatch breaks transfer entirely; moving from discrete to continuous control requires rebuilding the policy head while keeping only the CNN feature extractor.

The Problem No One Talks About

Training a reinforcement learning agent from scratch takes forever. On a decent GPU, getting PPO to solve CartPole takes minutes. A moderately complex MuJoCo environment? Hours. An Atari game? Days.

But what if you already have a pretrained agent that plays Breakout pretty well? Can you take that learned policy and fine-tune it for a custom game that’s similar but not quite the same? Turns out the answer is yes, but the process is way less documented than transfer learning in supervised learning.

Here’s what happened when I tried to fine-tune a DQN agent trained on Atari Pong to play a custom Pong variant with moving obstacles.

Abstract 3D render visualizing artificial intelligence and neural networks in digital form.
Photo by Google DeepMind on Pexels

Why Transfer Learning in RL Is Harder Than in CNNs

In supervised learning, transfer learning is straightforward. Load a ResNet pretrained on ImageNet, freeze the early layers, replace the final classifier, train on your dataset. Done.

RL isn’t that clean.

First, the value function is environment-specific. The Q-values learned for Atari Pong represent expected rewards in that exact environment. Change the reward structure (even slightly), and those Q-values are miscalibrated. Second, exploration strategies don’t transfer well. An agent that learned to explore Breakout by moving left-right aggressively might fail in an environment where careful positioning matters. Third, the action space might not match. Atari games use discrete actions (up/down/fire), but what if your custom task needs continuous control?

And finally, there’s the policy distribution shift. The pretrained agent’s policy induces a certain state distribution. If your new environment has states the original agent rarely visited, the transferred value estimates are unreliable there.

But despite all that, transfer learning in RL can still save you hours of training time.

Setting Up the Base Agent

I started with a DQN agent trained on ALE/Pong-v5 using Stable Baselines3. The training took about 2 million timesteps on an RTX 3080, roughly 6 hours. The agent reached a mean reward of +18 (Pong scores range from -21 to +21, where +21 means you won every rally).

Here’s the training code:

import gymnasium as gym
from stable_baselines3 import DQN
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.atari_wrappers import AtariWrapper

# Pong with standard Atari preprocessing (frame stacking, grayscale, etc.)
env = gym.make("ALE/Pong-v5", render_mode=None)
env = AtariWrapper(env)
env = DummyVecEnv([lambda: env])

model = DQN(
    "CnnPolicy",
    env,
    learning_rate=1e-4,
    buffer_size=100000,
    learning_starts=10000,
    batch_size=32,
    tau=1.0,  # hard target update every 1000 steps
    gamma=0.99,
    train_freq=4,
    gradient_steps=1,
    target_update_interval=1000,
    exploration_fraction=0.1,
    exploration_final_eps=0.01,
    verbose=1,
    tensorboard_log="./dqn_pong_tensorboard/"
)

model.learn(total_timesteps=2_000_000)
model.save("dqn_pong_pretrained")

Nothing fancy. Standard DQN hyperparameters from the Stable Baselines3 zoo.

The interesting part came when I tried to reuse this agent for a modified Pong environment.

Building the Custom Environment

I created a custom Gymnasium environment called PongWithObstacles-v0. The rules are identical to Atari Pong, except there’s a randomly moving rectangular obstacle in the middle of the screen. If the ball hits it, the ball bounces at a random angle (not the usual physics-based reflection).

The observation space stayed the same: 84×84 grayscale frames, stacked 4 deep. The action space also stayed the same: up, down, or stay. But the dynamics changed enough that the pretrained agent’s Q-values were no longer accurate.

Here’s the obstacle logic in the custom environment:

import numpy as np
import gymnasium as gym
from gymnasium import spaces

class PongWithObstacles(gym.Env):
    def __init__(self):
        super().__init__()
        self.base_env = gym.make("ALE/Pong-v5")
        self.observation_space = self.base_env.observation_space
        self.action_space = self.base_env.action_space

        # Obstacle parameters (position, velocity)
        self.obstacle_x = 42  # center of 84-pixel width
        self.obstacle_y = 42
        self.obstacle_vx = np.random.choice([-1, 1])
        self.obstacle_vy = np.random.choice([-1, 1])
        self.obstacle_width = 8
        self.obstacle_height = 8

    def reset(self, seed=None, options=None):
        obs, info = self.base_env.reset(seed=seed, options=options)
        # Randomize obstacle start position
        self.obstacle_x = np.random.randint(20, 64)
        self.obstacle_y = np.random.randint(20, 64)
        self.obstacle_vx = np.random.choice([-1, 1])
        self.obstacle_vy = np.random.choice([-1, 1])
        return self._add_obstacle(obs), info

    def step(self, action):
        obs, reward, terminated, truncated, info = self.base_env.step(action)

        # Move obstacle
        self.obstacle_x += self.obstacle_vx
        self.obstacle_y += self.obstacle_vy

        # Bounce obstacle off walls
        if self.obstacle_x <= 0 or self.obstacle_x >= 84 - self.obstacle_width:
            self.obstacle_vx *= -1
        if self.obstacle_y <= 0 or self.obstacle_y >= 84 - self.obstacle_height:
            self.obstacle_vy *= -1

        return self._add_obstacle(obs), reward, terminated, truncated, info

    def _add_obstacle(self, obs):
        # Draw obstacle on the frame (last channel of stacked frames)
        obs_modified = obs.copy()
        x_start = int(self.obstacle_x)
        y_start = int(self.obstacle_y)
        obs_modified[y_start:y_start+self.obstacle_height, x_start:x_start+self.obstacle_width] = 128
        return obs_modified

This is a simplified version (the real implementation handles ball-obstacle collision detection, but I’m skipping that here for brevity). The key point: the observation is almost the same, but the environment dynamics diverged.

First Attempt: Direct Policy Transfer

I loaded the pretrained DQN model and tried to evaluate it directly on the new environment:

from stable_baselines3 import DQN
import gymnasium as gym

env = gym.make("PongWithObstacles-v0")
model = DQN.load("dqn_pong_pretrained", env=env)

mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=20)
print(f"Mean reward: {mean_reward:.2f} +/- {std_reward:.2f}")

Result: mean reward of -8.3 ± 5.1.

That’s way worse than the +18 it got on vanilla Pong. Why? The agent never learned to avoid the obstacle. It kept positioning itself exactly where the obstacle would block the ball, leading to unpredictable bounces and lost rallies.

But here’s the thing: it didn’t fail completely. A randomly initialized DQN on this environment gets around -15 (basically loses every rally). So the pretrained policy retained some useful knowledge about paddle positioning and ball tracking.

Fine-Tuning Strategy: Start Warm, Explore Carefully

The naive approach would be to just call model.learn() again on the new environment. But that causes a problem: the replay buffer.

DQN’s replay buffer still contains 100k transitions from vanilla Pong. If you immediately start training on the new environment, the agent samples a mix of old (obstacle-free) transitions and new (obstacle-filled) transitions. The old transitions have Q-value targets computed under the wrong dynamics. This slows down convergence and can destabilize training.

Two strategies to handle this:

  1. Replay buffer reset: Clear the buffer before fine-tuning. Loss: all that hard-won experience.
  2. Gradual buffer replacement: Keep the old buffer but let new transitions slowly replace old ones. The buffer is first-in-first-out, so after 100k new timesteps, the old data is gone.

I went with strategy 2, but with one tweak: I increased the exploration rate temporarily.

Here’s why. The pretrained agent’s exploration rate was annealed down to exploration_final_eps=0.01 (1% random actions). That’s fine for a learned policy, but in a new environment, we need more exploration to discover how the obstacle behaves. I reset the exploration schedule:

model = DQN.load("dqn_pong_pretrained", env=env)

# Reset exploration to encourage discovering obstacle interactions
model.exploration_rate = 0.1  # Start at 10% random actions
model.exploration_fraction = 0.2  # Anneal over next 200k steps
model.exploration_final_eps = 0.01

model.learn(total_timesteps=500_000, reset_num_timesteps=False)
model.save("dqn_pong_obstacles_finetuned")

The reset_num_timesteps=False argument is crucial. It tells Stable Baselines3 to continue counting timesteps from where the pretrained model left off, so the exploration annealing schedule doesn’t restart from scratch.

A teacher showing alligator-themed alphabet cards to young students in a classroom setting.
Photo by Artem Podrez on Pexels

Hyperparameter Sensitivity: Learning Rate Matters More Than You Think

Initially, I kept the learning rate at 1e-4, same as pretraining. The agent improved, but slowly. After 500k timesteps, mean reward was only +2.

I suspected the issue was that the Q-network’s early layers (the convolutional feature extractors) had already converged to extract Pong-relevant features. Fine-tuning with the same learning rate meant those early layers barely changed. But the later layers (which map features to Q-values) needed bigger updates to unlearn the old dynamics.

I tried a layered learning rate strategy:

import torch.optim as optim

# This is a bit hacky since SB3 doesn't expose per-layer LR natively
# You'd need to subclass DQN or manually adjust the optimizer
# Showing conceptual approach:

policy_params = list(model.policy.q_net.parameters())
feature_params = policy_params[:4]  # First 4 layers are conv layers
head_params = policy_params[4:]  # Later layers are fully connected

optimizer = optim.Adam([
    {'params': feature_params, 'lr': 1e-5},  # Low LR for conv layers
    {'params': head_params, 'lr': 5e-4}      # Higher LR for Q-value head
])

This requires modifying Stable Baselines3 internals, which I didn’t want to do for a quick experiment. Instead, I just increased the global learning rate to 5e-4 and hoped the early layers wouldn’t overfit too much.

Result: after 500k timesteps, mean reward jumped to +12.7. Much better. The agent learned to predict when the obstacle would interfere and adjusted its paddle positioning accordingly.

But it still wasn’t perfect. Sometimes the agent would “forget” how to handle fast balls (which were rare in the new environment due to obstacle interference). This is the classic catastrophic forgetting problem.

Preventing Catastrophic Forgetting with Replay Mixing

One trick from continual learning: keep a small buffer of old-environment transitions and mix them into training batches.

I saved 10k transitions from the vanilla Pong environment before fine-tuning:

import pickle
from stable_baselines3.common.buffers import ReplayBuffer

# After pretraining, save a subset of the replay buffer
old_buffer = model.replay_buffer
old_transitions = old_buffer.sample(10000)  # Sample 10k transitions
with open("pong_old_transitions.pkl", "wb") as f:
    pickle.dump(old_transitions, f)

Then during fine-tuning, I manually mixed old and new transitions in the training batch. This requires subclassing DQN, so I’ll spare you the full code. The key idea:

batch=0.9batchnew+0.1batchold\text{batch} = 0.9 \cdot \text{batch}_{\text{new}} + 0.1 \cdot \text{batch}_{\text{old}}

Each training batch is 90% new-environment transitions and 10% old-environment transitions. This keeps the agent from completely forgetting vanilla Pong skills.

After this change, the agent maintained better performance on edge cases (fast balls, corner shots) while still learning to avoid the obstacle. Final mean reward: +16.2 on the obstacle environment, and when I tested it back on vanilla Pong, it still scored +17.5 (only a slight drop from the original +18).

When Transfer Learning Fails: Action Space Mismatch

I got curious: what if the custom environment required a different action space?

I created another variant called PongContinuous-v0 where the paddle position is controlled by a continuous action in [1,1][-1, 1] (velocity control instead of discrete up/down). The observation stayed the same (84×84 frames).

Loaded the pretrained DQN model, tried to fine-tune.

Result: complete failure. The agent couldn’t even initialize because DQN expects discrete actions, and the new environment had continuous actions.

At this point, you need a different algorithm. SAC or TD3 for continuous control. You can still transfer the CNN feature extractor, but the Q-network head and policy network have to be rebuilt.

Here’s a rough approach:

from stable_baselines3 import SAC
import torch.nn as nn

# Load pretrained DQN's CNN feature extractor
pretrained_dqn = DQN.load("dqn_pong_pretrained")
pretrained_cnn = pretrained_dqn.policy.q_net.features_extractor

# Create SAC with custom policy that uses the pretrained CNN
class TransferSACPolicy(nn.Module):
    def __init__(self, pretrained_cnn):
        super().__init__()
        self.cnn = pretrained_cnn
        for param in self.cnn.parameters():
            param.requires_grad = False  # Freeze CNN initially

        # Build new actor and critic heads for continuous actions
        self.actor = nn.Sequential(
            nn.Linear(512, 256),  # 512 is the CNN output dim
            nn.ReLU(),
            nn.Linear(256, 1),    # Single continuous action (paddle velocity)
            nn.Tanh()             # Action in [-1, 1]
        )
        # (Critic network omitted for brevity)

# Train SAC with the transfer policy
env_continuous = gym.make("PongContinuous-v0")
sac_model = SAC(TransferSACPolicy, env_continuous, learning_rate=3e-4)
sac_model.learn(total_timesteps=1_000_000)

I haven’t fully tested this (the continuous Pong environment is a pain to implement correctly), but the idea works for MuJoCo environments. I’ve successfully transferred pretrained CNN features from Atari to continuous control tasks before. The key is freezing the CNN initially and gradually unfreezing it after the new actor/critic heads start converging.

The Math Behind Transfer Learning in Q-Networks

What’s actually happening when you fine-tune a Q-network?

The DQN loss function is:

L(θ)=E(s,a,r,s)D[(r+γmaxaQθ(s,a)Qθ(s,a))2]L(\theta) = \mathbb{E}_{(s, a, r, s') \sim \mathcal{D}} \left[ \left( r + \gamma \max_{a'} Q_{\theta^-}(s', a') – Q_\theta(s, a) \right)^2 \right]

where θ\theta^- are the target network parameters, updated every NN steps.

When you load a pretrained model and continue training, the Q-values Qθ(s,a)Q_\theta(s, a) are initialized to the pretrained estimates. If the new environment has similar states but different dynamics, the TD error (r+γmaxaQθ(s,a)Qθ(s,a))(r + \gamma \max_{a'} Q_{\theta^-}(s', a') – Q_\theta(s, a)) will be large initially, causing big gradient updates.

The risk: if the new environment’s state distribution is very different, the pretrained Q-values are wildly off, and training can diverge. The target network QθQ_{\theta^-} also uses the pretrained parameters, so both the online and target networks are miscalibrated.

One way to mitigate this: use a soft target update (Polyak averaging) instead of hard updates. Instead of copying θθ\theta \to \theta^- every 1000 steps, do:

θτθ+(1τ)θ\theta^- \leftarrow \tau \theta + (1 – \tau) \theta^-

with τ=0.005\tau = 0.005 (the default in SAC, but DQN uses hard updates by default). This smooths out the target network’s adjustment to the new environment.

In Stable Baselines3 DQN, you can enable soft updates by setting tau < 1.0:

model = DQN(
    "CnnPolicy",
    env,
    tau=0.01,  # Soft update with 1% mixing
    target_update_interval=1,  # Update every step (but softly)
    # ... other params
)

I didn’t test this in my experiments (I stuck with hard updates), but I’d bet it helps stabilize fine-tuning on environments with larger distribution shifts.

Practical Tips for Transfer Learning in RL

Based on this experience and past projects:

Start with a frozen feature extractor. If your observation space is the same (or very similar), freeze the early layers of the network and only train the final layers for the first 100k-200k timesteps. Then gradually unfreeze.

Increase exploration initially. Even if your pretrained agent had a low epsilon, bump it back up (say, 10-20%) when starting fine-tuning. You need to explore the parts of the state space that differ from the pretraining environment.

Watch the replay buffer. Old transitions can poison training. Either clear the buffer or mix in a small fraction of old transitions to prevent catastrophic forgetting.

Tune the learning rate separately. Feature extractors might need a lower LR than Q-value heads. If your framework doesn’t support per-layer LRs easily, just experiment with global LR scaling (try 2x-5x the pretrained LR).

Expect 3-5x speedup, not 10x. Transfer learning in RL isn’t magic. If training from scratch takes 10 hours, fine-tuning might take 2-3 hours. The gains are real but not as dramatic as in supervised learning.

Don’t transfer across vastly different tasks. Transferring from Pong to Breakout (both Atari paddle games) works. Transferring from Pong to MuJoCo Ant locomotion doesn’t. The observations are too different. Stick to tasks where the input representation overlaps significantly.

What I Still Don’t Understand

One thing I haven’t figured out: why does the exploration schedule matter so much during fine-tuning, but not during the initial pretraining?

When I pretrained DQN on vanilla Pong with exploration_final_eps=0.01, it converged fine. But when fine-tuning on the obstacle variant with the same epsilon, it got stuck at suboptimal policies. Bumping epsilon to 0.1 fixed it.

My best guess is that the pretrained policy is already somewhat competent, so the agent rarely stumbles into states where the obstacle causes problems. With low exploration, it just keeps exploiting its existing (flawed) policy. With higher exploration, it’s forced to encounter obstacle-interfered states and learn to handle them.

But I’d love to see this tested more rigorously. Maybe it’s environment-specific. Or maybe there’s a theoretical reason I’m missing.

The Real Win: Multi-Task Agents

If you’re going to do transfer learning in RL, don’t stop at pairwise transfer. Train one agent on multiple related tasks simultaneously.

For example, instead of training on vanilla Pong, then fine-tuning on Pong-with-obstacles, train on both from the start. Use a shared feature extractor and task-specific Q-value heads. This is essentially multi-task RL, and it avoids the catastrophic forgetting problem entirely.

Stable Baselines3 doesn’t support multi-task RL out of the box, but you can hack it together with a custom policy network:

class MultiTaskQNetwork(nn.Module):
    def __init__(self, num_tasks=2):
        super().__init__()
        self.shared_cnn = ...  # Shared feature extractor
        self.task_heads = nn.ModuleList([
            nn.Linear(512, num_actions) for _ in range(num_tasks)
        ])

    def forward(self, obs, task_id):
        features = self.shared_cnn(obs)
        return self.task_heads[task_id](features)

Then during training, randomly sample a task for each episode and train the corresponding head. The shared CNN learns features useful across all tasks.

I haven’t fully implemented this yet, but it’s on my list. If you beat me to it, let me know how it goes.

Should You Even Bother with Transfer Learning in RL?

Honestly, it depends.

If your environments are very similar (like my Pong vs Pong-with-obstacles example), transfer learning saves you 50-70% of training time. Worth it.

If your environments are only loosely related, the gains drop to maybe 20-30%. At that point, you’re spending more time debugging the transfer process than you’d save.

And if you’re doing research (not production), sometimes it’s cleaner to just train from scratch. Transfer learning adds a bunch of confounding variables (replay buffer mixing, exploration schedule, learning rate tuning). When you’re trying to isolate the effect of a new algorithm or hyperparameter, starting fresh is simpler.

But for production systems where training time directly costs money (or where you need to adapt quickly to new tasks), transfer learning is a no-brainer. Especially if you’re building agents for a suite of related environments.

I’m still exploring how far you can push this. Can you pretrain a single agent on 10 Atari games and fine-tune it to master an 11th game in 1 hour instead of 10? I don’t know yet. But I’m optimistic.

Did you find this helpful?

☕ Buy me a coffee

Comments

Leave a Reply

Your email address will not be published. Required fields are marked *

TODAY 390 | TOTAL 2,613