Как использовать jax.custom_vjp с функциями, которые принимают в качестве входных данных типы, отличные от JAX (напримерPython

Программы на Python
Ответить Пред. темаСлед. тема
Anonymous
 Как использовать jax.custom_vjp с функциями, которые принимают в качестве входных данных типы, отличные от JAX (например

Сообщение Anonymous »

Я пытаюсь использовать JAX custom_vjp для определения пользовательских вычислений градиента для функции, которая принимает выражение SymPy в качестве входных данных. Однако я сталкиваюсь с ошибками, поскольку JAX не поддерживает типы, отличные от JAX, в качестве входных данных для преобразуемых функций (например, с помощью grad, jit или custom_vjp). Недавно я модифицировал код в ScQubits, чтобы добавить новый бэкэнд jax для повышения эффективности, а затем столкнулся с некоторой проблемой с jax и Sympy.
Вот минимальный пример того, что я пытаюсь сделать. что сделать:

Код: Выделить всё

import jax
import sympy as sm

# Define symbols and expression
x, y, z = sm.symbols('x y z')
expr = x**2 + 2*y + z

# Attempt to iterate over expr (this will cause an error)
try:
for term in expr:
print(term)
except TypeError as e:
print(f"Error: {e}")

# Define a function that takes a SymPy expression and a value
def sympy_function(expr, x_value):
x = sm.Symbol('x')
result = expr.subs(x, x_value)
return float(result)

# Attempt to apply custom_vjp
sympy_function = jax.custom_vjp(sympy_function)

def sympy_function_fwd(expr, x_value):
y = sympy_function(expr, x_value)
return y, (expr, x_value)

def sympy_function_bwd(residual, grad_y):
expr, x_value = residual
x = sm.Symbol('x')
derivative_expr = sm.diff(expr, x)
grad_x_value = float(derivative_expr.subs(x, x_value))
grad_expr = None
return grad_expr, grad_y * grad_x_value

sympy_function.defvjp(sympy_function_fwd, sympy_function_bwd)

# Test the function
x = sm.Symbol('x')
expr = x**2 + 3*x + 2
x_value = 1.0

# This will raise an error
y = sympy_function(expr, x_value)
Когда я запускаю этот код, я получаю следующую ошибку:

Код: Выделить всё

TypeError: Value x**2 + 3*x + 2 with type  is not a valid JAX type
Как использовать jax.custom_vjp с функциями, которые принимают в качестве входных данных типы, отличные от JAX, например выражения SymPy? Есть ли способ обойти это ограничение или заставить JAX принимать такие функции?

Подробнее здесь: https://stackoverflow.com/questions/790 ... g-sympy-ex
Реклама
Ответить Пред. темаСлед. тема

Быстрый ответ

Изменение регистра текста: 
Смайлики
:) :( :oops: :roll: :wink: :muza: :clever: :sorry: :angel: :read: *x)
Ещё смайлики…
   
К этому ответу прикреплено по крайней мере одно вложение.

Если вы не хотите добавлять вложения, оставьте поля пустыми.

Максимально разрешённый размер вложения: 15 МБ.

  • Похожие темы
    Ответы
    Просмотры
    Последнее сообщение

Вернуться в «Python»