Как улучшить качество реконструкции с помощью VAE?Python

Программы на Python
Ответить
Anonymous
 Как улучшить качество реконструкции с помощью VAE?

Сообщение Anonymous »

Я обучаю архитектуре VAE на изображениях микроскопа. Набор данных из 1000 обучающих изображений, 253 тестовых изображений. Размер изображений изменяется до исходного разрешения 128x128 или до 256x256 по сравнению с исходным разрешением, которое составляет около 1024x720. Здесь реализация представляет собой ввод размером 256x256. Затем я пропустил их через свой VAE:

Код: Выделить всё

class SEMVAE(nn.Module):
def __init__(self, latent_dim=64):
super().__init__()
# ----------------- Encoder -----------------
self.enc_conv1 = nn.Conv2d(1, 32, 3, stride=2, padding=1)   # 256 -> 128
self.enc_bn1   = nn.BatchNorm2d(32)
self.enc_conv2 = nn.Conv2d(32, 64, 3, stride=2, padding=1)  # 128 -> 64
self.enc_bn2   = nn.BatchNorm2d(64)
self.enc_conv3 = nn.Conv2d(64, 128, 3, stride=2, padding=1) # 64 -> 32
self.enc_bn3   = nn.BatchNorm2d(128)
self.enc_conv4 = nn.Conv2d(128, 128, 3, stride=2, padding=1)# 32 -> 16
self.enc_bn4   = nn.BatchNorm2d(128)

self.dropout  = nn.Dropout(0.05)
self.flatten  = nn.Flatten()
self.fc_mu    = nn.Linear(128*16*16, latent_dim)
self.fc_logvar= nn.Linear(128*16*16, latent_dim)
self.fc_dec   = nn.Linear(latent_dim, 128*16*16)

# ----------------- Decoder -----------------
self.dec_deconv1 = nn.ConvTranspose2d(128, 128, 3, stride=2, padding=1, output_padding=1) # 16 -> 32
self.dec_bn1     = nn.BatchNorm2d(128)
self.dec_deconv2 = nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1)  # 32 -> 64
self.dec_bn2     = nn.BatchNorm2d(64)
self.dec_deconv3 = nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1)   # 64 -> 128
self.dec_bn3     = nn.BatchNorm2d(32)
self.dec_deconv4 = nn.ConvTranspose2d(32, 1, 3, stride=2, padding=1, output_padding=1)    # 128 ->  256

def encode(self, x):
x = F.leaky_relu(self.enc_bn1(self.enc_conv1(x)), 0.1)
x = self.dropout(x)
x = F.leaky_relu(self.enc_bn2(self.enc_conv2(x)), 0.1)
x = self.dropout(x)
x = F.leaky_relu(self.enc_bn3(self.enc_conv3(x)), 0.1)
x = self.dropout(x)
x = F.leaky_relu(self.enc_bn4(self.enc_conv4(x)), 0.1)
x = self.dropout(x)
x = self.flatten(x)
mu = self.fc_mu(x)
logvar = self.fc_logvar(x)
return mu, logvar

def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std

def decode(self, z):
x = self.fc_dec(z).view(-1, 128, 16, 16)
x = F.leaky_relu(self.dec_bn1(self.dec_deconv1(x)), 0.1)
x = F.leaky_relu(self.dec_bn2(self.dec_deconv2(x)), 0.1)
x = F.leaky_relu(self.dec_bn3(self.dec_deconv3(x)), 0.1)
x = torch.sigmoid(self.dec_deconv4(x))  # [0,1] for BCE
return x

def forward(self, x):
mu, logvar = self.encode(x)
z = self.reparameterize(mu, logvar)
recon = self.decode(z)
return recon, mu, logvar
Я использую базовые потери KL и потери при реконструкции в своей функции потерь, показанной ниже:

Код: Выделить всё

class VAELoss(nn.Module):
def __init__(self, recon_type = 'mse', beta = 1.0):
super().__init__()
self.recon_type = recon_type
self.beta = beta

def forward(self, recon, x, mu, logvar):
## Reconstruction Loss ##
if self.recon_type == 'mse':
recon_loss = F.mse_loss(recon, x, reduction='sum')
elif self.recon_type == 'bce':
recon_loss = F.binary_cross_entropy(recon, x, reduction='sum')
elif self.recon_type == 'l1':
recon_loss = F.smooth_l1_loss(recon, x, reduction='sum')
else:
raise ValueError("recon_type must be bce, mse, or l1")

## KL Divergence ##
kl_loss = - 0.5 * torch.sum(1+ logvar - mu.pow(2) - logvar.exp())

## Loss ##
loss = recon_loss + self.beta * kl_loss

return loss, recon_loss, kl_loss, self.beta * kl_loss
Вот цикл обучения, который я использую:

Код: Выделить всё

def train_vae(model, train_loader, test_loader, epochs=50, lr=1e-3, recon_type='mse', beta = 1, device = 'cuda'):
# Loading Model and Optimizer
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

# VAELoss
VAELoss_Calculator = VAELoss(recon_type = recon_type, beta = beta)

# Tracking Loss
train_loss_history, test_loss_history = [], []
train_recon_loss, train_kl_loss, train_weighted_kl_loss = [], [], []

# Training Loop
for epoch in range(epochs):
model.train()
total_epoch_loss = 0
total_epoch_kl_loss = 0
total_epoch_recon_loss = 0
total_epoch_weighted_kl_loss = 0

for imgs in train_loader:
batch_size = imgs.size(0)
imgs = imgs.to(device)
optimizer.zero_grad()
recon, mu, logvar = model(imgs)
loss, recon_loss, kl_loss, weighted_kl_loss = VAELoss_Calculator(recon, imgs, mu, logvar)

total_epoch_loss += loss.item()
total_epoch_recon_loss += recon_loss.item()
total_epoch_kl_loss += kl_loss.item()
total_epoch_weighted_kl_loss += weighted_kl_loss.item()

loss.backward()
optimizer.step()

n = len(train_loader.dataset)
avg_train_loss = total_epoch_loss / n
avg_recon_loss = total_epoch_recon_loss / n
avg_kl_loss = total_epoch_kl_loss / n
avg_weighted_kl_loss = total_epoch_weighted_kl_loss / n

train_loss_history.append(avg_train_loss)
train_recon_loss.append(avg_recon_loss)
train_kl_loss.append(avg_kl_loss)
train_weighted_kl_loss.append(avg_weighted_kl_loss)

# Model Evaluation
model.eval()
test_loss = 0
with torch.no_grad():
for test_imgs in test_loader:
test_batch_size = test_imgs.size(0)
test_imgs = test_imgs.to(device)
test_recon, mu, logvar = model(test_imgs)
loss, _, _, _ = VAELoss_Calculator(test_recon, test_imgs, mu, logvar)
test_loss += loss.item()
n_test = len(test_loader.dataset)
avg_test_loss = test_loss / n_test
test_loss_history.append(avg_test_loss)

print(f"Epoch [{epoch+1}/{epochs}]  "
f"Train ELBO: {avg_train_loss:.4f} | Recon: {avg_recon_loss:.4f} | "
f"KL: {avg_kl_loss:.4f} | KL (Weighted):  {avg_weighted_kl_loss:.4f} | Test ELBO: {avg_test_loss:.4f}")

# ---- Plot all losses ----
plt.figure(figsize=(9, 6))
plt.plot(train_loss_history, label='Train Total (ELBO)')
plt.plot(train_recon_loss, label='Train Reconstruction')
plt.plot(train_kl_loss, label='Train KL Divergence')
plt.plot(train_weighted_kl_loss, label='Train KL Divergence (Weighted)')
plt.plot(test_loss_history, label='Test Total (ELBO)')
# plt.plot(train_grad_losses, label='Gradient Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.title('VAE Training Loss Components')
plt.show()
После использования реконструкции «BCE», размера пакета 32, скорости обучения 1e-4, значения бета 1, скрытого измерения 64 и обучения в течение 150 эпох я получаю следующие результаты:

Код: Выделить всё

Epoch [1/100]  Train ELBO: 43814.1424 | Recon: 43757.6568 | KL: 564.8580 | KL (Weighted): 56.4858 | Test ELBO: 40342.6853
Epoch [2/100]  Train ELBO: 38939.3434 | Recon: 38859.7775 | KL: 795.6566 | KL (Weighted): 79.5657 | Test ELBO: 37637.1522
Epoch [3/100]  Train ELBO: 37398.1756 | Recon: 37349.1395 | KL: 490.3624 | KL (Weighted): 49.0362 | Test ELBO: 37155.4560
Epoch [4/100]  Train ELBO: 36790.2883 | Recon: 36751.9760 | KL: 383.1231 | KL (Weighted): 38.3123 | Test ELBO: 37103.1546
Epoch [5/100]  Train ELBO: 36300.4348 | Recon: 36265.6324 | KL: 348.0279 | KL (Weighted): 34.8028 | Test ELBO: 36329.0702
Epoch [6/100]  Train ELBO: 35811.1598 | Recon: 35778.0838 | KL: 330.7587 | KL (Weighted): 33.0759 | Test ELBO: 36197.2446
Epoch [7/100]  Train ELBO: 35541.3350 | Recon: 35506.4722 | KL: 348.6268 | KL (Weighted): 34.8627 | Test ELBO: 35781.9881
Epoch [8/100]  Train ELBO: 35109.2664 | Recon: 35070.1327 | KL: 391.3363 | KL (Weighted): 39.1336 | Test ELBO: 35419.9568
Epoch [9/100]  Train ELBO: 34904.9988 | Recon: 34865.2861 | KL: 397.1254 | KL (Weighted): 39.7125 | Test ELBO: 35368.0089
Epoch [10/100]  Train ELBO: 34677.8689 | Recon: 34636.5064 | KL: 413.6217 | KL (Weighted): 41.3622 | Test ELBO: 35279.6201
Epoch [11/100]  Train ELBO: 34567.9672 | Recon: 34524.4573 | KL: 435.0997 | KL (Weighted): 43.5100 | Test ELBO: 35029.4027
Epoch [12/100]  Train ELBO: 34412.6462 | Recon: 34365.4724 | KL: 471.7382 | KL (Weighted): 47.1738 | Test ELBO: 34983.6670
Epoch [13/100]  Train ELBO: 34222.2194 | Recon: 34172.9446 | KL: 492.7466 | KL (Weighted): 49.2747 | Test ELBO: 34911.6089
Epoch [14/100]  Train ELBO: 34152.5446 | Recon: 34101.2380 | KL: 513.0639 | KL (Weighted): 51.3064 | Test ELBO: 34925.7342
Epoch [15/100]  Train ELBO: 34086.8908 | Recon: 34034.6688 | KL: 522.2204 | KL (Weighted): 52.2220 | Test ELBO: 34980.5862
Epoch [16/100]  Train ELBO: 33968.7971 | Recon: 33914.5703 | KL: 542.2691 | KL (Weighted): 54.2269 | Test ELBO: 34592.5254
Epoch [17/100]  Train ELBO: 33861.9486 | Recon: 33812.2085 | KL: 497.4020 | KL (Weighted): 49.7402 | Test ELBO: 34556.8503
Epoch [18/100]  Train ELBO: 33706.1630 | Recon: 33656.8590 | KL: 493.0414 | KL (Weighted): 49.3041 | Test ELBO: 34511.0188
Epoch [19/100]  Train ELBO: 33802.2230 | Recon: 33750.2045 | KL: 520.1836 | KL (Weighted): 52.0184 | Test ELBO: 34526.2105
Epoch [20/100]  Train ELBO: 33730.4231 | Recon: 33680.2796 | KL: 501.4360 | KL (Weighted):

....

Epoch [85/100]  Train ELBO: 32994.7386 | Recon: 32951.0619 | KL: 436.7674 | KL (Weighted): 43.6767 | Test ELBO: 34010.8639
Epoch [86/100]  Train ELBO: 32933.0684 | Recon: 32889.8569 | KL: 432.1167 | KL (Weighted): 43.2117 | Test ELBO: 34042.3177
Epoch [87/100]  Train ELBO: 32934.8738 | Recon: 32893.2703 | KL: 416.0365 | KL (Weighted): 41.6036 | Test ELBO: 33952.0976
Epoch [88/100]  Train ELBO: 32993.6404 | Recon: 32950.2150 | KL: 434.2497 | KL (Weighted): 43.4250 | Test ELBO: 33978.0052
Epoch [89/100]  Train ELBO: 32914.6798 | Recon: 32871.9780 | KL: 427.0166 | KL (Weighted): 42.7017 | Test ELBO: 33968.7740
Epoch [90/100]  Train ELBO: 32886.1937 | Recon: 32844.8323 | KL: 413.6142 | KL (Weighted): 41.3614 | Test ELBO: 34007.8426
Epoch [91/100]  Train ELBO: 32932.2369 | Recon: 32890.9989 | KL: 412.3811 | KL (Weighted): 41.2381 | Test ELBO: 33970.2540
Epoch [92/100]  Train ELBO: 32997.2572 | Recon: 32955.8671 | KL: 413.9017 | KL (Weighted): 41.3902 | Test ELBO: 33961.6801
Epoch [93/100]  Train ELBO: 32911.2220 | Recon: 32868.1441 | KL: 430.7813 | KL (Weighted): 43.0781 | Test ELBO: 34012.0193
Epoch [94/100]  Train ELBO: 32875.5987 | Recon: 32833.4682 | KL: 421.3034 | KL (Weighted): 42.1303 | Test ELBO: 33971.3513
Epoch [95/100]  Train ELBO: 32991.2952 | Recon: 32950.2401 | KL: 410.5496 | KL (Weighted): 41.0550 | Test ELBO: 33978.4424
Epoch [96/100]  Train ELBO: 32938.0615 | Recon: 32893.7618 | KL: 442.9997 | KL (Weighted): 44.3000 | Test ELBO: 33971.5015
Epoch [97/100]  Train ELBO: 32930.5121 | Recon:  32888.0903 | KL: 424.2162 | KL (Weighted): 42.4216 | Test ELBO: 34008.6855
Epoch [98/100]  Train ELBO: 32919.0241 | Recon: 32876.8566 | KL: 421.6755 | KL (Weighted): 42.1675 | Test ELBO: 34012.3174
Epoch [99/100]  Train ELBO: 32906.9627 | Recon: 32866.2580 | KL: 407.0465 | KL (Weighted): 40.7046 | Test ELBO: 34066.2537
Epoch [100/100]  Train ELBO: 32921.0352 | Recon: 32879.3460 | KL: 416.8920 | KL (Weighted): 41.6892 | Test ELBO: 34039.4429
И окончательное качество реконструкции следующее:
Изображение

Верхний ряд показывает исходное изображение, а нижний – реконструкцию. Точность реконструкции низкая и не отражает окончательных деталей. Скорее всего, на изображениях попадают более крупные объекты, но качество очень плохое и размытое. Реконструкции кажутся шумными и не всегда обеспечивают правильный контраст и правильный уровень интенсивности пикселей в изображениях.
Как мне улучшить качество и точность реконструкции, достаточно сильные для более мелких деталей моих изображений после обучения моего VAE?

Подробнее здесь: https://stackoverflow.com/questions/798 ... y-with-vae
Ответить

Быстрый ответ

Изменение регистра текста: 
Смайлики
:) :( :oops: :roll: :wink: :muza: :clever: :sorry: :angel: :read: *x)
Ещё смайлики…
   
К этому ответу прикреплено по крайней мере одно вложение.

Если вы не хотите добавлять вложения, оставьте поля пустыми.

Максимально разрешённый размер вложения: 15 МБ.

Вернуться в «Python»