Модель Pytorch всегда выводит 0Python

Программы на Python
Ответить
Anonymous
 Модель Pytorch всегда выводит 0

Сообщение Anonymous »

Я обучил PVT (код здесь: https://github.com/CupidJay/Training-Vi ... 040-images), используя алгоритм IDMM (алгоритм, который используется для предварительного обучения ИИ).
Вот фиктивная функция с ИИ:

Код: Выделить всё

import torch
import torchvision
import VIT.models.pvt_modelv2 as KEVIN_pvtv2
from collections import OrderedDict

num1 = torch.rand(32,3,224,224)
num2 = torch.rand(32,3,224,224)
print(num1, num2)

checkpoint = torch.load("C:\\Users\\vvenkata\\Desktop\\coding\\KEVIN\\model\\VIT\\checkpoints\\224_finetune\\Training\\pvt_v2_b3_lr_0.05_wd_0.0013_bs_16_epochs_200_dim_192_cutmix_0.7_path_Training\\pretrained_pvt_v2_b3_lr_0.054_w\\model_best.pth.tar", map_location='cuda:0')
pvt = KEVIN_pvtv2.pvt_v2_b3(num_classes=3)
new_state_dict = OrderedDict()
for k, v in checkpoint.items():
name = k[7:] if k.startswith("module.") else k
new_state_dict[name] = v
pvt.load_state_dict(new_state_dict)
pvt.eval()

output1 = pvt(num1)
output2 = pvt(num2)
print(f"Output1 {output1}, output2 {output2}")
Набор данных состоит из 3 классов, в общей сложности около 2800 изображений. Я ожидал, что точность будет около 80-90%, поскольку я обучал модель в течение 500 эпох. Однако я получил точность 34%. При дальнейшей проверке это произошло из-за того, что выходные данные модели всегда были тензорными ([[ 0,1031, 0,0980, -0,2104]] независимо от того, какое входное изображение (я проверял это, вводя по одному изображению за раз. Благодаря этому примеру вывод кода приведен ниже:

Код: Выделить всё

input:  [[[0.6883, 0.8841, 0.8812,  ..., 0.2528, 0.6816, 0.1573],
[0.4715, 0.0426, 0.7539,  ..., 0.3806, 0.3291, 0.4134],
[0.1406, 0.2837, 0.1914,  ..., 0.6967, 0.9686, 0.9358],
...,
[0.9271, 0.6876, 0.2966,  ..., 0.8610, 0.3527, 0.9421],
[0.8541, 0.4442, 0.5698,  ..., 0.4103, 0.9373, 0.9191],
[0.5845, 0.6987, 0.1419,  ..., 0.3499, 0.8706, 0.8108]],

[[0.7734, 0.1418, 0.2561,  ..., 0.9039, 0.5064, 0.8974],
[0.9297, 0.5981, 0.6904,  ..., 0.5158, 0.0603, 0.4466],
[0.7587, 0.5091, 0.2655,  ..., 0.7955, 0.5502, 0.9425],
...,
[0.4653, 0.0370, 0.9665,  ..., 0.6748, 0.5855, 0.4120],
[0.6278, 0.7346, 0.1757,  ..., 0.9457, 0.8254, 0.5958],
[0.8474, 0.2402, 0.7168,  ..., 0.2539, 0.8296, 0.0168]]...

output:  [ 0.1031,  0.0980, -0.2104],
[ 0.1031,  0.0980, -0.2104],
[ 0.1031,  0.0980, -0.2104],
[ 0.1031,  0.0980, -0.2104],
[ 0.1031,  0.0980, -0.2104],
[ 0.1031,  0.0980, -0.2104],
[ 0.1031,  0.0980, -0.2104],
[ 0.1031,  0.0980, -0.2104], ...
Как это исправить? Я определенно могу исключить, что это какая-то проблема с загрузкой данных, поскольку этот базовый пример также воспроизводит проблему.

Подробнее здесь: https://stackoverflow.com/questions/792 ... tputting-0
Ответить

Быстрый ответ

Изменение регистра текста: 
Смайлики
:) :( :oops: :roll: :wink: :muza: :clever: :sorry: :angel: :read: *x)
Ещё смайлики…
   
К этому ответу прикреплено по крайней мере одно вложение.

Если вы не хотите добавлять вложения, оставьте поля пустыми.

Максимально разрешённый размер вложения: 15 МБ.

Вернуться в «Python»