Почему jit-компиляция JAX при втором запуске в моем примере медленнее?Python

Программы на Python
Ответить Пред. темаСлед. тема
Anonymous
 Почему jit-компиляция JAX при втором запуске в моем примере медленнее?

Сообщение Anonymous »

Я новичок в использовании JAX и все еще знакомлюсь с тем, как он работает. Насколько я понимаю, при использовании JIT-компиляции (jax.jit) первое выполнение функции может быть медленнее из-за накладных расходов на компиляцию, но последующие выполнения должны быть быстрее. Однако я наблюдаю противоположное поведение.
В следующем фрагменте кода:

Код: Выделить всё

from icecream import ic
import jax
from time import time
import numpy as np

@jax.jit
def my_function(x, y):
return x @ y

vectorized_function = jax.vmap(my_function, in_axes=(0, None))

shape = (1_000_000, 1_000)

x = np.ones(shape)
y = np.ones(shape[1])

start = time()
vectorized_function(x, y)
t_1 = time() - start

start = time()
vectorized_function(x, y)
t_2 = time() - start

print(f'{t_1 = }\n{t_2 = }')

Я получаю следующие результаты:

Код: Выделить всё

t_1 = 13.106784582138062
t_2 = 15.664098024368286
Как видите, второй запуск (t_2) на самом деле медленнее, чем первый (t_1), что мне кажется нелогичным. Я ожидал, что второй запуск будет быстрее из-за JIT-кэширования JAX.
Кто-нибудь сталкивался с подобной ситуацией или знает, почему это может происходить?
PS: Я знаю, что мог бы сделать x @ y напрямую, не вызывая vmap, но это простой пример, просто чтобы проверить его поведение. Мой реальный код сложнее, а разница во времени выполнения еще больше (примерно в 8 раз медленнее). Надеюсь, этот простой пример работает аналогично.

Подробнее здесь: https://stackoverflow.com/questions/791 ... my-example
Реклама
Ответить Пред. темаСлед. тема

Быстрый ответ

Изменение регистра текста: 
Смайлики
:) :( :oops: :roll: :wink: :muza: :clever: :sorry: :angel: :read: *x)
Ещё смайлики…
   
К этому ответу прикреплено по крайней мере одно вложение.

Если вы не хотите добавлять вложения, оставьте поля пустыми.

Максимально разрешённый размер вложения: 15 МБ.

  • Похожие темы
    Ответы
    Просмотры
    Последнее сообщение

Вернуться в «Python»