В следующем фрагменте кода:
Код: Выделить всё
from icecream import ic
import jax
from time import time
import numpy as np
@jax.jit
def my_function(x, y):
return x @ y
vectorized_function = jax.vmap(my_function, in_axes=(0, None))
shape = (1_000_000, 1_000)
x = np.ones(shape)
y = np.ones(shape[1])
start = time()
vectorized_function(x, y)
t_1 = time() - start
start = time()
vectorized_function(x, y)
t_2 = time() - start
print(f'{t_1 = }\n{t_2 = }')
Код: Выделить всё
t_1 = 13.106784582138062
t_2 = 15.664098024368286
Кто-нибудь сталкивался с подобной ситуацией или знает, почему это может происходить?
PS: Я знаю, что мог бы сделать x @ y напрямую, не вызывая vmap, но это простой пример, просто чтобы проверить его поведение. Мой реальный код сложнее, а разница во времени выполнения еще больше (примерно в 8 раз медленнее). Надеюсь, этот простой пример работает аналогично.
Подробнее здесь: https://stackoverflow.com/questions/791 ... my-example