Ошибка в Keras3 RuntimeError: метод требует нахождения в контексте перекрестной реплики, используйте get_replica_contextPython

Программы на Python
Ответить Пред. темаСлед. тема
Anonymous
 Ошибка в Keras3 RuntimeError: метод требует нахождения в контексте перекрестной реплики, используйте get_replica_context

Сообщение Anonymous »

Я хочу обучать GAN в Keras 3, насколько это возможно, независимо от бэкэнда, а это означает, что мне не следует использовать низкоуровневый API, такой как tf.GradientTape или что-то связанное с синтаксисом tf. Вместо этого я должен использовать API высокого уровня, например ops и так далее. Но в документации DCGAN используется синтаксис tf.
Вот моя ошибка:

Код: Выделить всё

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[8], line 1
----> 1 static_gan.fit(X_train, batch_size=64, epochs=10, callbacks=[GANMonitor()])

File ~/.venv/lib/python3.12/site-packages/keras/src/utils/traceback_utils.py:122, in filter_traceback..error_handler(*args, **kwargs)
119     filtered_tb = _process_traceback_frames(e.__traceback__)
120     # To get the full stack trace, call:
121     # `keras.config.disable_traceback_filtering()`
--> 122     raise e.with_traceback(filtered_tb) from None
123 finally:
124     del filtered_tb

Cell In[4], line 36
33 real_labels = ops.ones((batch_size, 1))
34 fake_labels = ops.zeros((batch_size, 1))
--->  36 d_loss_real = self.discriminator_trainable.train_on_batch(real_sample, real_labels)
37 d_loss_fake = self.discriminator_trainable.train_on_batch(generated_sample, fake_labels)
39 d_loss = 0.5 * (d_loss_real + d_loss_fake)

RuntimeError: Method requires being in cross-replica context, use get_replica_context().merge_call()
А вот определение StaticGAN:

Код: Выделить всё

class StaticGAN(Model):
def __init__(self, generator:Model, discriminator:Model, **kwargs):
super().__init__(**kwargs)
self.seed_generator = keras.random.SeedGenerator(42)

self.generator = generator

self.discriminator_trainable = discriminator
self.discriminator_trainable.trainable = True

self.discriminator_notrainable = keras.models.clone_model(self.discriminator_trainable)
self.discriminator_notrainable.trainable = False

self.generator_judgement = Model(generator.input, self.discriminator_notrainable(generator.output))

def compile(self, d_optimizer, g_optimizer, loss_fn):
super().compile()
self.loss_fn = loss_fn
self.d_loss_metric = keras.metrics.Mean(name="d_loss")
self.g_loss_metric = keras.metrics.Mean(name="g_loss")
self.discriminator_trainable.compile(optimizer=d_optimizer, loss=loss_fn)
self.generator_judgement.compile(optimizer=g_optimizer, loss=loss_fn)

@property
def metrics(self):
return [self.d_loss_metric, self.g_loss_metric]

def train_step(self, real_sample):
batch_size = ops.shape(real_sample)[0]

noise = ops.random.normal((batch_size, self.generator.input_shape[1]), seed=self.seed_generator)
generated_sample = self.generator(noise)
real_labels = ops.ones((batch_size, 1))
fake_labels = ops.zeros((batch_size, 1))

d_loss_real = self.discriminator_trainable.train_on_batch(real_sample, real_labels)
d_loss_fake = self.discriminator_trainable.train_on_batch(generated_sample, fake_labels)

d_loss = 0.5 * (d_loss_real + d_loss_fake)

# Transfer weights from discriminator_trainable to discriminator_notrainable
self.discriminator_notrainable.set_weights(self.discriminator_trainable.get_weights())

noise = ops.random.normal((batch_size, self.generator.input_shape[1]), seed=self.seed_generator)
g_loss = self.generator_judgement.train_on_batch(noise, real_labels)

# Transfer weights from discriminator_notrainable to generator
self.generator.set_weights(self.generator_judgement.get_weights())

# Update metrics
self.d_loss_metric.update_state(d_loss)
self.g_loss_metric.update_state(g_loss)
return {
"d_loss": self.d_loss_metric.result(),
"g_loss":  self.g_loss_metric.result(),
}
Вот генератор и дискриминатор:

Код: Выделить всё

#@title Generator Model
generator_input_layer = layers.Input(shape=(64,))
generator_hidden_layer = layers.Dense(7 * 7 * 28)(generator_input_layer)
generator_hidden_layer = layers.Reshape((7, 7, 28))(generator_hidden_layer)
generator_hidden_layer = layers.Conv2DTranspose(28, kernel_size=4, strides=2, padding="same")(generator_hidden_layer)
generator_hidden_layer = layers.LeakyReLU(negative_slope=0.2)(generator_hidden_layer)
generator_hidden_layer = layers.Conv2DTranspose(64, kernel_size=4, strides=2, padding="same")(generator_hidden_layer)
generator_hidden_layer = layers.LeakyReLU(negative_slope=0.2)(generator_hidden_layer)
generator_output_layer = layers.Conv2D(1, kernel_size=5, padding="same", activation="sigmoid", name='g_ol')(generator_hidden_layer)
gm = Model(generator_input_layer, generator_output_layer)

#@title Discriminator Model
discriminator_input_layer = layers.Input(shape=(28, 28, 1))
discriminator_hidden_layer = layers.Conv2D(64, kernel_size=4, strides=2, padding="same")(discriminator_input_layer)
discriminator_hidden_layer = layers.LeakyReLU(negative_slope=0.2)(discriminator_hidden_layer)
discriminator_hidden_layer = layers.Conv2D(128, kernel_size=4, strides=2, padding="same")(discriminator_hidden_layer)
discriminator_hidden_layer = layers.LeakyReLU(negative_slope=0.2)(discriminator_hidden_layer)
discriminator_hidden_layer = layers.Conv2D(128, kernel_size=4, strides=2, padding="same")(discriminator_hidden_layer)
discriminator_hidden_layer = layers.LeakyReLU(negative_slope=0.2)(discriminator_hidden_layer)
discriminator_hidden_layer = layers.Flatten()(discriminator_hidden_layer)
discriminator_hidden_layer = layers.Dropout(0.2)(discriminator_hidden_layer)
discriminator_output_layer = layers.Dense(1, activation="sigmoid", name='d_ol')(discriminator_hidden_layer)
dm = Model(inputs=discriminator_input_layer, outputs=discriminator_output_layer)
А вот компиляция:

Код: Выделить всё

static_gan = StaticGAN(generator=gm, discriminator=dm)

static_gan.compile(
d_optimizer=keras.optimizers.Adam(learning_rate=0.0003),
g_optimizer=keras.optimizers.Adam(learning_rate=0.0003),
loss_fn=keras.losses.BinaryCrossentropy(),
)
Итак, что не так с моей настройкой, все должно быть в порядке.

Подробнее здесь: https://stackoverflow.com/questions/790 ... context-us
Реклама
Ответить Пред. темаСлед. тема

Быстрый ответ

Изменение регистра текста: 
Смайлики
:) :( :oops: :roll: :wink: :muza: :clever: :sorry: :angel: :read: *x)
Ещё смайлики…
   
К этому ответу прикреплено по крайней мере одно вложение.

Если вы не хотите добавлять вложения, оставьте поля пустыми.

Максимально разрешённый размер вложения: 15 МБ.

  • Похожие темы
    Ответы
    Просмотры
    Последнее сообщение

Вернуться в «Python»