Colab, Jax и GPU: почему выполнение ячейки занимает 60 секунд, хотя %%timeit говорит, что это занимает всего 70 мс?Python

Программы на Python
Ответить Пред. темаСлед. тема
Anonymous
 Colab, Jax и GPU: почему выполнение ячейки занимает 60 секунд, хотя %%timeit говорит, что это занимает всего 70 мс?

Сообщение Anonymous »

В качестве основы для проекта по фракталам я пытаюсь использовать вычисления на графическом процессоре в Google Colab с использованием библиотеки Jax.
Я использую Мандельброта на всех ускорителях в качестве модели. , и я столкнулся с проблемой.
Когда я использую команду %%timeit, чтобы измерить, сколько времени требуется для расчета моей функции графического процессора (так же, как в модели ноутбука ), время вполне разумное и соответствует ожидаемым результатам — от 70 до 80 мс.
Но на самом деле работает % %timeit занимает примерно полную минуту. (По умолчанию функция запускается 7 раз подряд и сообщает среднее значение, но даже это должно занять меньше секунды.)
Аналогично, когда я запускаю функцию в ячейку и вывести результаты (изображение размером 6 мегапикселей), для завершения ячейки требуется около 60 секунд - для выполнения функции, которая предположительно занимает всего 70-80 мс.
Похоже как будто что-то производит огромное количество накладных расходов, которые, похоже, также масштабируются с объемом вычислений - например. когда функция содержит 1000 итеративных вычислений, %%timeit говорит, что это занимает 71 мс, тогда как на самом деле это занимает 60 секунд, но всего за 20 итераций %%timeit говорит, что это занимает 10 мс, тогда как на самом деле это занимает около 10 секунд.
Я вставляю код ниже, но вот ссылка на сам блокнот Colab — любой может сделать копию, подключиться к экземпляру «T4 GPU» и запустить это сами увидите.
import math
import numpy as np
import matplotlib.pyplot as plt
import jax

assert len(jax.devices("gpu")) == 1

def run_jax_kernel(c, fractal):
z = c
for i in range(1000):
z = z**2 + c
diverged = jax.numpy.absolute(z) > 2
diverging_now = diverged & (fractal == 1000)
fractal = jax.numpy.where(diverging_now, i, fractal)
return fractal

run_jax_gpu_kernel = jax.jit(run_jax_kernel, backend="gpu")

def run_jax_gpu(height, width):

mx = -0.69291874321833995150613818345974774914923989808007473759199
my = 0.36963080032727980808623018005116209090839988898368679237704
zw = 4 / 1e3

y, x = jax.numpy.ogrid[(my-zw/2):(my+zw/2):height*1j, (mx-zw/2):(mx+zw/2):width*1j]
c = x + y*1j
fractal = jax.numpy.full(c.shape, 1000, dtype=np.int32)
return np.asarray(run_jax_gpu_kernel(c, fractal).block_until_ready())

Создание изображения занимает около минуты:
fig, ax = plt.subplots(1, 1, figsize=(15, 10))
ax.imshow(run_jax_gpu(2000, 3000));

Отображение сообщения о том, что выполнение функции занимает всего 70–80 мс, занимает около минуты:
%%timeit -o
run_jax_gpu(2000, 3000)


Подробнее здесь: https://stackoverflow.com/questions/787 ... imeit-says
Реклама
Ответить Пред. темаСлед. тема

Быстрый ответ

Изменение регистра текста: 
Смайлики
:) :( :oops: :roll: :wink: :muza: :clever: :sorry: :angel: :read: *x)
Ещё смайлики…
   
К этому ответу прикреплено по крайней мере одно вложение.

Если вы не хотите добавлять вложения, оставьте поля пустыми.

Максимально разрешённый размер вложения: 15 МБ.

  • Похожие темы
    Ответы
    Просмотры
    Последнее сообщение

Вернуться в «Python»