Правильно ли и безопасно ли я использовать tf.distribute.MirroredStrategy и Strategy.scope() для обучения с несколькими Python

Программы на Python
Ответить
Anonymous
 Правильно ли и безопасно ли я использовать tf.distribute.MirroredStrategy и Strategy.scope() для обучения с несколькими

Сообщение Anonymous »

Я обучаю модель с использованием Keras + TensorFlow с tf.distribute.MirroredStrategy в конфигурации с несколькими графическими процессорами. Я хотел бы убедиться, что я правильно использую Strategy.scope().

Код: Выделить всё

import time
import logging
import os
import json
import datetime
os.environ["TF_GPU_ALLOCATOR"] = "cuda_malloc_async"
os.environ['TF_CUDNN_DETERMINISTIC'] = '1'
os.environ["KERAS_BACKEND"] = "tensorflow"
#os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
#os.environ['TF_CPP_MAX_VLOG_LEVEL'] = '0'
import tensorflow as tf
gpus = tf.config.list_physical_devices('GPU')
if gpus:
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
from tensorflow.python.profiler import profiler_v2 as profiler
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import keras
from keras import ops
from keras import layers
from keras import mixed_precision
from medicai.models import UNETRPlusPlus
from medicai.metrics import BinaryDiceMetric
from medicai.losses import BinaryDiceCELoss
from medicai.utils.inference import SlidingWindowInference
from medicai.callbacks import SlidingWindowInferenceCallback
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from src.experiment_config import ExperimentConfig
from src.data_pipeline.data_loader import data_loader

class TFCheckpointCallback(keras.callbacks.Callback):
"""Save model + optimizer + epoch using TF checkpointing and save SWI callback best score to a JSON file for restoring."""
def __init__(self, ckpt, ckpt_manager, swi_callback, checkpoint_dir):
super().__init__()
self.ckpt = ckpt
self.ckpt_manager = ckpt_manager
self.swi_callback = swi_callback
self.best_score_file = os.path.join(checkpoint_dir, "swi_best_score.json")

def on_epoch_end(self, epoch, logs=None):
# Update epoch variable and save checkpoint
self.ckpt.epoch.assign_add(1)   # increment epoch counter
save_path = self.ckpt_manager.save()
print(f"Saved checkpoint: {save_path} (epoch {int(self.ckpt.epoch.numpy())})")

# Save SWI best score externally
best_score = getattr(self.swi_callback, "best_score", -float("inf"))
with open(self.best_score_file, "w") as f:
json.dump({"best_score": best_score}, f)
print(f"[CheckpointCallback] Saved SWI best score: {best_score}")

class HistorySaverCallback(keras.callbacks.Callback):
"""Saves training history every epoch to CSV and allows resuming."""
def __init__(self, history_file, initial_history=None):
super().__init__()
self.history_file = history_file
self.full_history = initial_history if initial_history else {}

def on_epoch_end(self, epoch, logs=None):
if logs is None:
logs = {}
for k, v in logs.items():
self.full_history.setdefault(k, []).append(v)

# Save updated history
pd.DataFrame(self.full_history).to_csv(self.history_file, index=False)

def get_model(total_device):

model = UNETRPlusPlus(
encoder_name="unetr_plusplus_encoder",
input_shape=ExperimentConfig.input_shape,
num_classes=ExperimentConfig.num_classes,
classifier_activation=None,
)

total_train_samples = 387 # 80% ( approx.) split of the total dataset for train  as Unetr

# Compute steps per epoch and total steps
steps_per_epoch = total_train_samples // (ExperimentConfig.batch_size_train * total_device)
print(f"Steps per epoch : {steps_per_epoch}")
total_steps = steps_per_epoch * ExperimentConfig.epochs

# Warmup:  10% of total steps
warmup_steps = int(total_steps * 0.1)

# CosineDecay schedule with warmup
lr_schedule = keras.optimizers.schedules.CosineDecay(
initial_learning_rate=0.01 * ExperimentConfig.lr,  # very small starting LR
decay_steps=total_steps - warmup_steps,   # decay after warmup
alpha=ExperimentConfig.alpha,
warmup_target=ExperimentConfig.lr,
warmup_steps=warmup_steps
)

model.compile(
optimizer=keras.optimizers.AdamW(
learning_rate=lr_schedule,
weight_decay=ExperimentConfig.weight_decay,
),
loss=BinaryDiceCELoss(
from_logits=True,
dice_weight=1.0,
ce_weight=1.0,
reduction="mean",
num_classes=ExperimentConfig.num_classes,
),
metrics=[
BinaryDiceMetric(
from_logits=True,
ignore_empty=True,
num_classes=ExperimentConfig.num_classes,
name='dice',
),
BinaryDiceMetric(
from_logits=True,
ignore_empty=True,
target_class_ids=[0],
num_classes=ExperimentConfig.num_classes,
name='dice_tc',
),
BinaryDiceMetric(
from_logits=True,
ignore_empty=True,
target_class_ids=[1],
num_classes=ExperimentConfig.num_classes,
name='dice_wt',
),
BinaryDiceMetric(
from_logits=True,
ignore_empty=True,
target_class_ids=[2],
num_classes=ExperimentConfig.num_classes,
name='dice_et',
)
],
)

return model

def get_inference_metric():
swi_callback_metric = BinaryDiceMetric(
from_logits=True,
ignore_empty=True,
num_classes=ExperimentConfig.num_classes,
name='val_dice',
)
return swi_callback_metric

"""def run_sliding_window_inference_per_class_average(model, ds, roi_size, sw_batch_size, overlap, metrics_list):

#    Run sliding window inference on a dataset and compute all metrics (average + per class)

for metric in metrics_list:
metric.reset_states()

swi = SlidingWindowInference(
model,
num_classes=metrics_list[0].num_classes,
roi_size=roi_size,
sw_batch_size=sw_batch_size,
overlap=overlap
)

for x, y in ds:
y_pred = swi(x)
for metric in metrics_list:
metric.update_state(ops.convert_to_tensor(y), ops.convert_to_tensor(y_pred))

# Gather results
results = {}
for metric in metrics_list:
results[metric.name] = float(ops.convert_to_numpy(metric.result()))

return results"""

def main():
# reproducibility
keras.utils.set_random_seed(101)

print(
f"keras backend: {keras.config.backend()}\n"
f"keras version: {keras.version()}\n"
f"tensorflow version:  {tf.__version__}\n"
)

# get keras backend
keras_backend = keras.config.backend()

strategy = tf.distribute.MirroredStrategy()
total_device = strategy.num_replicas_in_sync

print('Keras backend ', keras_backend)
print('Total device found ', total_device)

project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
base_save_path = os.path.join(project_root, "experiments", "msd_brain")
unetrplusplus_path = os.path.join(base_save_path, "unetrplusplus")
os.makedirs(unetrplusplus_path, exist_ok=True)

# Subfolders
logs_path = os.path.join(unetrplusplus_path, "logs")
history_path = os.path.join(unetrplusplus_path, "history")
plots_path = os.path.join(unetrplusplus_path, "plots")
os.makedirs(logs_path, exist_ok=True)
os.makedirs(history_path, exist_ok=True)
os.makedirs(plots_path, exist_ok=True)

# Timestamp
timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")

# Save path for best model weights
save_path = os.path.join(unetrplusplus_path, f"best_model_weights_{timestamp}.weights.h5")

# File for containing the learning history
history_file = os.path.join(history_path, f"training_history.csv")

# Load datasets
tfrecord_pattern = os.path.join(project_root, "data", "msd_brain", "tfrecords", "{}_shard_*.tfrec")

# batch size for training
train_batch = ExperimentConfig.batch_size_train * total_device

train_ds = data_loader(
tfrecord_pattern.format("training"),
batch_size=train_batch,
shuffle=True
)
val_ds = data_loader(
tfrecord_pattern.format("validation"),
batch_size=ExperimentConfig.batch_size_val,
shuffle=False
)
test_ds = data_loader(
tfrecord_pattern.format("test"),
batch_size=ExperimentConfig.batch_size_val,
shuffle=False
)

with strategy.scope():
model = get_model(total_device)

checkpoint_dir = os.path.join(unetrplusplus_path, "checkpoints")
os.makedirs(checkpoint_dir, exist_ok=True)

with strategy.scope():
ckpt = tf.train.Checkpoint(
epoch=tf.Variable(0),          # epoch counter — saved as part of checkpoint
optimizer=model.optimizer,     # optimizer state
model=model                    # model weights
)

ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_dir, max_to_keep=3)

# Validation with sliding window callback
swi_callback_metric = get_inference_metric()
# Create SWI callback
swi_callback = SlidingWindowInferenceCallback(
model,
dataset=val_ds,
metrics=swi_callback_metric,
num_classes=ExperimentConfig.num_classes,
interval= ExperimentConfig.sliding_window_interval,
overlap=ExperimentConfig.sliding_window_overlap,
roi_size=(ExperimentConfig.input_shape[0],ExperimentConfig.input_shape[1],ExperimentConfig.input_shape[2]),
sw_batch_size=ExperimentConfig.sw_batch_size * total_device ,
save_path=save_path
)

# TFCheckpointCallback (save model, optimizer, epoch + SWI best score)
tf_ckpt_callback = TFCheckpointCallback(ckpt, ckpt_manager, swi_callback, checkpoint_dir)

# History callback
# Load previous history if exists
if os.path.exists(history_file):
prev_history = pd.read_csv(history_file).to_dict(orient='list')
else:
prev_history = {}
history_callback = HistorySaverCallback(history_file, initial_history=prev_history)

# Resume or start from scratch
if ckpt_manager.latest_checkpoint:
ckpt.restore(ckpt_manager.latest_checkpoint)
initial_epoch = int(ckpt.epoch.numpy())
print(f"[Resume] Restored checkpoint:  starting from epoch {initial_epoch}")

# Restore SWI best score
best_score_file = os.path.join(checkpoint_dir, "swi_best_score.json")
if os.path.exists(best_score_file):
with open(best_score_file, "r") as f:
swi_callback.best_score = json.load(f).get("best_score", -float("inf"))
print(f"[Resume] Restored SWI best validation score: {swi_callback.best_score}")
else:
print(f"[Resume] Couldn't Restore SWI best validation score")
else:
initial_epoch = 0
print("[Resume] No checkpoint found.  Starting from scratch.")

print(f"Model size: {model.count_params() / 1e6:.2f} M")

start_time = time.time()

with strategy.scope():
history = model.fit(
train_ds,
epochs=ExperimentConfig.epochs,
initial_epoch=initial_epoch,
callbacks=[
swi_callback,
tf_ckpt_callback,
history_callback
])

end_time = time.time()
training_time = end_time - start_time
print(f"Total training time (seconds): {training_time:.2f}")

# Save history to CSV
full_history = history_callback.full_history
# Save CSV
pd.DataFrame(full_history).to_csv(history_file, index=False)

# Plot loss
plt.figure(figsize=(10, 5))
plt.plot(full_history['loss'], label='train_loss')
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training Loss")
plt.legend()
plt.grid()
plt.savefig(os.path.join(plots_path, f"loss_curve_{timestamp}.png"))
plt.close()

# Plot average Dice
if 'dice' in full_history:
plt.figure(figsize=(10, 5))
plt.plot(full_history['dice'], label='train_dice')
plt.xlabel("Epoch")
plt.ylabel("Average Dice")
plt.title("Training Average Dice")
plt.legend()
plt.grid()
plt.savefig(os.path.join(plots_path, f"dice_curve_{timestamp}.png"))
plt.close()

print("Training and saving plots finished successfully.")

if __name__ == "__main__":
main()

Чтобы избежать ошибок, в настоящее время я помещаю почти все, что связано с обучением, внутри Strategy.scope(), включая некоторые объекты, в которых я не уверен, создают ли они переменные TensorFlow или нет.
В частности, внутри области, которую я создаю:
  • модель
  • оптимизатор
  • потеря
  • Все показатели обучения
  • показатель, используемый пользовательским обратным вызовом проверки
  • Объекты контрольных точек (

    Код: Выделить всё

    tf.train.Checkpoint
    , CheckpointManager)
  • Обратные вызовы, которые ссылаются на модель и метрики
Наборы данных, пути, ведение журнала и утилиты чистого Python создаются вне области действия.
Мое текущее понимание:
  • Объекты, которые создают переменные TensorFlow (модель, оптимизатор, метрики), должны быть созданы внутри Strategy.scope().
  • Объекты, которые владеют или обновляют метрики (например, пользовательские обратные вызовы, отслеживающие результаты проверки), также должны создаваться внутри области.
  • Контрольная точка объекты должны создаваться внутри области, чтобы они правильно отслеживали распределенные переменные.
  • Создание набора данных не обязательно должно быть внутри области.
Поэтому меня больше всего беспокоит то, что есть некоторые объекты, в которых я не уверен на 100%, создают ли они общие переменные TensorFlow внутри (например, пользовательские обратные вызовы или служебные классы, которые принимают метрики или модели).
Из-за этой неопределенности я выбрал, как мне кажется, самый безопасный вариант, который помещает все, в чем я не уверен, внутри Strategy.scope().
Поэтому мой вопрос: является ли мой код правильным и безопасным для распределенного обучения с несколькими графическими процессорами, или есть ли какие-либо ошибки в том, как я использую tf.distribute.MirroredStrategy и Strategy.scope()?

Подробнее здесь: https://stackoverflow.com/questions/798 ... ct-and-saf
Ответить

Быстрый ответ

Изменение регистра текста: 
Смайлики
:) :( :oops: :roll: :wink: :muza: :clever: :sorry: :angel: :read: *x)
Ещё смайлики…
   
К этому ответу прикреплено по крайней мере одно вложение.

Если вы не хотите добавлять вложения, оставьте поля пустыми.

Максимально разрешённый размер вложения: 15 МБ.

Вернуться в «Python»