Я использую случайный cnn для дискриминатора, но для генератора я узнал об unet и пытаюсь попробовать.
Однако это совершенно не работает. Он всегда выводит очень маленькое число (от нуля до единицы), что не очень хорошо для RGB, и я не знаю, что делать.
Может быть, я неправильно его закодировал?
Дискриминатор всегда заканчивается выигрыш с большим отрывом и крах гана, действительно генератор не может вывести ни одного хорошего изображения!
Пожалуйста, научите меня, спасибо!
Вот изображение, которое я хочу масштабировать и результат напечатан ниже

Вот мой код:
import os
from torch import nn
import numpy as np
import math
import matplotlib.pyplot as plt
import torchvision
import torchvision.transforms as transforms
from PIL import Image
from torchsummary import summary
import itertools
from tqdm import tqdm
from torch.nn.functional import relu
device = ""
if torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
X = []
Y = []
c = 0
data_size = 513
print(torch.cuda.mem_get_info())
dirname = "../input/mushrooms/train"
for filename in tqdm(os.listdir(dirname)):
if filename != "_classes.csv":
im = Image.open(os.path.join(dirname, filename))
im1 = im.resize((256,256))
Y.append(np.array(im1))
im2 = im.resize((128,128))
X.append(np.array(im2))
c+=1
if(c == data_size-1):
break
X = np.array(X, dtype = 'float32')
Y = np.array(Y, dtype = 'float32')
X = torch.tensor(X)
Y = torch.tensor(Y)
X = torch.transpose(X,1,3)
Y = torch.transpose(Y,1,3)
batch_size = 16
#attention c'est toujours le même ordre de data
low_loader = torch.utils.data.DataLoader(
X, batch_size=batch_size#, shuffle=True
)
high_loader = torch.utils.data.DataLoader(
Y, batch_size=batch_size#, shuffle=True
)
print("data loaded in the ram")
class Discriminator(nn.Module):
def __init__(self):
super().__init__()
self.model = nn.Sequential(
#256
nn.Conv2d(3,16,3), #254
nn.ReLU(),
nn.MaxPool2d(2), #127
nn.Conv2d(16,32,3), #125
nn.ReLU(),
nn.MaxPool2d(2), #62
nn.Conv2d(32,64,3), #60
nn.ReLU(),
nn.MaxPool2d(2), #30
nn.Conv2d(64,128,3), #28
nn.ReLU(),
nn.MaxPool2d(2), #14
nn.Flatten(),
nn.Linear(128*14*14, 1024),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(1024, 512),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(512, 256),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(256, 1),
nn.Sigmoid(),
)
def forward(self, x):
return self.model(x)
discriminator = Discriminator().to(device=device)
summary(discriminator.model, (3, 256, 256))
print("disriminator set")
class Generator(nn.Module):
def __init__(self):
super().__init__()
# Encoder
self.mp = nn.MaxPool2d(2,stride=2)
#input: 3x128x128
self.c1 = nn.Sequential(
nn.Conv2d(3,64,3,padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Conv2d(64,64,3,padding=1),
nn.ReLU())
#input: 64x64x64
self.c2 = nn.Sequential(
nn.Conv2d(64,128,3,padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.Conv2d(128,128,3,padding=1),
nn.ReLU())
#input: 128x32x32
self.c3 = nn.Sequential(
nn.Conv2d(128,256,3,padding=1),
nn.ReLU(),
nn.Conv2d(256,256,3,padding=1),
nn.ReLU())
#input: 256x16x16
self.c4 = nn.Sequential(
nn.Conv2d(256,512,3,padding=1),
nn.ReLU(),
nn.Conv2d(512,512,3,padding=1),
nn.ReLU())
#input: 512x8x8
self.c5 = nn.Sequential(
nn.Conv2d(512,1024,3,padding=1),
nn.ReLU(),
nn.Conv2d(1024,1024,3,padding=1),
nn.ReLU())
# Decoder
#input: 1024x8x8
self.tc1 = nn.ConvTranspose2d(1024,512,2,stride=2)
#output: 512x16x16
#skip connection
#input: 1024x16x16
self.u1 = nn.Sequential(
nn.Conv2d(1024,512,3,padding=1),
nn.ReLU(),
nn.Conv2d(512,512,3,padding=1),
nn.ReLU())
#output: 512x16x16
#input: 512x16x16
self.tc2 = nn.ConvTranspose2d(512,256,2,stride=2)
#output: 256x32x32
#skip connection
#input: 512x32x32
self.u2 = nn.Sequential(
nn.Conv2d(512,256,3,padding=1),
nn.ReLU(),
nn.Conv2d(256,256,3,padding=1),
nn.ReLU())
#output: 256x32x32
#input: 256x32x32
self.tc3 = nn.ConvTranspose2d(256,128,2,stride=2)
#output: 128x64x64
#skip connection
#input: 256x64x64
self.u3 = nn.Sequential(
nn.Conv2d(256,128,3,padding=1),
nn.ReLU(),
nn.Conv2d(128,128,3,padding=1),
nn.ReLU())
#output: 128x64x64
#input: 128x64x64
self.tc4 = nn.ConvTranspose2d(128,64,2,stride=2)
#output: 64x128x128
#skip connection
#input: 128x128x128
self.u4 = nn.Sequential(
nn.Conv2d(128,64,3,padding=1),
nn.ReLU(),
nn.Conv2d(64,64,3,padding=1),
nn.ReLU())
#output: 64x128x128
#input: 64x128x128
self.tc5 = nn.ConvTranspose2d(64,64,2,stride=2)
#output: 64x256x256
#input: 64x256x256
self.u5 = nn.Sequential(
nn.Conv2d(64,64,3,padding=1),
nn.ReLU(),
nn.Conv2d(64,64,3,padding=1),
nn.ReLU())
#output: 32x256x256
#input: 64x512x512
self.last = nn.Sequential(
nn.Conv2d(64,3,1),
nn.ReLU())
def forward(self, x):
x1 = self.c1(x)
x1p = self.mp(x1)
x2 = self.c2(x1p)
x2p = self.mp(x2)
x3 = self.c3(x2p)
x3p = self.mp(x3)
x4 = self.c4(x3p)
x4p = self.mp(x4)
x5 = self.c5(x4p)
xt1 = self.tc1(x5)
xc1 = self.u1(torch.cat([x4, xt1], dim=1))
xt2 = self.tc2(xc1)
xc2 = self.u2(torch.cat([x3, xt2], dim=1))
xt3 = self.tc3(xc2)
xc3 = self.u3(torch.cat([x2, xt3], dim=1))
xt4 = self.tc4(xc3)
xc4 = self.u4(torch.cat([x1, xt4], dim=1))
xt5 = self.tc5(xc4)
xc5 = self.u5(xt5)
return self.last(xc5)
generator = Generator().to(device=device)
print("generator set")
lr = 0.0001
num_epochs = 4
loss_function = nn.BCELoss()
optimizer_discriminator = torch.optim.Adam(discriminator.parameters(), lr=lr)
optimizer_generator = torch.optim.Adam(generator.parameters(), lr=lr)
for epoch in range(num_epochs):
for low_samples,high_samples in tqdm(itertools.zip_longest(enumerate(low_loader), enumerate(high_loader))):
n=low_samples[0]
low_samples = low_samples[1]
high_samples = high_samples[1]
generated_labels = torch.zeros((batch_size, 1))
generated_samples = generator(low_samples.to(device=device))
high_samples_labels = torch.ones((batch_size, 1))
all_samples = torch.cat((high_samples.to(device=device), generated_samples))
all_samples_labels = torch.cat((high_samples_labels, generated_labels))
# Training the discriminator
discriminator.zero_grad()
output_discriminator = discriminator(all_samples)
loss_discriminator = loss_function(output_discriminator, all_samples_labels.to(device=device))
loss_discriminator.backward() #compute gradient
optimizer_discriminator.step() #backpropagation
# Training the generator
generator.zero_grad()
generated_output = generator(low_samples.to(device=device))
plt.imshow(torch.transpose(generated_output.cpu().detach()[0],0,2).type(torch.int64))
plt.show()
discriminated_generated_output = discriminator(generated_output)
loss_generator = loss_function(discriminated_generated_output,high_samples_labels.to(device=device))
loss_generator.backward() #compute gradient
optimizer_generator.step() #backpropagation
# Show loss
if n == 3:
print(f"Epoch: {epoch} Loss D.: {loss_discriminator}")
print(f"Epoch: {epoch} Loss G.: {loss_generator}")
Подробнее здесь: https://stackoverflow.com/questions/793 ... -correctly
Мобильная версия