Функция потерь StyleGAN почти равна 0 в черно-белых изображенияхPython

Программы на Python
Ответить
Anonymous
 Функция потерь StyleGAN почти равна 0 в черно-белых изображениях

Сообщение Anonymous »

Я обучаю StyleGAN для проекта, в котором хочу создавать черно-белые изображения для увеличения данных. У меня недостаточно знаний о StyleGAN, поэтому я ищу пример в Интернете в PyTorch.
Я нашел этот: https://www.kaggle.com/code/tauilabdelilah/stylegan -implementation-from-scratch-pytorch/notebook
Я провел несколько тестов с кодом, и кажется, что он генерирует хорошие изображения. Проблема в том, что график функции потерь кажется странным. Потери дискриминатора и генератора достигают почти 0 в определенную эпоху и остаются стабильными на этом значении.
[img]https://i.sstatic .net/LS5QL6dr.png[/img]

Вот как я определяю функцию обучения и штраф за градиент:

Код: Выделить всё

def gradient_penalty(dis, real, fake, alpha, train_step):
BATCH_SIZE, C, H, W = real.shape

beta = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
interpolated_images = real * beta + fake.detach() * (1 - beta)
interpolated_images.requires_grad_(True)

# Calculate discriminator scores
mixed_scores = dis(interpolated_images, alpha, train_step)

# Take the gradient of the scores with respect to the images
gradient = torch.autograd.grad(
inputs=interpolated_images,
outputs=mixed_scores,
grad_outputs=torch.ones_like(mixed_scores),
create_graph=True,
retain_graph=True,
)[0]
gradient = gradient.view(gradient.shape[0], -1)
gradient_norm = gradient.norm(2, dim=1)
gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
return gradient_penalty

def train_fn(dis, gen, loader, dataset, step, alpha, opt_dis, opt_gen, device):

loop = tqdm(loader, leave=True)

for i, (imgs, labels, _) in enumerate(loop):
imgs = imgs.to(device)
cur_batch_size = imgs.shape[0]
noise = torch.randn(cur_batch_size, Z_DIM).to(device)
fake  = gen(noise, alpha, step)
dis_real = dis(imgs, alpha, step)
dis_fake = dis(fake.detach(), alpha, step)
gp = gradient_penalty(dis, imgs, fake, alpha, step)

dis_loss = (
-(torch.mean(dis_real) - torch.mean(dis_fake))
+ LAMBDA_GP * gp
+ (0.001) * torch.mean(dis_real ** 2)
)

dis.zero_grad()
dis_loss.backward()
opt_dis.step()

gen_fake = dis(fake, alpha, step)
gen_loss = -torch.mean(gen_fake)

gen.zero_grad()
gen_loss.backward()
opt_gen.step()

alpha += cur_batch_size / (
PROGRESSIVE_EPOCHS[step] * 0.5 * len(dataset)
)
alpha = min(alpha,1)

loop.set_postfix(
{'Dis Loss': dis_loss.item(),
'Gen Loss': gen_loss.item(),
'Gradient Penalty:': gp.item()}
)

return alpha, dis_loss.item(), gen_loss.item(), imgs, fake
Я хотел бы знать, является ли это нормальной функцией потерь StyleGAN (я не нашел ни одного графика StyleGAN), а если нет, возможно, кто-нибудь сможет помочь мне улучшить потери или штраф за градиент кода.
Изображения одноканальные (очевидно) и 128x128.
Большое спасибо.

Подробнее здесь: https://stackoverflow.com/questions/791 ... -bw-images
Ответить

Быстрый ответ

Изменение регистра текста: 
Смайлики
:) :( :oops: :roll: :wink: :muza: :clever: :sorry: :angel: :read: *x)
Ещё смайлики…
   
К этому ответу прикреплено по крайней мере одно вложение.

Если вы не хотите добавлять вложения, оставьте поля пустыми.

Максимально разрешённый размер вложения: 15 МБ.

Вернуться в «Python»