World Model이란?
강화학습(RL)에서 가장 큰 문제 중 하나는 샘플 효율(Sample Efficiency) 입니다. 실제 환경과 상호작용하며 데이터를 수집하는 비용이 매우 크기 때문에, 적은 데이터로 효과적으로 학습하는 것이 핵심입니다.
World Model은 환경의 동역학(dynamics)을 학습한 모델로, 실제 환경 대신 시뮬레이션된 경험을 통해 정책을 학습할 수 있게 합니다.
이를 통해 Model-Based RL(MBRL) 은 Model-Free RL 대비 10배 이상의 샘플 효율을 달성할 수 있습니다.
Model-Based vs Model-Free RL
| 구분 | Model-Free RL | Model-Based RL |
|---|---|---|
| 대표 알고리즘 | PPO, SAC, DQN | Dreamer v3, MBPO, MuZero |
| 환경 모델 | 불필요 | 필요 (World Model 학습) |
| 샘플 효율 | 낮음 (수백만 step) | 높음 (수만~수십만 step) |
| 계산 비용 | 낮음 | 높음 (모델 학습 + 정책 학습) |
| 적용 난이도 | 쉬움 | 어려움 (모델 오차 관리 필요) |
Model-Free는 직접 보상을 최대화하는 정책을 학습하지만, Model-Based는 환경 모델을 먼저 학습한 후 상상 속에서(in imagination) 정책을 개선합니다.
Dreamer v3: 잠재 공간에서의 상상
Dreamer v3는 DeepMind의 Danijar Hafner가 개발한 MBRL 알고리즘으로, 이미지 기반 환경에서 뛰어난 성능을 보입니다.
핵심 구조
Dreamer는 다음 세 가지 요소로 구성됩니다:
- World Model (RSSM): 관측(observation)을 저차원 잠재 공간(latent space)으로 인코딩하고, 잠재 상태의 동역학을 학습합니다.
- Actor: 잠재 상태에서 행동을 선택하는 정책 네트워크
- Critic: 잠재 상태의 가치를 추정하는 가치 함수
RSSM (Recurrent State-Space Model)
Dreamer의 핵심은 RSSM입니다. 잠재 상태 는 결정론적 부분 와 확률적 부분 로 구성됩니다:
여기서:
– : GRU/LSTM으로 계산되는 결정론적 은닉 상태
– : 확률 분포에서 샘플링되는 확률적 잠재 변수
– : 이전 시점의 행동
– : 순환 신경망(RNN)
이를 통해 Dreamer는 미래를 예측하고, 예측된 궤적에서 정책을 학습합니다.
Dreamer v3의 개선점
Dreamer v3는 다음과 같은 개선을 도입했습니다:
- Symlog 예측: 보상과 가치를 로 변환하여 큰 값에도 안정적으로 학습
- World Model의 3가지 손실:
- 예측 손실:
- 동역학 손실:
- 표현 손실:
- Return 정규화: 학습 안정성을 위해 반환값(return)을 정규화
실전 구현 예시
import torch
import torch.nn as nn
class RSSM(nn.Module):
def __init__(self, action_dim, hidden_dim=200, stochastic_dim=30):
super().__init__()
self.gru = nn.GRUCell(stochastic_dim + action_dim, hidden_dim)
self.fc_prior = nn.Linear(hidden_dim, stochastic_dim * 2) # mean, std
self.fc_posterior = nn.Linear(hidden_dim + obs_embed_dim, stochastic_dim * 2)
def imagine_step(self, prev_state, prev_action):
"""상상 속에서 다음 상태 예측"""
h, s = prev_state
# 결정론적 상태 업데이트
h = self.gru(torch.cat([s, prev_action], -1), h)
# 확률적 상태 샘플링
prior_mean, prior_std = self.fc_prior(h).chunk(2, -1)
prior_std = F.softplus(prior_std) + 0.1
s = prior_mean + prior_std * torch.randn_like(prior_mean)
return (h, s)
def observe_step(self, prev_state, prev_action, obs_embed):
"""실제 관측을 반영한 상태 업데이트"""
h, _ = prev_state
h = self.gru(torch.cat([prev_state[1], prev_action], -1), h)
# Posterior 분포
post_mean, post_std = self.fc_posterior(
torch.cat([h, obs_embed], -1)
).chunk(2, -1)
post_std = F.softplus(post_std) + 0.1
s = post_mean + post_std * torch.randn_like(post_mean)
return (h, s)
MBPO: 모델 예측의 불확실성 관리
MBPO (Model-Based Policy Optimization) 는 UC Berkeley의 연구팀이 개발한 알고리즘으로, 짧은 롤아웃(short rollout) 을 통해 모델 오차를 최소화합니다.
핵심 아이디어
- 실제 환경에서 소량의 데이터 수집
- 앙상블 동역학 모델 학습 (보통 5~7개 모델)
- 모델에서 k-step 롤아웃 생성 (k=1~5)
- 생성된 가상 데이터로 Model-Free RL (SAC 등) 학습
핵심: 모델이 부정확할 수 있으므로, 짧은 예측만 사용하여 오차 누적을 방지합니다.
앙상블 모델의 불확실성 추정
앙상블 의 예측 분산을 통해 불확실성을 측정합니다:
불확실성이 높은 영역에서는 짧은 롤아웃을, 낮은 영역에서는 긴 롤아웃을 사용할 수 있습니다.
MBPO 구현 예시
import numpy as np
from stable_baselines3 import SAC
class MBPO:
def __init__(self, env, ensemble_size=5, rollout_length=1):
self.env = env
self.models = [DynamicsModel() for _ in range(ensemble_size)]
self.policy = SAC('MlpPolicy', env)
self.rollout_length = rollout_length
self.replay_buffer = ReplayBuffer()
def train_step(self):
# 1. 실제 환경에서 데이터 수집
real_data = self.collect_real_data(num_steps=1000)
self.replay_buffer.add(real_data)
# 2. 앙상블 모델 학습
for model in self.models:
model.train(real_data)
# 3. 모델에서 가상 롤아웃 생성
model_data = self.generate_rollouts(
start_states=real_data['states'],
rollout_length=self.rollout_length
)
self.replay_buffer.add(model_data)
# 4. SAC 정책 학습
for _ in range(20): # 20번의 gradient step
batch = self.replay_buffer.sample(256)
self.policy.train_on_batch(batch)
def generate_rollouts(self, start_states, rollout_length):
"""앙상블에서 랜덤 선택하여 k-step 롤아웃"""
rollouts = []
for s0 in start_states:
state = s0
for _ in range(rollout_length):
action = self.policy.predict(state)
# 앙상블에서 랜덤 선택
model = np.random.choice(self.models)
next_state, reward = model.predict(state, action)
rollouts.append((state, action, reward, next_state))
state = next_state
return rollouts
샘플 효율 비교
실제 벤치마크에서 MBRL의 샘플 효율을 확인할 수 있습니다:
| 환경 | SAC (Model-Free) | MBPO | Dreamer v3 |
|---|---|---|---|
| Walker2d | 300K steps | 30K steps (10배 개선) | 25K steps |
| Humanoid | 1M steps | 100K steps | 80K steps |
| Atari (100k) | 낮은 성능 | 중간 성능 | SOTA 성능 |
Dreamer v3는 Atari 100k 벤치마크에서 인간 수준의 52% 게임에서 인간을 능가하며, 10만 프레임만으로 학습합니다.
실무 적용 시 고려사항
1. 모델 오차 관리
- 앙상블 사용: 단일 모델보다 5~7개 앙상블이 안정적
- 짧은 롤아웃: k=1~5 step으로 제한
- 모델 재학습: 새로운 데이터마다 모델 업데이트
2. 계산 비용
- Dreamer v3: GPU 필수 (이미지 인코딩 + RSSM)
- MBPO: CPU로도 가능하지만 앙상블 학습 시간 소요
- 하이브리드 접근: 초기에는 MBRL로 빠르게 학습, 후반에는 Model-Free로 미세조정
3. 적용 도메인
| 도메인 | 추천 알고리즘 | 이유 |
|---|---|---|
| 로봇 제어 | MBPO | 실제 로봇 상호작용 비용 높음 |
| 게임 (이미지) | Dreamer v3 | 시각 입력에 최적화 |
| 산업 공정 | MBPO | 물리 법칙 기반 모델링 가능 |
| 자율주행 시뮬 | Dreamer v3 | 복잡한 시각 환경 |
시작하기: 코드 예시
Dreamer v3 (DreamerV3 라이브러리)
pip install dreamerv3
import dreamerv3
import gymnasium as gym
env = gym.make('HalfCheetah-v4')
config = dreamerv3.Config('small') # small/medium/large
agent = dreamerv3.Agent(env.observation_space, env.action_space, config)
for episode in range(100):
obs, _ = env.reset()
done = False
while not done:
action = agent.policy(obs)
obs, reward, done, truncated, info = env.step(action)
agent.observe(obs, reward, done)
agent.train_step()
MBPO (mbrl-lib)
pip install mbrl
import mbrl.algorithms.mbpo as mbpo
import gymnasium as gym
env = gym.make('Walker2d-v4')
cfg = mbpo.MBPOConfig(
ensemble_size=7,
rollout_length=1,
num_epochs=100
)
agent = mbpo.MBPO(env, cfg)
agent.train()
마무리
World Model 기반 강화학습은 샘플 효율이 중요한 실세계 문제에서 필수적인 기술입니다.
핵심 요약:
– Dreamer v3: 잠재 공간에서 상상을 통해 학습, 이미지 환경에 강점
– MBPO: 짧은 롤아웃으로 모델 오차 관리, 로봇 제어에 적합
– 샘플 효율: Model-Free 대비 10배 이상 개선 가능
– 트레이드오프: 계산 비용 증가, 모델 오차 관리 필요
– 실전 팁: 앙상블 + 짧은 롤아웃 + 하이브리드 접근
실제 환경에서 데이터 수집이 비싸다면(로봇, 산업 공정, 실험 등), MBRL은 더 이상 선택이 아닌 필수입니다. Dreamer v3와 MBPO로 시작해보세요!
Did you find this helpful?
☕ Buy me a coffee
Leave a Reply