Градиент определен только для функций скалярного вывода. Выходные данные имели форму: (1,)Python

Программы на Python
Ответить
Anonymous
 Градиент определен только для функций скалярного вывода. Выходные данные имели форму: (1,)

Сообщение Anonymous »

Я работаю с градиентами, и у меня возникли проблемы. Вот мой код

Код: Выделить всё

import jax

def model(x):
return (x+1)**2 + (x-1)**2

def loss(x, y):
return y - model(x)

x = 2
grad = jax.grad(loss, argnums=0)
gradient = grad(x, 0)

И в последней строке я получаю следующую ошибку

TypeError: Градиент определен только для функций скалярного вывода . Выходные данные имели форму: (1,).
Приведенная ниже трассировка стека исключает внутренние фреймы JAX.
Предыдущее представляет собой исходное возникшее исключение без изменений
Как решить?

Подробнее здесь: https://stackoverflow.com/questions/707 ... ad-shape-1
Ответить

Быстрый ответ

Изменение регистра текста: 
Смайлики
:) :( :oops: :roll: :wink: :muza: :clever: :sorry: :angel: :read: *x)
Ещё смайлики…
   
К этому ответу прикреплено по крайней мере одно вложение.

Если вы не хотите добавлять вложения, оставьте поля пустыми.

Максимально разрешённый размер вложения: 15 МБ.

Вернуться в «Python»