Anonymous
Ошибка во входных тензорах для тестирования 4D-модальностей в наборе данных BRATS.
Сообщение
Anonymous » 28 дек 2024, 01:33
Я работаю над набором данных BRATS-2021 и pytorch с моделью U-net. У меня проблема с функцией eval_fn, я не могу начать тестирование набора данных проверки, потому что у меня есть ошибки в размерах входных тензоров. У меня есть 2 тензора для обучения (с изображением в 4 каналах модальностей МРТ - t1,t2,flair,t2ce и тензор сегментации с 1 каналом в trainloader, но в testloader у меня есть только изображение для тестирования без каких-либо масок, каналы сегментации и получить ошибку.
Моя функция поезда:
Код: Выделить всё
def train_fn(dataloader, model, optimizer):
model.train()
total_loss = 0.0
total_dice_loss = 0.0
tp_list, fp_list, fn_list, tn_list = [], [], [], []
hd95_list, dsc_list = [], []
for batch in tqdm(dataloader):
images, masks = batch
images, masks = images.to(DEVICE), masks.to(DEVICE)
# logits, loss = model(images, masks)
buffer = model(images, masks)
if len(buffer) > 1:
logits, loss = buffer
else:
logits = buffer
loss = nn.BCEWithLogitsLoss()(logits, masks)
# print(buffer)
dice_loss = DiceLoss(mode="binary")(logits, masks)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item() * images.size(0)
total_dice_loss += dice_loss.item() * images.size(0)
output = (logits > 0.5).float()
batch_tp, batch_fp, batch_fn, batch_tn = smp.metrics.get_stats(
output.long(), masks.long(), mode="binary", threshold=0.5
)
tp_list.append(batch_tp)
fp_list.append(batch_fp)
fn_list.append(batch_fn)
tn_list.append(batch_tn)
for pred, gt in zip(output.cpu().numpy(), masks.cpu().numpy()):
if np.sum(pred) == 0 or np.sum(gt) == 0:
# hd95_list.append(float('inf'))
continue
intersection = np.logical_and(pred, gt)
if np.sum(intersection) == 0:
# hd95_list.append(float('inf'))
continue
pred = (pred > 0.5).astype(np.uint8)
gt = gt.astype(np.uint8)
# hd95_value = hd95(pred, gt)
# hd95_list.append(hd95_value)
dsc_value = dc(pred, gt)
dsc_list.append(dsc_value)
avg_loss = total_loss / len(dataloader.dataset)
avg_dice_loss = total_dice_loss / len(dataloader.dataset)
tp = torch.cat(tp_list, dim=0)
fp = torch.cat(fp_list, dim=0)
fn = torch.cat(fn_list, dim=0)
tn = torch.cat(tn_list, dim=0)
# avg_hd95 = sum(hd95_list) / len(hd95_list) if hd95_list else float('inf')
# avg_dsc = sum(dsc_list) / len(dsc_list)
iou_score = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro")
f1_score = smp.metrics.f1_score(tp, fp, fn, tn, reduction="micro")
sensitivity = smp.metrics.sensitivity(tp, fp, fn, tn, reduction="macro")
recall = smp.metrics.recall(tp, fp, fn, tn, reduction="micro-imagewise")
precision = smp.metrics.precision(tp, fp, fn, tn, reduction="micro")
print(f"Average Loss: {avg_loss}")
print(f"Average Dice Loss: {avg_dice_loss}")
# print(f"Average HD95: {avg_hd95}")
# print(f"Average DSC: {avg_dsc}")
print(f"iou_score: {iou_score}")
print(f"f1_score: {f1_score}")
print(f"sensitivity: {sensitivity}")
print(f"recall: {recall}")
print(f"precision: {precision}")
return avg_loss, avg_dice_loss
Код: Выделить всё
def eval_fn(dataloader, model):
model.eval()
predictions = []
with torch.inference_mode():
for images in tqdm(dataloader, desc="Оценка..."):
images = images.to(DEVICE)
logits = model(images)
pred = torch.sigmoid(logits)
predictions.append(pred.cpu().numpy())
predictions = np.concatenate(predictions, axis=0)
avg_prediction = np.mean(predictions)
print(f"\n: {avg_prediction:.4f}")
return avg_prediction, 0.0
Код: Выделить всё
best_valid_loss = np.inf
train_losses = []
test_losses = []
train_dice_losses = []
test_dice_losses = []
for i in range(EPOCHS):
train_loss, train_dice_loss = train_fn(trainloader, model, optimizer)
test_loss, test_dice_loss = eval_fn(testloader, model)
train_losses.append(train_loss)
test_losses.append(test_loss)
train_dice_losses.append(train_dice_loss)
test_dice_losses.append(test_dice_loss)
if train_loss < best_valid_loss:
torch.save(model.state_dict(), "best_model.pt")
print("saved model")
best_valid_loss = train_loss
print(
f"epochs: {i + 1} , train loss: {train_loss:.4f}")
введите здесь описание изображения
введите здесь описание изображения
введите здесь описание изображения
Я думаю, ошибка в загрузчиках набора данных.
Подробнее здесь:
https://stackoverflow.com/questions/793 ... ts-dataset
1735338786
Anonymous
Я работаю над набором данных BRATS-2021 и pytorch с моделью U-net. У меня проблема с функцией eval_fn, я не могу начать тестирование набора данных проверки, потому что у меня есть ошибки в размерах входных тензоров. У меня есть 2 тензора для обучения (с изображением в 4 каналах модальностей МРТ - t1,t2,flair,t2ce и тензор сегментации с 1 каналом в trainloader, но в testloader у меня есть только изображение для тестирования без каких-либо масок, каналы сегментации и получить ошибку. Моя функция поезда: [code]def train_fn(dataloader, model, optimizer): model.train() total_loss = 0.0 total_dice_loss = 0.0 tp_list, fp_list, fn_list, tn_list = [], [], [], [] hd95_list, dsc_list = [], [] for batch in tqdm(dataloader): images, masks = batch images, masks = images.to(DEVICE), masks.to(DEVICE) # logits, loss = model(images, masks) buffer = model(images, masks) if len(buffer) > 1: logits, loss = buffer else: logits = buffer loss = nn.BCEWithLogitsLoss()(logits, masks) # print(buffer) dice_loss = DiceLoss(mode="binary")(logits, masks) optimizer.zero_grad() loss.backward() optimizer.step() total_loss += loss.item() * images.size(0) total_dice_loss += dice_loss.item() * images.size(0) output = (logits > 0.5).float() batch_tp, batch_fp, batch_fn, batch_tn = smp.metrics.get_stats( output.long(), masks.long(), mode="binary", threshold=0.5 ) tp_list.append(batch_tp) fp_list.append(batch_fp) fn_list.append(batch_fn) tn_list.append(batch_tn) for pred, gt in zip(output.cpu().numpy(), masks.cpu().numpy()): if np.sum(pred) == 0 or np.sum(gt) == 0: # hd95_list.append(float('inf')) continue intersection = np.logical_and(pred, gt) if np.sum(intersection) == 0: # hd95_list.append(float('inf')) continue pred = (pred > 0.5).astype(np.uint8) gt = gt.astype(np.uint8) # hd95_value = hd95(pred, gt) # hd95_list.append(hd95_value) dsc_value = dc(pred, gt) dsc_list.append(dsc_value) avg_loss = total_loss / len(dataloader.dataset) avg_dice_loss = total_dice_loss / len(dataloader.dataset) tp = torch.cat(tp_list, dim=0) fp = torch.cat(fp_list, dim=0) fn = torch.cat(fn_list, dim=0) tn = torch.cat(tn_list, dim=0) # avg_hd95 = sum(hd95_list) / len(hd95_list) if hd95_list else float('inf') # avg_dsc = sum(dsc_list) / len(dsc_list) iou_score = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro") f1_score = smp.metrics.f1_score(tp, fp, fn, tn, reduction="micro") sensitivity = smp.metrics.sensitivity(tp, fp, fn, tn, reduction="macro") recall = smp.metrics.recall(tp, fp, fn, tn, reduction="micro-imagewise") precision = smp.metrics.precision(tp, fp, fn, tn, reduction="micro") print(f"Average Loss: {avg_loss}") print(f"Average Dice Loss: {avg_dice_loss}") # print(f"Average HD95: {avg_hd95}") # print(f"Average DSC: {avg_dsc}") print(f"iou_score: {iou_score}") print(f"f1_score: {f1_score}") print(f"sensitivity: {sensitivity}") print(f"recall: {recall}") print(f"precision: {precision}") return avg_loss, avg_dice_loss [/code] [code]def eval_fn(dataloader, model): model.eval() predictions = [] with torch.inference_mode(): for images in tqdm(dataloader, desc="Оценка..."): images = images.to(DEVICE) logits = model(images) pred = torch.sigmoid(logits) predictions.append(pred.cpu().numpy()) predictions = np.concatenate(predictions, axis=0) avg_prediction = np.mean(predictions) print(f"\n: {avg_prediction:.4f}") return avg_prediction, 0.0 [/code] [code]best_valid_loss = np.inf train_losses = [] test_losses = [] train_dice_losses = [] test_dice_losses = [] for i in range(EPOCHS): train_loss, train_dice_loss = train_fn(trainloader, model, optimizer) test_loss, test_dice_loss = eval_fn(testloader, model) train_losses.append(train_loss) test_losses.append(test_loss) train_dice_losses.append(train_dice_loss) test_dice_losses.append(test_dice_loss) if train_loss < best_valid_loss: torch.save(model.state_dict(), "best_model.pt") print("saved model") best_valid_loss = train_loss print( f"epochs: {i + 1} , train loss: {train_loss:.4f}") [/code] введите здесь описание изображения введите здесь описание изображения введите здесь описание изображения Я думаю, ошибка в загрузчиках набора данных. Подробнее здесь: [url]https://stackoverflow.com/questions/79312426/error-in-input-tensors-for-testing-of-4d-modalities-in-brats-dataset[/url]