RuntimeError: «элемент 0 тензоров не требует градации» при использовании пользовательской функции автоградации с create_Python

Программы на Python
Ответить
Anonymous
 RuntimeError: «элемент 0 тензоров не требует градации» при использовании пользовательской функции автоградации с create_

Сообщение Anonymous »

Я реализую специальную функцию активации (вариант Swish) в PyTorch для оптимизации использования памяти. Я реализовал его с помощью torch.autograd.Function, определив как прямые, так и обратные статические методы.
Стандартное обучение (производные первого порядка) работает отлично. Однако сейчас я пытаюсь использовать этот пользовательский слой в настройке WGAN-GP (Wasserstein GAN с градиентным штрафом). Для этого требуется вычислить градиент градиентов (двойное обратное свойство) с помощью torch.autograd.grad(..., create_graph=True).
Как только я включаю create_graph=True, я получаю следующую ошибку: RuntimeError: элемент 0 тензоров не требует grad и не имеет grad_fn.
Что я пробовал:
  • I проверил, что x.requires_grad имеет значение True.
  • Я попробовал заменить свою пользовательскую функцию на torch.nn.SiLU() (встроенную в
    Swish), и двойная обратная связь работает отлично, так что проблема
    определенно в моем классе CustomSwish.
  • Я удалил ctx.save_for_backward и пересчитал вводит данные в обратном направлении,
    но ошибка сохраняется.
Вопрос: Почему моя пользовательская функция autograd.Function нарушает граф вычислений во время второго обратного прохода, хотя я использую дифференцируемые операции PyTorch внутри в обратном направлении? Как правильно реализовать пользовательскую функцию, поддерживающую производные более высокого порядка?

Подробнее здесь: https://stackoverflow.com/questions/798 ... custom-aut
Ответить

Быстрый ответ

Изменение регистра текста: 
Смайлики
:) :( :oops: :roll: :wink: :muza: :clever: :sorry: :angel: :read: *x)
Ещё смайлики…
   
К этому ответу прикреплено по крайней мере одно вложение.

Если вы не хотите добавлять вложения, оставьте поля пустыми.

Максимально разрешённый размер вложения: 15 МБ.

Вернуться в «Python»