Можно ли скомпилировать собственный оператор vmapped в PyTorch?Python

Программы на Python
Ответить
Anonymous
 Можно ли скомпилировать собственный оператор vmapped в PyTorch?

Сообщение Anonymous »

Я пытаюсь скомпилировать собственный оператор vmapped в Pytorch и получаю ошибку «INTERNAL ASSERT FAILED».
Вот четыре тестовые функции, которые показывают, что собственные операторы Pytorch могут работают с torch.compile и vmap, а пользовательские операторы работают с torch.compile или vmap отдельно, но не вместе.
import torch, os
from torch import Tensor

lib = torch.library.Library('mylib', 'FRAGMENT')

@torch.library.custom_op('mylib::inc_custom', mutates_args=())
def inc_custom(x: Tensor) -> Tensor:
return x + 1

@torch.library.register_fake('mylib::inc_custom')
def _(x):
return torch.empty_like(x)

@torch.library.register_vmap('mylib::inc_custom')
def inc_vmap(info, in_dims, x):
return inc_custom(x), 0

def inc_simple(x: Tensor) -> Tensor:
return x + 1

f1 = torch.compile( torch.vmap(inc_simple) )
f2 = torch.compile(inc_custom)
f3 = torch.vmap(inc_custom)
f4 = torch.compile(f3)

a = torch.arange(8)
b = a.view(2, 4)

print('inc_simple vmap + compile:')
print( f1(b) )

print('inc_custom compile:')
print( f2(a) )

print('inc_custom vmap:')
print( f3(b) )

print('inc_custom vmap + compile:')
print( f4(b) )

Выход:
inc_simple vmap + compile:
tensor([[1, 2, 3, 4],
[5, 6, 7, 8]])
inc_custom compile:
tensor([1, 2, 3, 4, 5, 6, 7, 8])
inc_custom vmap:
tensor([[1, 2, 3, 4],
[5, 6, 7, 8]])
inc_custom vmap + compile:
Traceback (most recent call last):
File "/home/alex/python_tests/test2.py", line 42, in
print( f4(b) )

...
...
...

File "/home/alex/miniconda3/envs/myenv/lib/python3.11/site-packages/torch/_ops.py", line 716, in __call__
return self._op(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^
torch._dynamo.exc.TorchRuntimeError: Failed running call_function mylib.inc_custom.default(*(BatchedTensor(lvl=1, bdim=0, value=
FakeTensor(..., size=(2, 4), dtype=torch.int64)
),), **{}):
tls_on_entry.has_value() INTERNAL ASSERT FAILED at "/opt/conda/conda-bld/pytorch_1729647429097/work/aten/src/ATen/core/PythonFallbackKernel.cpp":49, please report a bug to PyTorch.

from user code:
File "/home/alex/miniconda3/envs/myenv/lib/python3.11/site-packages/torch/_functorch/apis.py", line 203, in wrapped
return vmap_impl(
File "/home/alex/miniconda3/envs/myenv/lib/python3.11/site-packages/torch/_functorch/vmap.py", line 331, in vmap_impl
return _flat_vmap(
File "/home/alex/miniconda3/envs/myenv/lib/python3.11/site-packages/torch/_functorch/vmap.py", line 479, in _flat_vmap
batched_outputs = func(*batched_inputs, **kwargs)
File "/home/alex/miniconda3/envs/myenv/lib/python3.11/site-packages/torch/_library/custom_ops.py", line 669, in __call__
return self._opoverload(*args, **kwargs)


Подробнее здесь: https://stackoverflow.com/questions/792 ... in-pytorch
Ответить

Быстрый ответ

Изменение регистра текста: 
Смайлики
:) :( :oops: :roll: :wink: :muza: :clever: :sorry: :angel: :read: *x)
Ещё смайлики…
   
К этому ответу прикреплено по крайней мере одно вложение.

Если вы не хотите добавлять вложения, оставьте поля пустыми.

Максимально разрешённый размер вложения: 15 МБ.

Вернуться в «Python»