Я хочу:
- Когда я обновляю политику (актёр), я не хочу, чтобы параметры Q-сетей (критики) менялись.
- Однако мне все еще нужен градиент сигнал от Q-сетей, чтобы правильно обновить политику.
with torch.no_grad():
q_pi_1 = self.q1(obs, new_action)
q_pi_2 = self.q2(obs, new_action)
потому что я думал, что это предотвратит влияние Q-сетей на обновление политики.
Но теперь я понимаю, что это может быть неправильно, потому что torch.no_grad() отключает все отслеживание градиента, включая градиент Q(s,a) относительно действия a.
И этот градиент — именно то, что нужно политике во время обновления субъекта SAC.
/>Я прав или нет?
Мое решение этой проблемы (я не знаю, действительно ли это проблема) будет следующим:
с этим решением я думаю, что параметры не изменятся, но градиент все равно будет доступен.
Или я неправильно понимаю, как здесь следует обрабатывать градиенты?
for p in self.q1.parameters():
p.requires_grad = False
for p in self.q2.parameters():
p.requires_grad = False
# actor loss (Policy) uses q1/q2 forward pass here (without torch.no_grad)
for p in self.q1.parameters():
p.requires_grad = True
for p in self.q2.parameters():
p.requires_grad = True
Подробнее здесь: https://stackoverflow.com/questions/798 ... ementation
Мобильная версия