Я новичок в JAX и пытаюсь использовать его с Pennylane и Optax для оптимизации простой квантовой схемы. Тем не менее, я заметил, что мой оператор печати внутри функции стоимости не выполняется в каждой итерации. В частности, он печатает только один раз в начале, а затем перестает появляться. Я просто хотел как можно больше упростить пример. Я считаю, что схема на самом деле не имеет отношения к вопросу, но она включена в качестве примера.import pennylane as qml
import jax
import jax.numpy as jnp
import optax
jax.config.update("jax_enable_x64", True)
device = qml.device("default.qubit", wires=1)
@qml.qnode(device, interface='jax')
def circuit(params):
qml.RX(params, wires=0)
return qml.expval(qml.PauliZ(0))
def cost(params):
print('Evaluating')
return circuit(params)
# Define optimizer
params = jnp.array(0.1)
opt = optax.adam(learning_rate=0.1)
opt_state = opt.init(params)
# JIT the gradient function
grad = jax.jit(jax.grad(cost))
for epoch in range(5):
print(f'{epoch = }')
grad_value = grad(params)
updates, opt_state = opt.update(grad_value, opt_state)
params = optax.apply_updates(params, updates)
< /code>
ожидаемый вывод: < /h2>
epoch = 0
Evaluating
epoch = 1
Evaluating
epoch = 2
Evaluating
epoch = 3
Evaluating
epoch = 4
Evaluating
< /code>
фактический вывод: < /h2>
epoch = 0
Evaluating
epoch = 1
Evaluating
epoch = 2
epoch = 3
epoch = 4
< /code>
Вопрос: < /h2>
Почему оператор печати внутри стоимости не выполняется после первой итерации? Является ли JAX кэшировать вызов функции или оптимизировать его таким образом, чтобы пропустить выполнение? Как я могу гарантировать, что стоимость оценивается на каждой итерации?
Подробнее здесь: https://stackoverflow.com/questions/794 ... t-function
Почему выпускник JAX не всегда печатает внутри функции стоимости? ⇐ Python
-
- Похожие темы
- Ответы
- Просмотры
- Последнее сообщение
-
-
Как недавний выпускник, как я могу процветать в области разработки Android [закрыто]
Anonymous » » в форуме Android - 0 Ответы
- 5 Просмотры
-
Последнее сообщение Anonymous
-