Почему форма этого тензора внезапно отличается от формы тензора Pytorch в библиотеке NNI?Python

Программы на Python
Ответить Пред. темаСлед. тема
Anonymous
 Почему форма этого тензора внезапно отличается от формы тензора Pytorch в библиотеке NNI?

Сообщение Anonymous »

Я использую библиотеку NNI для обрезки нейронной сети.
При сжатии она использует функцию. _metric_fuse. В этой функции происходит ошибка. Ошибка:
RuntimeError: выходные данные с формой [12, 1] не соответствуют широковещательной форме [12, 12]
Функция выглядит вот так:

Код: Выделить всё

def _metric_fuse(metrics: _METRICS) -> torch.Tensor:
# mean all metric value
fused_metric = None
count = 0
for _, module_metrics in metrics.items():
for _, target_metric in module_metrics.items():
print("-----------------------------------------------------------")
print(fused_metric)
print(target_metric)
if fused_metric is not None:
fused_metric += target_metric
else:
fused_metric = target_metric.clone()
count += 1
assert fused_metric is not None
return fused_metric / count
Ошибка возникает в Fused_metric += target_metric. Напечатанные тензоры выглядят так:

Код: Выделить всё

None

tensor([[1.6796e-06],
[2.3090e-07],
[6.1086e-07],
[1.5183e-07],
[3.7875e-09],
[8.8616e-07],
[6.5265e-07],
[9.6800e-08],
[8.1835e-07],
[2.9000e-07],
[1.1585e-07],
[5.4605e-08]], device='cuda:0')

-----------------------------------------------------------
tensor([[1.6796e-06],
[2.3090e-07],
[6.1086e-07],
[1.5183e-07],
[3.7875e-09],
[8.8616e-07],
[6.5265e-07],
[9.6800e-08],
[8.1835e-07],
[2.9000e-07],
[1.1585e-07],
[5.4605e-08]], device='cuda:0')

tensor([[3.4065e-06],
[2.0305e-06],
[2.3129e-06],
[1.8377e-06],
[6.3720e-05],
[3.4308e-06],
[2.4689e-06],
[2.8250e-06],
[1.7304e-06],
[4.2669e-06],
[5.7278e-07],
[4.4945e-07]], device='cuda:0')

-----------------------------------------------------------
tensor([[5.0860e-06],
[2.2613e-06],
[2.9237e-06],
[1.9896e-06],
[6.3724e-05],
[4.3169e-06],
[3.1215e-06],
[2.9218e-06],
[2.5488e-06],
[4.5569e-06],
[6.8863e-07],
[5.0406e-07]], device='cuda:0')

tensor([[2.3013e-06],
[2.6519e-07],
[6.8275e-07],
[1.2156e-07],
[8.3832e-09],
[1.1710e-06],
[7.8374e-07],
[8.3750e-08],
[8.2380e-07],
[3.3028e-07],
[1.9342e-07],
[1.1004e-07]], device='cuda:0')

-----------------------------------------------------------
tensor([[7.3874e-06],
[2.5265e-06],
[3.6065e-06],
[2.1111e-06],
[6.3732e-05],
[5.4879e-06],
[3.9053e-06],
[3.0056e-06],
[3.3726e-06],
[4.8872e-06],
[8.8205e-07],
[6.1409e-07]], device='cuda:0')

tensor([[4.2203e-06, 2.2458e-06, 5.0584e-06, 5.3744e-06, 5.3401e-05, 5.1009e-06,
4.3903e-06, 4.8045e-06, 2.6889e-06, 7.0391e-06, 3.9046e-06, 1.3097e-06]],
device='cuda:0')

Код: Выделить всё

nni==3.0rc1
torch==2.0.1
torchaudio==2.0.2
torchvision==0.15.2
transformers==4.33.1
python = 3.10.15
По какой-то причине оно меняется с [12,1] на [12,12], что меня смущает. Кто-нибудь знает, в чем может быть проблема?

Подробнее здесь: https://stackoverflow.com/questions/791 ... sor-in-the
Реклама
Ответить Пред. темаСлед. тема

Быстрый ответ

Изменение регистра текста: 
Смайлики
:) :( :oops: :roll: :wink: :muza: :clever: :sorry: :angel: :read: *x)
Ещё смайлики…
   
К этому ответу прикреплено по крайней мере одно вложение.

Если вы не хотите добавлять вложения, оставьте поля пустыми.

Максимально разрешённый размер вложения: 15 МБ.

  • Похожие темы
    Ответы
    Просмотры
    Последнее сообщение

Вернуться в «Python»