Странно то, что моя установка JAX в принципе работает: если я запускаю простую команду JAX в одном процессе, она правильно идентифицирует и использует мои графические процессоры NVIDIA. Сбой только происходит, когда я запускаю параллельные рабочие процессы.
Вот сведения о моей системе и среде:
Среда:
- ОС: Ubuntu 18.04 LTS
- Python: 3.10 (управляется через Conda)
- Графический процессор: 8 x NVIDIA Quadro RTX 8000
- Драйвер NVIDIA: 550.144.03
- Версия CUDA (из драйвера): 12.4
- Установка JAX:
Код: Выделить всё
jax==0.4.26 Код: Выделить всё
jaxlib==0.4.26Код: Выделить всё
jax-cuda12-plugin==0.4.26
Минимальный, воспроизводимый пример:
Следующий упрощенный сценарий прекрасно демонстрирует проблему.
Код: Выделить всё
import jax
import jax.numpy as jnp
from joblib import Parallel, delayed
import multiprocessing
# Define a simple JAX function that will be executed on the GPU
def simple_worker(i):
"""A simple function that performs a JAX computation."""
try:
# Create some data on the GPU and perform a computation
x = jnp.ones((100, 100))
y = jnp.dot(x, x)
# block_until_ready() ensures the computation is finished before returning
y.block_until_ready()
return i, "Success"
except Exception as e:
return i, f"Failed with: {e}"
if __name__ == "__main__":
# --- Verification Step ---
print("--- Verifying JAX in the main process ---")
try:
devices = jax.devices()
print(f"JAX sees {len(devices)} devices: {devices}")
if 'gpu' not in str(devices[0]).lower() and 'cuda' not in str(devices[0]).lower():
print("WARNING: JAX does not see the GPU in the main process!")
except Exception as e:
print(f"Error during JAX verification: {e}")
print("-" * 40)
# --- Test 1: Serial execution (This works perfectly) ---
print("\n--- Running jobs serially (expected to work) ---")
results_serial = []
for i in range(4):
results_serial.append(simple_worker(i))
print(f"Serial results: {results_serial}\n")
print("-" * 40)
# --- Test 2: Parallel execution with joblib (This crashes) ---
print("\n--- Running jobs in parallel with joblib (expected to fail) ---")
try:
# Also tried backend='threading' and backend='multiprocessing' with 'spawn' context
multiprocessing.set_start_method('spawn', force=True)
results_parallel = Parallel(n_jobs=4)(
delayed(simple_worker)(i) for i in range(4)
)
print(f"Parallel results: {results_parallel}")
except Exception as e:
print(f"Joblib Parallel failed with error: {e}")
print("-" * 40)
Когда я запускаю этот сценарий, разделы проверки и последовательного выполнения работают отлично, показывая, что моя базовая установка JAX может использовать графический процессор. Однако параллельный раздел тут же вылетает. Перед сбоем основного сценария каждый рабочий процесс выдает серию ошибок, которые выглядят следующим образом:
Код: Выделить всё
E ... external/xla/xla/stream_executor/cuda/cuda_dnn.cc:536] Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR
...
jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.
- Проверка установки: Как показано в MRE, запуск кода JAX в основном отдельном процессе работает безупречно. jax.devices() правильно перечисляет все мои графические процессоры.
- Изменение метода запуска многопроцессорной обработки: Я попробовал добавить multiprocessing.set_start_method('spawn', Force=True) в начале моего скрипта. Сбой по-прежнему происходит с той же ошибкой.
- Изменение библиотеки заданий Серверная часть: Я попробовал Parallel(n_jobs=4, backend='threading'). Это также приводит к точно такому же сбою CUDNN_STATUS_INTERNAL_ERROR.
Известна ли проблема конфигурации с некоторыми настройками драйверов Linux/NVIDIA, которая препятствует инициализации параллельных рабочих процессов Контексты CUDA? Как правильно структурировать параллельный скрипт Python с использованием такой библиотеки, как joblib, для распределения рабочих нагрузок JAX?
Подробнее здесь: https://stackoverflow.com/questions/798 ... -multiproc
Мобильная версия