При использовании JAX для обучения модели машинного обучения мы лишь пытаемся минимизировать потери при обучении.
В то время как в моем требовании, чтобы оценить количество эпох или избежать Из-за перетренированности мне также нужно знать потери при тестировании на каждом этапе обновления параметров. Но опция обратного вызова или отладки, доступная в JAX, явно предполагает, что мне не следует выполнять какие-либо задачи, требующие больших вычислительных ресурсов, например определение потерь и точности тестов.
Приведенная ниже оптимизация работает безупречно
TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape float32[]
The error occurred while tracing the function print_fn at C:\Users\mysore\AppData\Local\Temp\ipykernel_43468\2623796165.py:10 for cond. This value became a tracer due to JAX operations on these lines:
operation a:f32[1] = slice[limit_indices=(1,) start_indices=(0,) strides=None] b
from line C:\Users\mysore\AppData\Local\Temp\ipykernel_43468\1231864253.py:61:15 (ansatz)
operation a:f32[1] = slice[limit_indices=(2,) start_indices=(1,) strides=None] b
from line C:\Users\mysore\AppData\Local\Temp\ipykernel_43468\1231864253.py:61:15 (ansatz)
operation a:f32[1] = slice[limit_indices=(3,) start_indices=(2,) strides=None] b
from line C:\Users\mysore\AppData\Local\Temp\ipykernel_43468\1231864253.py:61:15 (ansatz)
operation a:f32[1] = slice[limit_indices=(4,) start_indices=(3,) strides=None] b
from line C:\Users\mysore\AppData\Local\Temp\ipykernel_43468\1231864253.py:61:15 (ansatz)
operation a:f32[1] = slice[limit_indices=(5,) start_indices=(4,) strides=None] b
from line C:\Users\mysore\AppData\Local\Temp\ipykernel_43468\1231864253.py:61:15 (ansatz)
(Additional originating lines are not shown.)
Мое время обучения резко увеличится, если я не буду использовать JAX.
Таким образом, с JAX я могу получить производительность на тестовом наборе только в конце обучения , и я никак не могу получить тестовые потери в середине обучения или есть обходной путь для этого?
Для минимально воспроизводимого примера вы можете попробовать запустить приведенный пример в этом демонстрационном коде.
При использовании JAX для обучения модели машинного обучения мы лишь пытаемся минимизировать потери при обучении. В то время как в моем требовании, чтобы оценить количество эпох или избежать Из-за перетренированности мне также нужно знать потери при тестировании на каждом этапе обновления параметров. Но опция обратного вызова или отладки, доступная в JAX, явно предполагает, что мне не следует выполнять какие-либо задачи, требующие больших вычислительных ресурсов, например определение потерь и точности тестов. Приведенная ниже оптимизация работает безупречно [code]@jax.jit def update_step_jit(i, args): params, opt_state, data, targets, batch_no, print_training = args _data = data[batch_no % num_batch] _targets = targets[batch_no % num_batch] loss_val, grads = jax.value_and_grad(cost)(params, _data, _targets) updates, opt_state = opt.update(grads, opt_state) params = optax.apply_updates(params, updates)
# Calculate accuracy and loss for training and test sets train_accuracy = acc(params, X_train, y_train) test_predictions = jnp.array([variational_classifier(params["weights"], params["bias"], x) for x in X_test]) test_loss = square_loss(y_test, test_predictions) test_accuracy = accuracy(y_test, test_predictions)
jax.debug.print("Step: {i}, Train Accuracy {train_accuracy}", i=i, train_accuracy = train_accuracy) jax.debug.print("Step: {i}, Test Accuracy {test_accuracy}", i=i, test_accuracy = test_accuracy) jax.debug.print("Step: {i}, Test Loss {test_loss}", i=i, test_loss = test_loss)
# if print_training=True, print the loss every 5 steps jax.lax.cond((jnp.mod(i, 1) == 0) & print_training, print_fn, lambda: None) return (params, opt_state, data, targets, batch_no + 1, print_training)
# Calculate accuracy and loss for training and test sets train_accuracy = acc(params, X_train, y_train) test_predictions = jnp.array([variational_classifier(params["weights"], params["bias"], x) for x in X_test]) test_loss = square_loss(y_test, test_predictions) test_accuracy = accuracy(y_test, test_predictions)
jax.debug.print("Step: {i}, Train Accuracy {train_accuracy}", i=i, train_accuracy = train_accuracy) jax.debug.print("Step: {i}, Test Accuracy {test_accuracy}", i=i, test_accuracy = test_accuracy) jax.debug.print("Step: {i}, Test Loss {test_loss}", i=i, test_loss = test_loss)
# if print_training=True, print the loss every 5 steps jax.lax.cond((jnp.mod(i, 1) == 0) & print_training, print_fn, lambda: None) return (params, opt_state, data, targets, batch_no + 1, print_training)
print("Training accuracy: ", var_train_acc) print("Testing accuracy: ", var_test_acc) [/code] Выдает странную ошибку типа: [code]TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape float32[] The error occurred while tracing the function print_fn at C:\Users\mysore\AppData\Local\Temp\ipykernel_43468\2623796165.py:10 for cond. This value became a tracer due to JAX operations on these lines:
operation a:f32[1] = slice[limit_indices=(1,) start_indices=(0,) strides=None] b from line C:\Users\mysore\AppData\Local\Temp\ipykernel_43468\1231864253.py:61:15 (ansatz)
operation a:f32[1] = slice[limit_indices=(2,) start_indices=(1,) strides=None] b from line C:\Users\mysore\AppData\Local\Temp\ipykernel_43468\1231864253.py:61:15 (ansatz)
operation a:f32[1] = slice[limit_indices=(3,) start_indices=(2,) strides=None] b from line C:\Users\mysore\AppData\Local\Temp\ipykernel_43468\1231864253.py:61:15 (ansatz)
operation a:f32[1] = slice[limit_indices=(4,) start_indices=(3,) strides=None] b from line C:\Users\mysore\AppData\Local\Temp\ipykernel_43468\1231864253.py:61:15 (ansatz)
operation a:f32[1] = slice[limit_indices=(5,) start_indices=(4,) strides=None] b from line C:\Users\mysore\AppData\Local\Temp\ipykernel_43468\1231864253.py:61:15 (ansatz)
(Additional originating lines are not shown.) [/code] Мое время обучения резко увеличится, если я не буду использовать JAX. Таким образом, с JAX я могу получить производительность на тестовом наборе только в конце обучения , и я никак не могу получить тестовые потери в середине обучения или есть обходной путь для этого? Для минимально воспроизводимого примера вы можете попробовать запустить приведенный пример в этом демонстрационном коде.