Anonymous
Обученная и загруженная модель CycleGAN дает искаженные выходные изображения.
Сообщение
Anonymous » 09 фев 2026, 23:34
Я обучил модель CycleGAN в Google Colab, используя этот репозиторий —
https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix
Модель должна улучшать темные изображения. Я протестировал модель на своем тестовом наборе данных в Google Colab, используя скрипт test.py из репозитория Github, и изображения получились в порядке. Я пытаюсь загрузить сохраненные веса модели в Jupyter Notebook и использовать ее для улучшения одного входного изображения. Вот результат, который я получаю:
Вывод выглядит совершенно иначе, чем изображения, которые я получил после тестирования. Вот мой код:
Код: Выделить всё
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
import matplotlib.pyplot as plt
import functools
class ResnetBlock(nn.Module):
def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
super(ResnetBlock, self).__init__()
self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)
def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
conv_block = []
p = 0
if padding_type == "reflect":
conv_block += [nn.ReflectionPad2d(1)]
elif padding_type == "replicate":
conv_block += [nn.ReplicationPad2d(1)]
elif padding_type == "zero":
p = 1
else:
raise NotImplementedError("padding [%s] is not implemented" % padding_type)
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True)]
if use_dropout:
conv_block += [nn.Dropout(0.5)]
p = 0
if padding_type == "reflect":
conv_block += [nn.ReflectionPad2d(1)]
elif padding_type == "replicate":
conv_block += [nn.ReplicationPad2d(1)]
elif padding_type == "zero":
p = 1
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)]
return nn.Sequential(*conv_block)
def forward(self, x):
return x + self.conv_block(x)
class ResnetGenerator(nn.Module):
def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=9, padding_type="reflect"):
super(ResnetGenerator, self).__init__()
use_bias = norm_layer == nn.InstanceNorm2d or isinstance(norm_layer, functools.partial)
model = [nn.ReflectionPad2d(3),
nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
norm_layer(ngf),
nn.ReLU(True)]
n_downsampling = 2
for i in range(n_downsampling):
mult = 2 ** i
model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
norm_layer(ngf * mult * 2),
nn.ReLU(True)]
mult = 2 ** n_downsampling
for i in range(n_blocks):
model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
for i in range(n_downsampling):
mult = 2 ** (n_downsampling - i)
model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1, bias=use_bias),
norm_layer(int(ngf * mult / 2)),
nn.ReLU(True)]
model += [nn.ReflectionPad2d(3)]
model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
model += [nn.Tanh()]
self.model = nn.Sequential(*model)
def forward(self, input):
return self.model(input)
def load_generator(checkpoint_path):
netG = ResnetGenerator(input_nc=3, output_nc=3, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=9, padding_type="reflect")
state_dict = torch.load(checkpoint_path, map_location=torch.device('cpu'))
netG.load_state_dict(state_dict, strict=False)
netG.eval()
netG.to('cpu')
return netG
checkpoint_path = './IdeaProjects/GAN_project/model/latest_net_G_A2.pth'
netG_A2B = load_generator(checkpoint_path)
transform_input = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
def denormalize(tensor):
return tensor * 0.5 + 0.5
input_image_path = './2015_00019_jpg.rf.a11c63668bdd756b639f780d06c38a31.jpg'
input_image = Image.open(input_image_path).convert('RGB')
input_tensor = transform_input(input_image).unsqueeze(0)
with torch.no_grad():
output_tensor = netG_A2B(input_tensor)
output_image = denormalize(output_tensor.squeeze(0)).cpu()
output_image = torch.clamp(output_image, 0, 1)
output_image = transforms.ToPILImage()(output_image)
plt.figure(figsize=(12,6))
plt.subplot(1,2,1)
plt.title("Original")
plt.imshow(input_image)
plt.axis('off')
plt.subplot(1,2,2)
plt.title("Enhanced")
plt.imshow(output_image)
plt.axis('off')
plt.show()
Если я попытаюсь загрузить модель без strict=False:
Я получаю следующую ошибку:
Код: Выделить всё
RuntimeError: Error(s) in loading state_dict for ResnetGenerator:
Missing key(s) in state_dict: "model.2.weight", "model.2.bias", "model.2.running_mean", "model.2.running_var", "model.2.num_batches_tracked", "model.5.weight", "model.5.bias", "model.5.running_mean", "model.5.running_var", "model.5.num_batches_tracked", "model.8.weight", "model.8.bias", "model.8.running_mean", "model.8.running_var", "model.8.num_batches_tracked", "model.10.conv_block.2.weight", "model.10.conv_block.2.bias", "model.10.conv_block.2.running_mean", "model.10.conv_block.2.running_var", "model.10.conv_block.2.num_batches_tracked", "model.10.conv_block.6.weight", "model.10.conv_block.6.bias", "model.10.conv_block.6.running_mean", "model.10.conv_block.6.running_var", "model.10.conv_block.6.num_batches_tracked", "model.11.conv_block.2.weight", "model.11.conv_block.2.bias", "model.11.conv_block.2.running_mean", "model.11.conv_block.2.running_var", "model.11.conv_block.2.num_batches_tracked", "model.11.conv_block.6.weight", "model.11.conv_block.6.bias", "model.11.conv_block.6.running_mean", "model.11.conv_block.6.running_var", "model.11.conv_block.6.num_batches_tracked", "model.12.conv_block.2.weight", "model.12.conv_block.2.bias", "model.12.conv_block.2.running_mean", "model.12.conv_block.2.running_var", "model.12.conv_block.2.num_batches_tracked", "model.12.conv_block.6.weight", "model.12.conv_block.6.bias", "model.12.conv_block.6.running_mean", "model.12.conv_block.6.running_var", "model.12.conv_block.6.num_batches_tracked", "model.13.conv_block.2.weight", "model.13.conv_block.2.bias", "model.13.conv_block.2.running_mean", "model.13.conv_block.2.running_var", "model.13.conv_block.2.num_batches_tracked", "model.13.conv_block.6.weight", "model.13.conv_block.6.bias", "model.13.conv_block.6.running_mean", "model.13.conv_block.6.running_var", "model.13.conv_block.6.num_batches_tracked", "model.14.conv_block.2.weight", "model.14.conv_block.2.bias", "model.14.conv_block.2.running_mean", "model.14.conv_block.2.running_var", "model.14.conv_block.2.num_batches_tracked", "model.14.conv_block.6.weight", "model.14.conv_block.6.bias", "model.14.conv_block.6.running_mean", "model.14.conv_block.6.running_var", "model.14.conv_block.6.num_batches_tracked", "model.15.conv_block.2.weight", "model.15.conv_block.2.bias", "model.15.conv_block.2.running_mean", "model.15.conv_block.2.running_var", "model.15.conv_block.2.num_batches_tracked", "model.15.conv_block.6.weight", "model.15.conv_block.6.bias", "model.15.conv_block.6.running_mean", "model.15.conv_block.6.running_var", "model.15.conv_block.6.num_batches_tracked", "model.16.conv_block.2.weight", "model.16.conv_block.2.bias", "model.16.conv_block.2.running_mean", "model.16.conv_block.2.running_var", "model.16.conv_block.2.num_batches_tracked", "model.16.conv_block.6.weight", "model.16.conv_block.6.bias", "model.16.conv_block.6.running_mean", "model.16.conv_block.6.running_var", "model.16.conv_block.6.num_batches_tracked", "model.17.conv_block.2.weight", "model.17.conv_block.2.bias", "model.17.conv_block.2.running_mean", "model.17.conv_block.2.running_var", "model.17.conv_block.2.num_batches_tracked", "model.17.conv_block.6.weight", "model.17.conv_block.6.bias", "model.17.conv_block.6.running_mean", "model.17.conv_block.6.running_var", "model.17.conv_block.6.num_batches_tracked", "model.18.conv_block.2.weight", "model.18.conv_block.2.bias", "model.18.conv_block.2.running_mean", "model.18.conv_block.2.running_var", "model.18.conv_block.2.num_batches_tracked", "model.18.conv_block.6.weight", "model.18.conv_block.6.bias", "model.18.conv_block.6.running_mean", "model.18.conv_block.6.running_var", "model.18.conv_block.6.num_batches_tracked", "model.20.weight", "model.20.bias", "model.20.running_mean", "model.20.running_var", "model.20.num_batches_tracked", "model.23.weight", "model.23.bias", "model.23.running_mean", "model.23.running_var", "model.23.num_batches_tracked".
Unexpected key(s) in state_dict: "model.1.bias", "model.4.bias", "model.7.bias", "model.10.conv_block.1.bias", "model.10.conv_block.5.bias", "model.11.conv_block.1.bias", "model.11.conv_block.5.bias", "model.12.conv_block.1.bias", "model.12.conv_block.5.bias", "model.13.conv_block.1.bias", "model.13.conv_block.5.bias", "model.14.conv_block.1.bias", "model.14.conv_block.5.bias", "model.15.conv_block.1.bias", "model.15.conv_block.5.bias", "model.16.conv_block.1.bias", "model.16.conv_block.5.bias", "model.17.conv_block.1.bias", "model.17.conv_block.5.bias", "model.18.conv_block.1.bias", "model.18.conv_block.5.bias", "model.19.bias", "model.22.bias".
Важно отметить, что я обучал и тестировал модель на графическом процессоре. В Jupyter Notebook я использовал процессор для создания изображения. Я подумал, что, возможно, это может быть проблемой, поэтому попробовал запустить этот код в Google Colab с использованием графического процессора, однако получил такое же искаженное выходное изображение:
Код: Выделить всё
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
import matplotlib.pyplot as plt
import functools
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class ResnetBlock(nn.Module):
def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
super(ResnetBlock, self).__init__()
self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)
def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
conv_block = []
p = 0
if padding_type == "reflect":
conv_block += [nn.ReflectionPad2d(1)]
elif padding_type == "replicate":
conv_block += [nn.ReplicationPad2d(1)]
elif padding_type == "zero":
p = 1
else:
raise NotImplementedError("padding [%s] is not implemented" % padding_type)
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True)]
if use_dropout:
conv_block += [nn.Dropout(0.5)]
p = 0
if padding_type == "reflect":
conv_block += [nn.ReflectionPad2d(1)]
elif padding_type == "replicate":
conv_block += [nn.ReplicationPad2d(1)]
elif padding_type == "zero":
p = 1
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)]
return nn.Sequential(*conv_block)
def forward(self, x):
return x + self.conv_block(x)
class ResnetGenerator(nn.Module):
def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=9, padding_type="reflect"):
super(ResnetGenerator, self).__init__()
use_bias = norm_layer == nn.InstanceNorm2d or isinstance(norm_layer, functools.partial)
model = [nn.ReflectionPad2d(3),
nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
norm_layer(ngf),
nn.ReLU(True)]
n_downsampling = 2
for i in range(n_downsampling):
mult = 2 ** i
model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
norm_layer(ngf * mult * 2),
nn.ReLU(True)]
mult = 2 ** n_downsampling
for i in range(n_blocks):
model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
for i in range(n_downsampling):
mult = 2 ** (n_downsampling - i)
model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1, bias=use_bias),
norm_layer(int(ngf * mult / 2)),
nn.ReLU(True)]
model += [nn.ReflectionPad2d(3)]
model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
model += [nn.Tanh()]
self.model = nn.Sequential(*model)
def forward(self, input):
return self.model(input)
def load_generator(checkpoint_path):
# Adjust parameters as per your training setup
netG = ResnetGenerator(input_nc=3, output_nc=3, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=9, padding_type="reflect")
state_dict = torch.load(checkpoint_path, map_location=torch.device('cuda'))
netG.load_state_dict(state_dict, strict=False)
netG.eval()
netG.to(device)
return netG
def denormalize(tensor):
return (tensor * 0.5) + 0.5
checkpoint_path = './latest_net_G_A2.pth'
netG_A2B = load_generator(checkpoint_path)
transform_input = transforms.Compose([
transforms.ToTensor(), # [0,1]
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
input_image_path = './2015_00019_jpg.rf.a11c63668bdd756b639f780d06c38a31.jpg'
input_image = Image.open(input_image_path).convert('RGB')
input_tensor = transform_input(input_image).unsqueeze(0).to(device) # Add batch dim
with torch.no_grad():
output_tensor = netG_A2B(input_tensor)
output_image = denormalize(output_tensor.squeeze(0))
output_image = transforms.ToPILImage()(output_image.to(device))
plt.figure(figsize=(12,6))
plt.subplot(1,2,1)
plt.title("Original")
plt.imshow(input_image)
plt.axis('off')
plt.subplot(1,2,2)
plt.title("Enhanced")
plt.imshow(output_image)
plt.axis('off')
plt.show()
Что я здесь делаю не так? Почему выходное изображение зеленое и искаженное?
Подробнее здесь:
https://stackoverflow.com/questions/798 ... put-images
1770669271
Anonymous
Я обучил модель CycleGAN в Google Colab, используя этот репозиторий — https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix Модель должна улучшать темные изображения. Я протестировал модель на своем тестовом наборе данных в Google Colab, используя скрипт test.py из репозитория Github, и изображения получились в порядке. Я пытаюсь загрузить сохраненные веса модели в Jupyter Notebook и использовать ее для улучшения одного входного изображения. Вот результат, который я получаю: [img]https://i.sstatic.net/Wj9N9swX.png[/img] Вывод выглядит совершенно иначе, чем изображения, которые я получил после тестирования. Вот мой код: [code]import torch import torch.nn as nn import torchvision.transforms as transforms from PIL import Image import matplotlib.pyplot as plt import functools class ResnetBlock(nn.Module): def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias): super(ResnetBlock, self).__init__() self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias) def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias): conv_block = [] p = 0 if padding_type == "reflect": conv_block += [nn.ReflectionPad2d(1)] elif padding_type == "replicate": conv_block += [nn.ReplicationPad2d(1)] elif padding_type == "zero": p = 1 else: raise NotImplementedError("padding [%s] is not implemented" % padding_type) conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True)] if use_dropout: conv_block += [nn.Dropout(0.5)] p = 0 if padding_type == "reflect": conv_block += [nn.ReflectionPad2d(1)] elif padding_type == "replicate": conv_block += [nn.ReplicationPad2d(1)] elif padding_type == "zero": p = 1 conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)] return nn.Sequential(*conv_block) def forward(self, x): return x + self.conv_block(x) class ResnetGenerator(nn.Module): def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=9, padding_type="reflect"): super(ResnetGenerator, self).__init__() use_bias = norm_layer == nn.InstanceNorm2d or isinstance(norm_layer, functools.partial) model = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias), norm_layer(ngf), nn.ReLU(True)] n_downsampling = 2 for i in range(n_downsampling): mult = 2 ** i model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias), norm_layer(ngf * mult * 2), nn.ReLU(True)] mult = 2 ** n_downsampling for i in range(n_blocks): model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)] for i in range(n_downsampling): mult = 2 ** (n_downsampling - i) model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1, bias=use_bias), norm_layer(int(ngf * mult / 2)), nn.ReLU(True)] model += [nn.ReflectionPad2d(3)] model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)] model += [nn.Tanh()] self.model = nn.Sequential(*model) def forward(self, input): return self.model(input) def load_generator(checkpoint_path): netG = ResnetGenerator(input_nc=3, output_nc=3, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=9, padding_type="reflect") state_dict = torch.load(checkpoint_path, map_location=torch.device('cpu')) netG.load_state_dict(state_dict, strict=False) netG.eval() netG.to('cpu') return netG checkpoint_path = './IdeaProjects/GAN_project/model/latest_net_G_A2.pth' netG_A2B = load_generator(checkpoint_path) transform_input = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) ]) def denormalize(tensor): return tensor * 0.5 + 0.5 input_image_path = './2015_00019_jpg.rf.a11c63668bdd756b639f780d06c38a31.jpg' input_image = Image.open(input_image_path).convert('RGB') input_tensor = transform_input(input_image).unsqueeze(0) with torch.no_grad(): output_tensor = netG_A2B(input_tensor) output_image = denormalize(output_tensor.squeeze(0)).cpu() output_image = torch.clamp(output_image, 0, 1) output_image = transforms.ToPILImage()(output_image) plt.figure(figsize=(12,6)) plt.subplot(1,2,1) plt.title("Original") plt.imshow(input_image) plt.axis('off') plt.subplot(1,2,2) plt.title("Enhanced") plt.imshow(output_image) plt.axis('off') plt.show() [/code] Если я попытаюсь загрузить модель без strict=False: [code]netG.load_state_dict(state_dict) [/code] Я получаю следующую ошибку: [code]RuntimeError: Error(s) in loading state_dict for ResnetGenerator: Missing key(s) in state_dict: "model.2.weight", "model.2.bias", "model.2.running_mean", "model.2.running_var", "model.2.num_batches_tracked", "model.5.weight", "model.5.bias", "model.5.running_mean", "model.5.running_var", "model.5.num_batches_tracked", "model.8.weight", "model.8.bias", "model.8.running_mean", "model.8.running_var", "model.8.num_batches_tracked", "model.10.conv_block.2.weight", "model.10.conv_block.2.bias", "model.10.conv_block.2.running_mean", "model.10.conv_block.2.running_var", "model.10.conv_block.2.num_batches_tracked", "model.10.conv_block.6.weight", "model.10.conv_block.6.bias", "model.10.conv_block.6.running_mean", "model.10.conv_block.6.running_var", "model.10.conv_block.6.num_batches_tracked", "model.11.conv_block.2.weight", "model.11.conv_block.2.bias", "model.11.conv_block.2.running_mean", "model.11.conv_block.2.running_var", "model.11.conv_block.2.num_batches_tracked", "model.11.conv_block.6.weight", "model.11.conv_block.6.bias", "model.11.conv_block.6.running_mean", "model.11.conv_block.6.running_var", "model.11.conv_block.6.num_batches_tracked", "model.12.conv_block.2.weight", "model.12.conv_block.2.bias", "model.12.conv_block.2.running_mean", "model.12.conv_block.2.running_var", "model.12.conv_block.2.num_batches_tracked", "model.12.conv_block.6.weight", "model.12.conv_block.6.bias", "model.12.conv_block.6.running_mean", "model.12.conv_block.6.running_var", "model.12.conv_block.6.num_batches_tracked", "model.13.conv_block.2.weight", "model.13.conv_block.2.bias", "model.13.conv_block.2.running_mean", "model.13.conv_block.2.running_var", "model.13.conv_block.2.num_batches_tracked", "model.13.conv_block.6.weight", "model.13.conv_block.6.bias", "model.13.conv_block.6.running_mean", "model.13.conv_block.6.running_var", "model.13.conv_block.6.num_batches_tracked", "model.14.conv_block.2.weight", "model.14.conv_block.2.bias", "model.14.conv_block.2.running_mean", "model.14.conv_block.2.running_var", "model.14.conv_block.2.num_batches_tracked", "model.14.conv_block.6.weight", "model.14.conv_block.6.bias", "model.14.conv_block.6.running_mean", "model.14.conv_block.6.running_var", "model.14.conv_block.6.num_batches_tracked", "model.15.conv_block.2.weight", "model.15.conv_block.2.bias", "model.15.conv_block.2.running_mean", "model.15.conv_block.2.running_var", "model.15.conv_block.2.num_batches_tracked", "model.15.conv_block.6.weight", "model.15.conv_block.6.bias", "model.15.conv_block.6.running_mean", "model.15.conv_block.6.running_var", "model.15.conv_block.6.num_batches_tracked", "model.16.conv_block.2.weight", "model.16.conv_block.2.bias", "model.16.conv_block.2.running_mean", "model.16.conv_block.2.running_var", "model.16.conv_block.2.num_batches_tracked", "model.16.conv_block.6.weight", "model.16.conv_block.6.bias", "model.16.conv_block.6.running_mean", "model.16.conv_block.6.running_var", "model.16.conv_block.6.num_batches_tracked", "model.17.conv_block.2.weight", "model.17.conv_block.2.bias", "model.17.conv_block.2.running_mean", "model.17.conv_block.2.running_var", "model.17.conv_block.2.num_batches_tracked", "model.17.conv_block.6.weight", "model.17.conv_block.6.bias", "model.17.conv_block.6.running_mean", "model.17.conv_block.6.running_var", "model.17.conv_block.6.num_batches_tracked", "model.18.conv_block.2.weight", "model.18.conv_block.2.bias", "model.18.conv_block.2.running_mean", "model.18.conv_block.2.running_var", "model.18.conv_block.2.num_batches_tracked", "model.18.conv_block.6.weight", "model.18.conv_block.6.bias", "model.18.conv_block.6.running_mean", "model.18.conv_block.6.running_var", "model.18.conv_block.6.num_batches_tracked", "model.20.weight", "model.20.bias", "model.20.running_mean", "model.20.running_var", "model.20.num_batches_tracked", "model.23.weight", "model.23.bias", "model.23.running_mean", "model.23.running_var", "model.23.num_batches_tracked". Unexpected key(s) in state_dict: "model.1.bias", "model.4.bias", "model.7.bias", "model.10.conv_block.1.bias", "model.10.conv_block.5.bias", "model.11.conv_block.1.bias", "model.11.conv_block.5.bias", "model.12.conv_block.1.bias", "model.12.conv_block.5.bias", "model.13.conv_block.1.bias", "model.13.conv_block.5.bias", "model.14.conv_block.1.bias", "model.14.conv_block.5.bias", "model.15.conv_block.1.bias", "model.15.conv_block.5.bias", "model.16.conv_block.1.bias", "model.16.conv_block.5.bias", "model.17.conv_block.1.bias", "model.17.conv_block.5.bias", "model.18.conv_block.1.bias", "model.18.conv_block.5.bias", "model.19.bias", "model.22.bias". [/code] Важно отметить, что я обучал и тестировал модель на графическом процессоре. В Jupyter Notebook я использовал процессор для создания изображения. Я подумал, что, возможно, это может быть проблемой, поэтому попробовал запустить этот код в Google Colab с использованием графического процессора, однако получил такое же искаженное выходное изображение: [code]import torch import torch.nn as nn import torchvision.transforms as transforms from PIL import Image import matplotlib.pyplot as plt import functools device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') class ResnetBlock(nn.Module): def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias): super(ResnetBlock, self).__init__() self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias) def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias): conv_block = [] p = 0 if padding_type == "reflect": conv_block += [nn.ReflectionPad2d(1)] elif padding_type == "replicate": conv_block += [nn.ReplicationPad2d(1)] elif padding_type == "zero": p = 1 else: raise NotImplementedError("padding [%s] is not implemented" % padding_type) conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True)] if use_dropout: conv_block += [nn.Dropout(0.5)] p = 0 if padding_type == "reflect": conv_block += [nn.ReflectionPad2d(1)] elif padding_type == "replicate": conv_block += [nn.ReplicationPad2d(1)] elif padding_type == "zero": p = 1 conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)] return nn.Sequential(*conv_block) def forward(self, x): return x + self.conv_block(x) class ResnetGenerator(nn.Module): def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=9, padding_type="reflect"): super(ResnetGenerator, self).__init__() use_bias = norm_layer == nn.InstanceNorm2d or isinstance(norm_layer, functools.partial) model = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias), norm_layer(ngf), nn.ReLU(True)] n_downsampling = 2 for i in range(n_downsampling): mult = 2 ** i model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias), norm_layer(ngf * mult * 2), nn.ReLU(True)] mult = 2 ** n_downsampling for i in range(n_blocks): model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)] for i in range(n_downsampling): mult = 2 ** (n_downsampling - i) model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1, bias=use_bias), norm_layer(int(ngf * mult / 2)), nn.ReLU(True)] model += [nn.ReflectionPad2d(3)] model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)] model += [nn.Tanh()] self.model = nn.Sequential(*model) def forward(self, input): return self.model(input) def load_generator(checkpoint_path): # Adjust parameters as per your training setup netG = ResnetGenerator(input_nc=3, output_nc=3, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=9, padding_type="reflect") state_dict = torch.load(checkpoint_path, map_location=torch.device('cuda')) netG.load_state_dict(state_dict, strict=False) netG.eval() netG.to(device) return netG def denormalize(tensor): return (tensor * 0.5) + 0.5 checkpoint_path = './latest_net_G_A2.pth' netG_A2B = load_generator(checkpoint_path) transform_input = transforms.Compose([ transforms.ToTensor(), # [0,1] transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) ]) input_image_path = './2015_00019_jpg.rf.a11c63668bdd756b639f780d06c38a31.jpg' input_image = Image.open(input_image_path).convert('RGB') input_tensor = transform_input(input_image).unsqueeze(0).to(device) # Add batch dim with torch.no_grad(): output_tensor = netG_A2B(input_tensor) output_image = denormalize(output_tensor.squeeze(0)) output_image = transforms.ToPILImage()(output_image.to(device)) plt.figure(figsize=(12,6)) plt.subplot(1,2,1) plt.title("Original") plt.imshow(input_image) plt.axis('off') plt.subplot(1,2,2) plt.title("Enhanced") plt.imshow(output_image) plt.axis('off') plt.show() [/code] Что я здесь делаю не так? Почему выходное изображение зеленое и искаженное? Подробнее здесь: [url]https://stackoverflow.com/questions/79886152/trained-and-loaded-cyclegan-model-is-giving-distorted-output-images[/url]