Можно ли использовать тензор с dtype uint8 для функции потерь, которая позже вызовет «.backward()»?Python

Программы на Python
Ответить Пред. темаСлед. тема
Anonymous
 Можно ли использовать тензор с dtype uint8 для функции потерь, которая позже вызовет «.backward()»?

Сообщение Anonymous »

Я попытался вычислить потери между тензором с dtype float32 и другим с dtype uint8.
Поскольку функция потерь выполняет автоматическое повышение типа, я сначала не выполнил явное преобразование типов.
Вот код:

Код: Выделить всё

import torch
import torch.nn as nn

a = torch.randn(3, 3, dtype=torch.float32, requires_grad=True)
b = torch.randint(0, 256, (3, 3), dtype=torch.uint8)
loss = nn.MSELoss()(a, b)
print(loss.dtype)
loss.backward()
Выход:

Код: Выделить всё

torch.float32
С моей точки зрения, это означает, что продвижение автотипа работает должным образом.
Однако возникает ошибка:

Код: Выделить всё

Traceback (most recent call last):
File "/root/.../test.py", line 8, in 
loss.backward()
File "/root/.../python3.8/site-packages/torch/_tensor.py", line 522, in backward
torch.autograd.backward(
File "/root/.../python3.8/site-packages/torch/autograd/__init__.py", line 266, in backward
Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: Found dtype Byte but expected Float
затем я изменил строчку о потере:

Код: Выделить всё

# loss = nn.MSELoss()(a, b)
loss = nn.MSELoss()(a, b.to(torch.float32))
Это работает.
Но почему?
Я также заметил, что в определении 'b' , я не могу установить require_grad=True, так как это вызовет ошибку:

RuntimeError: Только тензоры с плавающей запятой и сложный тип dtype могут требуются градиенты

Я думаю, что, возможно, это какая-то связь с моей основной проблемой, поэтому я включил ее сюда.
Итак, можно ли использовать тензор с dtype uint8 для функции потерь, которая позже вызовет .backward()?


Подробнее здесь: https://stackoverflow.com/questions/783 ... later-call
Реклама
Ответить Пред. темаСлед. тема

Быстрый ответ

Изменение регистра текста: 
Смайлики
:) :( :oops: :roll: :wink: :muza: :clever: :sorry: :angel: :read: *x)
Ещё смайлики…
   
К этому ответу прикреплено по крайней мере одно вложение.

Если вы не хотите добавлять вложения, оставьте поля пустыми.

Максимально разрешённый размер вложения: 15 МБ.

  • Похожие темы
    Ответы
    Просмотры
    Последнее сообщение

Вернуться в «Python»