Я обучаю агента PPO на flappy_bird_gymnasium.
Настройка
- Алгоритм: PPO (stable-baselines3)
- Среда: FlappyBird-v0
- Два варианта наблюдения:
Случай A: Высокоразмерный лидар
- Наблюдение в стиле лидара с частотой 180–2000 лучей
- На основе любого спортзала Приведение лучей на основе лидара или CV
- Результат: агент может достичь очень высоких результатов (от сотен до тысяч)
Случай Б: Низкомерный
12-мерный вектор:- положение птицы по оси y
- Вертикальная скорость птицы
- Положения трубы x
- Центры зазоров
- Размеры зазоров
- Относительные смещения (птица y − зазор y)
Проблема
При 12D-наблюдении:
- PPO постоянно стабилизируется в районе значения ~80–120.
- Обучение стабильно (нет коллапса, нет NaN)
- Увеличение:
размера сети - шагов обучения
- гаммы (до 0,999)
- формирование вознаграждения
не значительно улучшается производительность
С
- Агент работает намного лучше, но поведение кажется хрупким и не соответствует визуальным эффектам.
Почему PPO работает намного хуже с компактным, полностью наблюдаемым 12D-состоянием по сравнению с многомерный входной сигнал лидара?
В частности:
- Является ли это проблемой частичной наблюдаемости/временного кредитования?
- Требует ли PPO явных функций, связанных со временем (например, время до следующего-
- Почему может высокомерное наблюдение неявно решает эту проблему, а низкомерное - не может?
import mss
import cv2
import numpy as np
import pydirectinput
import time
from stable_baselines3 import PPO
# ==========================
# 1. 遊戲視窗截圖區域設定
# ==========================
REGION = {"top": 50, "left": 100, "width": 500, "height": 800} # 遊戲實際視窗位置
SCREEN_W = REGION["width"]
SCREEN_H = REGION["height"]
# ==========================
# 2. 載入 PPO 模型
# ==========================
model = PPO.load("ppo_flappy_cv12_final") # 之前訓練的 12 維模型
# ==========================
# 3. CV 偵測參數
# ==========================
bird_tpl = cv2.imread("bird.png", 0)
pipe_upper_tpl = cv2.imread("pipe_upper.png", 0)
pipe_lower_tpl = cv2.imread("pipe_lower.png", 0)
pydirectinput.PAUSE = 0 # 移除延遲
last_bird_y = REGION["height"] / 2 # 初始鳥 y
last_pipe1_x = 0
last_pipe2_x = 0
# ==========================
# 4. 影像處理函式 (生成 12 維 observation)
# ==========================
def get_cv12_obs(frame, last_bird_y, last_pipe1_x, last_pipe2_x):
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
# --- 偵測鳥 ---
res_bird = cv2.matchTemplate(gray, bird_tpl, cv2.TM_CCOEFF_NORMED)
_, max_val_b, _, max_loc_b = cv2.minMaxLoc(res_bird)
if max_val_b < 0.4:
return None, last_bird_y, last_pipe1_x, last_pipe2_x
bird_y = max_loc_b[1] + bird_tpl.shape[0] / 2
bird_vy = bird_y - last_bird_y
# --- 偵測水管(只取前兩根) ---
res_up = cv2.matchTemplate(gray, pipe_upper_tpl, cv2.TM_CCOEFF_NORMED)
res_lo = cv2.matchTemplate(gray, pipe_lower_tpl, cv2.TM_CCOEFF_NORMED)
_, max_val_u, _, max_loc_u = cv2.minMaxLoc(res_up)
_, max_val_l, _, max_loc_l = cv2.minMaxLoc(res_lo)
if max_val_u < 0.4 or max_val_l < 0.4:
return None, last_bird_y, last_pipe1_x, last_pipe2_x
pipe1_x = max_loc_u[0]
gap1_top = max_loc_u[1] + pipe_upper_tpl.shape[0]
gap1_bottom = max_loc_l[1]
# 第二根管道可以用 0 填充,如果沒偵測到
pipe2_x = last_pipe2_x
gap2_top = 0
gap2_bottom = 0
# --- 計算相對位移 ---
pipe1_vx = pipe1_x - last_pipe1_x
pipe2_vx = pipe2_x - last_pipe2_x
gap1_cy = (gap1_top + gap1_bottom) / 2
gap2_cy = (gap2_top + gap2_bottom) / 2
dy1 = bird_y - gap1_cy
dy2 = bird_y - gap2_cy
# --- 更新 last 值 ---
last_bird_y = bird_y
last_pipe1_x = pipe1_x
last_pipe2_x = pipe2_x
# --- 生成 12 維 observation ---
obs12 = np.zeros((12,), dtype=np.float32)
obs12[0] = bird_y / SCREEN_H
obs12[1] = bird_vy / 10.0
obs12[2] = pipe1_x / SCREEN_W
obs12[3] = gap1_cy / SCREEN_H
obs12[4] = (gap1_bottom - gap1_top) / SCREEN_H
obs12[5] = pipe1_vx / 10.0
obs12[6] = pipe2_x / SCREEN_W
obs12[7] = gap2_cy / SCREEN_H
obs12[8] = (gap2_bottom - gap2_top) / SCREEN_H
obs12[9] = pipe2_vx / 10.0
obs12[10] = dy1 / SCREEN_H
obs12[11] = dy2 / SCREEN_H
return obs12, last_bird_y, last_pipe1_x, last_pipe2_x
# ==========================
# 5. 主循環
# ==========================
with mss.mss() as sct:
print("
while True:
screenshot = sct.grab(REGION)
frame = cv2.cvtColor(np.array(screenshot), cv2.COLOR_BGRA2BGR)
obs, last_bird_y, last_pipe1_x, last_pipe2_x = get_cv12_obs(
frame, last_bird_y, last_pipe1_x, last_pipe2_x
) if 'get_cv12_obs' in locals() else (None, last_bird_y, last_pipe1_x, last_pipe2_x)
if obs is not None:
action, _ = model.predict(obs, deterministic=True)
if action == 1:
pydirectinput.press('space')
print("
# 可選:顯示 debug 畫面
cv2.imshow("Debug", cv2.resize(frame, (480, 270)))
if cv2.waitKey(1) & 0xFF == ord('q'):
break
import gymnasium as gym
import flappy_bird_gymnasium
import numpy as np
import torch
import random
import os
from gymnasium import spaces
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.callbacks import EvalCallback, CheckpointCallback
# =========================================================
# 1. 全域隨機種子設定 (確保實驗可重複性)
# =========================================================
def set_seed(seed=42):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
# 讓 cuDNN 運算結果確定化 (會稍微犧牲一點點速度,但對比最準)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
SEED = 42
set_seed(SEED)
# =========================================================
# 2. 環境 Wrapper (修正 Reset 漏洞與獎勵縮放)
# =========================================================
class CV12FlappyEnv(gym.Wrapper):
def __init__(self, env):
super().__init__(env)
self.observation_space = spaces.Box(
low=-2.0, high=2.0, shape=(12,), dtype=np.float32
)
self.last_score = 0
def reset(self, seed=None, options=None):
# 核心修正:必須將 seed 傳遞給底層環境
obs, info = self.env.reset(seed=seed, options=options)
self.last_score = 0
return self._build_obs(obs), info
def step(self, action):
obs, _, terminated, truncated, info = self.env.step(action)
# --- 獎勵設計區 (建議將數值縮小 10 倍以提升穩定性) ---
reward = 0.01 # 生存獎勵 (原 0.05 -> 0.01)
# 死亡懲罰
if terminated or truncated:
reward = -2.0 # 原 -20.0 -> -2.0
# 過關獎勵
score = info.get("score", 0)
if score > self.last_score:
reward += 1.5 # 原 15.0 -> 1.5
self.last_score = score
# 引導獎勵:鼓勵鳥靠近水管缺口中心 (可選,有助於突破 50 步)
# bird_y = obs[9]
# gap1_cy = (obs[4] + obs[5]) / 2.0
# reward -= 0.001 * abs(bird_y - gap1_cy)
return self._build_obs(obs), reward, terminated, truncated, info
def _build_obs(self, obs):
# 官方索引對齊
p1_x, p1_y_t, p1_y_b = obs[3], obs[4], obs[5]
p2_x, p2_y_t, p2_y_b = obs[6], obs[7], obs[8]
bird_y, bird_v, bird_r = obs[9], obs[10], obs[11]
gap1_cy = (p1_y_t + p1_y_b) / 2.0
gap2_cy = (p2_y_t + p2_y_b) / 2.0
return np.array([
bird_y / 512.0,
bird_v / 10.0,
p1_x / 288.0,
gap1_cy / 512.0,
(p1_y_b - p1_y_t) / 512.0,
p2_x / 288.0,
gap2_cy / 512.0,
(p2_y_b - p2_y_t) / 512.0,
(bird_y - gap1_cy) / 512.0,
(bird_y - gap2_cy) / 512.0,
bird_r / 90.0,
(p1_x - p2_x) / 288.0
], dtype=np.float32)
# =========================================================
# 3. 訓練主程式
# =========================================================
if __name__ == "__main__":
# 建立目錄
os.makedirs("./models/best/", exist_ok=True)
os.makedirs("./tb_logs/", exist_ok=True)
env_kwargs = {"render_mode": None, "use_lidar": False}
policy_kwargs = dict(net_arch=[128, 128, 128])
# 訓練環境 (加入 Seed)
train_env = make_vec_env(
lambda: CV12FlappyEnv(gym.make("FlappyBird-v0", **env_kwargs)),
n_envs=4,
seed=SEED
)
# 評估環境
eval_env = CV12FlappyEnv(gym.make("FlappyBird-v0", **env_kwargs))
eval_env.reset(seed=SEED)
# PPO 模型設定
model = PPO(
"MlpPolicy",
train_env,
policy_kwargs=policy_kwargs,
learning_rate=2e-3, # 建議降至 3e-4,2e-3 在 RL 中極容易跑飛
n_steps=2048,
batch_size=1024,
gamma=0.99,
ent_coef=0.01, # 稍微提高探索,防止太快變成 PPO_19 的死腦筋
clip_range=0.3, # 標準 PPO 常用 0.2
verbose=1,
seed=SEED, # 核心:固定模型種子
tensorboard_log="./tb_logs/"
)
# Callbacks
eval_callback = EvalCallback(
eval_env,
best_model_save_path="./models/best/",
log_path="./tb_logs/",
eval_freq=10000,
deterministic=True,
render=False
)
checkpoint_callback = CheckpointCallback(save_freq=50000, save_path="./models/")
print(f"
model.learn(
total_timesteps=500_000,
callback=[eval_callback, checkpoint_callback],
tb_log_name="PPO"
)
model.save("ppo_flappy_final_v1")
print("
Подробнее здесь: https://stackoverflow.com/questions/798 ... im-lidar-t
Мобильная версия