Решение большого количества (1 миллион) отдельных небольших нелинейных систем уравнений с использованием JAXPython

Программы на Python
Ответить
Anonymous
 Решение большого количества (1 миллион) отдельных небольших нелинейных систем уравнений с использованием JAX

Сообщение Anonymous »

У меня есть некоторые технические вопросы относительно возможностей JAX в решении значительного числа (1 миллион) отдельных небольших нелинейных систем уравнений. В настоящее время мой подход предполагает распределение этих заданий между 200 процессорами с использованием MPI, где каждый процессор последовательно обрабатывает 5000 отдельных нелинейных систем. Этот процесс требует дифференцирования и использования решателя линейной алгебры.
После рассмотрения документации JAX я хотел бы уточнить следующие моменты:
(1) JAX оказывается эффективным для решения одной нелинейной системы уравнений с использованием автоматического дифференцирования и jax.numpy. Однако какой метод параллельного выполнения рекомендуется использовать при обработке 1 миллиона отдельных заданий? Насколько я понимаю, графические процессоры и TPU в первую очередь повышают производительность отдельных задач, а не нескольких параллельных задач.
(2) Учитывая, что размерность нелинейных систем меньше 100, как влияет производительность JAX? сравнить с NumPy для этих более мелких проблем? Является ли преимущество JAX очевидным в первую очередь в сценариях более высокого уровня?
Я ценю любую информацию, которую вы можете предоставить. Спасибо.

Подробнее здесь: https://stackoverflow.com/questions/790 ... -of-equati
Ответить

Быстрый ответ

Изменение регистра текста: 
Смайлики
:) :( :oops: :roll: :wink: :muza: :clever: :sorry: :angel: :read: *x)
Ещё смайлики…
   
К этому ответу прикреплено по крайней мере одно вложение.

Если вы не хотите добавлять вложения, оставьте поля пустыми.

Максимально разрешённый размер вложения: 15 МБ.

Вернуться в «Python»