Пропустить вычисление выходных листьев функции pytree `jax`, если для входных листьев установлено значение None`Python

Программы на Python
Ответить
Anonymous
 Пропустить вычисление выходных листьев функции pytree `jax`, если для входных листьев установлено значение None`

Сообщение Anonymous »

У меня есть функция fun, которая принимает в качестве аргумента дерево pytree jax и возвращает дерево pytree jax с той же структурой. Иногда мне не хочется вычислять функцию для конкретного листа дерева pytree. В этом случае я бы хотел, чтобы эта часть не вычислялась всякий раз, когда для соответствующего входного листа установлено значение None.
Вот минимальный пример, где fun действует на dict. Представьте, что в реальном примере вычисление a**2 обходится дорого. Тогда я бы хотел, чтобы, когда b=None на входе, a**2 не вычислялся и напрямую заменялся на None на выходе. Я хотел бы не переопределять новую функцию для этого.

Код: Выделить всё

def fun(d):
a = d["a"]
return dict(a=0, b=a**2)

d0 = dict(a=3, b=1)
res0 = fun(d0) # Output is {'a': 0, 'b': 9}

d1 = dict(a=3, b=None)
res1 = fun(d1) # Output still {'a': 0, 'b': 9}, but want {'a': 0, 'b': None} without ever calculating a**2=9
Возможно ли это? Я надеялся, что так и будет, поскольку jax построен таким образом, что сглаженные pytree читают None как отсутствие листа. Однако здесь функция действует на все дерево pytree, а не на его отдельные листья, поэтому это кажется более сложным.


Подробнее здесь: https://stackoverflow.com/questions/798 ... leaves-are
Ответить

Быстрый ответ

Изменение регистра текста: 
Смайлики
:) :( :oops: :roll: :wink: :muza: :clever: :sorry: :angel: :read: *x)
Ещё смайлики…
   
К этому ответу прикреплено по крайней мере одно вложение.

Если вы не хотите добавлять вложения, оставьте поля пустыми.

Максимально разрешённый размер вложения: 15 МБ.

Вернуться в «Python»