Ошибка во входных тензорах для тестирования 4D-модальностей в наборе данных BRATS.Python

Программы на Python
Ответить
Anonymous
 Ошибка во входных тензорах для тестирования 4D-модальностей в наборе данных BRATS.

Сообщение Anonymous »

Я работаю над набором данных BRATS-2021 и pytorch с моделью U-net. У меня проблема с функцией eval_fn, я не могу начать тестирование набора данных проверки, потому что у меня есть ошибки в размерах входных тензоров. У меня есть 2 тензора для обучения (с изображением в 4 каналах модальностей МРТ - t1,t2,flair,t2ce и тензор сегментации с 1 каналом в trainloader, но в testloader у меня есть только изображение для тестирования без каких-либо масок, каналы сегментации и получить ошибку.
Моя функция поезда:

Код: Выделить всё

def train_fn(dataloader, model, optimizer):
model.train()
total_loss = 0.0
total_dice_loss = 0.0
tp_list, fp_list, fn_list, tn_list = [], [], [], []
hd95_list, dsc_list = [], []

for batch in tqdm(dataloader):
images, masks = batch
images, masks = images.to(DEVICE), masks.to(DEVICE)
# logits, loss = model(images, masks)
buffer = model(images, masks)
if len(buffer) > 1:
logits, loss = buffer
else:
logits = buffer
loss = nn.BCEWithLogitsLoss()(logits, masks)
# print(buffer)

dice_loss = DiceLoss(mode="binary")(logits, masks)
optimizer.zero_grad()

loss.backward()

optimizer.step()
total_loss += loss.item() * images.size(0)
total_dice_loss += dice_loss.item() * images.size(0)

output = (logits > 0.5).float()

batch_tp, batch_fp, batch_fn, batch_tn = smp.metrics.get_stats(
output.long(), masks.long(), mode="binary", threshold=0.5
)

tp_list.append(batch_tp)
fp_list.append(batch_fp)
fn_list.append(batch_fn)
tn_list.append(batch_tn)

for pred, gt in zip(output.cpu().numpy(), masks.cpu().numpy()):

if np.sum(pred) == 0 or np.sum(gt) == 0:
# hd95_list.append(float('inf'))
continue

intersection = np.logical_and(pred, gt)
if np.sum(intersection) == 0:
# hd95_list.append(float('inf'))
continue

pred = (pred >  0.5).astype(np.uint8)
gt = gt.astype(np.uint8)

#     hd95_value = hd95(pred, gt)
#     hd95_list.append(hd95_value)

dsc_value = dc(pred, gt)
dsc_list.append(dsc_value)

avg_loss = total_loss / len(dataloader.dataset)
avg_dice_loss = total_dice_loss / len(dataloader.dataset)

tp = torch.cat(tp_list, dim=0)
fp = torch.cat(fp_list, dim=0)
fn = torch.cat(fn_list, dim=0)
tn = torch.cat(tn_list, dim=0)

# avg_hd95 = sum(hd95_list) / len(hd95_list) if hd95_list else float('inf')
# avg_dsc = sum(dsc_list) / len(dsc_list)

iou_score = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro")
f1_score = smp.metrics.f1_score(tp, fp, fn, tn, reduction="micro")
sensitivity = smp.metrics.sensitivity(tp, fp, fn, tn, reduction="macro")
recall = smp.metrics.recall(tp, fp, fn, tn, reduction="micro-imagewise")
precision = smp.metrics.precision(tp, fp, fn, tn, reduction="micro")

print(f"Average Loss: {avg_loss}")
print(f"Average Dice Loss: {avg_dice_loss}")
# print(f"Average HD95: {avg_hd95}")
# print(f"Average DSC: {avg_dsc}")

print(f"iou_score: {iou_score}")
print(f"f1_score: {f1_score}")
print(f"sensitivity: {sensitivity}")
print(f"recall: {recall}")
print(f"precision: {precision}")

return avg_loss, avg_dice_loss

Код: Выделить всё

def eval_fn(dataloader, model):
model.eval()
predictions = []

with torch.inference_mode():
for images in tqdm(dataloader, desc="Оценка..."):
images = images.to(DEVICE)

logits = model(images)

pred = torch.sigmoid(logits)

predictions.append(pred.cpu().numpy())

predictions = np.concatenate(predictions, axis=0)

avg_prediction = np.mean(predictions)

print(f"\n: {avg_prediction:.4f}")

return avg_prediction, 0.0

Код: Выделить всё

best_valid_loss = np.inf
train_losses = []
test_losses = []
train_dice_losses = []
test_dice_losses = []

for i in range(EPOCHS):
train_loss, train_dice_loss = train_fn(trainloader, model, optimizer)
test_loss, test_dice_loss = eval_fn(testloader, model)

train_losses.append(train_loss)
test_losses.append(test_loss)
train_dice_losses.append(train_dice_loss)
test_dice_losses.append(test_dice_loss)

if train_loss < best_valid_loss:
torch.save(model.state_dict(), "best_model.pt")
print("saved model")
best_valid_loss = train_loss

print(
f"epochs: {i + 1} , train loss: {train_loss:.4f}")
введите здесь описание изображения

введите здесь описание изображения

введите здесь описание изображения
Я думаю, ошибка в загрузчиках набора данных.

Подробнее здесь: https://stackoverflow.com/questions/793 ... ts-dataset
Ответить

Быстрый ответ

Изменение регистра текста: 
Смайлики
:) :( :oops: :roll: :wink: :muza: :clever: :sorry: :angel: :read: *x)
Ещё смайлики…
   
К этому ответу прикреплено по крайней мере одно вложение.

Если вы не хотите добавлять вложения, оставьте поля пустыми.

Максимально разрешённый размер вложения: 15 МБ.

Вернуться в «Python»