Пакетное умножение матриц с помощью JAX на графическом процессоре быстрее с матрицами большего размераPython

Программы на Python
Ответить Пред. темаСлед. тема
Anonymous
 Пакетное умножение матриц с помощью JAX на графическом процессоре быстрее с матрицами большего размера

Сообщение Anonymous »

Я пытаюсь выполнить пакетное умножение матриц с помощью JAX на графическом процессоре и заметил, что умножение фигур (1000, 1000, 3, 35) @ (1000, 1000, 35, 1) происходит примерно в 3 раза быстрее, чем на самом деле. умножить (1000, 1000, 3, 25) @ (1000, 1000, 25, 1) на f64 и ~5x на f32.
Что объясняет эту разницу, учитывая, что на процессоре ни JAX, ни NumPy не показывают такого поведения, а на графическом процессоре CuPy не показывает такого поведения.
Я запускаю это с JAX: 0.4.32 на NVIDIA RTX A5000 (и получаю аналогичные результаты на Tesla T4), код для воспроизведения:

Код: Выделить всё

import numpy as np
import cupy as cp
from cupyx.profiler import benchmark
from jax import config
config.update("jax_enable_x64", True)
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt

rng = np.random.default_rng()

x = np.arange(5, 55, 5)
Тайминги графического процессора:

Код: Выделить всё

dtype = cp.float64
timings_cp = []
for i in range(5, 55, 5):
a = cp.array(rng.random((1000, 1000, 3, i)), dtype=dtype)
b = cp.array(rng.random((1000, 1000, i, 1)), dtype=dtype)
timings_cp.append(benchmark(lambda a, b: a@b, (a, b), n_repeat=10, n_warmup=10))

dtype = jnp.float64
timings_jax_gpu = []
with jax.default_device(jax.devices('gpu')[0]):
for i in range(5, 55, 5):
a = jnp.array(rng.random((1000, 1000, 3, i)), dtype=dtype)
b = jnp.array(rng.random((1000, 1000, i, 1)), dtype=dtype)
func = jax.jit(lambda a, b: a@b)
timings_jax_gpu.append(benchmark(lambda a, b: func(a, b).block_until_ready(), (a, b), n_repeat=10, n_warmup=10))

plt.figure()
plt.plot(x, [i.gpu_times.mean() for i in timings_cp], label="CuPy")
plt.plot(x, [i.gpu_times.mean() for i in timings_jax_gpu], label="JAX GPU")
plt.legend()
Изображение

Время с этими конкретными фигурами:

Код: Выделить всё

dtype = jnp.float64
with jax.default_device(jax.devices('gpu')[0]):
a = jnp.array(rng.random((1000, 1000, 3, 25)), dtype=dtype)
b = jnp.array(rng.random((1000, 1000, 25, 1)), dtype=dtype)
func = jax.jit(lambda a, b: a@b)
print(benchmark(lambda a, b: func(a, b).block_until_ready(), (a, b), n_repeat=1000, n_warmup=10).gpu_times.mean())

a = jnp.array(rng.random((1000, 1000, 3, 35)), dtype=dtype)
b = jnp.array(rng.random((1000, 1000, 35, 1)), dtype=dtype)
print(benchmark(lambda a, b: func(a, b).block_until_ready(), (a, b), n_repeat=1000, n_warmup=10).gpu_times.mean())
Дает

Код: Выделить всё

f64:
0.01453789699935913
0.004859122595310211

f32:

0.005860503035545349
0.001209742688536644
Тайминги процессора:

Код: Выделить всё

timings_np = []
for i in range(5, 55, 5):
a = rng.random((1000, 1000, 3, i))
b = rng.random((1000, 1000, i, 1))
timings_np.append(benchmark(lambda a, b: a@b, (a, b), n_repeat=10, n_warmup=10))

timings_jax_cpu = []
with jax.default_device(jax.devices('cpu')[0]):
for i in range(5, 55, 5):
a = jnp.array(rng.random((1000, 1000, 3, i)))
b = jnp.array(rng.random((1000, 1000, i, 1)))
func = jax.jit(lambda a, b: a@b)
timings_jax_cpu.append(benchmark(lambda a, b: func(a, b).block_until_ready(), (a, b), n_repeat=10, n_warmup=10))

plt.figure()
plt.plot(x, [i.cpu_times.mean() for i in timings_np], label="NumPy")
plt.plot(x, [i.cpu_times.mean() for i in timings_jax_cpu], label="JAX CPU")
plt.legend()
Изображение


Подробнее здесь: https://stackoverflow.com/questions/790 ... r-matrices
Реклама
Ответить Пред. темаСлед. тема

Быстрый ответ

Изменение регистра текста: 
Смайлики
:) :( :oops: :roll: :wink: :muza: :clever: :sorry: :angel: :read: *x)
Ещё смайлики…
   
К этому ответу прикреплено по крайней мере одно вложение.

Если вы не хотите добавлять вложения, оставьте поля пустыми.

Максимально разрешённый размер вложения: 15 МБ.

  • Похожие темы
    Ответы
    Просмотры
    Последнее сообщение

Вернуться в «Python»