Не могу обучить свою модель многоклассовой сегментации UNET [дублируется]Python

Программы на Python
Ответить
Anonymous
 Не могу обучить свою модель многоклассовой сегментации UNET [дублируется]

Сообщение Anonymous »

Я пытался создать UNET с нуля с помощью pytorch. На выходе моей модели я не получил ничего, кроме черных масок. Мне нужно сегментировать повреждения автомобилей, поэтому я реализовал карту цветов. Я уверен на 70%, что с моим набором данных что-то не так именно с этой цветовой картой. Задача — мультиклассовое предсказание, поэтому я использую функцию перекрестной энтропии. Я предоставлю код своего набора данных и файлы обучения.
# dataset.py
import os
from PIL import Image
from torch.utils.data import Dataset
import numpy as np
import torch

class Segm_Dataset(Dataset):
def __init__(self, image_dir, mask_dir, color_map):
self.image_dir = image_dir
self.mask_dir = mask_dir
self.image_files = os.listdir(self.image_dir)
self.mask_files = os.listdir(self.mask_dir)
self.color_map = color_map

def __len__(self):
return len(self.image_files)

def __getitem__(self, idx):
image_path = os.path.join(self.image_dir, self.image_files[idx])
mask_path = os.path.join(self.mask_dir, self.mask_files[idx])
image = np.array(Image.open(image_path).convert('RGB'))
mask = np.array(Image.open(mask_path).convert('RGB'), dtype=np.float32)
label_mask = np.zeros((mask.shape[0], mask.shape[1]), dtype=np.int64)

for color, label in self.color_map.items():
color_array = np.array(color, dtype=np.float32)
mask_area = np.all(mask == color_array, axis=-1)
label_mask[mask_area] = label

image = torch.tensor(image, dtype=torch.float32).permute(2, 0, 1)
label_mask = torch.tensor(label_mask, dtype=torch.long)

return image, label_mask

# train.py
from model import UNET
from tqdm import tqdm
from dataset import Segm_Dataset
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
import os

LEARNING_RATE = 1e-4
BATCH_SIZE = 5
NUM_EPOCHS = 10
NUM_WORKERS = 2
IMAGE_HEIGHT = 180
IMAGE_WIDTH = 180
PIN_MEMORY = True
LOAD_MODEL = False
TRAIN_IMG_DIR = r'data\train\images'
TRAIN_MASK_DIR = r'data\train\masks'
VAL_IMG_DIR = r'data\val\images'
VAL_MASK_DIR = r'data\val\masks'
SAVED_MODELS_PATH = r'saved_models'

color_map = {
(19, 164, 201): 0, # Missing part: #13A4C9
(166, 255, 71): 1, # Broken part: #A6FF47
(180, 45, 56): 2, # Scratch: #B42D38
(225, 150, 96): 3, # Cracked: #E19660
(144, 60, 89): 4, # Dent: #903C59
(167, 116, 27): 5, # Flaking: #A7741B
(180, 14, 19): 6, # Paint chip: #B40E13
(115, 194, 206): 7, # Corrosion: #73C2CE
}

train_dataset = Segm_Dataset(TRAIN_IMG_DIR, TRAIN_MASK_DIR, color_map)
train_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)

val_dataset = Segm_Dataset(VAL_IMG_DIR, VAL_MASK_DIR, color_map)
val_loader = DataLoader(dataset=val_dataset, batch_size=BATCH_SIZE)

model = UNET(in_channels=3, out_channels=len(color_map))
model = model.cuda() if torch.cuda.is_available() else model

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

for epoch in range(NUM_EPOCHS):
train_loop = tqdm(enumerate(train_loader), total=len(train_loader))

for batch_index, (data, targets) in train_loop:
#Forward pass
scores = model(data)
train_loss = criterion(scores, targets)

#Backward pass
optimizer.zero_grad()
train_loss.backward()

#Gradient descent or optimizer step
optimizer.step()

if batch_index % 10 == 0:
current_batch = batch_index
val_loss = 0
with torch.no_grad():
for val_data, val_targets in val_loader:
val_scores = model(val_data)
val_loss = criterion(val_scores, val_targets)

#Update progress bar
train_loop.set_description(f'Epoch: [{epoch+1}/{NUM_EPOCHS}]')
train_loop.set_postfix(train_loss=train_loss.item(), val_loss=val_loss.item(), val_batch=current_batch)

else:
train_loop.set_description(f'Epoch: [{epoch+1}/{NUM_EPOCHS}]')
train_loop.set_postfix(train_loss=train_loss.item(), val_loss=val_loss.item(), val_batch=current_batch)

checkpoint = {
'epoch': epoch + 1,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'train_loss': train_loss.item(),
'val_loss': val_loss.item()
}

torch.save(checkpoint, os.path.join(SAVED_MODELS_PATH, f'unet_epoch_{epoch}.pth'))

Некоторые эпохи обучения:
Epoch: [9/10]: 100%|██████████████████| 888/888 [34:24

Подробнее здесь: https://stackoverflow.com/questions/791 ... duplicated
Ответить

Быстрый ответ

Изменение регистра текста: 
Смайлики
:) :( :oops: :roll: :wink: :muza: :clever: :sorry: :angel: :read: *x)
Ещё смайлики…
   
К этому ответу прикреплено по крайней мере одно вложение.

Если вы не хотите добавлять вложения, оставьте поля пустыми.

Максимально разрешённый размер вложения: 15 МБ.

Вернуться в «Python»