Проблемы с правильным отображением изображений, созданных EnlightenGAN.Python

Программы на Python
Ответить
Anonymous
 Проблемы с правильным отображением изображений, созданных EnlightenGAN.

Сообщение Anonymous »

Я загрузил предварительно обученную модель EnlightenGAN из этого репозитория — https://github.com/VITA-Group/EnlightenGAN. Я пытаюсь использовать эту модель в своем приложении Flask для улучшения загруженных изображений, но не могу понять, как правильно нормализовать изображение. Вот результаты, которые я получаю:
Изображение

Когда тензор входного изображения находится в [0; 1] я получаю результат 1, очень размытую версию входного изображения (выходной тензор в диапазоне [0; 4,37]). Когда я нормализую тензор до [-1; 1] перед тем, как пропустить его через генератор, я получаю выход 2 (выходной тензор в диапазоне [-1,01, 5,086]), и это едва улучшенная версия входа. Когда я протестировал ту же модель с использованием сценария Predict.py из исходного репозитория, я получил желаемое выходное изображение, чего я и пытаюсь достичь. Кажется, проблема в том, как я предварительно обрабатываю/нормализую/постобрабатываю изображения. Вот мой код:

Код: Выделить всё

opt = TrainOpt()

# Initialize generator
netG = networks_eg.define_G(input_nc=3, output_nc=3, ngf=64, which_model_netG='sid_unet_resize', norm='instance', skip=True, opt=opt)

# Load your pretrained weights
state_dict = torch.load('model/200_net_G_A.pth', map_location='cpu')

# Remove 'module.' prefix if present
new_state_dict = OrderedDict()
for k, v in state_dict.items():
if k.startswith('module.'):
k = k.replace('module.', '', 1)
new_state_dict[k] = v

netG.load_state_dict(new_state_dict)
netG.eval()

def tensor2im(image_tensor, imtype=np.uint8):
image_numpy = image_tensor[0].cpu().float().numpy()
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
image_numpy = np.maximum(image_numpy, 0)
image_numpy = np.minimum(image_numpy, 255)
return image_numpy.astype(imtype)

def enhance_image(input_image):

transform = transforms.ToTensor()
inverse_transform = transforms.ToPILImage()
# Convert input image to tensor
input_tensor = transform(input_image).unsqueeze(0)
# If the model was trained with im in [-1, 1], normalize:
input_tensor = input_tensor * 2 - 1

# Convert grayscale image to tensor
gray_image = input_image.convert('L')
gray_tensor = transform(gray_image).unsqueeze(0)

real_A = Variable(input_tensor, volatile=True)
real_A_gray = Variable(gray_tensor, volatile=True)

with torch.no_grad():
fake_B, latent_real_A = netG(real_A, real_A_gray)

print("Raw output min:", fake_B.min().item())
print("Raw output max:", fake_B.max().item())

fake_B = inverse_transform(tensor2im(fake_B.data))

return fake_B
Функция tensor2im взята из исходного репозитория. Что я здесь делаю не так? Как нормализовать изображение?

Подробнее здесь: https://stackoverflow.com/questions/798 ... lightengan
Ответить

Быстрый ответ

Изменение регистра текста: 
Смайлики
:) :( :oops: :roll: :wink: :muza: :clever: :sorry: :angel: :read: *x)
Ещё смайлики…
   
К этому ответу прикреплено по крайней мере одно вложение.

Если вы не хотите добавлять вложения, оставьте поля пустыми.

Максимально разрешённый размер вложения: 15 МБ.

Вернуться в «Python»