Я запустил следующие 4 команды в командной строке (Bash): < /p>
JAX_PLATFORM_NAME=cpu python -c "import jax; import jax.numpy as jnp; key = jax.random.PRNGKey(1); print(jax.random.uniform(key, (2, 2),))"
JAX_PLATFORM_NAME=cpu python -c "import jax; import jax.numpy as jnp; key = jax.random.PRNGKey(1); print(jax.random.normal(key, (2, 2),))"
JAX_PLATFORM_NAME=gpu python -c "import jax; import jax.numpy as jnp; key = jax.random.PRNGKey(1); print(jax.random.uniform(key, (2, 2),))"
JAX_PLATFORM_NAME=gpu python -c "import jax; import jax.numpy as jnp; key = jax.random.PRNGKey(1); print(jax.random.normal(key, (2, 2),))"
Все они работали нормально, за исключением единого выборки на GPU, что привело к ошибке сегментации (ядро сброшено) с кодом выхода 139 (и выходом 245, когда аналогичный код был запущен как часть более длинной программы).
partial of nvidia-smi). NVIDIA-SMI 565.57.01 Driver Version: 565.57.01 CUDA Version: 12.7
NVIDIA GeForce RTX 4090
< /code>
jax и jaxlib версии: < /p>
Name: jax
Version: 0.4.37
Name: jaxlib
Version: 0.4.36
Chatgpt думает, что это проблема совместимости CUDA и JAX, но страница здесь, кажется, предполагает CUDA> = 12.1 должна быть в порядке.
Есть идеи?~/anaconda3/envs/unc/lib/python3.10/site-packages/jaxlib/plugin_support.py:71: RuntimeWarning: JAX plugin jax_cuda12_plugin version 0.4.33 is installed, but it is not compatible with the installed jaxlib version 0.6.2, so it will not be used.
warnings.warn(
Segmentation fault (core dumped)
Подробнее здесь: https://stackoverflow.com/questions/796 ... not-on-cpu
Jax.random.Uniforms, вызывая ошибку сегментации при вызове GPU, но не на процессоре, и Jax.random.normal Crashing ⇐ Python
-
- Похожие темы
- Ответы
- Просмотры
- Последнее сообщение
-
-
Используют ли, когда и зачем numpy.random.rand(…) и numpy.random.random(…)?
Anonymous » » в форуме Python - 0 Ответы
- 64 Просмотры
-
Последнее сообщение Anonymous
-