qcoding

[강화학습-8] 딥 강화학습(Deep RL) 기초 – DQN 본문

머신러닝 딥러닝

[강화학습-8] 딥 강화학습(Deep RL) 기초 – DQN

Qcoding 2025. 5. 28. 17:23
반응형
8. 딥 강화학습(Deep RL) 기초 – DQN

8. 딥 강화학습(Deep RL) 기초 – DQN

2015년 DeepMind의 DQN(Deep Q-Network)은 픽셀 입력만으로 Atari 2600 게임들을 사람 수준으로 플레이하며
“딥러닝 + 강화학습” 시대를 열었습니다. 핵심은 표준 Q-Learning두 가지 안정화 기법을 더한 것입니다.


8-1. DQN 아키텍처

블록설명
입력state $s$ (원본 픽셀·연속 벡터 모두 가능)
피처 추출 CNN / MLP2~3 Conv + ReLU + FC (Atari)
또는 2~3 FC (CartPole 등 저차원)
출력층$\hat Q(s,a;\theta)$ – 각 행동에 대한 Q-값

$$L(\theta)=\mathbb{E}_{(s,a,r,s')\sim\mathcal{D}} \Bigl[\bigl( r+\gamma\max_{a'}\hat Q(s',a';\theta^{-}) -\hat Q(s,a;\theta) \bigr)^{2}\Bigr]$$

  • $\theta$ : 온라인(학습) 네트워크 파라미터
  • $\theta^{-}$ : 타깃 네트워크 파라미터(주기적 고정)
  • $\mathcal{D}$ : 경험 재플레이(Replay Buffer)에서 샘플링한 미니배치

8-2. 핵심 안정화 기법

기법아이디어효과
Experience Replay
(재플레이 버퍼)
최근 경험 $(s,a,r,s')$ 를 큰 버퍼 𝔻에 저장 ⟶
학습 시 무작위 미니배치 추출
① 데이터 상관성 ↓
② 샘플 효율 ↑
③ 배치 학습(벡터화) 가능
Target Network
(타깃 네트워크)
Q̂ 값 계산용 네트워크(파라미터 $\theta^{-}$)를
주기 $C$ 스텝마다 온라인 $\theta$로 하드 복사
목표값 변화 완만 ⟶ 학습 안정성 ↑

두 기법이 없으면 비선형 함수근사 + 부트스트랩 조합이 발산하기 쉽습니다.
훗날 나온 DDQN, Dueling DQN, PER(우선순위 버퍼) 등도 이 틀 위에서 개선되었습니다.


8-3. 구현 실습 – 간단한 DQN으로 CartPole 학습

PyTorch 2.x 기반 200줄 미만 코드. (구현 간결성을 위해 PER·DDQN 생략)

"""
pip install gymnasium torch numpy
"""
import gymnasium as gym
import numpy as np
import torch, torch.nn as nn, torch.optim as optim
from collections import deque, namedtuple
import random

# --- Hyperparameters ----------------------------------------------------
ENV_ID       = "CartPole-v1"
EPISODES     = 500
GAMMA        = 0.99
LR           = 1e-3
BATCH_SIZE   = 64
REPLAY_SIZE  = 50_000
START_EPS    = 1.0
END_EPS      = 0.05
EPS_DECAY    = 500     # decay episodes
TARGET_SYNC  = 10      # episodes
DEVICE       = "cuda" if torch.cuda.is_available() else "cpu"

# --- Environment --------------------------------------------------------
env = gym.make(ENV_ID)
n_state  = env.observation_space.shape[0]
n_action = env.action_space.n
print("state:", n_state, "actions:", n_action)

# --- Network ------------------------------------------------------------
class DQN(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_state, 128), nn.ReLU(),
            nn.Linear(128, 128),      nn.ReLU(),
            nn.Linear(128, n_action)
        )
    def forward(self, x):
        return self.net(x)

# main / target
policy_net = DQN().to(DEVICE)
target_net = DQN().to(DEVICE)
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()

optimizer = optim.Adam(policy_net.parameters(), lr=LR)

# --- Replay Buffer ------------------------------------------------------
Transition = namedtuple("Transition", "s a r s_ done")
buffer = deque(maxlen=REPLAY_SIZE)

def push(*args):
    buffer.append(Transition(*args))
def sample(batch):
    t = random.sample(buffer, batch)
    s  = torch.tensor([tr.s  for tr in t], dtype=torch.float32, device=DEVICE)
    a  = torch.tensor([tr.a  for tr in t], dtype=torch.int64,   device=DEVICE).unsqueeze(1)
    r  = torch.tensor([tr.r  for tr in t], dtype=torch.float32, device=DEVICE).unsqueeze(1)
    s_ = torch.tensor([tr.s_ for tr in t], dtype=torch.float32, device=DEVICE)
    d  = torch.tensor([tr.done for tr in t], dtype=torch.float32, device=DEVICE).unsqueeze(1)
    return s, a, r, s_, d

# --- ε-greedy -----------------------------------------------------------
def epsilon_by_episode(ep):
    eps = END_EPS + (START_EPS - END_EPS) * np.exp(-ep / EPS_DECAY)
    return eps
def select_action(state, eps):
    if random.random() < eps:
        return env.action_space.sample()
    with torch.no_grad():
        state_v = torch.tensor(state, dtype=torch.float32, device=DEVICE)
        q_values = policy_net(state_v)
        return int(torch.argmax(q_values).item())

# --- Training Loop ------------------------------------------------------
for ep in range(1, EPISODES+1):
    state, _ = env.reset()
    ep_reward = 0
    done = False
    while not done:
        action = select_action(state, epsilon_by_episode(ep))
        next_state, reward, terminated, truncated, _ = env.step(action)
        done = terminated or truncated
        push(state, action, reward, next_state, done)
        state = next_state
        ep_reward += reward

        # 학습 스텝
        if len(buffer) >= BATCH_SIZE:
            s, a, r, s_, d = sample(BATCH_SIZE)
            q_vals   = policy_net(s).gather(1, a)
            with torch.no_grad():
                q_next = target_net(s_).max(1, keepdim=True).values
                target = r + GAMMA * q_next * (1 - d)
            loss = nn.functional.mse_loss(q_vals, target)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    # 타깃 네트워크 동기화
    if ep % TARGET_SYNC == 0:
        target_net.load_state_dict(policy_net.state_dict())

    if ep % 10 == 0:
        print(f"Ep {ep:4d} | R = {ep_reward:4.0f} | ε = {epsilon_by_episode(ep):.3f}")

print("학습 완료!")
env.close()

실행 결과 예시

Ep   10 | R =  35 | ε = 0.905
...
Ep  200 | R = 500 | ε = 0.083
Ep  500 | R = 500 | ε = 0.050
  • CartPole 성공 기준: avg_reward ≥ 475 (500 만점) over 100 episodes.
  • GPU가 없어도 MLPCartPole 4-입력을 2-레이어 FC 128×128이면 수초 내 수렴.
  • 더 어려운 환경은 DDQN, Dueling DQN, PER, NoisyNet 등을 이어 붙여 성능 향상.

8-4. 요약 & 다음 편 예고

  • DQN은 Replay Buffer + Target Network로 Q-Learning 발산 문제를 해결.
  • MLP/CNN을 함수 근사기로 활용해 고차원 관측도 직접 처리.
  • CartPole 같은 저차원 태스크는 ~200 에피소드 내에 안정적 수렴.

다음 글 : 정책 자체를 근사하는 정책 기반 방법
Advantage를 사용하는 Actor-Critic 아키텍처(A2C·A3C·PPO)를 다룹니다.


참고 자료

  • Mnih et al., “Human-level control through deep reinforcement learning,” Nature, 2015
  • Sutton & Barto, Reinforcement Learning: An Introduction, Ch. 11
  • “Playing Atari with Deep Reinforcement Learning” (2013) – DQN 원논문 arXiv:1312.5602
반응형
Comments