Понимание градиентов Pytorch и обратной функции, когда он обратно не раз в обратном направленииPython

Программы на Python
Ответить Пред. темаСлед. тема
Anonymous
 Понимание градиентов Pytorch и обратной функции, когда он обратно не раз в обратном направлении

Сообщение Anonymous »

Я пытаюсь добавить более одного шага обучения генератора за цикл в GAN, то есть я хочу, чтобы мой генератор обновлял его параметры n раз каждый m обновления дискриминатора, где n> m .

Код: Выделить всё

for epoch in range(num_epochs):
for batch_idx, (real, _) in enumerate(loader):
real = real.view(-1, 784).to(device)
batch_size = real.shape[0]

# Training Generator
for i in range(gen_advantage):
noise = torch.randn(batch_size, z_dim).to(device)
fake = gen(noise)
output = disc(fake).view(-1)
lossG = criterion(output, torch.ones_like(output))
lossG.backward()
opt_gen.step()
gen.zero_grad()

# Training Discriminator
for i in range(disc_advantage):
disc_real = disc(real).view(-1)
lossD_real = criterion(disc_real, torch.ones_like(disc_real))
disc_fake = disc(fake).view(-1)
lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
lossD = (lossD_real + lossD_fake) * 0.5
lossD.backward() # Breaks here
opt_disc.step()
disc.zero_grad()

Для контекста, критерий - это bceloss , opt_gen и opt_disc оптимальные. Adam , Disc и Gen являются дискриминатором и генераторами и изображениями в кодовом коде> 28x28. Lossd.backward () Line, даже если disc_advantage == 1 :

Код: Выделить всё

RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.
I can't get why, since in my understanding, I'm neither accessing freed tensors nor backwarding the lossD multiple times.
Anyhow, i tried as suggested to put retain_graph=True in the lossG.backward() line (in the generator loop), but it throws another different error:

Код: Выделить всё

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [512, 784]], which is output 0 of AsStridedBackward0, is at version 15; expected version 14 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).
, который я действительно не могу понять, так как ошибка выбрасывается в той же строке, что и раньше, то есть hourdd.backward () .
это все. Я попытался выяснить это в одиночку, соскребая сеть для объяснений того, как работают градиенты питорха, но я нашел только некоторые теоретические статьи о том, как вычисляются градиенты, что, хотя и интересно, а не то, что мне нужно.
Так что помогите.

Подробнее здесь: https://stackoverflow.com/questions/796 ... g-more-tha
Реклама
Ответить Пред. темаСлед. тема

Быстрый ответ

Изменение регистра текста: 
Смайлики
:) :( :oops: :roll: :wink: :muza: :clever: :sorry: :angel: :read: *x)
Ещё смайлики…
   
К этому ответу прикреплено по крайней мере одно вложение.

Если вы не хотите добавлять вложения, оставьте поля пустыми.

Максимально разрешённый размер вложения: 15 МБ.

  • Похожие темы
    Ответы
    Просмотры
    Последнее сообщение

Вернуться в «Python»