Обучение RNN — градиенты и сходимость моделейPython

Программы на Python
Ответить Пред. темаСлед. тема
Anonymous
 Обучение RNN — градиенты и сходимость моделей

Сообщение Anonymous »

В настоящее время я работаю над задачей прогнозирования последовательности, используя RNN с TensorFlow. сталкиваюсь с проблемами во время обучения, которые, по моему мнению, связаны с взрывом градиентов.
потеря иногда достигает чрезвычайно высоких значений (например, на несколько порядков выше, чем первоначальная потеря). Градиенты кажутся чрезмерными

Код: Выделить всё

import tensorflow as tf

class SimpleRNNModel(tf.keras.Model):
def __init__(self, input_size, hidden_size, output_size):
super(SimpleRNNModel, self).__init__()
self.rnn = tf.keras.layers.SimpleRNN(hidden_size, return_sequences=True)
self.dense = tf.keras.layers.Dense(output_size)

def call(self, x):
rnn_out = self.rnn(x)
return self.dense(rnn_out)

input_size = 10
hidden_size = 20
output_size = 1
learning_rate = 0.001
epochs = 100

model = SimpleRNNModel(input_size, hidden_size, output_size)

model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate), loss='mean_squared_error')

for epoch in range(epochs):
with tf.GradientTape() as tape:
predictions = model(data)
loss = tf.keras.losses.mean_squared_error(labels, predictions)

gradients = tape.gradient(loss, model.trainable_variables)

if any(tf.reduce_max(tf.abs(grad)) > 1e5 for grad in gradients):
print("Exploding gradients detected. Adjusting learning rate.")
clipped_gradients = [tf.clip_by_value(grad, -1e5, 1e5) for grad in gradients]
else:
clipped_gradients = gradients

model.optimizer.apply_gradients(zip(clipped_gradients, model.trainable_variables))

print(f"Epoch {epoch}: Loss = {loss.numpy().mean()}")
Пожалуйста, предоставьте любые указатели, входные данные для точной настройки кода для оптимизации/сокращения потерь.

Подробнее здесь: https://stackoverflow.com/questions/790 ... onvergence
Реклама
Ответить Пред. темаСлед. тема

Быстрый ответ

Изменение регистра текста: 
Смайлики
:) :( :oops: :roll: :wink: :muza: :clever: :sorry: :angel: :read: *x)
Ещё смайлики…
   
К этому ответу прикреплено по крайней мере одно вложение.

Если вы не хотите добавлять вложения, оставьте поля пустыми.

Максимально разрешённый размер вложения: 15 МБ.

  • Похожие темы
    Ответы
    Просмотры
    Последнее сообщение

Вернуться в «Python»