JAX-сканирование по ведущему измерению (обычный способ) или по индексуPython

Программы на Python
Anonymous
JAX-сканирование по ведущему измерению (обычный способ) или по индексу

Сообщение Anonymous »

Я много раз использовал сканирование в своем проекте и случайно обнаружил, что сканирование массива по ведущей оси (normal_scan в примере ниже) медленнее, чем сканирование с индексом (scan_with_index).
Кто-нибудь скажет мне, почему это происходит? Можно ли это делать (есть ли у нас негативные побочные эффекты, например, точность)?
import jax
import jax.numpy as jnp
import time

@jax.jit
def normal_scan(x0, arr):

def body(state, input):
x = state
a = input
new_state = x + a**2 + jnp.sin(a)
return (new_state, x)
state = (x0)
input = arr
result = jax.lax.scan(body, state, input)
return result

@jax.jit
def scan_with_index(x0, arr):
N = len(arr)
def body(state, input):
x, ind = state
new_state = (x+arr[ind]**2 + jnp.sin(arr[ind]), ind + 1)
return (new_state, x)
state = (x0, 0)

result = jax.lax.scan(body, state, length=N)
return result

if __name__ == "__main__":
key = jax.random.key(0)
N = 100

arr = jax.random.normal(key, (N, 2))
x0 = jnp.array([1.0, 2.0])

# warm up
for i in range(2):
result1 = normal_scan(x0, arr)
result2 = scan_with_index(x0, arr)

start_time = time.time()
for i in range(100):
result1 = normal_scan(x0, arr)
result1[0][0].block_until_ready()
end_time = time.time()

print(f"Execution time: {end_time - start_time:.4f} seconds")
# around 0.08

start_time = time.time()
for i in range(100):
result2 = scan_with_index(x0, arr)
result2[0][0].block_until_ready()
end_time = time.time()

print(f"Execution time: {end_time - start_time:.4f} seconds")
# around 0.04


Подробнее здесь: https://stackoverflow.com/questions/797 ... h-an-index

Вернуться в «Python»