Модель Берта не учится с использованием JAX. Результаты не меняютсяPython

Программы на Python
Ответить Пред. темаСлед. тема
Anonymous
 Модель Берта не учится с использованием JAX. Результаты не меняются

Сообщение Anonymous »

Я обучаю модель BERT классификации спама с использованием JAX на TPU. Моя модель не обучалась, и ее результаты не изменились.

Код: Выделить всё

Epoch 0: Train Loss = 2.7961559295654297: Train Accuracy: 0.30608975887298584 Eval Loss = 3.6600053310394287: Eval Accuracy = 0.0
Epoch 1: Train Loss = 2.7961559295654297: Train Accuracy: 0.30608975887298584 Eval Loss = 3.6600053310394287: Eval Accuracy = 0.0
Epoch 2: Train Loss = 2.7961559295654297: Train Accuracy: 0.30608975887298584 Eval Loss = 3.6600053310394287: Eval Accuracy = 0.0
< /code>
Код для обучения: < /p>
@jax.pmap
def train_step(state, batch, labels):
def loss_fn(params):

# get everything out of the batch to the model and pass the model parameters
logits = model(**batch, params = state.params).logits
loss = compute_loss(logits, labels) # compute the loss

return loss, logits

# turn the loss function into a grad differential function
grad_fn = jax.value_and_grad(loss_fn, has_aux = True) # has_aux allows the return of the logits
# get the loss and grads from the grad_fn
(loss, logits), grads = grad_fn(state.params)
# update the model state by using the produces gradients
new_state = state.apply_gradients(grads = grads)

return loss, logits, new_state

for epoch in range(epochs):
epoch_losses, epoch_accuracies = [], []
for batch in train_dataset:

batch["input_ids"] = jnp.array(batch["input_ids"])
batch["attention_mask"] = jnp.array(batch["attention_mask"])
batch["token_type_ids"] = jnp.array(batch["token_type_ids"])

# we will replicate the value over multiple devices (tpus)
batch_inputs = {k: jax.device_put_replicated(v, jax.devices()) for k, v in batch.items() if k != "Category"}
batch_labels = jax.device_put_replicated(batch["Category"], jax.devices())  # replicate labels across devices

# remove none from data
batch_labels = safe_convert_to_jax_array(jnp.array(batch_labels))
batch_labels = batch_labels.transpose(1, 0)

loss, logits, state = train_step(state, batch_inputs, batch_labels)

cls_logits = logits[:, :, 0, :]
classification_logits = cls_logits[:, :, :2]

predicted_labels = jnp.argmax(classification_logits, axis = -1)
accuracy = compute_accuracy(predicted_labels, batch_labels)
< /code>
Код для инициализации состояния: < /p>
class TrainState(train_state.TrainState):
pass

# our model parameters
params = model.params
# create the intial state for our training
state = TrainState.create(apply_fn = model.__call__, params = params, tx = optimizer)

def safe_convert_to_jax_array(input_data, default_value = 0):
# replace None values with default_value
return jnp.array([default_value if x is None else x for x in input_data])

# replicate the state across tpus
state = jax.device_put_replicated(state, jax.devices())
Чтобы просмотреть полный код: https://www.kaggle.com/code/yousefr/ber ... x-and-tpus
Кроме того, я пытался настроить скорость обучения, но это не помогло.

Подробнее здесь: https://stackoverflow.com/questions/793 ... ont-change
Реклама
Ответить Пред. темаСлед. тема

Быстрый ответ

Изменение регистра текста: 
Смайлики
:) :( :oops: :roll: :wink: :muza: :clever: :sorry: :angel: :read: *x)
Ещё смайлики…
   
К этому ответу прикреплено по крайней мере одно вложение.

Если вы не хотите добавлять вложения, оставьте поля пустыми.

Максимально разрешённый размер вложения: 15 МБ.

  • Похожие темы
    Ответы
    Просмотры
    Последнее сообщение

Вернуться в «Python»