IndexError в SBX (Stable Baselines 3) с Flax: «индекс кортежа вне диапазона» во время инициализации сети субъектовPython

Программы на Python
Ответить Пред. темаСлед. тема
Anonymous
 IndexError в SBX (Stable Baselines 3) с Flax: «индекс кортежа вне диапазона» во время инициализации сети субъектов

Сообщение Anonymous »

Ранее я реализовал SAC со стабильными базовыми линиями3 в специальной среде Gymnasium, и это сработало. Теперь я пытаюсь использовать JAX стабильного базового уровня3 (SBX) в той же среде, но сталкиваюсь с этой ошибкой во время инициализации модели SAC:

Код: Выделить всё

"/workspaces/ros2_ws_humble/src/rl_node/training_loops_method1/run_method1.py", line 158, in run_test
model = SAC(
File "/usr/local/lib/python3.10/dist-packages/sbx/sac/sac.py", line 112, in __init__
self._setup_model()
File "/usr/local/lib/python3.10/dist-packages/sbx/sac/sac.py", line 127, in _setup_model
self.key = self.policy.build(self.key, self.lr_schedule, self.qf_learning_rate)
File "/usr/local/lib/python3.10/dist-packages/sbx/sac/policies.py", line 120, in build#
params=self.actor.init(actor_key, obs),
File "/usr/local/lib/python3.10/dist-packages/sbx/sac/policies.py", line 35, in __call__
x = nn.Dense(n_units)(x)
File "/usr/local/lib/python3.10/dist-packages/flax/linen/linear.py", line 237, in __call__
(jnp.shape(inputs)[-1], self.features),
IndexError: tuple index out of range
Здесь я инициализирую модель SAC:

Код: Выделить всё

def run_test(config_, rl_node, run_count):
np.seterr(all='raise')
th.autograd.set_detect_anomaly(True)
env = None
mode = None
env = make_env(config_, rl_node)
env = RecordEpisodeStatistics(env, buffer_length=100)
env = DummyVecEnv([lambda: env])
model_name = "Agent_Long_absolut_05"
policy_kwargs = {"activation_fn": th.nn.Mish,"net_arch": {"pi":[32,32],"qf": [64, 64, 64]}}
model = SAC("MultiInputPolicy", env, learning_rate=0.01, gamma=0.8, batch_size=128, verbose=1, policy_kwargs=policy_kwargs,tensorboard_log=f"{parent_dir_path}/logs",device="cuda")

Ниже показано пространство наблюдения:

Код: Выделить всё

 obs_space = {
'obs_long_acc': gym.spaces.Box(low=-1000, high=1000, shape=(1,), dtype=np.float32),
'obs_long_jerk': gym.spaces.Box(low=-1000, high=1000, shape=(1,), dtype=np.float32),
'obs_relative_speed': gym.spaces.Box(low=-1000, high=1000, shape=(1,), dtype=np.float32),
'obs_relative_distance': gym.spaces.Box(low=-1000, high=1000, shape=(1,), dtype=np.float32),
'obs_time_elapsed' : gym.spaces.Box(low=0, high=100000, shape=(1,), dtype=np.float32),
'obs_parameters_valid' : gym.spaces.Box(low=0, high=1, shape=(1,), dtype=np.float32),
'obs_first_loop_err' : gym.spaces.Box(low=0, high=10, shape=(1,), dtype=np.float32),
'obs_second_loop_err': gym.spaces.Box(low=0, high=10, shape=(1,), dtype=np.float32),
'obs_parameter_score_1': gym.spaces.Box(low=-100, high=0, shape=(24,), dtype=np.float32),
'obs_parameter_score_2': gym.spaces.Box(low=-100, high=0, shape=(25,), dtype=np.float32),
'obs_crash': gym.spaces.Box(low=0, high=1, shape=(1,), dtype=np.float32),
}

self.observation_space = gym.spaces.Dict(obs_space)
Что может быть причиной ошибки IndexError в этом случае? Существует ли несоответствие между пространством наблюдения и архитектурой политики в SBX? Я некоторое время застрял в этой проблеме и был бы очень признателен за любые советы по ее решению. Спасибо.
Я попробовал использовать операцию выравнивания, чтобы обеспечить правильное выравнивание пространства наблюдения, но ошибка по-прежнему сохраняется во время инициализации.

Подробнее здесь: https://stackoverflow.com/questions/793 ... -range-dur
Реклама
Ответить Пред. темаСлед. тема

Быстрый ответ

Изменение регистра текста: 
Смайлики
:) :( :oops: :roll: :wink: :muza: :clever: :sorry: :angel: :read: *x)
Ещё смайлики…
   
К этому ответу прикреплено по крайней мере одно вложение.

Если вы не хотите добавлять вложения, оставьте поля пустыми.

Максимально разрешённый размер вложения: 15 МБ.

  • Похожие темы
    Ответы
    Просмотры
    Последнее сообщение

Вернуться в «Python»