При использовании JAX для обучения модели машинного обучения мы лишь пытаемся минимизировать потери при обучении.
В то время как в моем требовании, чтобы оценить количество эпох или избежать Из-за перетренированности мне также нужно знать потери при тестировании на каждом этапе обновления параметров. Но опция обратного вызова или отладки, доступная в JAX, явно предполагает, что мне не следует выполнять какие-либо задачи, требующие больших вычислительных ресурсов, например определение потерь и точности тестов.
Приведенная ниже оптимизация работает безупречно
TracerIntegerConversionError: The __index__() method was called on traced array with shape int32[]
The error occurred while tracing the function print_fn at ...\AppData\Local\Temp\ipykernel_37368\1395230403.py:10 for cond. This value became a tracer due to JAX operations on these lines:
operation a:i32[] b:i32[] = pjit[
name=divmod
jaxpr={ lambda ; c:i32[] d:i32[]. let
e:i32[] = pjit[
name=floor_divide
jaxpr={ lambda ; f:i32[] g:i32[]. let
h:i32[] = div f g
i:i32[] = sign f
j:i32[] = sign g
k:bool[] = ne i j
l:i32[] = rem f g
m:bool[] = ne l 0
n:bool[] = convert_element_type[new_dtype=bool weak_type=False] k
o:bool[] = convert_element_type[new_dtype=bool weak_type=False] m
p:bool[] = and n o
q:i32[] = sub h 1
r:i32[] = pjit[
name=_where
jaxpr={ lambda ; s:bool[] t:i32[] u:i32[]. let
v:i32[] = select_n s u t
in (v,) }
] p q h
in (r,) }
] c d
w:i32[] = pjit[
name=remainder
jaxpr={ lambda ; x:i32[] y:i32[]. let
z:bool[] = eq y 0
ba:i32[] = pjit[
name=_where
jaxpr={ lambda ; bb:bool[] bc:i32[] bd:i32[]. let
be:i32[] = select_n bb bd bc
in (be,) }
] z 1 y
bf:i32[] = rem x ba
bg:bool[] = ne bf 0
bh:bool[] = lt bf 0
bi:bool[] = lt ba 0
bj:bool[] = ne bh bi
bk:bool[] = and bj bg
bl:i32[] = add bf ba
bm:i32[] = select_n bk bf bl
in (bm,) }
] c d
in (e, w) }
] bn bo
from line ...\AppData\Local\Temp\ipykernel_37368\1395230403.py:12:27 (update_step_jit..print_fn)
Мое время обучения резко увеличится, если я не буду использовать JAX.
Таким образом, с JAX я могу получить производительность на тестовом наборе только в конце обучения , и я никак не могу получить тестовые потери в середине обучения или есть обходной путь?
Для минимально воспроизводимого примера вы можете попробовать запустить приведенный пример в этом демонстрационном коде.
При использовании JAX для обучения модели машинного обучения мы лишь пытаемся минимизировать потери при обучении. В то время как в моем требовании, чтобы оценить количество эпох или избежать Из-за перетренированности мне также нужно знать потери при тестировании на каждом этапе обновления параметров. Но опция обратного вызова или отладки, доступная в JAX, явно предполагает, что мне не следует выполнять какие-либо задачи, требующие больших вычислительных ресурсов, например определение потерь и точности тестов. Приведенная ниже оптимизация работает безупречно [code]@jax.jit def update_step_jit(i, args): params, opt_state, data, targets, batch_no, print_training = args _data = data[batch_no % num_batch] _targets = targets[batch_no % num_batch] loss_val, grads = jax.value_and_grad(cost)(params, _data, _targets) updates, opt_state = opt.update(grads, opt_state) params = optax.apply_updates(params, updates)
def print_fn(): jax.debug.print("Step: {i}, Train Loss: {loss_val}", i=i, loss_val=loss_val) # if print_training=True, print the loss every 5 steps jax.lax.cond((jnp.mod(i, 5) == 0) & print_training, print_fn, lambda: None) return (params, opt_state, data, targets, batch_no + 1, print_training)
# Calculate accuracy and loss for training and test sets train_accuracy = acc(params, X_train_ndarray, y_train_ndarray) test_predictions = [variational_classifier(params["weights"], params["bias"], x) for x in X_test_ndarray] test_loss = square_loss(y_test_ndarray, test_predictions) test_accuracy = accuracy(y_test_ndarray, test_predictions)
jax.debug.print("Step: {i}, Train Accuracy {train_accuracy}", i=i, train_accuracy = train_accuracy) jax.debug.print("Step: {i}, Test Accuracy {test_accuracy}", i=i, test_accuracy = test_accuracy) jax.debug.print("Step: {i}, Test Loss {test_loss}", i=i, test_loss = test_loss)
# if print_training=True, print the loss every 5 steps jax.lax.cond((jnp.mod(i, 1) == 0) & print_training, print_fn, lambda: None) return (params, opt_state, data, targets, batch_no + 1, print_training)
params = optimization_jit(params, X_batched, y_batched, print_training = True) [/code] Выдает странную ошибку типа: [code]TracerIntegerConversionError: The __index__() method was called on traced array with shape int32[] The error occurred while tracing the function print_fn at ...\AppData\Local\Temp\ipykernel_37368\1395230403.py:10 for cond. This value became a tracer due to JAX operations on these lines:
operation a:i32[] b:i32[] = pjit[ name=divmod jaxpr={ lambda ; c:i32[] d:i32[]. let e:i32[] = pjit[ name=floor_divide jaxpr={ lambda ; f:i32[] g:i32[]. let h:i32[] = div f g i:i32[] = sign f j:i32[] = sign g k:bool[] = ne i j l:i32[] = rem f g m:bool[] = ne l 0 n:bool[] = convert_element_type[new_dtype=bool weak_type=False] k o:bool[] = convert_element_type[new_dtype=bool weak_type=False] m p:bool[] = and n o q:i32[] = sub h 1 r:i32[] = pjit[ name=_where jaxpr={ lambda ; s:bool[] t:i32[] u:i32[]. let v:i32[] = select_n s u t in (v,) } ] p q h in (r,) } ] c d w:i32[] = pjit[ name=remainder jaxpr={ lambda ; x:i32[] y:i32[]. let z:bool[] = eq y 0 ba:i32[] = pjit[ name=_where jaxpr={ lambda ; bb:bool[] bc:i32[] bd:i32[]. let be:i32[] = select_n bb bd bc in (be,) } ] z 1 y bf:i32[] = rem x ba bg:bool[] = ne bf 0 bh:bool[] = lt bf 0 bi:bool[] = lt ba 0 bj:bool[] = ne bh bi bk:bool[] = and bj bg bl:i32[] = add bf ba bm:i32[] = select_n bk bf bl in (bm,) } ] c d in (e, w) } ] bn bo from line ...\AppData\Local\Temp\ipykernel_37368\1395230403.py:12:27 (update_step_jit..print_fn) [/code] Мое время обучения резко увеличится, если я не буду использовать JAX. Таким образом, с JAX я могу получить производительность на тестовом наборе только в конце обучения , и я никак не могу получить тестовые потери в середине обучения или есть обходной путь? Для минимально воспроизводимого примера вы можете попробовать запустить приведенный пример в этом демонстрационном коде.