Код: Выделить всё
def make_env():
env = RacingEnvironment(render_mode = "rgb_array")
env = NormalizeReward(env, gamma = 0.99, epsilon = 1e-8)
env = DilatedFrameStack(env, num_stack = 4, dilation = 12)
return env
Вот модель, которую я использую:
Код: Выделить всё
class DQN(nn.Module):
def __init__(self, action_size):
super(DQN, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(4, 16, kernel_size=3, stride=3), # (16, 41, 41)
nn.BatchNorm2d(16),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(16, 32, kernel_size=3, stride=2), # (32, 20, 20)
nn.BatchNorm2d(32),
nn.LeakyReLU(0.2, inplace=True),
nn.Dropout2d(0.1),
nn.Conv2d(32, 64, kernel_size=3, stride=2), # (64, 9, 9)
nn.BatchNorm2d(64),
nn.LeakyReLU(0.2, inplace=True),
nn.Dropout2d(0.1),
nn.Conv2d(64, 96, kernel_size=3, stride=2), # (96, 4, 4)
nn.BatchNorm2d(96),
nn.LeakyReLU(0.2, inplace=True)
)
self.flatten = nn.Flatten()
self.mlp = nn.Sequential(
nn.Linear(96 * 4 * 4, 256),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(256, action_size)
)
def forward(self, x):
x = self.features(x)
x = self.flatten(x)
return self.mlp(x)
I Stack 4 рамки серого с помощью дилатации времени (предпринимает n шагов между каждым кадром) в качестве ввода (4, 84, 84).
< /p>
< /li>
Пространство действий дискретизируется: 0 - вперед, 1 - справа, 2 - левый < /p>
< /li>
Я использую стандартный DQN Training Setup: Br /Br /Br /Br />
Функция вознаграждения выглядит так: < /p>
def _calculate_reward(self, collision: bool) -> float:
progress_reward = self.car.checkpoint_index * 0.5
collision_penalty = -30.0 if collision else 0.0
_, distances = self.car.get_rays_and_distances(self.TRACK_BORDER_MASK)
wall_penalty = -5.0 if distances[6] or distances[0] or distances[4]
train Loop: < /p>
EPOCHS = 500_000
BATCH_SIZE = 32
agent.decay_steps = 350_000
state = env.reset()
progress_bar = trange(0, EPOCHS)
for i in progress_bar:
agent.update_epsilon_value(i)
_, state = play_and_record(state, agent, env, n_steps = 1)
loss, grad_norm = agent.replay(BATCH_SIZE)
if i % 50 == 0:
td_loss_history.append(loss)
grad_norm_history.append(grad_norm)
if i % 10_000 == 0:
agent.synchronize()
if i % 1000 == 0:
torch.save(agent.Q_model.state_dict(), save_path)
mean_rw_history.append(evaluate(make_env(), agent, n_games = 2, greedy = True, t_max = 500))
clear_output(True)
print("Buffer size = %i, epsilon = %.5f" % (len(agent.memory), agent.epsilon))
plt.figure(figsize = [15, 4])
plt.subplot(1, 3, 1)
plt.title("Mean reward per game")
plt.plot(mean_rw_history, color = 'dodgerblue')
plt.grid(color = 'black', ls = '--', alpha = 0.5)
assert not np.isnan(td_loss_history[-1])
plt.subplot(1, 3, 2)
plt.title("TD loss history")
plt.plot(smoothen(td_loss_history), color = 'crimson')
plt.grid(color = 'black', ls = '--', alpha = 0.5)
plt.subplot(1, 3, 3)
plt.title("Grad norm history")
plt.plot(smoothen(grad_norm_history), color = 'lime')
plt.grid(color = 'black', ls = '--', alpha = 0.5)
plt.tight_layout()
plt.show()
< /code>
Другие функции: < /p>
def play_and_record(initial_state, agent, env, n_steps = 1):
"""Play the game for exactly n steps, record every (s, a, r, s', done) to replay buffer. Return sum of rewards over time and the state in which the env stays."""
s = initial_state
sum_rewards = 0.0
for _ in range(n_steps):
a = agent.get_action(s)
next_s, r, terminated, truncated, _ = env.step(a)
done = (terminated or truncated)
agent.remember(s, a, r, next_s, done)
sum_rewards += r
s = env.reset() if done else next_s
return sum_rewards, s
def evaluate(env, agent, n_games = 1, greedy = False, t_max = 10_000):
"""Plays n_games full games. If greedy, picks best actions as argmax(qvalues). Returns mean reward."""
rewards = []
for _ in range(n_games):
s = env.reset()
s = s / 255.0
reward = 0
for _ in range(t_max):
action = agent.get_best_action(s) if greedy else agent.get_action(s)
s, r, done, truncated, _ = env.step(action)
s = s / 255.0
reward += r
if done or truncated:
break
rewards.append(reward)
return np.mean(rewards)
< /code>
Что здесь может пойти не так? Есть ли в моей модели архитектурные проблемы для этой задачи? Или есть известные проблемы при использовании DQN для Carrighting, которые я должен учесть? src = "https://i.sstatic.net/ajjyezw8.png"/>
Подробнее здесь: https://stackoverflow.com/questions/796 ... g-progress