Векторизовать объекты в Python JaxPython

Программы на Python
Ответить
Anonymous
 Векторизовать объекты в Python Jax

Сообщение Anonymous »

Я не уверен, как лучше всего векторизовать объекты в Python Jax.
В частности, я хочу написать код, который обрабатывает вызов метода как из одного экземпляра класса, так и из нескольких (векторизованных) экземпляры класса.
Далее я напишу простой пример того, чего я хотел бы достичь.

Код: Выделить всё

import jax
import jax.numpy as jnp
import jax.random as random

class Dummy:

def __init__(self, x, key):
self.x = x
self.key = key

def to_pytree(self):
return (self.x, self.key), None

def get_noisy_x(self):
self.key, subkey = random.split(self.key)
return self.x + random.normal(subkey, self.x.shape)

@staticmethod
def from_pytree(auxiliary, pytree):
return Dummy(*pytree)

jax.tree_util.register_pytree_node(Dummy,
Dummy.to_pytree,
Dummy.from_pytree)
Класс Dummy содержит некоторую информацию x и ключи и имеет метод get_noisy_x. Следующий код работает должным образом:

Код: Выделить всё

key = random.PRNGKey(0)
dummy = Dummy(jnp.array([1., 2., 3.]), key)
dummy.get_noisy_x()
Я бы хотел, чтобы get_noisy_x работал также с векторизованной версией объекта Dummy.

Код: Выделить всё

key = random.PRNGKey(0)
key, subkey = random.split(key)
key_batch = random.split(subkey, 100)
dummy_vmap = jax.vmap(lambda x: Dummy(jnp.array([1., 2., 3.]), x))(key_batch)
Я ожидаю, что dummy_vmap будет массивом объектов Dummy; однако вместо этого dummy_vmap оказывается только одним Dummy с векторизованным x и ключом. Для меня это не идеально, потому что это меняет поведение кода. Например, если я вызываю dummy_vmap.get_noisy_x(), мне возвращается ошибка, в которой говорится, что self.key, subkey = random.split(self.key) не работает, потому что self.key не работает. один ключ. Хотя эту ошибку можно решить несколькими способами — и на самом деле в этом примере векторизация не особо нужна, моя цель — понять, как писать код объектно-ориентированным способом, который оба обрабатываются правильно

Код: Выделить всё

dummy = Dummy(jnp.array([1., 2., 3.]), key)
dummy.get_noisy_x()
и

Код: Выделить всё

vectorized_dummy = .... ?
vectorized_dummy.get_noisy_x()
Обратите внимание, что приведенный мною пример может работать несколькими способами без векторизации. Однако я ищу более общий способ решения векторизации в гораздо более сложных сценариях.

Подробнее здесь: https://stackoverflow.com/questions/792 ... python-jax
Ответить

Быстрый ответ

Изменение регистра текста: 
Смайлики
:) :( :oops: :roll: :wink: :muza: :clever: :sorry: :angel: :read: *x)
Ещё смайлики…
   
К этому ответу прикреплено по крайней мере одно вложение.

Если вы не хотите добавлять вложения, оставьте поля пустыми.

Максимально разрешённый размер вложения: 15 МБ.

Вернуться в «Python»