Сценарий JAX завершается с ошибкой INTERNAL: нет поддержки BLAS для потока при параллельном запуске нескольких процессовPython

Программы на Python
Ответить
Anonymous
 Сценарий JAX завершается с ошибкой INTERNAL: нет поддержки BLAS для потока при параллельном запуске нескольких процессов

Сообщение Anonymous »

Я столкнулся с неприятной ошибкой выполнения JAX на сервере с несколькими графическими процессорами. Мой сценарий отлично работает для простого теста, но терпит неудачу из-за отсутствия поддержки BLAS для ошибки потока, когда я пытаюсь запустить несколько его экземпляров параллельно на разных графических процессорах. Я подозреваю, что это глубокая проблема со средой или связями, и я исчерпал свои шаги по отладке.
Цель:
У меня есть сценарий научных вычислений (worker.py), который я хочу запустить примерно 600 раз с разными параметрами. Я использую сценарий bash для отправки этих заданий на 8 доступных графических процессоров NVIDIA Quadro RTX 8000, назначая каждое задание определенному графическому процессору с помощью CUDA_VISIBLE_DEVICES.
Проблема:
Когда я запускаю диспетчер, некоторые из первых заданий запускаются, но быстро завершаются сбоем. Журнал ошибок для каждого сбойного задания всегда один и тот же:

Код: Выделить всё

jaxlib._jax.XlaRuntimeError: INTERNAL: No BLAS support for stream
Это происходит, когда код вызывает JIT-компилированную функцию, выполняющую умножение матриц.
Среда:
  • ОС: Ubuntu 18.04 LTS
  • Аппаратное обеспечение: 8 x NVIDIA Quadro RTX 8000 (48 ГБ)
  • Драйвер: Команда nvidia-smi работает отлично.
  • Python: Управляется с помощью Conda (Python 3.10).
Что я сделал (шаги отладки):
  • Полная перестройка среды: Я запустил conda env удалил -n jax_working и пересобрал среду с нуля, используя официально рекомендованную команду pip со ссылкой на CUDA 12:

Код: Выделить всё

    conda create --name jax_working python=3.10 -y
conda activate jax_working
pip install -U "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install pandas tqdm # and other packages
  • Проверка ресурсов: htop показывает, что на сервере достаточно свободного процессора и оперативной памяти. nvidia-smi показывает, что все графические процессоры полностью простаивают, прежде чем я приступаю к работе.
  • Минимальный тест JAX: Простой тестовый сценарий работает правильно.

Код: Выделить всё

    # test_jax.py
import jax
print(jax.devices())

Код: Выделить всё

Running python test_jax.py correctly prints the list of all available GPUs. This proves the fresh environment can see the GPUs.
Минимально воспроизводимый пример:
Это упрощенная структура моего проекта, воспроизводящая ошибку.
1. Рабочий скрипт (mcmc_worker.py):

Код: Выделить всё

import jax
import jax.numpy as jnp

# Enable 64-bit precision, as my real code uses it.
jax.config.update("jax_enable_x64", True)

@jax.jit
def perform_math(A, x):
# The real error happens in a function that does matrix math like this.
return A @ (x * 2.0)

if __name__ == "__main__":
# Create some dummy data.
key = jax.random.PRNGKey(0)
matrix = jax.random.normal(key, (100, 50), dtype=jnp.float64)
vector = jax.random.normal(key, (50,), dtype=jnp.float64)

# This is where the script fails.
result = perform_math(matrix, vector)
result.block_until_ready() # Force the computation

print("Job finished successfully.")
2. Сценарий диспетчера (run_all.sh):

Код: Выделить всё

#!/bin/bash

# Configuration
AVAILABLE_GPUS=(0 1 2 3 4 5 6 7)
NUM_GPUS=${#AVAILABLE_GPUS[@]}
MAX_JOBS=16 # Run 2 jobs per GPU
TASK_ID=0

# Clean up
rm -f logs/*.txt
mkdir -p logs

# Dispatch 50 dummy jobs
for i in {1..50}; do

# Throttle to avoid overloading the system
if [[ $(jobs -r -p | wc -l) -ge $MAX_JOBS ]]; then
wait -n
fi

GPU_INDEX=$((TASK_ID % NUM_GPUS))
GPU_ID=${AVAILABLE_GPUS[$GPU_INDEX]}

echo "Dispatching Job $i on GPU $GPU_ID"

# Run the worker in the background
CUDA_VISIBLE_DEVICES=$GPU_ID nohup python mcmc_worker.py > logs/job_${i}.txt &

TASK_ID=$((TASK_ID + 1))
done

wait
echo "All jobs dispatched."
Мой основной вопрос:
Почему ошибка «Нет поддержки BLAS» возникает периодически при запуске нескольких процессов JAX, даже если работает один процесс и среда только что создана? Это похоже на состояние гонки во время инициализации или на фундаментальную неправильную конфигурацию между Conda, установленным с помощью pip JAX и системными библиотеками CUDA/cuBLAS, которая проявляется только при параллельной загрузке.
Как я могу отладить это дальше? Существуют ли какие-либо известные проблемы или конкретные переменные среды, которые мне следует установить, чтобы каждый процесс JAX правильно и независимо связывался с библиотеками BLAS?

Подробнее здесь: https://stackoverflow.com/questions/798 ... g-multiple
Ответить

Быстрый ответ

Изменение регистра текста: 
Смайлики
:) :( :oops: :roll: :wink: :muza: :clever: :sorry: :angel: :read: *x)
Ещё смайлики…
   
К этому ответу прикреплено по крайней мере одно вложение.

Если вы не хотите добавлять вложения, оставьте поля пустыми.

Максимально разрешённый размер вложения: 15 МБ.

Вернуться в «Python»