Вот минимальный пример того, что я пытаюсь сделать. что сделать:
Код: Выделить всё
import jax
import sympy as sm
# Define symbols and expression
x, y, z = sm.symbols('x y z')
expr = x**2 + 2*y + z
# Attempt to iterate over expr (this will cause an error)
try:
for term in expr:
print(term)
except TypeError as e:
print(f"Error: {e}")
# Define a function that takes a SymPy expression and a value
def sympy_function(expr, x_value):
x = sm.Symbol('x')
result = expr.subs(x, x_value)
return float(result)
# Attempt to apply custom_vjp
sympy_function = jax.custom_vjp(sympy_function)
def sympy_function_fwd(expr, x_value):
y = sympy_function(expr, x_value)
return y, (expr, x_value)
def sympy_function_bwd(residual, grad_y):
expr, x_value = residual
x = sm.Symbol('x')
derivative_expr = sm.diff(expr, x)
grad_x_value = float(derivative_expr.subs(x, x_value))
grad_expr = None
return grad_expr, grad_y * grad_x_value
sympy_function.defvjp(sympy_function_fwd, sympy_function_bwd)
# Test the function
x = sm.Symbol('x')
expr = x**2 + 3*x + 2
x_value = 1.0
# This will raise an error
y = sympy_function(expr, x_value)
Код: Выделить всё
TypeError: Value x**2 + 3*x + 2 with type is not a valid JAX type
Подробнее здесь: https://stackoverflow.com/questions/790 ... g-sympy-ex