Код: Выделить всё
class SEMVAE(nn.Module):
def __init__(self, latent_dim=64):
super().__init__()
# ----------------- Encoder -----------------
self.enc_conv1 = nn.Conv2d(1, 32, 3, stride=2, padding=1) # 256 -> 128
self.enc_bn1 = nn.BatchNorm2d(32)
self.enc_conv2 = nn.Conv2d(32, 64, 3, stride=2, padding=1) # 128 -> 64
self.enc_bn2 = nn.BatchNorm2d(64)
self.enc_conv3 = nn.Conv2d(64, 128, 3, stride=2, padding=1) # 64 -> 32
self.enc_bn3 = nn.BatchNorm2d(128)
self.enc_conv4 = nn.Conv2d(128, 128, 3, stride=2, padding=1)# 32 -> 16
self.enc_bn4 = nn.BatchNorm2d(128)
self.dropout = nn.Dropout(0.05)
self.flatten = nn.Flatten()
self.fc_mu = nn.Linear(128*16*16, latent_dim)
self.fc_logvar= nn.Linear(128*16*16, latent_dim)
self.fc_dec = nn.Linear(latent_dim, 128*16*16)
# ----------------- Decoder -----------------
self.dec_deconv1 = nn.ConvTranspose2d(128, 128, 3, stride=2, padding=1, output_padding=1) # 16 -> 32
self.dec_bn1 = nn.BatchNorm2d(128)
self.dec_deconv2 = nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1) # 32 -> 64
self.dec_bn2 = nn.BatchNorm2d(64)
self.dec_deconv3 = nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1) # 64 -> 128
self.dec_bn3 = nn.BatchNorm2d(32)
self.dec_deconv4 = nn.ConvTranspose2d(32, 1, 3, stride=2, padding=1, output_padding=1) # 128 -> 256
def encode(self, x):
x = F.leaky_relu(self.enc_bn1(self.enc_conv1(x)), 0.1)
x = self.dropout(x)
x = F.leaky_relu(self.enc_bn2(self.enc_conv2(x)), 0.1)
x = self.dropout(x)
x = F.leaky_relu(self.enc_bn3(self.enc_conv3(x)), 0.1)
x = self.dropout(x)
x = F.leaky_relu(self.enc_bn4(self.enc_conv4(x)), 0.1)
x = self.dropout(x)
x = self.flatten(x)
mu = self.fc_mu(x)
logvar = self.fc_logvar(x)
return mu, logvar
def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def decode(self, z):
x = self.fc_dec(z).view(-1, 128, 16, 16)
x = F.leaky_relu(self.dec_bn1(self.dec_deconv1(x)), 0.1)
x = F.leaky_relu(self.dec_bn2(self.dec_deconv2(x)), 0.1)
x = F.leaky_relu(self.dec_bn3(self.dec_deconv3(x)), 0.1)
x = torch.sigmoid(self.dec_deconv4(x)) # [0,1] for BCE
return x
def forward(self, x):
mu, logvar = self.encode(x)
z = self.reparameterize(mu, logvar)
recon = self.decode(z)
return recon, mu, logvar
Код: Выделить всё
class VAELoss(nn.Module):
def __init__(self, recon_type = 'mse', beta = 1.0):
super().__init__()
self.recon_type = recon_type
self.beta = beta
def forward(self, recon, x, mu, logvar):
## Reconstruction Loss ##
if self.recon_type == 'mse':
recon_loss = F.mse_loss(recon, x, reduction='sum')
elif self.recon_type == 'bce':
recon_loss = F.binary_cross_entropy(recon, x, reduction='sum')
elif self.recon_type == 'l1':
recon_loss = F.smooth_l1_loss(recon, x, reduction='sum')
else:
raise ValueError("recon_type must be bce, mse, or l1")
## KL Divergence ##
kl_loss = - 0.5 * torch.sum(1+ logvar - mu.pow(2) - logvar.exp())
## Loss ##
loss = recon_loss + self.beta * kl_loss
return loss, recon_loss, kl_loss, self.beta * kl_loss
Код: Выделить всё
def train_vae(model, train_loader, test_loader, epochs=50, lr=1e-3, recon_type='mse', beta = 1, device = 'cuda'):
# Loading Model and Optimizer
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
# VAELoss
VAELoss_Calculator = VAELoss(recon_type = recon_type, beta = beta)
# Tracking Loss
train_loss_history, test_loss_history = [], []
train_recon_loss, train_kl_loss, train_weighted_kl_loss = [], [], []
# Training Loop
for epoch in range(epochs):
model.train()
total_epoch_loss = 0
total_epoch_kl_loss = 0
total_epoch_recon_loss = 0
total_epoch_weighted_kl_loss = 0
for imgs in train_loader:
batch_size = imgs.size(0)
imgs = imgs.to(device)
optimizer.zero_grad()
recon, mu, logvar = model(imgs)
loss, recon_loss, kl_loss, weighted_kl_loss = VAELoss_Calculator(recon, imgs, mu, logvar)
total_epoch_loss += loss.item()
total_epoch_recon_loss += recon_loss.item()
total_epoch_kl_loss += kl_loss.item()
total_epoch_weighted_kl_loss += weighted_kl_loss.item()
loss.backward()
optimizer.step()
n = len(train_loader.dataset)
avg_train_loss = total_epoch_loss / n
avg_recon_loss = total_epoch_recon_loss / n
avg_kl_loss = total_epoch_kl_loss / n
avg_weighted_kl_loss = total_epoch_weighted_kl_loss / n
train_loss_history.append(avg_train_loss)
train_recon_loss.append(avg_recon_loss)
train_kl_loss.append(avg_kl_loss)
train_weighted_kl_loss.append(avg_weighted_kl_loss)
# Model Evaluation
model.eval()
test_loss = 0
with torch.no_grad():
for test_imgs in test_loader:
test_batch_size = test_imgs.size(0)
test_imgs = test_imgs.to(device)
test_recon, mu, logvar = model(test_imgs)
loss, _, _, _ = VAELoss_Calculator(test_recon, test_imgs, mu, logvar)
test_loss += loss.item()
n_test = len(test_loader.dataset)
avg_test_loss = test_loss / n_test
test_loss_history.append(avg_test_loss)
print(f"Epoch [{epoch+1}/{epochs}] "
f"Train ELBO: {avg_train_loss:.4f} | Recon: {avg_recon_loss:.4f} | "
f"KL: {avg_kl_loss:.4f} | KL (Weighted): {avg_weighted_kl_loss:.4f} | Test ELBO: {avg_test_loss:.4f}")
# ---- Plot all losses ----
plt.figure(figsize=(9, 6))
plt.plot(train_loss_history, label='Train Total (ELBO)')
plt.plot(train_recon_loss, label='Train Reconstruction')
plt.plot(train_kl_loss, label='Train KL Divergence')
plt.plot(train_weighted_kl_loss, label='Train KL Divergence (Weighted)')
plt.plot(test_loss_history, label='Test Total (ELBO)')
# plt.plot(train_grad_losses, label='Gradient Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.title('VAE Training Loss Components')
plt.show()
Эпоха [1/100] Train ELBO: 43814.1424 | Реконструкция: 43757.6568 | КЛ: 564.8580 | КЛ (взвешенный): 56,4858 | Тестовый ELBO: 40342.6853
Эпоха [2/100] Train ELBO: 38939.3434 | Разведка: 38859.7775 | КЛ: 795.6566 | КЛ (взвешенный): 79,5657 | Тестовый ELBO: 37637.1522
Эпоха [3/100] Train ELBO: 37398.1756 | Реконструкция: 37349.1395 | КЛ: 490.3624 | КЛ (взвешенный): 49,0362 | Тестовый ELBO: 37155.4560
Эпоха [4/100] Train ELBO: 36790.2883 | Реконструкция: 36751.9760 | КЛ: 383.1231 | КЛ (взвешенный): 38,3123 | Тестовый ELBO: 37103.1546
Эпоха [5/100] Train ELBO: 36300.4348 | Реконструкция: 36265.6324 | КЛ: 348.0279 | КЛ (взвешенный): 34,8028 | Тестовый ELBO: 36329.0702
Эпоха [6/100] Train ELBO: 35811.1598 | Реконструкция: 35778.0838 | КЛ: 330.7587 | КЛ (взвешенный): 33,0759 | Тестовый ELBO: 36197.2446
Эпоха [7/100] Train ELBO: 35541.3350 | Реконструкция: 35506.4722 | КЛ: 348.6268 | КЛ (взвешенный): 34,8627 | Тестовый ELBO: 35781.9881
Эпоха [8/100] Train ELBO: 35109.2664 | Реконструкция: 35070.1327 | КЛ: 391.3363 | КЛ (взвешенный): 39,1336 | Тестовый ELBO: 35419.9568
Эпоха [9/100] Train ELBO: 34904.9988 | Реконструкция: 34865.2861 | КЛ: 397.1254 | КЛ (взвешенный): 39,7125 | Тестовый ELBO: 35368.0089
Эпоха [10/100] Train ELBO: 34677.8689 | Реконструкция: 34636.5064 | КЛ: 413.6217 | КЛ (взвешенный): 41,3622 | Тестовый ELBO: 35279.6201
Эпоха [11/100] Train ELBO: 34567.9672 | Реконструкция: 34524.4573 | КЛ: 435.0997 | КЛ (взвешенный): 43,5100 | Тестовый ELBO: 35029.4027
Эпоха [12/100] Train ELBO: 34412.6462 | Реконструкция: 34365.4724 | КЛ: 471.7382 | КЛ (взвешенный): 47,1738 | Тестовый ELBO: 34983.6670
Эпоха [13/100] Train ELBO: 34222.2194 | Реконструкция: 34172.9446 | КЛ: 492.7466 | КЛ (взвешенный): 49,2747 | Тестовый ELBO: 34911.6089
Эпоха [14/100] Train ELBO: 34152.5446 | Реконструкция: 34101.2380 | КЛ: 513.0639 | КЛ (взвешенный): 51,3064 | Тестовый ELBO: 34925.7342
Эпоха [15/100] Train ELBO: 34086.8908 | Реконструкция: 34034.6688 | КЛ: 522.2204 | КЛ (взвешенный): 52,2220 | Тестовый ELBO: 34980.5862
Эпоха [16/100] Train ELBO: 33968.7971 | Реконструкция: 33914.5703 | КЛ: 542.2691 | КЛ (взвешенный): 54,2269 | Тестовый ELBO: 34592.5254
Эпоха [17/100] Train ELBO: 33861.9486 | Реконструкция: 33812.2085 | КЛ: 497.4020 | КЛ (взвешенный): 49,7402 | Тестовый ЭЛЬБО: 34556.8503
Эпоха [18/100] Поезд ЭЛЬБО: 33706.1630 | Реконструкция: 33656.8590 | КЛ: 493.0414 | КЛ (взвешенный): 49,3041 | Тестовый ELBO: 34511.0188
Эпоха [19/100] Train ELBO: 33802.2230 | Реконструкция: 33750.2045 | КЛ: 520.1836 | КЛ (взвешенный): 52,0184 | Тестовый ЭЛБО: 34526.2105
Эпоха [20/100] Поезд ЭЛЬБО: 33730.4231 | Реконструкция: 33680.2796 | КЛ: 501.4360 | КЛ (взвешенный):
....
Эпоха [85/100] Поезд ЭЛЬБО: 32994.7386 | Реконструкция: 32951.0619 | КЛ: 436.7674 | КЛ (взвешенный): 43,6767 | Тестовый ELBO: 34010.8639
Эпоха [86/100] Train ELBO: 32933.0684 | Разведка: 32889.8569 | КЛ: 432.1167 | КЛ (взвешенный): 43,2117 | Тестовый ELBO: 34042.3177
Эпоха [87/100] Train ELBO: 32934.8738 | Реконструкция: 32893.2703 | КЛ: 416.0365 | КЛ (взвешенный): 41,6036 | Тестовый ELBO: 33952.0976
Эпоха [88/100] Train ELBO: 32993.6404 | Реконструкция: 32950.2150 | КЛ: 434.2497 | КЛ (взвешенный): 43,4250 | Тестовый ELBO: 33978.0052
Эпоха [89/100] Train ELBO: 32914.6798 | Разведка: 32871.9780 | КЛ: 427.0166 | КЛ (взвешенный): 42,7017 | Тестовый ЭЛЬБО: 33968.7740
Эпоха [90/100] Поезд ЭЛЬБО: 32886.1937 | Реконструкция: 32844.8323 | КЛ: 413.6142 | КЛ (взвешенный): 41,3614 | Тестовый ELBO: 34007.8426
Эпоха [91/100] Train ELBO: 32932.2369 | Реконструкция: 32890.9989 | КЛ: 412.3811 | КЛ (взвешенный): 41,2381 | Тестовый ELBO: 33970.2540
Эпоха [92/100] Train ELBO: 32997.2572 | Реконструкция: 32955.8671 | КЛ: 413.9017 | КЛ (взвешенный): 41,3902 | Тестовый ELBO: 33961.6801
Эпоха [93/100] Train ELBO: 32911.2220 | Реконструкция: 32868.1441 | КЛ: 430.7813 | КЛ (взвешенный): 43,0781 | Тестовый ELBO: 34012.0193
Эпоха [94/100] Train ELBO: 32875.5987 | Реконструкция: 32833.4682 | КЛ: 421.3034 | КЛ (взвешенный): 42,1303 | Тестовый ЭЛЬБО: 33971.3513
Эпоха [95/100] Поезд ЭЛЬБО: 32991.2952 | Реконструкция: 32950.2401 | КЛ: 410.5496 | KL (взвешенный): 41,0550 | Тестовый ELBO: 33978.4424
Эпоха [96/100] Train ELBO: 32938.0615 | Реконструкция: 32893.7618 | КЛ: 442.9997 | КЛ (взвешенный): 44,3000 | Тестовый ELBO: 33971.5015
Эпоха [97/100] Train ELBO: 32930.5121 | Реконструкция: 32888.0903 | КЛ: 424.2162 | КЛ (взвешенный): 42,4216 | Тестовый ELBO: 34008.6855
Эпоха [98/100] Train ELBO: 32919.0241 | Реконструкция: 32876.8566 | КЛ: 421.6755 | КЛ (взвешенный): 42,1675 | Тестовый ELBO: 34012.3174
Эпоха [99/100] Train ELBO: 32906.9627 | Реконструкция: 32866.2580 | КЛ: 407.0465 | KL (взвешенный): 40,7046 | Тестовый ELBO: 34066.2537
Эпоха [100/100] Train ELBO: 32921.0352 | Разведка: 32879.3460 | КЛ: 416.8920 | КЛ (взвешенный): 41,6892 | Тест ELBO: 34039.4429
И окончательное качество реконструкции следующее:

Верхний ряд показывает исходное изображение, а нижний — реконструкцию. Точность реконструкции низкая и не отражает окончательных деталей. Скорее всего, на изображениях попадают более крупные объекты, но качество очень плохое и размытое. Реконструкции кажутся шумными и не всегда обеспечивают правильный контраст и правильный уровень интенсивности пикселей в изображениях.
Как мне улучшить качество и точность реконструкции, достаточно сильные для более мелких деталей моих изображений после обучения моего VAE?
Подробнее здесь: https://stackoverflow.com/questions/798 ... y-with-vae
Мобильная версия