Параллельный генератор случайных чисел с шардингом JAXPython

Программы на Python
Ответить
Anonymous
 Параллельный генератор случайных чисел с шардингом JAX

Сообщение Anonymous »

Каков правильный подход к параллельной генерации псевдослучайных чисел с использованием сегментирования в jax?
Следующее не работает (из-за выборки одной и той же цепочки)< /p>

Код: Выделить всё

sharding = jax.sharding.PositionalSharding(
jax.experimental.mesh_utils.create_device_mesh((8,)),
)

@jax.jit(static_argnum=1, out_sharding=sharding.reshape(8, 1))
def uniform_sharded(rng_key, n):
return jax.random.uniform(key=rng_key, shape=(n,))

Я рассматривал возможность выполнения pmap между устройствами с массивом ключей, но тогда результат будет зависеть от количества устройств и, похоже, противоречит цели сегментирования.< /п>

Подробнее здесь: https://stackoverflow.com/questions/767 ... x-sharding
Ответить

Быстрый ответ

Изменение регистра текста: 
Смайлики
:) :( :oops: :roll: :wink: :muza: :clever: :sorry: :angel: :read: *x)
Ещё смайлики…
   
К этому ответу прикреплено по крайней мере одно вложение.

Если вы не хотите добавлять вложения, оставьте поля пустыми.

Максимально разрешённый размер вложения: 15 МБ.

Вернуться в «Python»