Почему мой экспорт PyTorch ONNX создает слишком сложный график?Python

Программы на Python
Ответить
Anonymous
 Почему мой экспорт PyTorch ONNX создает слишком сложный график?

Сообщение Anonymous »

Я пытаюсь экспортировать простую нейронную сеть (NN), определенную в PyTorch, в формат ONNX, но сгенерированный график кажется слишком сложным. Вот моя модель PyTorch:

Код: Выделить всё

import torch
import torch.nn as nn

class LogisticNet(nn.Module):
def __init__(self):
super(LogisticNet, self).__init__()
self.pol1 = nn.MaxPool3d(2)
self.fc1 = nn.Linear(2048, 4)
self.act1 = nn.Softmax(dim=1)

def forward(self, input):
p1 = torch.flatten(self.pol1(input), 1)
s1 = self.fc1(p1)
output = self.act1(s1)
return output
Вот мой код экспорта ONNX:

Код: Выделить всё

def export_ONNX(model):
tensor_x = torch.rand((1, 1, 32, 16, 32), dtype=torch.float32)
torch.onnx.export(model, (tensor_x,), "my_model.onnx", input_names=["input"], dynamo=True)

def main():
net = LogisticNet()
print(net)
params = list(net.parameters())
#print(len(params))
#print(params[0].size())
print(params)
export_ONNX(net)
Когда я открываю экспортированный граф ONNX, я ожидаю, что он будет содержать только следующие узлы:

[*]MaxPool
Изменить форму
[*]Gemm (для линейного слоя)
[*]Softmax

Вместо этого граф выглядит слишком сложным из-за дополнительных узлов I не понимаю. Вот скриншот графика:
Изображение

Кто-нибудь может объяснить:
  • < li>Почему граф ONNX сложнее, чем ожидалось?
  • Как создать упрощенный граф ONNX, содержащий только ожидаемые узлы?
Окружающая среда:
  • Версия PyTorch: 2.5.1
  • Версия ONNX: 1.17.0


Подробнее здесь: https://stackoverflow.com/questions/793 ... ated-graph
Ответить

Быстрый ответ

Изменение регистра текста: 
Смайлики
:) :( :oops: :roll: :wink: :muza: :clever: :sorry: :angel: :read: *x)
Ещё смайлики…
   
К этому ответу прикреплено по крайней мере одно вложение.

Если вы не хотите добавлять вложения, оставьте поля пустыми.

Максимально разрешённый размер вложения: 15 МБ.

Вернуться в «Python»