Интеграция GNN с PPO для управления роботами в Ant-v4 от GymnasiumPython

Программы на Python
Ответить Пред. темаСлед. тема
Anonymous
 Интеграция GNN с PPO для управления роботами в Ant-v4 от Gymnasium

Сообщение Anonymous »

В настоящее время я изучаю интеграцию графовых нейронных сетей (GNN) и обучения с подкреплением (RL) с использованием стабильных_базовых линий3 и Gymnasium, специально для управления роботами, вдохновленного подходом в NerveNet. В моем проекте используется среда Ant-v4, в которой я пытаюсь смоделировать среду в виде графа, где каждый узел представляет собой шарнир в структуре робота.
Я решил использовать алгоритм оптимизации проксимальной политики (PPO) для оптимизации процесса RL. Несмотря на несколько недель усилий, мне так и не удалось успешно объединить эти концепции. Основная задача, по-видимому, заключается в эффективной интеграции графовой модели с алгоритмом RL для улучшения обучения.
Может ли кто-нибудь поделиться информацией или опытом по следующим вопросам:
  • Лучшие практики моделирования физической среды (например, робота Ant) в виде графиков для использования в GNN?
  • Советы по эффективной интеграции этих представлений графов с PPO или другими алгоритмами RL?
Любые рекомендации или ссылки на подобные проекты будут очень признательны!
Заранее благодарим за помощь!
from typing import Callable, Dict, List, Optional, Tuple, Type, Union

from torch_geometric.nn import GATv2Conv

from gymnasium import spaces

import gymnasium as gym

from gymnasium import ActionWrapper, ObservationWrapper, RewardWrapper, Wrapper

import torch as th

from torch import nn

from torch_geometric.data import Data

from stable_baselines3 import PPO

from stable_baselines3.common.policies import ActorCriticPolicy

import torch.nn.functional as F

import numpy as np

from torch.distributions import Normal

edge_index = th.tensor([[0, 1], [0, 3], [0, 5], [0, 7], [1, 2], [3, 4], [5, 6], [7, 8]], dtype=th.long).t().contiguous()

class GATv2ConvWrapper(nn.Module):
def __init__(self, in_channels, out_channels, edge_index):
super(GATv2ConvWrapper, self).__init__()
self.conv = GATv2Conv(in_channels, out_channels)
self.edge_index = edge_index # Storing edge_index as an attribute

def forward(self, x):
# Use the stored edge_index for the graph convolution
return self.conv(x, self.edge_index)


class CustomNetwork(nn.Module):

def __init__(
self,
feature_dim: int = 11,
last_layer_dim_pi: int = 64,
last_layer_dim_vf: int = 64,
):
super().__init__()

# IMPORTANT:
# Save output dimensions, used to create the distributions
self.latent_dim_pi = last_layer_dim_pi
self.latent_dim_vf = last_layer_dim_vf

# Policy Network

self.policy_net = nn.Sequential(
GATv2ConvWrapper(feature_dim, last_layer_dim_pi, edge_index),
nn.BatchNorm1d(64),
GATv2ConvWrapper(last_layer_dim_pi, last_layer_dim_pi, edge_index),
nn.BatchNorm1d(64),
)

# Value network components
self.value_net = nn.Sequential(
GATv2ConvWrapper(feature_dim, last_layer_dim_pi, edge_index),
nn.BatchNorm1d(64),
GATv2ConvWrapper(last_layer_dim_pi, last_layer_dim_pi, edge_index),
nn.BatchNorm1d(64),
)
# Log Probs
self.log_std = nn.Parameter(th.zeros(1, 8)) # Log standard deviations for actions

def forward(self, features):
# Compute policy (mean action values) and sample actions
actions = self.forward_actor(features).mean(dim=1)
values = self.forward_critic(features).mean(dim=1)
print("ACtions Mean: ", actions)
return actions, values

def forward_actor(self, features: Data):
print("Actore giren:", features)
x, edge_index = features.x, features.edge_index
return self.policy_net(x)

def forward_critic(self, features: Data):
print("Critice giren:", features)
x, edge_index = features.x, features.edge_index
return self.value_net(x)

def model_ant_as_graph(state: th.Tensor) -> Data:
# Check if state is a batch of observations
if state.dim() == 2:
state = state[0] # Take the first observation in the batch

node_indices = {
'torso': slice(0, 5),
'hip_1': slice(5, 6), 'ankle_1': slice(6, 7),
'hip_2': slice(7, 8), 'ankle_2': slice(8, 9),
'hip_3': slice(9, 10), 'ankle_3': slice(10, 11),
'hip_4': slice(11, 12), 'ankle_4': slice(12, 13),
}
velocity_indices = {
'torso': slice(13, 19),
'hip_1': slice(19, 20), 'ankle_1': slice(20, 21),
'hip_2': slice(21, 22), 'ankle_2': slice(22, 23),
'hip_3': slice(23, 24), 'ankle_3': slice(24, 25),
'hip_4': slice(25, 26), 'ankle_4': slice(26, 27),
}

node_features = []
for key in node_indices.keys():
combined_features = th.cat((state[node_indices[key]], state[velocity_indices[key]]))
padded_features = F.pad(combined_features, (0, 11 - combined_features.size(0)), "constant", 0)
node_features.append(padded_features)

node_features_tensor = th.stack(node_features).float()
graph = Data(x=node_features_tensor, edge_index=edge_index, num_nodes=9)
return graph

class CustomActorCriticPolicy(ActorCriticPolicy):
def __init__(
self,
observation_space: spaces.Space,
action_space: spaces.Space,
lr_schedule: Callable[[float], float],
*args,
**kwargs,
):
kwargs["ortho_init"] = False
super().__init__(
observation_space,
action_space,
lr_schedule,
# Pass remaining arguments to base class
*args,
**kwargs,
)

# In the forward method of your policy:
def forward(self, obs):

graph_data = model_ant_as_graph(obs) # Convert observations to graph data here
actions, values = self.mlp_extractor(graph_data)
print("Actions:", actions)

return actions, values

def _build_mlp_extractor(self) -> None:
self.mlp_extractor = CustomNetwork(11)
env = gym.make("Ant-v4")
model = PPO(CustomActorCriticPolicy, env, verbose=1)
print(model.policy)

Вывод оператора печати:
CustomActorCriticPolicy(
(features_extractor): FlattenExtractor(
(flatten): Flatten(start_dim=1, end_dim=-1)
)
(pi_features_extractor): FlattenExtractor(
(flatten): Flatten(start_dim=1, end_dim=-1)
)
(vf_features_extractor): FlattenExtractor(
(flatten): Flatten(start_dim=1, end_dim=-1)
)
(mlp_extractor): CustomNetwork(
(policy_net): Sequential(
(0): GATv2ConvWrapper(
(conv): GATv2Conv(11, 64, heads=1)
)
(1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): GATv2ConvWrapper(
(conv): GATv2Conv(64, 64, heads=1)
)
(3): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(value_net): Sequential(
(0): GATv2ConvWrapper(
(conv): GATv2Conv(11, 64, heads=1)
)
(1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): GATv2ConvWrapper(
(conv): GATv2Conv(64, 64, heads=1)
)
(3): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(action_net): Linear(in_features=64, out_features=8, bias=True)
(value_net): Linear(in_features=64, out_features=1, bias=True)
)


Подробнее здесь: https://stackoverflow.com/questions/787 ... ums-ant-v4
Реклама
Ответить Пред. темаСлед. тема

Быстрый ответ

Изменение регистра текста: 
Смайлики
:) :( :oops: :roll: :wink: :muza: :clever: :sorry: :angel: :read: *x)
Ещё смайлики…
   
К этому ответу прикреплено по крайней мере одно вложение.

Если вы не хотите добавлять вложения, оставьте поля пустыми.

Максимально разрешённый размер вложения: 15 МБ.

  • Похожие темы
    Ответы
    Просмотры
    Последнее сообщение
  • Как обучать GNN на огромных наборах данных?
    Гость » » в форуме Python
    0 Ответы
    33 Просмотры
    Последнее сообщение Гость
  • Gymnasium env.render() запускается, но ничего не появляется
    Гость » » в форуме Python
    0 Ответы
    25 Просмотры
    Последнее сообщение Гость
  • CUDAError: недостаточно памяти для среды RL с использованием Gymnasium.
    Anonymous » » в форуме Python
    0 Ответы
    21 Просмотры
    Последнее сообщение Anonymous
  • Ошибка пользовательской среды Gymnasium «слишком много значений для распаковки»
    Anonymous » » в форуме Python
    0 Ответы
    9 Просмотры
    Последнее сообщение Anonymous
  • Ошибка пользовательской среды Gymnasium «слишком много значений для распаковки»
    Anonymous » » в форуме Python
    0 Ответы
    9 Просмотры
    Последнее сообщение Anonymous

Вернуться в «Python»