Попробовал изменить вывод модели vae, чтобы удовлетворить функцию потерь, но что-то не работаетPython

Программы на Python
Ответить Пред. темаСлед. тема
Anonymous
 Попробовал изменить вывод модели vae, чтобы удовлетворить функцию потерь, но что-то не работает

Сообщение Anonymous »

Я пытался создать модель VAE, которая обычно содержит пользовательские потери, для реализации которых используется GradientTape() или класс. Я не хотел использовать эти методы и вместо этого попробовал обходной путь, который не работает должным образом, и мне хотелось знать, почему.
код модели, который я пробовал - (кодер и декодер только базовые модели, где кодер выводит z_mean, z_log_var, z, а декодер выводит пакет изображений) -

Код: Выделить всё

def build_vae(encoder, decoder, input_shape):
input_layer = layers.Input(shape=input_shape)
z_mean, z_log_var, z = encoder(input_layer)
generated_images = decoder(z)

return models.Model(input_layer, [generated_images, z_mean, z_log_var])
таможенная потеря -

Код: Выделить всё

def custom_vae_loss(y_true, y_pred):
generated_images = y_pred[0]
z_mean = y_pred[1]
z_log_var = y_pred[2]

r_loss = tf.reduce_mean(tf.keras.losses.binary_crossentropy(y_true, generated_images))
kl_loss = -0.5 * (tf.reduce_mean(1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)))

return r_loss+kl_loss
код обучения –

Код: Выделить всё

vae.train_on_batch(batch, batch)
ошибка в том, что target_shape = (8, 256, 256, 3) не соответствует output_shape = (256, 256, 3)
Я думал о том, что список, который должен был быть y_pred, сжимается или что-то в этом роде, так что z_mean и z_log_var удаляются, что означает, что y_pred (из custom_vae_loss) — это пакет сгенерированных изображений, и поэтомуgenerated_images — это просто Первое сгенерированное изображение пакета.
после добавления операторов печати в функцию потерь
y_pred равно Tensor("vae_1/decoder_1/conv2d_transpose_31_1/Sigmoid:0", shape=(8, 256, 256, 3 ), dtype=float32)
форма y_true: (8, 256, 256, 3)
форма сгенерированных_изображений: (256, 256, 3)
форма z_mean: (256, 256, 3)
форма z_log_var: (256, 256, 3)
как мне стоит это исправить? желательно без GradientTape() или класса. (Я хотел сделать это без этих методов)

Подробнее здесь: https://stackoverflow.com/questions/789 ... mething-do
Реклама
Ответить Пред. темаСлед. тема

Быстрый ответ

Изменение регистра текста: 
Смайлики
:) :( :oops: :roll: :wink: :muza: :clever: :sorry: :angel: :read: *x)
Ещё смайлики…
   
К этому ответу прикреплено по крайней мере одно вложение.

Если вы не хотите добавлять вложения, оставьте поля пустыми.

Максимально разрешённый размер вложения: 15 МБ.

  • Похожие темы
    Ответы
    Просмотры
    Последнее сообщение

Вернуться в «Python»