Как восстановить контрольную точку Orbax с JAX/леном?Python

Программы на Python
Ответить Пред. темаСлед. тема
Anonymous
 Как восстановить контрольную точку Orbax с JAX/леном?

Сообщение Anonymous »

Я сохранил контрольную точку Orbax с кодом ниже: < /p>

Код: Выделить всё

check_options = ocp.CheckpointManagerOptions(max_to_keep=5, create=True)
check_path = Path(os.getcwd(), out_dir, 'checkpoint')
checkpoint_manager = ocp.CheckpointManager(check_path, options=check_options, item_names=('state', 'metadata'))
checkpoint_manager.save(
step=iter_num,
args=ocp.args.Composite(
state=ocp.args.StandardSave(state),
metadata=ocp.args.JsonSave((model_args, iter_num, best_val_loss, losses['val'].item(), config))))
Когда я пытаюсь возобновить сохраненные контрольные точки, я использовал код ниже, чтобы восстановить переменную Code :

Код: Выделить всё

state, lr_schedule = init_train_state(model, params['params'], learning_rate, weight_decay, beta1, beta2, decay_lr, warmup_iters,
lr_decay_iters, min_lr)  # Here state is the initialied state variable with type Train_state.
state = checkpoint_manager.restore(checkpoint_manager.latest_step(), items={'state': state})
< /code>
Но когда я пытаюсь использовать восстановленное состояние в учебном цикле, я получил эту ошибку: < /p>
---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
File /opt/conda/envs/py_3.10/lib/python3.10/site-packages/jax/_src/api_util.py:584, in shaped_abstractify(x)
583 try:
--> 584   return _shaped_abstractify_handlers[type(x)](x)
585 except KeyError:

KeyError: 

During handling of the above exception, another exception occurred:

TypeError                                 Traceback (most recent call last)
Cell In[40], line 37
34 if iter_num == 0 and eval_only:
35     break
---> 37 state, loss = train_step(state, get_batch('train'))
39 # timing and logging
40 t1 = time.time()

[... skipping hidden 6 frame]

File /opt/conda/envs/py_3.10/lib/python3.10/site-packages/jax/_src/api_util.py:575, in _shaped_abstractify_slow(x)
573   dtype = dtypes.canonicalize_dtype(x.dtype, allow_extended_dtype=True)
574 else:
--> 575   raise TypeError(
576       f"Cannot interpret value of type {type(x)} as an abstract array; it "
577       "does not have a dtype attribute")
578 return core.ShapedArray(np.shape(x), dtype, weak_type=weak_type,
579                         named_shape=named_shape)

TypeError: Cannot interpret value of type  as an abstract array; it does not have a dtype attribute
Итак, как мне правильно восстановить контрольную точку состояния и использовать ее в цикле обучения?
Спасибо!

Подробнее здесь: https://stackoverflow.com/questions/783 ... h-jax-flax
Реклама
Ответить Пред. темаСлед. тема

Быстрый ответ

Изменение регистра текста: 
Смайлики
:) :( :oops: :roll: :wink: :muza: :clever: :sorry: :angel: :read: *x)
Ещё смайлики…
   
К этому ответу прикреплено по крайней мере одно вложение.

Если вы не хотите добавлять вложения, оставьте поля пустыми.

Максимально разрешённый размер вложения: 15 МБ.

  • Похожие темы
    Ответы
    Просмотры
    Последнее сообщение

Вернуться в «Python»