Как правильно установить JAX с CUDA в Linux, когда `jax[cuda12_pip]` постоянно возвращается к версии ЦП?Python

Программы на Python
Ответить
Anonymous
 Как правильно установить JAX с CUDA в Linux, когда `jax[cuda12_pip]` постоянно возвращается к версии ЦП?

Сообщение Anonymous »

Я пытаюсь установить JAX с поддержкой графического процессора на мощный выделенный сервер Linux, но я застрял в ловушке-22, где каждый официальный метод установки дает сбой по-своему, что всегда приводит к тому, что JAX возвращается к процессору.
Я ищу окончательный, надежный набор команд для установки работающей установки графического процессора.
Характеристики системы:
  • ОС: Ubuntu 18.04 LTS
  • Графический процессор: 8x NVIDIA Quadro RTX 8000
  • Драйвер NVIDIA: 550.144.03
  • Версия CUDA (сообщенная драйвером): 12.4
  • Python: 3.10 (управляется Conda)
Что я пробовал
Я тщательно создавал новые среды conda для каждой попытки, чтобы гарантировать отсутствие конфликтов.
Попытка № 1: Стандартный рекомендуемый метод
Это официальная рекомендуемая команда.

Код: Выделить всё

conda create -n jax_test python=3.10 -y
conda activate jax_test
pip install "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html --no-cache-dir
  • Ожидаемый результат: следует загрузить и установить большое многогигабайтное колесо jaxlib с включенными библиотеками CUDA.
  • Фактический результат: pip последовательно игнорирует директиву [cuda12_pip], загружает версию jaxlib для маленького процессора (89,9 МБ), и выдает предупреждение. Команда проверки подтверждает этот сбой:

Код: Выделить всё

WARNING: jax 0.6.2 does not provide the extra 'cuda12-pip'
Downloading jaxlib-0.6.2-cp310-cp310-manylinux2014_x86_64.whl (89.9 MB)
...
$ python -c "import jax; print(jax.devices())"
WARNING: An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
[CpuDevice(id=0)]
Попытка №2: метод прямого URL
Это экспертный обходной путь для принудительной установки определенного колеса графического процессора.

Код: Выделить всё

# In a clean environment...
pip install "https://storage.googleapis.com/jax-releases/cuda12/jaxlib-0.4.23+cuda12.cudnn88-cp310-cp310-manylinux2014_x86_64.whl"
pip install jax==0.4.23 "numpy

Подробнее здесь: [url]https://stackoverflow.com/questions/79817557/how-to-correctly-install-jax-with-cuda-on-linux-when-jaxcuda12-pip-consisten[/url]
Ответить

Быстрый ответ

Изменение регистра текста: 
Смайлики
:) :( :oops: :roll: :wink: :muza: :clever: :sorry: :angel: :read: *x)
Ещё смайлики…
   
К этому ответу прикреплено по крайней мере одно вложение.

Если вы не хотите добавлять вложения, оставьте поля пустыми.

Максимально разрешённый размер вложения: 15 МБ.

Вернуться в «Python»