
Когда тензор входного изображения находится в [0; 1] я получаю результат 1, очень размытую версию входного изображения (выходной тензор в диапазоне [0; 4,37]). Когда я нормализую тензор до [-1; 1] перед тем, как пропустить его через генератор, я получаю выход 2 (выходной тензор в диапазоне [-1,01, 5,086]), и это едва улучшенная версия входа. Когда я протестировал ту же модель с использованием сценария Predict.py из исходного репозитория, я получил желаемое выходное изображение, чего я и пытаюсь достичь. Кажется, проблема в том, как я предварительно обрабатываю/нормализую/постобрабатываю изображения. Вот мой код:
Код: Выделить всё
opt = TrainOpt()
# Initialize generator
netG = networks_eg.define_G(input_nc=3, output_nc=3, ngf=64, which_model_netG='sid_unet_resize', norm='instance', skip=True, opt=opt)
# Load your pretrained weights
state_dict = torch.load('model/200_net_G_A.pth', map_location='cpu')
# Remove 'module.' prefix if present
new_state_dict = OrderedDict()
for k, v in state_dict.items():
if k.startswith('module.'):
k = k.replace('module.', '', 1)
new_state_dict[k] = v
netG.load_state_dict(new_state_dict)
netG.eval()
def tensor2im(image_tensor, imtype=np.uint8):
image_numpy = image_tensor[0].cpu().float().numpy()
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
image_numpy = np.maximum(image_numpy, 0)
image_numpy = np.minimum(image_numpy, 255)
return image_numpy.astype(imtype)
def enhance_image(input_image):
transform = transforms.ToTensor()
inverse_transform = transforms.ToPILImage()
# Convert input image to tensor
input_tensor = transform(input_image).unsqueeze(0)
# If the model was trained with im in [-1, 1], normalize:
input_tensor = input_tensor * 2 - 1
# Convert grayscale image to tensor
gray_image = input_image.convert('L')
gray_tensor = transform(gray_image).unsqueeze(0)
real_A = Variable(input_tensor, volatile=True)
real_A_gray = Variable(gray_tensor, volatile=True)
with torch.no_grad():
fake_B, latent_real_A = netG(real_A, real_A_gray)
print("Raw output min:", fake_B.min().item())
print("Raw output max:", fake_B.max().item())
fake_B = inverse_transform(tensor2im(fake_B.data))
return fake_B
Подробнее здесь: https://stackoverflow.com/questions/798 ... lightengan
Мобильная версия