Код: Выделить всё
import jax
def model(x):
return (x+1)**2 + (x-1)**2
def loss(x, y):
return y - model(x)
x = 2
grad = jax.grad(loss, argnums=0)
gradient = grad(x, 0)
TypeError: Градиент определен только для функций скалярного вывода . Выходные данные имели форму: (1,).
Приведенная ниже трассировка стека исключает внутренние фреймы JAX.
Предыдущее представляет собой исходное возникшее исключение без изменений
Как решить?
Подробнее здесь: https://stackoverflow.com/questions/707 ... ad-shape-1
Мобильная версия