Обученная и загруженная модель CycleGAN дает искаженные выходные изображения.Python

Программы на Python
Ответить
Anonymous
 Обученная и загруженная модель CycleGAN дает искаженные выходные изображения.

Сообщение Anonymous »

Я обучил модель CycleGAN в Google Colab, используя этот репозиторий — https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix
Модель должна улучшать темные изображения. Я протестировал модель на своем тестовом наборе данных в Google Colab, используя скрипт test.py из репозитория Github, и изображения получились в порядке. Я пытаюсь загрузить сохраненные веса модели в Jupyter Notebook и использовать ее для улучшения одного входного изображения. Вот результат, который я получаю:
Изображение

Вывод выглядит совершенно иначе, чем изображения, которые я получил после тестирования. Вот мой код:

Код: Выделить всё

import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
import matplotlib.pyplot as plt
import functools

class ResnetBlock(nn.Module):
def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
super(ResnetBlock, self).__init__()
self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)

def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
conv_block = []
p = 0
if padding_type == "reflect":
conv_block += [nn.ReflectionPad2d(1)]
elif padding_type == "replicate":
conv_block += [nn.ReplicationPad2d(1)]
elif padding_type == "zero":
p = 1
else:
raise NotImplementedError("padding [%s] is not implemented"  % padding_type)

conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True)]
if use_dropout:
conv_block += [nn.Dropout(0.5)]

p = 0
if padding_type == "reflect":
conv_block += [nn.ReflectionPad2d(1)]
elif padding_type == "replicate":
conv_block += [nn.ReplicationPad2d(1)]
elif padding_type == "zero":
p = 1
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)]
return nn.Sequential(*conv_block)

def forward(self, x):
return x + self.conv_block(x)

class ResnetGenerator(nn.Module):
def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=9, padding_type="reflect"):
super(ResnetGenerator, self).__init__()
use_bias = norm_layer == nn.InstanceNorm2d or isinstance(norm_layer, functools.partial)
model = [nn.ReflectionPad2d(3),
nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
norm_layer(ngf),
nn.ReLU(True)]

n_downsampling = 2
for i in range(n_downsampling):
mult = 2 ** i
model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
norm_layer(ngf * mult * 2),
nn.ReLU(True)]

mult = 2 ** n_downsampling
for i in range(n_blocks):
model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]

for i in range(n_downsampling):
mult = 2 ** (n_downsampling - i)
model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1, bias=use_bias),
norm_layer(int(ngf * mult / 2)),
nn.ReLU(True)]
model += [nn.ReflectionPad2d(3)]
model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
model += [nn.Tanh()]

self.model = nn.Sequential(*model)

def forward(self, input):
return self.model(input)

def load_generator(checkpoint_path):
netG = ResnetGenerator(input_nc=3, output_nc=3, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=9, padding_type="reflect")
state_dict = torch.load(checkpoint_path, map_location=torch.device('cpu'))
netG.load_state_dict(state_dict, strict=False)
netG.eval()
netG.to('cpu')
return netG

checkpoint_path = './IdeaProjects/GAN_project/model/latest_net_G_A2.pth'
netG_A2B = load_generator(checkpoint_path)

transform_input = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

def denormalize(tensor):
return tensor * 0.5 + 0.5

input_image_path = './2015_00019_jpg.rf.a11c63668bdd756b639f780d06c38a31.jpg'
input_image = Image.open(input_image_path).convert('RGB')

input_tensor = transform_input(input_image).unsqueeze(0)

with torch.no_grad():
output_tensor = netG_A2B(input_tensor)

output_image = denormalize(output_tensor.squeeze(0)).cpu()
output_image = torch.clamp(output_image, 0, 1)
output_image = transforms.ToPILImage()(output_image)

plt.figure(figsize=(12,6))
plt.subplot(1,2,1)
plt.title("Original")
plt.imshow(input_image)
plt.axis('off')

plt.subplot(1,2,2)
plt.title("Enhanced")
plt.imshow(output_image)
plt.axis('off')
plt.show()
Если я попытаюсь загрузить модель без strict=False:

Код: Выделить всё

netG.load_state_dict(state_dict)
Я получаю следующую ошибку:

Код: Выделить всё

RuntimeError: Error(s) in loading state_dict for ResnetGenerator:
Missing key(s) in state_dict:  "model.2.weight", "model.2.bias", "model.2.running_mean", "model.2.running_var", "model.2.num_batches_tracked", "model.5.weight", "model.5.bias", "model.5.running_mean", "model.5.running_var", "model.5.num_batches_tracked", "model.8.weight", "model.8.bias", "model.8.running_mean", "model.8.running_var", "model.8.num_batches_tracked", "model.10.conv_block.2.weight", "model.10.conv_block.2.bias", "model.10.conv_block.2.running_mean", "model.10.conv_block.2.running_var", "model.10.conv_block.2.num_batches_tracked", "model.10.conv_block.6.weight", "model.10.conv_block.6.bias", "model.10.conv_block.6.running_mean", "model.10.conv_block.6.running_var", "model.10.conv_block.6.num_batches_tracked", "model.11.conv_block.2.weight", "model.11.conv_block.2.bias", "model.11.conv_block.2.running_mean", "model.11.conv_block.2.running_var", "model.11.conv_block.2.num_batches_tracked", "model.11.conv_block.6.weight", "model.11.conv_block.6.bias", "model.11.conv_block.6.running_mean", "model.11.conv_block.6.running_var", "model.11.conv_block.6.num_batches_tracked", "model.12.conv_block.2.weight", "model.12.conv_block.2.bias", "model.12.conv_block.2.running_mean", "model.12.conv_block.2.running_var", "model.12.conv_block.2.num_batches_tracked", "model.12.conv_block.6.weight", "model.12.conv_block.6.bias", "model.12.conv_block.6.running_mean", "model.12.conv_block.6.running_var", "model.12.conv_block.6.num_batches_tracked", "model.13.conv_block.2.weight", "model.13.conv_block.2.bias", "model.13.conv_block.2.running_mean", "model.13.conv_block.2.running_var", "model.13.conv_block.2.num_batches_tracked", "model.13.conv_block.6.weight", "model.13.conv_block.6.bias", "model.13.conv_block.6.running_mean", "model.13.conv_block.6.running_var", "model.13.conv_block.6.num_batches_tracked", "model.14.conv_block.2.weight", "model.14.conv_block.2.bias", "model.14.conv_block.2.running_mean", "model.14.conv_block.2.running_var", "model.14.conv_block.2.num_batches_tracked", "model.14.conv_block.6.weight", "model.14.conv_block.6.bias", "model.14.conv_block.6.running_mean", "model.14.conv_block.6.running_var", "model.14.conv_block.6.num_batches_tracked", "model.15.conv_block.2.weight", "model.15.conv_block.2.bias", "model.15.conv_block.2.running_mean", "model.15.conv_block.2.running_var", "model.15.conv_block.2.num_batches_tracked", "model.15.conv_block.6.weight", "model.15.conv_block.6.bias", "model.15.conv_block.6.running_mean", "model.15.conv_block.6.running_var", "model.15.conv_block.6.num_batches_tracked", "model.16.conv_block.2.weight", "model.16.conv_block.2.bias", "model.16.conv_block.2.running_mean", "model.16.conv_block.2.running_var", "model.16.conv_block.2.num_batches_tracked", "model.16.conv_block.6.weight", "model.16.conv_block.6.bias", "model.16.conv_block.6.running_mean", "model.16.conv_block.6.running_var", "model.16.conv_block.6.num_batches_tracked", "model.17.conv_block.2.weight", "model.17.conv_block.2.bias", "model.17.conv_block.2.running_mean", "model.17.conv_block.2.running_var", "model.17.conv_block.2.num_batches_tracked", "model.17.conv_block.6.weight", "model.17.conv_block.6.bias", "model.17.conv_block.6.running_mean", "model.17.conv_block.6.running_var", "model.17.conv_block.6.num_batches_tracked", "model.18.conv_block.2.weight", "model.18.conv_block.2.bias", "model.18.conv_block.2.running_mean", "model.18.conv_block.2.running_var", "model.18.conv_block.2.num_batches_tracked", "model.18.conv_block.6.weight", "model.18.conv_block.6.bias", "model.18.conv_block.6.running_mean", "model.18.conv_block.6.running_var", "model.18.conv_block.6.num_batches_tracked", "model.20.weight", "model.20.bias", "model.20.running_mean", "model.20.running_var", "model.20.num_batches_tracked", "model.23.weight", "model.23.bias", "model.23.running_mean", "model.23.running_var", "model.23.num_batches_tracked".
Unexpected key(s) in state_dict: "model.1.bias", "model.4.bias", "model.7.bias", "model.10.conv_block.1.bias", "model.10.conv_block.5.bias", "model.11.conv_block.1.bias", "model.11.conv_block.5.bias", "model.12.conv_block.1.bias", "model.12.conv_block.5.bias", "model.13.conv_block.1.bias", "model.13.conv_block.5.bias", "model.14.conv_block.1.bias", "model.14.conv_block.5.bias", "model.15.conv_block.1.bias", "model.15.conv_block.5.bias", "model.16.conv_block.1.bias", "model.16.conv_block.5.bias", "model.17.conv_block.1.bias", "model.17.conv_block.5.bias", "model.18.conv_block.1.bias", "model.18.conv_block.5.bias", "model.19.bias", "model.22.bias".
Важно отметить, что я обучал и тестировал модель на графическом процессоре. В Jupyter Notebook я использовал процессор для создания изображения. Я подумал, что, возможно, это может быть проблемой, поэтому попробовал запустить этот код в Google Colab с использованием графического процессора, однако получил такое же искаженное выходное изображение:

Код: Выделить всё

import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
import matplotlib.pyplot as plt
import functools

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class ResnetBlock(nn.Module):
def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
super(ResnetBlock, self).__init__()
self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)

def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
conv_block = []
p = 0
if padding_type == "reflect":
conv_block += [nn.ReflectionPad2d(1)]
elif padding_type == "replicate":
conv_block += [nn.ReplicationPad2d(1)]
elif padding_type == "zero":
p = 1
else:
raise NotImplementedError("padding [%s] is not implemented"  % padding_type)

conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True)]
if use_dropout:
conv_block += [nn.Dropout(0.5)]

p = 0
if padding_type == "reflect":
conv_block += [nn.ReflectionPad2d(1)]
elif padding_type == "replicate":
conv_block += [nn.ReplicationPad2d(1)]
elif padding_type == "zero":
p = 1
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)]
return nn.Sequential(*conv_block)

def forward(self, x):
return x + self.conv_block(x)

class ResnetGenerator(nn.Module):
def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=9, padding_type="reflect"):
super(ResnetGenerator, self).__init__()
use_bias = norm_layer == nn.InstanceNorm2d or isinstance(norm_layer, functools.partial)
model = [nn.ReflectionPad2d(3),
nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
norm_layer(ngf),
nn.ReLU(True)]

n_downsampling = 2
for i in range(n_downsampling):
mult = 2 ** i
model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
norm_layer(ngf * mult * 2),
nn.ReLU(True)]

mult = 2 ** n_downsampling
for i in range(n_blocks):
model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]

for i in range(n_downsampling):
mult = 2 ** (n_downsampling - i)
model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1, bias=use_bias),
norm_layer(int(ngf * mult / 2)),
nn.ReLU(True)]
model += [nn.ReflectionPad2d(3)]
model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
model += [nn.Tanh()]

self.model = nn.Sequential(*model)

def forward(self, input):
return self.model(input)

def load_generator(checkpoint_path):
# Adjust parameters as per your training setup
netG = ResnetGenerator(input_nc=3, output_nc=3, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=9, padding_type="reflect")
state_dict = torch.load(checkpoint_path, map_location=torch.device('cuda'))
netG.load_state_dict(state_dict, strict=False)
netG.eval()
netG.to(device)
return netG

def denormalize(tensor):
return (tensor * 0.5) + 0.5

checkpoint_path = './latest_net_G_A2.pth'
netG_A2B = load_generator(checkpoint_path)

transform_input = transforms.Compose([
transforms.ToTensor(),  # [0,1]
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

input_image_path = './2015_00019_jpg.rf.a11c63668bdd756b639f780d06c38a31.jpg'
input_image = Image.open(input_image_path).convert('RGB')

input_tensor = transform_input(input_image).unsqueeze(0).to(device)  # Add batch dim

with torch.no_grad():
output_tensor = netG_A2B(input_tensor)

output_image = denormalize(output_tensor.squeeze(0))
output_image = transforms.ToPILImage()(output_image.to(device))

plt.figure(figsize=(12,6))
plt.subplot(1,2,1)
plt.title("Original")
plt.imshow(input_image)
plt.axis('off')

plt.subplot(1,2,2)
plt.title("Enhanced")
plt.imshow(output_image)
plt.axis('off')
plt.show()
Что я здесь делаю не так? Почему выходное изображение зеленое и искаженное?

Подробнее здесь: https://stackoverflow.com/questions/798 ... put-images
Ответить

Быстрый ответ

Изменение регистра текста: 
Смайлики
:) :( :oops: :roll: :wink: :muza: :clever: :sorry: :angel: :read: *x)
Ещё смайлики…
   
К этому ответу прикреплено по крайней мере одно вложение.

Если вы не хотите добавлять вложения, оставьте поля пустыми.

Максимально разрешённый размер вложения: 15 МБ.

Вернуться в «Python»