Как установить имя пользовательского оператора в TensorFlow LitePython

Программы на Python
Ответить
Anonymous
 Как установить имя пользовательского оператора в TensorFlow Lite

Сообщение Anonymous »

Я пытаюсь создать собственный оператор в TFLite, завершая функцию lfilter scipy. Цель состоит в том, чтобы связать его с реализацией на C++ при вызове интерпретатора TFLite (из эквивалентной пользовательской библиотеки TFLite).
В настоящее время я могу экспортировать файл .tflite, но с одним предостережение: имя пользовательского оператора — PyFunc. Есть ли способ правильно установить имя пользовательского оператора? Обратите внимание, что моя цель — инкапсулировать операцию в черный ящик, а не в TF-трансляцию операции.
Вот минимальный пример
< pre class="lang-py Prettyprint-override">

Код: Выделить всё

import tensorflow as tf
import keras
from scipy import signal
import numpy as np

LFILTER_COEFF_DTYPE = np.float32
LFILTER_DATA_DTYPE = np.float32

@tf.numpy_function(Tout=LFILTER_DATA_DTYPE, name="Lfilter")
def np_lfilter(b, a, x):
y = signal.lfilter(b, a, x)
return y.astype(LFILTER_DATA_DTYPE)

@tf.function(
input_signature=[
tf.TensorSpec(shape=[None], dtype=LFILTER_COEFF_DTYPE, name="b"),
tf.TensorSpec(shape=[None], dtype=LFILTER_COEFF_DTYPE, name="a"),
tf.TensorSpec(shape=[None, None], dtype=LFILTER_DATA_DTYPE, name="x"),
],
autograph=False,
)
def tf_lfilter(b, a, x):
return np_lfilter(b, a, x)

@keras.saving.register_keras_serializable()
class Lfilter(keras.layers.Layer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.b = self.add_weight(
name="b",
shape=(3,),
initializer="ones",
trainable=True,
)
self.a = self.add_weight(
name="a",
shape=(3,),
initializer="ones",
trainable=True,
)

def get_config(self):
config = super().get_config()
config.update(
{
"b": self.b.numpy(),
"a": self.a.numpy(),
}
)
return config

def call(self, x, training=None):
if training:
...
# not relevant...
else:
y = tf_lfilter(self.b, self.a, x)
# For some reason, y comes out with unknown shape
# Hence we need to set the shape manually
# See https://stackoverflow.com/questions/75110247/keras-custom-layer-unknown-output-shape
y.set_shape(x.shape)
return y

def compute_output_shape(self, input_shape):
return input_shape

@keras.saving.register_keras_serializable()
class ModelLfilter(tf.keras.Model):

def __init__(self, **kwargs):
super().__init__(**kwargs)
self.filter = Lfilter()

def call(self, x):
return self.filter.call(x)

if __name__ == "__main__":
import os

CURR_DIR = os.path.dirname(os.path.abspath(__file__))
input_shape = (1, 32)

def convert_to_tflite(model, input_shape, name=None):
if name is not None:
model.name = name
tf_callable = tf.function(
model.call,
autograph=False,
input_signature=[tf.TensorSpec(input_shape, LFILTER_DATA_DTYPE)],
)
tf_concrete_function = tf_callable.get_concrete_function()
converter = tf.lite.TFLiteConverter.from_concrete_functions(
[tf_concrete_function], tf_callable
)
converter.allow_custom_ops = True
tflite_model = converter.convert()

with open(os.path.join(CURR_DIR, f"{model.name}.tflite"), "wb") as f:
f.write(tflite_model)

# SVF
lfilter = ModelLfilter()
lfilter(tf.zeros(input_shape))
convert_to_tflite(lfilter, input_shape, "lfilter")
Что я получаю вывод:

Код: Выделить всё

W0000 00:00:1733157983.471392    8472 tf_tfl_flatbuffer_helpers.cc:392] Ignored output_format.
W0000 00:00:1733157983.471771    8472 tf_tfl_flatbuffer_helpers.cc:395] Ignored drop_control_dependency.
2024-12-02 17:46:23.584285: W tensorflow/compiler/mlir/lite/flatbuffer_export.cc:3474] The following operation(s) need TFLite custom op implementation(s):
Custom ops: PyFunc
Details:
tf.PyFunc(tensor, tensor, tensor) ->  (tensor) : {Tin = [f32, f32, f32], Tout = [f32], device = "/job:localhost/replica:0/task:0/device:CPU:0", token = "pyfunc_0"}
See instructions: https://www.tensorflow.org/lite/guide/ops_custom
А вот скриншот архитектуры модели от Netron (обратите внимание, что имена весов b и a также не были сериализованы должным образом).
Я пробовал следовать официальному примеру из документации, но это не помогло, поскольку экспорт в этом случае генерирует оператор на основе уже существующего tf.atan< /п>

Подробнее здесь: https://stackoverflow.com/questions/792 ... rflow-lite
Ответить

Быстрый ответ

Изменение регистра текста: 
Смайлики
:) :( :oops: :roll: :wink: :muza: :clever: :sorry: :angel: :read: *x)
Ещё смайлики…
   
К этому ответу прикреплено по крайней мере одно вложение.

Если вы не хотите добавлять вложения, оставьте поля пустыми.

Максимально разрешённый размер вложения: 15 МБ.

Вернуться в «Python»