반응형
Notice
Recent Posts
Recent Comments
Link
일 | 월 | 화 | 수 | 목 | 금 | 토 |
---|---|---|---|---|---|---|
1 | 2 | 3 | 4 | 5 | ||
6 | 7 | 8 | 9 | 10 | 11 | 12 |
13 | 14 | 15 | 16 | 17 | 18 | 19 |
20 | 21 | 22 | 23 | 24 | 25 | 26 |
27 | 28 | 29 | 30 | 31 |
Tags
- python
- 전국국밥
- 카트폴
- 데이터분석
- TeachagleMachine
- selenium
- Ros
- redux
- 강화학습
- 사이드프로젝트
- App
- Instagrame clone
- GYM
- ReactNative
- Reinforcement Learning
- expo
- 머신러닝
- pandas
- FirebaseV9
- 조코딩
- coding
- JavaScript
- 클론코딩
- 딥러닝
- clone coding
- 앱개발
- 리액트네이티브
- kaggle
- React
- 강화학습 기초
Archives
- Today
- Total
qcoding
[강화학습-8] 딥 강화학습(Deep RL) 기초 – DQN 본문
반응형
8. 딥 강화학습(Deep RL) 기초 – DQN
2015년 DeepMind의 DQN(Deep Q-Network)은 픽셀 입력만으로 Atari 2600 게임들을 사람 수준으로 플레이하며
“딥러닝 + 강화학습” 시대를 열었습니다. 핵심은 표준 Q-Learning에 두 가지 안정화 기법을 더한 것입니다.
8-1. DQN 아키텍처
블록 | 설명 |
---|---|
입력 | state $s$ (원본 픽셀·연속 벡터 모두 가능) |
피처 추출 CNN / MLP | 2~3 Conv + ReLU + FC (Atari) 또는 2~3 FC (CartPole 등 저차원) |
출력층 | $\hat Q(s,a;\theta)$ – 각 행동에 대한 Q-값 |
$$L(\theta)=\mathbb{E}_{(s,a,r,s')\sim\mathcal{D}} \Bigl[\bigl( r+\gamma\max_{a'}\hat Q(s',a';\theta^{-}) -\hat Q(s,a;\theta) \bigr)^{2}\Bigr]$$
- $\theta$ : 온라인(학습) 네트워크 파라미터
- $\theta^{-}$ : 타깃 네트워크 파라미터(주기적 고정)
- $\mathcal{D}$ : 경험 재플레이(Replay Buffer)에서 샘플링한 미니배치
8-2. 핵심 안정화 기법
기법 | 아이디어 | 효과 |
---|---|---|
Experience Replay (재플레이 버퍼) |
최근 경험 $(s,a,r,s')$ 를 큰 버퍼 𝔻에 저장 ⟶ 학습 시 무작위 미니배치 추출 |
① 데이터 상관성 ↓ ② 샘플 효율 ↑ ③ 배치 학습(벡터화) 가능 |
Target Network (타깃 네트워크) |
Q̂ 값 계산용 네트워크(파라미터 $\theta^{-}$)를 주기 $C$ 스텝마다 온라인 $\theta$로 하드 복사 |
목표값 변화 완만 ⟶ 학습 안정성 ↑ |
두 기법이 없으면 비선형 함수근사 + 부트스트랩 조합이 발산하기 쉽습니다.
훗날 나온 DDQN, Dueling DQN, PER(우선순위 버퍼) 등도 이 틀 위에서 개선되었습니다.
8-3. 구현 실습 – 간단한 DQN으로 CartPole 학습
PyTorch 2.x 기반 200줄 미만 코드. (구현 간결성을 위해 PER·DDQN 생략)
"""
pip install gymnasium torch numpy
"""
import gymnasium as gym
import numpy as np
import torch, torch.nn as nn, torch.optim as optim
from collections import deque, namedtuple
import random
# --- Hyperparameters ----------------------------------------------------
ENV_ID = "CartPole-v1"
EPISODES = 500
GAMMA = 0.99
LR = 1e-3
BATCH_SIZE = 64
REPLAY_SIZE = 50_000
START_EPS = 1.0
END_EPS = 0.05
EPS_DECAY = 500 # decay episodes
TARGET_SYNC = 10 # episodes
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# --- Environment --------------------------------------------------------
env = gym.make(ENV_ID)
n_state = env.observation_space.shape[0]
n_action = env.action_space.n
print("state:", n_state, "actions:", n_action)
# --- Network ------------------------------------------------------------
class DQN(nn.Module):
def __init__(self):
super().__init__()
self.net = nn.Sequential(
nn.Linear(n_state, 128), nn.ReLU(),
nn.Linear(128, 128), nn.ReLU(),
nn.Linear(128, n_action)
)
def forward(self, x):
return self.net(x)
# main / target
policy_net = DQN().to(DEVICE)
target_net = DQN().to(DEVICE)
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()
optimizer = optim.Adam(policy_net.parameters(), lr=LR)
# --- Replay Buffer ------------------------------------------------------
Transition = namedtuple("Transition", "s a r s_ done")
buffer = deque(maxlen=REPLAY_SIZE)
def push(*args):
buffer.append(Transition(*args))
def sample(batch):
t = random.sample(buffer, batch)
s = torch.tensor([tr.s for tr in t], dtype=torch.float32, device=DEVICE)
a = torch.tensor([tr.a for tr in t], dtype=torch.int64, device=DEVICE).unsqueeze(1)
r = torch.tensor([tr.r for tr in t], dtype=torch.float32, device=DEVICE).unsqueeze(1)
s_ = torch.tensor([tr.s_ for tr in t], dtype=torch.float32, device=DEVICE)
d = torch.tensor([tr.done for tr in t], dtype=torch.float32, device=DEVICE).unsqueeze(1)
return s, a, r, s_, d
# --- ε-greedy -----------------------------------------------------------
def epsilon_by_episode(ep):
eps = END_EPS + (START_EPS - END_EPS) * np.exp(-ep / EPS_DECAY)
return eps
def select_action(state, eps):
if random.random() < eps:
return env.action_space.sample()
with torch.no_grad():
state_v = torch.tensor(state, dtype=torch.float32, device=DEVICE)
q_values = policy_net(state_v)
return int(torch.argmax(q_values).item())
# --- Training Loop ------------------------------------------------------
for ep in range(1, EPISODES+1):
state, _ = env.reset()
ep_reward = 0
done = False
while not done:
action = select_action(state, epsilon_by_episode(ep))
next_state, reward, terminated, truncated, _ = env.step(action)
done = terminated or truncated
push(state, action, reward, next_state, done)
state = next_state
ep_reward += reward
# 학습 스텝
if len(buffer) >= BATCH_SIZE:
s, a, r, s_, d = sample(BATCH_SIZE)
q_vals = policy_net(s).gather(1, a)
with torch.no_grad():
q_next = target_net(s_).max(1, keepdim=True).values
target = r + GAMMA * q_next * (1 - d)
loss = nn.functional.mse_loss(q_vals, target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 타깃 네트워크 동기화
if ep % TARGET_SYNC == 0:
target_net.load_state_dict(policy_net.state_dict())
if ep % 10 == 0:
print(f"Ep {ep:4d} | R = {ep_reward:4.0f} | ε = {epsilon_by_episode(ep):.3f}")
print("학습 완료!")
env.close()
실행 결과 예시
Ep 10 | R = 35 | ε = 0.905
...
Ep 200 | R = 500 | ε = 0.083
Ep 500 | R = 500 | ε = 0.050
- CartPole 성공 기준:
avg_reward ≥ 475
(500 만점) over 100 episodes. - GPU가 없어도 MLPCartPole 4-입력을 2-레이어 FC 128×128이면 수초 내 수렴.
- 더 어려운 환경은 DDQN, Dueling DQN, PER, NoisyNet 등을 이어 붙여 성능 향상.
8-4. 요약 & 다음 편 예고
- DQN은 Replay Buffer + Target Network로 Q-Learning 발산 문제를 해결.
- MLP/CNN을 함수 근사기로 활용해 고차원 관측도 직접 처리.
- CartPole 같은 저차원 태스크는 ~200 에피소드 내에 안정적 수렴.
다음 글 : 정책 자체를 근사하는 정책 기반 방법과
Advantage를 사용하는 Actor-Critic 아키텍처(A2C·A3C·PPO)를 다룹니다.
참고 자료
- Mnih et al., “Human-level control through deep reinforcement learning,” Nature, 2015
- Sutton & Barto, Reinforcement Learning: An Introduction, Ch. 11
- “Playing Atari with Deep Reinforcement Learning” (2013) – DQN 원논문 arXiv:1312.5602
반응형
'머신러닝 딥러닝' 카테고리의 다른 글
[강화학습-10] 액터-크리틱(Actor-Critic) (0) | 2025.05.28 |
---|---|
[강화학습-9] 정책 기반 방법 (Policy Gradient) – REINFORCE (0) | 2025.05.28 |
[강화학습-7] 함수 근사 (Function Approximation) (0) | 2025.05.28 |
[강화학습-6] 시간차 학습 (Temporal-Difference, TD) (0) | 2025.05.28 |
[강화학습-5] 몬테카를로 방법 (Monte Carlo Methods) (0) | 2025.05.28 |
Comments