Ошибка на этапе прогнозирования после импорта модели Keras в тензорный поток JavaPython

Программы на Python
Ответить Пред. темаСлед. тема
Anonymous
 Ошибка на этапе прогнозирования после импорта модели Keras в тензорный поток Java

Сообщение Anonymous »

Моя цель — использовать модель Keras в программе Java.
Я экспортирую модель Keras с помощью model.export(), а не model.save(), поэтому я получаю папку с модель в формате .pb.
Затем я использовал py .\saved_model_cli.py show -- dir '.' -all , чтобы увидеть входные и выходные данные для заполнения Java-кода.
Я понимаю:

Код: Выделить всё

MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:

signature_def['__saved_model_init_op']:
The given SavedModel SignatureDef contains the following input(s):
The given SavedModel SignatureDef contains the following output(s):
outputs['__saved_model_init_op'] tensor_info:
dtype: DT_INVALID
shape: unknown_rank
name: NoOp
Method name is:

signature_def['serve']:
The given SavedModel SignatureDef contains the following input(s):
inputs['keras_tensor'] tensor_info:
dtype: DT_FLOAT
shape: (-1, 6)
name: serve_keras_tensor:0
The given SavedModel SignatureDef contains the following output(s):
outputs['output_0'] tensor_info:
dtype: DT_FLOAT
shape: (-1, 1)
name: StatefulPartitionedCall:0
Method name is: tensorflow/serving/predict

signature_def['serving_default']:
The given SavedModel SignatureDef contains the following input(s):
inputs['keras_tensor'] tensor_info:
dtype: DT_FLOAT
shape: (-1, 6)
name: serving_default_keras_tensor:0
The given SavedModel SignatureDef contains the following output(s):
outputs['output_0'] tensor_info:
dtype: DT_FLOAT
shape: (-1, 1)
name: StatefulPartitionedCall_1:0
Method name is: tensorflow/serving/predict
The MetaGraph with tag set ['serve'] contains the following ops: {'ReadVariableOp', 'Select', 'StatefulPartitionedCall', 'RestoreV2', 'NoOp', 'Identity', 'StaticRegexFullMatch', 'StringJoin', 'AssignVariableOp', 'SaveV2', 'MergeV2Checkpoints', 'VarIsInitializedOp', 'AddV2', 'VarHandleOp', 'DisableCopyOnRead', 'Pack', 'Placeholder', 'MatMul', 'Const', 'Relu', 'ShardedFilename'}

Concrete Functions:2024-11-12 16:47:24.597134: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.

Function Name: 'serve'
Option #1
Callable with:
Argument #1
keras_tensor: TensorSpec(shape=(None, 6), dtype=tf.float32, name='keras_tensor')
Наконец, Java-код для импорта и прогнозирования:

Код: Выделить всё

public static void importKerasModel() {
try (SavedModelBundle model = SavedModelBundle.load("PATH\kerasModel", "serve")) {
float[] x = {0.48f, 0.48f, 0.48f, 0.48f, 0.48f, 0.48f};
try (Tensor input = TFloat32.vectorOf(x);
Tensor output = model.session()
.runner()
.feed("serve_keras_tensor", input)
.fetch("StatefulPartitionedCall")
.run()
.get(0)) {

float prediction = output.dataType().getNumber();
System.out.println("prediction = " + prediction);
}
}
}

Но я получаю следующее сообщение об ошибке:

Код: Выделить всё

2024-11-12 17:26:01.089591: I tensorflow/cc/saved_model/loader.cc:317] SavedModel load for tags { serve }; Status: success: OK.  Took 61548 microseconds.
2024-11-12 17:26:01.317247: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: INVALID_ARGUMENT: In[0] is not a matrix
[[{{node StatefulPartitionedCall/StatefulPartitionedCall/sequential_1/dense_1/Relu}}]]
Exception in thread "main" org.tensorflow.exceptions.TFInvalidArgumentException: In[0] is not a matrix
[[{{node StatefulPartitionedCall/StatefulPartitionedCall/sequential_1/dense_1/Relu}}]]
at org.tensorflow.internal.c_api.AbstractTF_Status.throwExceptionIfNotOK(AbstractTF_Status.java:76)
at org.tensorflow.Session.run(Session.java:826)
at org.tensorflow.Session$Runner.runHelper(Session.java:549)
at org.tensorflow.Session$Runner.run(Session.java:476)
at com.ptvgroup.platform.truckslogs.converter.HelloTensorFlow.importKerasModel(HelloTensorFlow.java:471)
at com.ptvgroup.platform.truckslogs.converter.Main.main(Main.java:25)
Кто-нибудь может мне помочь? Что значит «In[0] не является матрицей»? Это потому, что мои размеры/формы входов и выходов равны (-1,6) и (-1,1)?

Подробнее здесь: https://stackoverflow.com/questions/791 ... rflow-java
Реклама
Ответить Пред. темаСлед. тема

Быстрый ответ

Изменение регистра текста: 
Смайлики
:) :( :oops: :roll: :wink: :muza: :clever: :sorry: :angel: :read: *x)
Ещё смайлики…
   
К этому ответу прикреплено по крайней мере одно вложение.

Если вы не хотите добавлять вложения, оставьте поля пустыми.

Максимально разрешённый размер вложения: 15 МБ.

  • Похожие темы
    Ответы
    Просмотры
    Последнее сообщение

Вернуться в «Python»