OutOfMemory при обучении предварительно обученной модели BERT для задачи классификации токеновPython

Программы на Python
Ответить Пред. темаСлед. тема
Anonymous
 OutOfMemory при обучении предварительно обученной модели BERT для задачи классификации токенов

Сообщение Anonymous »

Я использую предварительно обученную BertForTokenClassification для задачи распознавания вложенных именованных объектов. Чтобы определить вложенные объекты, я использую метод нескольких меток. На выходе модель возвращает 3 списка логитов, по одному для каждого уровня, которые в конечном итоге объединяются вместе. Я запускаю процесс обучения на Linux Ubuntu 22.04 с 16 ГБ ОЗУ.
Проблема в том, что процесс обучения прерывается из-за OutOfMemory. Неважно, какой размер пакета: 1 или 16. Потребление памяти постоянно растет и процесс убивается. Чем меньше размер пакета, тем позже будет получен окончательный результат.
Класс модели:

Код: Выделить всё

import torch.nn as nn
from transformers import, BertForTokenClassification

class NestedNERMultiLabelModel(nn.Module):
def __init__(self, model_name, num_labels_level1, num_labels_level2, num_labels_level3, dropout):
super(NestedNERMultiLabelModel, self).__init__()

self.bert = BertForTokenClassification.from_pretrained(model_name, hidden_dropout_prob=dropout)

self.classifier_level1 = nn.Linear(self.bert.config.hidden_size, num_labels_level1)

self.classifier_level2 = nn.Linear(self.bert.config.hidden_size, num_labels_level2)

self.classifier_level3 = nn.Linear(self.bert.config.hidden_size, num_labels_level3)

def forward(self, input_ids, attention_mask=None):

outputs = self.bert(input_ids, attention_mask=attention_mask, output_hidden_states=True)
out = outputs.hidden_states[-1]

logits_level1 = self.classifier_level1(out)

logits_level2 = self.classifier_level2(out)

logits_level3 = self.classifier_level3(out)

return logits_level1, logits_level2, logits_level3
Обучающий модуль:

Код: Выделить всё

import torch
from transformers import get_linear_schedule_with_warmup

from NestedNERMultiLabelModel import NestedNERMultiLabelModel
import torch.nn.functional as F
from tqdm.auto import tqdm

class Trainer:

def __init__(self, config, preprocessor):
self.config = config
self.preprocessor = preprocessor
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = NestedNERMultiLabelModel(config["bert_model_name"], config["num_labels"], config["num_labels"], config["num_labels"], config['dropout_rate'])
self.optimizer = torch.optim.AdamW(
self.model.parameters(),
lr=config["learning_rate"],
weight_decay=self.config["weight_decay"]
)
self.start_epoch = 0

self.model = self.model.to(self.device)
self.epochs = config["num_epochs"]

def train(self, train_loader, valid_loader):

num_training_steps = len(train_loader) * (self.epochs - self.start_epoch)

scheduler = get_linear_schedule_with_warmup(
self.optimizer,
num_warmup_steps=500,
num_training_steps=num_training_steps
)

best_loss = 1000
with tqdm(range(num_training_steps)) as progress_bar:
for epoch in range(self.start_epoch, self.epochs):

train_loss = 0
self.model.train()

for input_ids, attention_mask, labels in train_loader:
input_ids, attention_mask = input_ids.to(self.device), attention_mask.to(self.device)

self.optimizer.zero_grad()
labels_pred = self.model(input_ids, attention_mask)

total_loss = 0
for i in range(3):
loss = F.cross_entropy(labels_pred[i].view(-1, 8), labels[:, i].reshape(-1), ignore_index=0)
total_loss += loss

# Update model weights
total_loss.backward()
train_loss += total_loss

torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.config["grad_norm"])
self.optimizer.step()
scheduler.step()
progress_bar.update(1)

train_loss = train_loss / num_training_steps

with torch.no_grad():
self.model.eval()
eval_loss = self.evaluate(self.model, valid_loader)
print(f'Epoch: {epoch} | train_loss: {train_loss:

Подробнее здесь: [url]https://stackoverflow.com/questions/79126798/outofmemory-while-training-pre-trained-bert-model-for-token-classification-task[/url]
Реклама
Ответить Пред. темаСлед. тема

Быстрый ответ

Изменение регистра текста: 
Смайлики
:) :( :oops: :roll: :wink: :muza: :clever: :sorry: :angel: :read: *x)
Ещё смайлики…
   
К этому ответу прикреплено по крайней мере одно вложение.

Если вы не хотите добавлять вложения, оставьте поля пустыми.

Максимально разрешённый размер вложения: 15 МБ.

  • Похожие темы
    Ответы
    Просмотры
    Последнее сообщение

Вернуться в «Python»