Код: Выделить всё
import torch
import torch.nn as nn
class LogisticNet(nn.Module):
def __init__(self):
super(LogisticNet, self).__init__()
self.pol1 = nn.MaxPool3d(2)
self.fc1 = nn.Linear(2048, 4)
self.act1 = nn.Softmax(dim=1)
def forward(self, input):
p1 = torch.flatten(self.pol1(input), 1)
s1 = self.fc1(p1)
output = self.act1(s1)
return output
Код: Выделить всё
def export_ONNX(model):
tensor_x = torch.rand((1, 1, 32, 16, 32), dtype=torch.float32)
torch.onnx.export(model, (tensor_x,), "my_model.onnx", input_names=["input"], dynamo=True)
def main():
net = LogisticNet()
print(net)
params = list(net.parameters())
#print(len(params))
#print(params[0].size())
print(params)
export_ONNX(net)
[*]MaxPool
Изменить форму
[*]Gemm (для линейного слоя)
[*]Softmax
Вместо этого граф выглядит слишком сложным из-за дополнительных узлов I не понимаю. Вот скриншот графика:

Кто-нибудь может объяснить:
- < li>Почему граф ONNX сложнее, чем ожидалось?
- Как создать упрощенный граф ONNX, содержащий только ожидаемые узлы?
- Версия PyTorch: 2.5.1
- Версия ONNX: 1.17.0
Подробнее здесь: https://stackoverflow.com/questions/793 ... ated-graph
Мобильная версия