Цель:
У меня есть сценарий научных вычислений (worker.py), который я хочу запустить примерно 600 раз с разными параметрами. Я использую сценарий bash для отправки этих заданий на 8 доступных графических процессоров NVIDIA Quadro RTX 8000, назначая каждое задание определенному графическому процессору с помощью CUDA_VISIBLE_DEVICES.
Проблема:
Когда я запускаю диспетчер, некоторые из первых заданий запускаются, но быстро завершаются сбоем. Журнал ошибок для каждого сбойного задания всегда один и тот же:
Код: Выделить всё
jaxlib._jax.XlaRuntimeError: INTERNAL: No BLAS support for stream
Среда:
- ОС: Ubuntu 18.04 LTS
- Аппаратное обеспечение: 8 x NVIDIA Quadro RTX 8000 (48 ГБ)
- Драйвер: Команда nvidia-smi работает отлично.
- Python: Управляется с помощью Conda (Python 3.10).
- Полная перестройка среды: Я запустил conda env удалил -n jax_working и пересобрал среду с нуля, используя официально рекомендованную команду pip со ссылкой на CUDA 12:
Код: Выделить всё
conda create --name jax_working python=3.10 -y
conda activate jax_working
pip install -U "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install pandas tqdm # and other packages
- Проверка ресурсов: htop показывает, что на сервере достаточно свободного процессора и оперативной памяти. nvidia-smi показывает, что все графические процессоры полностью простаивают, прежде чем я приступаю к работе.
- Минимальный тест JAX: Простой тестовый сценарий работает правильно.
Код: Выделить всё
# test_jax.py
import jax
print(jax.devices())
Код: Выделить всё
Running python test_jax.py correctly prints the list of all available GPUs. This proves the fresh environment can see the GPUs.
Это упрощенная структура моего проекта, воспроизводящая ошибку.
1. Рабочий скрипт (mcmc_worker.py):
Код: Выделить всё
import jax
import jax.numpy as jnp
# Enable 64-bit precision, as my real code uses it.
jax.config.update("jax_enable_x64", True)
@jax.jit
def perform_math(A, x):
# The real error happens in a function that does matrix math like this.
return A @ (x * 2.0)
if __name__ == "__main__":
# Create some dummy data.
key = jax.random.PRNGKey(0)
matrix = jax.random.normal(key, (100, 50), dtype=jnp.float64)
vector = jax.random.normal(key, (50,), dtype=jnp.float64)
# This is where the script fails.
result = perform_math(matrix, vector)
result.block_until_ready() # Force the computation
print("Job finished successfully.")
Код: Выделить всё
#!/bin/bash
# Configuration
AVAILABLE_GPUS=(0 1 2 3 4 5 6 7)
NUM_GPUS=${#AVAILABLE_GPUS[@]}
MAX_JOBS=16 # Run 2 jobs per GPU
TASK_ID=0
# Clean up
rm -f logs/*.txt
mkdir -p logs
# Dispatch 50 dummy jobs
for i in {1..50}; do
# Throttle to avoid overloading the system
if [[ $(jobs -r -p | wc -l) -ge $MAX_JOBS ]]; then
wait -n
fi
GPU_INDEX=$((TASK_ID % NUM_GPUS))
GPU_ID=${AVAILABLE_GPUS[$GPU_INDEX]}
echo "Dispatching Job $i on GPU $GPU_ID"
# Run the worker in the background
CUDA_VISIBLE_DEVICES=$GPU_ID nohup python mcmc_worker.py > logs/job_${i}.txt &
TASK_ID=$((TASK_ID + 1))
done
wait
echo "All jobs dispatched."
Почему ошибка «Нет поддержки BLAS» возникает периодически при запуске нескольких процессов JAX, даже если работает один процесс и среда только что создана? Это похоже на состояние гонки во время инициализации или на фундаментальную неправильную конфигурацию между Conda, установленным с помощью pip JAX и системными библиотеками CUDA/cuBLAS, которая проявляется только при параллельной загрузке.
Как я могу отладить это дальше? Существуют ли какие-либо известные проблемы или конкретные переменные среды, которые мне следует установить, чтобы каждый процесс JAX правильно и независимо связывался с библиотеками BLAS?
Подробнее здесь: https://stackoverflow.com/questions/798 ... g-multiple
Мобильная версия