RuntimeError: тензор веса должен быть определен либо для всех 1000 классов, либо ни для одного класса, но получен тензорPython

Программы на Python
Ответить Пред. темаСлед. тема
Anonymous
 RuntimeError: тензор веса должен быть определен либо для всех 1000 классов, либо ни для одного класса, но получен тензор

Сообщение Anonymous »

Я пытаюсь использовать VGG16 для ** набора данных из 5 классов**.
Я уже добавил 5 новых слоев, чтобы настроить вывод для logit как 5 .
model = models.vgg16(pretrained=True) #Загружает модель vgg16, предварительно обученную на наборе данных Imagenet.

Код: Выделить всё

#Replace the Final layer of pretrained vgg16 with 5 new layers.
model.fc = nn.Sequential(nn.Linear(1000,512),
nn.ReLU(inplace=True),
nn.Linear(512,256),
nn.ReLU(inplace=True),
nn.Linear(256,128),
nn.ReLU(inplace=True),
nn.Linear(128,64),
nn.ReLU(inplace=True),
nn.Linear(64,5))
И моя функция потерь выглядит следующим образом:

Код: Выделить всё

loss_fn = nn.CrossEntropyLoss(weight=class_weights) #CrossEntropyLoss with class_weights.
где class_weights определен следующим образом:

Код: Выделить всё

from sklearn.utils import class_weight #For calculating weights for each class.
class_weights = class_weight.compute_class_weight(class_weight='balanced',classes=np.array([0,1,2,3,4]),y=train_df['level'].values)
class_weights = torch.tensor(class_weights,dtype=torch.float).to(device)

print(class_weights) #Prints the calculated weights for the classes.
вывод: tensor([0.2556, 4.6000, 1.5333, 9.2000, 9.2000], device='cuda:0')
После первой эпохи I получите ошибку, указанную ниже:

RuntimeError: тензор веса должен быть определен либо для всех 1000 классов или отсутствие классов, но получен весовой тензор формы: [5]


Подробнее здесь: https://stackoverflow.com/questions/715 ... sses-or-no
Реклама
Ответить Пред. темаСлед. тема

Быстрый ответ

Изменение регистра текста: 
Смайлики
:) :( :oops: :roll: :wink: :muza: :clever: :sorry: :angel: :read: *x)
Ещё смайлики…
   
К этому ответу прикреплено по крайней мере одно вложение.

Если вы не хотите добавлять вложения, оставьте поля пустыми.

Максимально разрешённый размер вложения: 15 МБ.

  • Похожие темы
    Ответы
    Просмотры
    Последнее сообщение

Вернуться в «Python»