Я обучаю модель для задачи классификации по нескольким меткам для каждого класса. У меня есть несколько меток после запуска теста, я получил 100% для обоих классов
Я использовал 150 000 изображений для обучения и проверки и 30 000 изображений для тест, и я использую предварительно обученную модель mobilenet_v2, функцию потерь CrossEntropy
это функция расчета _accuracy, которую я использовал
Я обучаю модель для задачи классификации по нескольким меткам для каждого класса. У меня есть несколько меток после запуска теста, я получил 100% для обоих классов Я использовал 150 000 изображений для обучения и проверки и 30 000 изображений для тест, и я использую предварительно обученную модель mobilenet_v2, функцию потерь CrossEntropy это функция расчета _accuracy, которую я использовал [code]def calculate_metrics(output, target): _, predicted_action = output['action_name'].cpu().max(1) gt_action = target['action_name'].cpu()
with warnings.catch_warnings(): # sklearn may produce a warning when processing zero row in confusion matrix warnings.simplefilter("ignore") accuracy_action = accuracy_score(y_true=gt_action.numpy(), y_pred=predicted_action.numpy()) accuracy_condition = accuracy_score(y_true=gt_condition.numpy(), y_pred=predicted_condition.numpy())
return accuracy_action, accuracy_condition [/code] а это сценарий обучения и проверки [code]n_train_samples = len(train_dataloader) print("Starting training ...")
for epoch in range(start_epoch, N_epochs + 1): total_loss = 0 accuracy_action = 0 accuracy_condition = 0
for batch in train_dataloader: optimizer.zero_grad()
img = batch['img'] target_labels = batch['labels'] target_labels = {t: target_labels[t].to(device) for t in target_labels} output = model(img.to(device))
if epoch % 5 == 0: validate(model, val_dataloader, logger, epoch, device) checkpoint_save(model, savedir, epoch) [/code] и это функция проверки [code]def validate(model, dataloader, logger, iteration, device, checkpoint=None): if checkpoint is not None: checkpoint_load(model, checkpoint)
for batch in dataloader: img = batch['img'] target_labels = batch['labels'] target_labels = {t: target_labels[t].to(device) for t in target_labels} output = model(img.to(device))