Код: Выделить всё
Model(
(lstm): LSTM(3, 32, num_layers=3, batch_first=True, dropout=0.7)
(dense): Linear(in_features=32, out_features=2, bias=True)
)
Для функции потерь я использую перекрестную энтропийную потерю с вычисленными весами для каждого класса а для оптимизации я использую AdamW, F-Score измеряется классификацией_report из библиотеки sklearn
Проблема в том, что когда размер пакета набора тестовых и обучающих данных равен 64, производительность модели растет, как и ожидалось. Но когда размер пакета тестового набора данных равен 256, производительность сильно упадет и больше не будет расти.
На графике ниже вы можете увидеть производительность размеров пакетов. Розовый график представляет размер пакета тестовых данных 64, синий — размер пакета 256.

F-Score поезда вычисляется на наборе данных поезда с размером пакета 64
Я также всегда устанавливаю для своей модели значение eval() или train( ) режим
Код: Выделить всё
for epoch in range(epochs):
model.train()
print(f'Epoch {epoch}')
__train_loop(model, train_dataloader, loss_function, optimizer, scheduler, verbose, device=device)
model.eval()
train_accuracy, train_f_score = test_model(model, train_dataloader, device=device)
print(f'Train accuracy: {train_accuracy}')
print(f'Train F-Score: {train_f_score}')
accuracy, f_score = test_model(model, test_dataloader, device=device)
Код: Выделить всё
for batch_id, (X, y) in enumerate(train_dataloader):
X, y = X.to(device), y.to(device)
optimizer.zero_grad()
y_pred = model(X)
loss = loss_function(y_pred, y)
loss.backward()
optimizer.step()
Код: Выделить всё
with torch.no_grad():
for batch_id, (X, y) in enumerate(test_dataloader):
X, y = X.to(device), y.to(device)
y_pred = model(X)
y_pred = torch.argmax(y_pred, dim=1)
y_pred_all.append(y_pred.cpu().numpy())
y_all.append(y.cpu().numpy())
bar.update(batch_id)
y_pred_all = np.hstack(y_pred_all).flatten()
y_all = np.hstack(y_all).flatten()
cr = classification_report(y_all, y_pred_all, output_dict=True)
f_score = cr['macro avg']['f1-score']
accuracy = cr['accuracy']
Подробнее здесь: https://stackoverflow.com/questions/790 ... ize-on-mps