InvalidArgumentError в пользовательской функции потери KerasPython

Программы на Python
Ответить
Anonymous
 InvalidArgumentError в пользовательской функции потери Keras

Сообщение Anonymous »

Я следовал ответу на вопрос ниже, чтобы создать собственную функцию потерь в Keras, которая учитывает только 20 лучших прогнозов.
Как я могу отсортировать значения в пользовательской функции потерь? Функция потерь Keras/Tensorflow?
Однако, когда я пытаюсь скомпилировать свою модель с помощью этого кода, я получаю следующую ошибку о размерах:

Код: Выделить всё

InvalidArgumentError: input must have last dimension >= k = 20 but is 1 for 'loss_21/dense_65_loss/TopKV2' (op: 'TopKV2') with input shapes: [?,1], [] and with computed input tensors: input[1] = .
Следующая упрощенная версия кода, воспроизводящего ошибку.

Код: Выделить всё

import tensorflow as tf
from tensorflow.python.keras.models import Sequential
from tensorflow.python.keras.layers import Dense
from tensorflow.python.keras.optimizers import SGD

top = 20

def top_loss(y_true, y_pred):
y_pred_top_k, y_pred_ind_k = tf.nn.top_k(y_pred, top)
loss_per_sample = tf.reduce_mean(tf.reduce_sum(y_pred_top_k,
axis=-1))

return loss_per_sample

model = Sequential()
model.add(Dense(50, input_dim=201))
model.add(Dense(1))
sgd = SGD(lr=0.01, decay=0, momentum=0.9)
model.compile(loss=top_loss, optimizer=sgd)
и ошибка выдается в следующей строке функции top_loss при компиляции модели.

Код: Выделить всё

y_pred_top_k, y_pred_ind_k = tf.nn.top_k(y_pred, top)
Похоже, что y_pred во время компиляции по умолчанию имеет форму [?,1], тогда как функция tf.nn.top_k ожидает размерность как минимум выше, чем 'k` (т.е. 20).
Нужно ли мне привести y_pred к чему-то, чтобы tf.nn.top_k знал, что это правильные размеры?

Подробнее здесь: https://stackoverflow.com/questions/559 ... s-function
Ответить

Быстрый ответ

Изменение регистра текста: 
Смайлики
:) :( :oops: :roll: :wink: :muza: :clever: :sorry: :angel: :read: *x)
Ещё смайлики…
   
К этому ответу прикреплено по крайней мере одно вложение.

Если вы не хотите добавлять вложения, оставьте поля пустыми.

Максимально разрешённый размер вложения: 15 МБ.

Вернуться в «Python»