Я обучаю модель с использованием Keras + TensorFlow с tf.distribute.MirroredStrategy в конфигурации с несколькими графическими процессорами. Я хотел бы убедиться, что я правильно использую Strategy.scope().
import time
import logging
import os
import json
import datetime
os.environ["TF_GPU_ALLOCATOR"] = "cuda_malloc_async"
os.environ['TF_CUDNN_DETERMINISTIC'] = '1'
os.environ["KERAS_BACKEND"] = "tensorflow"
#os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
#os.environ['TF_CPP_MAX_VLOG_LEVEL'] = '0'
import tensorflow as tf
gpus = tf.config.list_physical_devices('GPU')
if gpus:
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
from tensorflow.python.profiler import profiler_v2 as profiler
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import keras
from keras import ops
from keras import layers
from keras import mixed_precision
from medicai.models import UNETRPlusPlus
from medicai.metrics import BinaryDiceMetric
from medicai.losses import BinaryDiceCELoss
from medicai.utils.inference import SlidingWindowInference
from medicai.callbacks import SlidingWindowInferenceCallback
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from src.experiment_config import ExperimentConfig
from src.data_pipeline.data_loader import data_loader
class TFCheckpointCallback(keras.callbacks.Callback):
"""Save model + optimizer + epoch using TF checkpointing and save SWI callback best score to a JSON file for restoring."""
def __init__(self, ckpt, ckpt_manager, swi_callback, checkpoint_dir):
super().__init__()
self.ckpt = ckpt
self.ckpt_manager = ckpt_manager
self.swi_callback = swi_callback
self.best_score_file = os.path.join(checkpoint_dir, "swi_best_score.json")
def on_epoch_end(self, epoch, logs=None):
# Update epoch variable and save checkpoint
self.ckpt.epoch.assign_add(1) # increment epoch counter
save_path = self.ckpt_manager.save()
print(f"Saved checkpoint: {save_path} (epoch {int(self.ckpt.epoch.numpy())})")
# Save SWI best score externally
best_score = getattr(self.swi_callback, "best_score", -float("inf"))
with open(self.best_score_file, "w") as f:
json.dump({"best_score": best_score}, f)
print(f"[CheckpointCallback] Saved SWI best score: {best_score}")
class HistorySaverCallback(keras.callbacks.Callback):
"""Saves training history every epoch to CSV and allows resuming."""
def __init__(self, history_file, initial_history=None):
super().__init__()
self.history_file = history_file
self.full_history = initial_history if initial_history else {}
def on_epoch_end(self, epoch, logs=None):
if logs is None:
logs = {}
for k, v in logs.items():
self.full_history.setdefault(k, []).append(v)
# Save updated history
pd.DataFrame(self.full_history).to_csv(self.history_file, index=False)
def get_model(total_device):
model = UNETRPlusPlus(
encoder_name="unetr_plusplus_encoder",
input_shape=ExperimentConfig.input_shape,
num_classes=ExperimentConfig.num_classes,
classifier_activation=None,
)
total_train_samples = 387 # 80% ( approx.) split of the total dataset for train as Unetr
# Compute steps per epoch and total steps
steps_per_epoch = total_train_samples // (ExperimentConfig.batch_size_train * total_device)
print(f"Steps per epoch : {steps_per_epoch}")
total_steps = steps_per_epoch * ExperimentConfig.epochs
# Warmup: 10% of total steps
warmup_steps = int(total_steps * 0.1)
# CosineDecay schedule with warmup
lr_schedule = keras.optimizers.schedules.CosineDecay(
initial_learning_rate=0.01 * ExperimentConfig.lr, # very small starting LR
decay_steps=total_steps - warmup_steps, # decay after warmup
alpha=ExperimentConfig.alpha,
warmup_target=ExperimentConfig.lr,
warmup_steps=warmup_steps
)
model.compile(
optimizer=keras.optimizers.AdamW(
learning_rate=lr_schedule,
weight_decay=ExperimentConfig.weight_decay,
),
loss=BinaryDiceCELoss(
from_logits=True,
dice_weight=1.0,
ce_weight=1.0,
reduction="mean",
num_classes=ExperimentConfig.num_classes,
),
metrics=[
BinaryDiceMetric(
from_logits=True,
ignore_empty=True,
num_classes=ExperimentConfig.num_classes,
name='dice',
),
BinaryDiceMetric(
from_logits=True,
ignore_empty=True,
target_class_ids=[0],
num_classes=ExperimentConfig.num_classes,
name='dice_tc',
),
BinaryDiceMetric(
from_logits=True,
ignore_empty=True,
target_class_ids=[1],
num_classes=ExperimentConfig.num_classes,
name='dice_wt',
),
BinaryDiceMetric(
from_logits=True,
ignore_empty=True,
target_class_ids=[2],
num_classes=ExperimentConfig.num_classes,
name='dice_et',
)
],
)
return model
def get_inference_metric():
swi_callback_metric = BinaryDiceMetric(
from_logits=True,
ignore_empty=True,
num_classes=ExperimentConfig.num_classes,
name='val_dice',
)
return swi_callback_metric
"""def run_sliding_window_inference_per_class_average(model, ds, roi_size, sw_batch_size, overlap, metrics_list):
# Run sliding window inference on a dataset and compute all metrics (average + per class)
for metric in metrics_list:
metric.reset_states()
swi = SlidingWindowInference(
model,
num_classes=metrics_list[0].num_classes,
roi_size=roi_size,
sw_batch_size=sw_batch_size,
overlap=overlap
)
for x, y in ds:
y_pred = swi(x)
for metric in metrics_list:
metric.update_state(ops.convert_to_tensor(y), ops.convert_to_tensor(y_pred))
# Gather results
results = {}
for metric in metrics_list:
results[metric.name] = float(ops.convert_to_numpy(metric.result()))
return results"""
def main():
# reproducibility
keras.utils.set_random_seed(101)
print(
f"keras backend: {keras.config.backend()}\n"
f"keras version: {keras.version()}\n"
f"tensorflow version: {tf.__version__}\n"
)
# get keras backend
keras_backend = keras.config.backend()
strategy = tf.distribute.MirroredStrategy()
total_device = strategy.num_replicas_in_sync
print('Keras backend ', keras_backend)
print('Total device found ', total_device)
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
base_save_path = os.path.join(project_root, "experiments", "msd_brain")
unetrplusplus_path = os.path.join(base_save_path, "unetrplusplus")
os.makedirs(unetrplusplus_path, exist_ok=True)
# Subfolders
logs_path = os.path.join(unetrplusplus_path, "logs")
history_path = os.path.join(unetrplusplus_path, "history")
plots_path = os.path.join(unetrplusplus_path, "plots")
os.makedirs(logs_path, exist_ok=True)
os.makedirs(history_path, exist_ok=True)
os.makedirs(plots_path, exist_ok=True)
# Timestamp
timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
# Save path for best model weights
save_path = os.path.join(unetrplusplus_path, f"best_model_weights_{timestamp}.weights.h5")
# File for containing the learning history
history_file = os.path.join(history_path, f"training_history.csv")
# Load datasets
tfrecord_pattern = os.path.join(project_root, "data", "msd_brain", "tfrecords", "{}_shard_*.tfrec")
# batch size for training
train_batch = ExperimentConfig.batch_size_train * total_device
train_ds = data_loader(
tfrecord_pattern.format("training"),
batch_size=train_batch,
shuffle=True
)
val_ds = data_loader(
tfrecord_pattern.format("validation"),
batch_size=ExperimentConfig.batch_size_val,
shuffle=False
)
test_ds = data_loader(
tfrecord_pattern.format("test"),
batch_size=ExperimentConfig.batch_size_val,
shuffle=False
)
with strategy.scope():
model = get_model(total_device)
checkpoint_dir = os.path.join(unetrplusplus_path, "checkpoints")
os.makedirs(checkpoint_dir, exist_ok=True)
with strategy.scope():
ckpt = tf.train.Checkpoint(
epoch=tf.Variable(0), # epoch counter — saved as part of checkpoint
optimizer=model.optimizer, # optimizer state
model=model # model weights
)
ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_dir, max_to_keep=3)
# Validation with sliding window callback
swi_callback_metric = get_inference_metric()
# Create SWI callback
swi_callback = SlidingWindowInferenceCallback(
model,
dataset=val_ds,
metrics=swi_callback_metric,
num_classes=ExperimentConfig.num_classes,
interval= ExperimentConfig.sliding_window_interval,
overlap=ExperimentConfig.sliding_window_overlap,
roi_size=(ExperimentConfig.input_shape[0],ExperimentConfig.input_shape[1],ExperimentConfig.input_shape[2]),
sw_batch_size=ExperimentConfig.sw_batch_size * total_device ,
save_path=save_path
)
# TFCheckpointCallback (save model, optimizer, epoch + SWI best score)
tf_ckpt_callback = TFCheckpointCallback(ckpt, ckpt_manager, swi_callback, checkpoint_dir)
# History callback
# Load previous history if exists
if os.path.exists(history_file):
prev_history = pd.read_csv(history_file).to_dict(orient='list')
else:
prev_history = {}
history_callback = HistorySaverCallback(history_file, initial_history=prev_history)
# Resume or start from scratch
if ckpt_manager.latest_checkpoint:
ckpt.restore(ckpt_manager.latest_checkpoint)
initial_epoch = int(ckpt.epoch.numpy())
print(f"[Resume] Restored checkpoint: starting from epoch {initial_epoch}")
# Restore SWI best score
best_score_file = os.path.join(checkpoint_dir, "swi_best_score.json")
if os.path.exists(best_score_file):
with open(best_score_file, "r") as f:
swi_callback.best_score = json.load(f).get("best_score", -float("inf"))
print(f"[Resume] Restored SWI best validation score: {swi_callback.best_score}")
else:
print(f"[Resume] Couldn't Restore SWI best validation score")
else:
initial_epoch = 0
print("[Resume] No checkpoint found. Starting from scratch.")
print(f"Model size: {model.count_params() / 1e6:.2f} M")
start_time = time.time()
with strategy.scope():
history = model.fit(
train_ds,
epochs=ExperimentConfig.epochs,
initial_epoch=initial_epoch,
callbacks=[
swi_callback,
tf_ckpt_callback,
history_callback
])
end_time = time.time()
training_time = end_time - start_time
print(f"Total training time (seconds): {training_time:.2f}")
# Save history to CSV
full_history = history_callback.full_history
# Save CSV
pd.DataFrame(full_history).to_csv(history_file, index=False)
# Plot loss
plt.figure(figsize=(10, 5))
plt.plot(full_history['loss'], label='train_loss')
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training Loss")
plt.legend()
plt.grid()
plt.savefig(os.path.join(plots_path, f"loss_curve_{timestamp}.png"))
plt.close()
# Plot average Dice
if 'dice' in full_history:
plt.figure(figsize=(10, 5))
plt.plot(full_history['dice'], label='train_dice')
plt.xlabel("Epoch")
plt.ylabel("Average Dice")
plt.title("Training Average Dice")
plt.legend()
plt.grid()
plt.savefig(os.path.join(plots_path, f"dice_curve_{timestamp}.png"))
plt.close()
print("Training and saving plots finished successfully.")
if __name__ == "__main__":
main()
Чтобы избежать ошибок, в настоящее время я помещаю почти все, что связано с обучением, внутри Strategy.scope(), включая некоторые объекты, в которых я не уверен, создают ли они переменные TensorFlow или нет.
В частности, внутри области, которую я создаю:
Обратные вызовы, которые ссылаются на модель и метрики
Наборы данных, пути, ведение журнала и утилиты чистого Python создаются вне области действия.
Мое текущее понимание:
Объекты, которые создают переменные TensorFlow (модель, оптимизатор, метрики), должны быть созданы внутри Strategy.scope().
Объекты, которые владеют или обновляют метрики (например, пользовательские обратные вызовы, отслеживающие результаты проверки), также должны создаваться внутри области.
Контрольная точка объекты должны создаваться внутри области, чтобы они правильно отслеживали распределенные переменные.
Создание набора данных не обязательно должно быть внутри области.
Поэтому меня больше всего беспокоит то, что есть некоторые объекты, в которых я не уверен на 100%, создают ли они общие переменные TensorFlow внутри (например, пользовательские обратные вызовы или служебные классы, которые принимают метрики или модели).
Из-за этой неопределенности я выбрал, как мне кажется, самый безопасный вариант, который помещает все, в чем я не уверен, внутри Strategy.scope().
Поэтому мой вопрос: является ли мой код правильным и безопасным для распределенного обучения с несколькими графическими процессорами, или есть ли какие-либо ошибки в том, как я использую tf.distribute.MirroredStrategy и Strategy.scope()?
Я обучаю модель с использованием Keras + TensorFlow с tf.distribute.MirroredStrategy в конфигурации с несколькими графическими процессорами. Я хотел бы убедиться, что я правильно использую Strategy.scope(). [code]import time import logging import os import json import datetime os.environ["TF_GPU_ALLOCATOR"] = "cuda_malloc_async" os.environ['TF_CUDNN_DETERMINISTIC'] = '1' os.environ["KERAS_BACKEND"] = "tensorflow" #os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' #os.environ['TF_CPP_MAX_VLOG_LEVEL'] = '0' import tensorflow as tf gpus = tf.config.list_physical_devices('GPU') if gpus: for gpu in gpus: tf.config.experimental.set_memory_growth(gpu, True) from tensorflow.python.profiler import profiler_v2 as profiler import numpy as np import pandas as pd from matplotlib import pyplot as plt import keras from keras import ops from keras import layers from keras import mixed_precision from medicai.models import UNETRPlusPlus from medicai.metrics import BinaryDiceMetric from medicai.losses import BinaryDiceCELoss from medicai.utils.inference import SlidingWindowInference from medicai.callbacks import SlidingWindowInferenceCallback import sys sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) from src.experiment_config import ExperimentConfig from src.data_pipeline.data_loader import data_loader
class TFCheckpointCallback(keras.callbacks.Callback): """Save model + optimizer + epoch using TF checkpointing and save SWI callback best score to a JSON file for restoring.""" def __init__(self, ckpt, ckpt_manager, swi_callback, checkpoint_dir): super().__init__() self.ckpt = ckpt self.ckpt_manager = ckpt_manager self.swi_callback = swi_callback self.best_score_file = os.path.join(checkpoint_dir, "swi_best_score.json")
# Save SWI best score externally best_score = getattr(self.swi_callback, "best_score", -float("inf")) with open(self.best_score_file, "w") as f: json.dump({"best_score": best_score}, f) print(f"[CheckpointCallback] Saved SWI best score: {best_score}")
class HistorySaverCallback(keras.callbacks.Callback): """Saves training history every epoch to CSV and allows resuming.""" def __init__(self, history_file, initial_history=None): super().__init__() self.history_file = history_file self.full_history = initial_history if initial_history else {}
def on_epoch_end(self, epoch, logs=None): if logs is None: logs = {} for k, v in logs.items(): self.full_history.setdefault(k, []).append(v)
# Save updated history pd.DataFrame(self.full_history).to_csv(self.history_file, index=False)
def get_model(total_device):
model = UNETRPlusPlus( encoder_name="unetr_plusplus_encoder", input_shape=ExperimentConfig.input_shape, num_classes=ExperimentConfig.num_classes, classifier_activation=None, )
total_train_samples = 387 # 80% ( approx.) split of the total dataset for train as Unetr
# Compute steps per epoch and total steps steps_per_epoch = total_train_samples // (ExperimentConfig.batch_size_train * total_device) print(f"Steps per epoch : {steps_per_epoch}") total_steps = steps_per_epoch * ExperimentConfig.epochs
# Warmup: 10% of total steps warmup_steps = int(total_steps * 0.1)
# CosineDecay schedule with warmup lr_schedule = keras.optimizers.schedules.CosineDecay( initial_learning_rate=0.01 * ExperimentConfig.lr, # very small starting LR decay_steps=total_steps - warmup_steps, # decay after warmup alpha=ExperimentConfig.alpha, warmup_target=ExperimentConfig.lr, warmup_steps=warmup_steps )
with strategy.scope(): ckpt = tf.train.Checkpoint( epoch=tf.Variable(0), # epoch counter — saved as part of checkpoint optimizer=model.optimizer, # optimizer state model=model # model weights )
# History callback # Load previous history if exists if os.path.exists(history_file): prev_history = pd.read_csv(history_file).to_dict(orient='list') else: prev_history = {} history_callback = HistorySaverCallback(history_file, initial_history=prev_history)
# Resume or start from scratch if ckpt_manager.latest_checkpoint: ckpt.restore(ckpt_manager.latest_checkpoint) initial_epoch = int(ckpt.epoch.numpy()) print(f"[Resume] Restored checkpoint: starting from epoch {initial_epoch}")
# Restore SWI best score best_score_file = os.path.join(checkpoint_dir, "swi_best_score.json") if os.path.exists(best_score_file): with open(best_score_file, "r") as f: swi_callback.best_score = json.load(f).get("best_score", -float("inf")) print(f"[Resume] Restored SWI best validation score: {swi_callback.best_score}") else: print(f"[Resume] Couldn't Restore SWI best validation score") else: initial_epoch = 0 print("[Resume] No checkpoint found. Starting from scratch.")
# Plot average Dice if 'dice' in full_history: plt.figure(figsize=(10, 5)) plt.plot(full_history['dice'], label='train_dice') plt.xlabel("Epoch") plt.ylabel("Average Dice") plt.title("Training Average Dice") plt.legend() plt.grid() plt.savefig(os.path.join(plots_path, f"dice_curve_{timestamp}.png")) plt.close()
print("Training and saving plots finished successfully.")
if __name__ == "__main__": main()
[/code] Чтобы избежать ошибок, в настоящее время я помещаю [b]почти все, что связано с обучением[/b], внутри Strategy.scope(), включая некоторые объекты, в которых я [b]не уверен, создают ли они переменные TensorFlow или нет[/b]. В частности, внутри области, которую я создаю: [list] [*][b]модель[/b]
[*][b]Обратные вызовы, которые ссылаются на модель и метрики[/b]
[/list] Наборы данных, пути, ведение журнала и утилиты чистого Python создаются вне области действия. Мое текущее понимание: [list] [*]Объекты, которые [b]создают переменные TensorFlow[/b] (модель, оптимизатор, метрики), должны быть созданы внутри Strategy.scope().
[*]Объекты, которые [b]владеют или обновляют метрики[/b] (например, пользовательские обратные вызовы, отслеживающие результаты проверки), также должны создаваться внутри области.
[*]Контрольная точка объекты должны создаваться внутри области, чтобы они правильно отслеживали распределенные переменные.
[*]Создание набора данных не обязательно должно быть внутри области.
[/list] Поэтому меня больше всего беспокоит то, что есть [b]некоторые объекты, в которых я не уверен на 100%[/b], создают ли они общие переменные TensorFlow внутри (например, пользовательские обратные вызовы или служебные классы, которые принимают метрики или модели). Из-за этой неопределенности я выбрал, как мне кажется, самый безопасный вариант, который [b]помещает все, в чем я не уверен, внутри Strategy.scope()[/b]. Поэтому мой вопрос: является ли мой код правильным и безопасным для распределенного обучения с несколькими графическими процессорами, или есть ли какие-либо ошибки в том, как я использую tf.distribute.MirroredStrategy и Strategy.scope()?