Pytorch RuntimeError: фигуры mat1 и mat2 не могут быть умноженыPython

Программы на Python
Ответить Пред. темаСлед. тема
Anonymous
 Pytorch RuntimeError: фигуры mat1 и mat2 не могут быть умножены

Сообщение Anonymous »

Я создаю CNN на Pytorch и получаю следующее сообщение об ошибке:

RuntimeError: фигуры mat1 и mat2 не могут быть умножены (32x32768 и
512x256)

Я построил следующую модель:

Код: Выделить всё

def classifier_block(input, output, kernel_size, stride, last_layer=False):
if not last_layer:
x = nn.Sequential(
nn.Conv2d(input, output, kernel_size, stride, padding=3),
nn.BatchNorm2d(output),
nn.LeakyReLU(0.2, inplace=True)
)
else:
x = nn.Sequential(
nn.Conv2d(input, output, kernel_size, stride),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
)
return x

class Classifier(nn.Module):
def __init__(self, input_dim, output):
super(Classifier, self).__init__()
self.classifier = nn.Sequential(
classifier_block(input_dim, 64, 7, 2),
classifier_block(64, 64, 3, 2),
classifier_block(64, 128, 3, 2),
classifier_block(128, 256, 3, 2),
classifier_block(256, 512, 3, 2, True)
)
print('CLF: ',self.classifier)

self.linear = nn.Sequential(
nn.Linear(512, 256),
nn.ReLU(inplace=True),
nn.Linear(256, 128),
nn.ReLU(inplace=True),
nn.Linear(128, 64),
nn.ReLU(inplace=True),
nn.Linear(64, output)
)
print('Linear: ', self.linear)

def forward(self, image):
print('IMG: ', image.shape)
x = self.classifier(image)
print('X: ', x.shape)
return self.linear(x.view(len(x), -1))
Входные изображения имеют размер 512x512. Вот мой тренировочный блок:

Код: Выделить всё

loss_train = []
loss_val = []

for epoch in range(epochs):
print('Epoch: {}/{}'.format(epoch, epochs))
total_train = 0
correct_train = 0
cumloss_train = 0
classifier.train()
for batch, (x, y) in enumerate(train_loader):
x = x.to(device)
print(x.shape)
print(y.shape)
output = classifier(x)
loss = criterion(output, y.to(device))
optimizer.zero_grad()
loss.backward()
optimizer.step()

print('Loss: {}'.format(loss))
Будем очень признательны за любые советы.

Подробнее здесь: https://stackoverflow.com/questions/756 ... multiplied
Реклама
Ответить Пред. темаСлед. тема

Быстрый ответ

Изменение регистра текста: 
Смайлики
:) :( :oops: :roll: :wink: :muza: :clever: :sorry: :angel: :read: *x)
Ещё смайлики…
   
К этому ответу прикреплено по крайней мере одно вложение.

Если вы не хотите добавлять вложения, оставьте поля пустыми.

Максимально разрешённый размер вложения: 15 МБ.

  • Похожие темы
    Ответы
    Просмотры
    Последнее сообщение

Вернуться в «Python»