Восстановление контрольных точек модели льна с использованием orbax выдает ValueErrorPython

Программы на Python
Ответить Пред. темаСлед. тема
Anonymous
 Восстановление контрольных точек модели льна с использованием orbax выдает ValueError

Сообщение Anonymous »

Следующие блоки кода используются для сохранения состояния обучения модели во время обучения и восстановления этого состояния обратно в память.
from flax.training import orbax_utils
import orbax.checkpoint

directory_gen_path = "checkpoints_loc"
orbax_checkpointer_gen = orbax.checkpoint.PyTreeCheckpointer()
gen_options = orbax.checkpoint.CheckpointManagerOptions(save_interval_steps=5, create=True)
gen_checkpoint_manager = orbax.checkpoint.CheckpointManager(
directory_gen_path, orbax_checkpointer_gen, gen_options
)

def save_model_checkpoints(step_, generator_state, generator_batch_stats):

gen_ckpt = {
"model": generator_state,
"batch_stats": generator_batch_stats,
}

save_args_gen = orbax_utils.save_args_from_target(gen_ckpt)
gen_checkpoint_manager.save(step_, gen_ckpt, save_kwargs={"save_args": save_args_gen})

def load_model_checkpoints(generator_state, generator_batch_stats):
gen_target = {
"model": generator_state,
"batch_stats": generator_batch_stats,
}

latest_step = gen_checkpoint_manager.latest_step()
gen_ckpt = gen_checkpoint_manager.restore(latest_step, items=gen_target)
generator_state = gen_ckpt["model"]
generator_batch_stats = gen_ckpt["batch_stats"]

return generator_state, generator_batch_stats


Обучение модели было выполнено на графическом процессоре, и загрузка состояния на устройство графического процессора работает нормально, однако при попытке загрузить модель в процессор выдается следующая ошибка Метод восстановления менеджера контрольных точек orbax
ValueError: SingleDeviceSharding with Device=cuda:0 was not found in jax.local_devices().

Я не совсем уверен, в чем может быть причина, есть какие мысли, ребята?
Обновление: Обновлено до последняя версия orbax-checkpoint, 0.8.0
обратная трассировка изменена на следующую ошибку
ValueError: sharding passed to deserialization should be specified, concrete and an instance of `jax.sharding.Sharding`. Got None


Подробнее здесь: https://stackoverflow.com/questions/791 ... valueerror
Реклама
Ответить Пред. темаСлед. тема

Быстрый ответ

Изменение регистра текста: 
Смайлики
:) :( :oops: :roll: :wink: :muza: :clever: :sorry: :angel: :read: *x)
Ещё смайлики…
   
К этому ответу прикреплено по крайней мере одно вложение.

Если вы не хотите добавлять вложения, оставьте поля пустыми.

Максимально разрешённый размер вложения: 15 МБ.

  • Похожие темы
    Ответы
    Просмотры
    Последнее сообщение

Вернуться в «Python»