Преобразование .pth в .onnx разрушает модель u2netPython

Программы на Python
Ответить Пред. темаСлед. тема
Anonymous
 Преобразование .pth в .onnx разрушает модель u2net

Сообщение Anonymous »

Задача такая: на сайт нужно добавить функционал по удалению фона с изображения автомобильного диска.
За основу решил взять библиотеку rembg: https ://github.com/danielgatis/rembg
Эта библиотека в свою очередь работает на базе u2net: https://github.com/xuebinqin/U-2-Net
Однако стандартная модель u2net убирает весь фон снаружи, оставляя нетронутым пространство внутри диска - между спицами, отверстиями и т.д.
Погуглив немного, я пришел к вывод, что я могу дополнительно обучить модель u2net под свои конкретные нужды.
Алгоритм действий следующий:
  • < li>дальнейшее обучение модели
  • загрузка ее в rembg
  • использование пользовательской модели для обрезки
Мне удалось обучить стандартную модель u2net, она отлично вырезает чёрно-белые маски так, как мне нужно.
Однако при конвертации модели из .pth в . onnx (который необходим для работы в rembg), он начинает плохо работать.
Маски размытые и мыльные. Пробовал конвертировать стандартную необученную модель u2net
и использовать в rembg - результат тот же, маски размыты, обрезка фона не работает.
Поэтому вывод - обучение прошло успешно.
Проблема в конвертации.
Итак, вот примеры маски моей обученной модели.
Исходное изображение
Маска, которую создает моя обученная модель
Маска, которую создает моя обученная модель после преобразования в формат .onnx
Для создания маски в u2net я использую:

Код: Выделить всё

python3 u2net_test.py
Чтобы сгенерировать маску в rembg, я использую команду:

Код: Выделить всё

rembg i -om -m u2net_custom -x '{"model_path": "~/.u2net/u2net_custom.onnx"}' 55.jpg 55.png
Я попробовал конвертировать готовую модель. Вот код преобразования:

Код: Выделить всё

import torch
import torch.onnx
from model.u2net import U2NET

def load_model(model_path, model_class):
checkpoint = torch.load(model_path, map_location='cpu')
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
model = model_class()
model.load_state_dict(checkpoint['state_dict'])
else:
model = model_class()
model.eval()
return model

def convert_to_onnx(model, output_path):
dummy_input = torch.randn(1, 3, 320, 320)
torch.onnx.export(model, dummy_input, output_path, opset_version=12,
dynamic_axes={'input': {0: 'batch_size', 2: 'height', 3: 'width'},
'output': {0: 'batch_size', 2: 'height', 3: 'width'}})
print(f"success {output_path}")

if __name__ == "__main__":
import argparse

parser = argparse.ArgumentParser(description="conversion PyTorch to ONNX")
parser.add_argument('--model-path', type=str, required=True, help='path to .pth file')
parser.add_argument('--output-path', type=str, required=True, help='save ONNX file')

args = parser.parse_args()
model = load_model(args.model_path, U2NET)
convert_to_onnx(model, args.output_path)
и попытался сохранить модель в процессе обучения:

Код: Выделить всё

        if ite_num % save_frq == 0:
timestamp = int(time.time())
filePath = model_dir + model_name+"_%d_%d." % (ite_num, timestamp)

torch.save(net.state_dict(), filePath + 'pth')

dummy_input = torch.randn(1, 3, 320, 320)
net.eval()
torch.onnx.export(net, dummy_input, filePath + 'onnx', opset_version=12)

running_loss = 0.0
running_tar_loss = 0.0
net.train()  # resume train
ite_num4val = 0
Я пробовал менять настройки, менять версии библиотек, менять opset_version и все остальное, что предлагает ChatGpt.
Результат всегда один и тот же.
После конвертации модель перестает работать .
Какие ошибки я допустил?

Подробнее здесь: https://stackoverflow.com/questions/791 ... odel-u2net
Реклама
Ответить Пред. темаСлед. тема

Быстрый ответ

Изменение регистра текста: 
Смайлики
:) :( :oops: :roll: :wink: :muza: :clever: :sorry: :angel: :read: *x)
Ещё смайлики…
   
К этому ответу прикреплено по крайней мере одно вложение.

Если вы не хотите добавлять вложения, оставьте поля пустыми.

Максимально разрешённый размер вложения: 15 МБ.

  • Похожие темы
    Ответы
    Просмотры
    Последнее сообщение
  • Преобразование .pth в .onnx разрушает модель u2net
    Anonymous » » в форуме Python
    0 Ответы
    27 Просмотры
    Последнее сообщение Anonymous
  • Преобразование .pth в .onnx разрушает модель u2net
    Anonymous » » в форуме Python
    0 Ответы
    14 Просмотры
    Последнее сообщение Anonymous
  • Преобразование .pth в .onnx разрушает модель u2net
    Anonymous » » в форуме Python
    0 Ответы
    45 Просмотры
    Последнее сообщение Anonymous
  • Преобразовать пользовательскую модель Yolox (.pth) в модель Tflite
    Anonymous » » в форуме Python
    0 Ответы
    6 Просмотры
    Последнее сообщение Anonymous
  • ONNX - Как мне преобразовать модель ONNX Float32 в BFLOAT16?
    Anonymous » » в форуме Python
    0 Ответы
    20 Просмотры
    Последнее сообщение Anonymous

Вернуться в «Python»