Как правильно обрабатывать сбой выполнения нескольких графических процессоров на одном графическом процессоре из-за OOMPython

Программы на Python
Ответить Пред. темаСлед. тема
Anonymous
 Как правильно обрабатывать сбой выполнения нескольких графических процессоров на одном графическом процессоре из-за OOM

Сообщение Anonymous »

Я работаю с несколькими графическими процессорами, обрабатывающими большие объемы данных.
Я хочу создать систему обнаружения нехватки памяти (OOM), которая пропускает текущий пакет на всех графических процессорах, если на каком-либо из них не хватает памяти.
Однако по непонятным мне причинам только графический процессор OOM достигает точки синхронизации dist.all_reduce. Остальные не регистрируют ничего, кроме первой печати, и выполнение зависает и завершается без дальнейшего сообщения.
Мне кажется, что я упускаю что-то простое или какие-то мелочи распределенных вычислений, которые мне не нужны. не знаю. Если бы кто-нибудь указал на мою ошибку, я был бы благодарен.

Код: Выделить всё

def train_epoch(model, loader, optimizer, device, loss_fn):
for batch_idx, data in enumerate(loader):
if hasattr(data, 'stores') and isinstance(data.stores, list):
for store in data.stores:
if 'name' in store:
print(f"[rank {idr_torch.rank}] Batch {train_count} contains samples with names: {store['name']}")

# Initialize OOM flag
oom_flag = torch.tensor(0, device=device)

try:
# Move data to device
data = data.to(device)
optimizer.zero_grad()

# Forward pass
pred = model(data)

# Compute loss
loss = loss_fn(pred, data=data, device=device)

# Backward pass
loss.backward()

# Optimizer step
optimizer.step()

except RuntimeError as e:
if 'CUDA out of memory' in str(e):
print(f"[rank {idr_torch.rank}] CUDA OOM at batch {batch_idx}. Skipping batch...")
torch.cuda.empty_cache()

# Log problematic batch
if hasattr(data, 'stores'):
for store in data.stores:
if 'name' in store:
print(f"[rank {idr_torch.rank}] Problematic batch samples: {store['name']}")

# Set OOM flag
oom_flag = torch.tensor(1, device=device)

# Clear gradients and cache to prevent residue state
optimizer.zero_grad(set_to_none=True)
torch.cuda.empty_cache()

else:
raise e  # Raise non-OOM exceptions

# Synchronize OOM flag across ranks (ensures all GPUs check if any had an OOM)
print(f"[rank {idr_torch.rank}] Waiting on OOM-flag synch in batch {batch_idx}...")
torch.distributed.all_reduce(oom_flag, op=torch.distributed.ReduceOp.MAX)
print(f"[rank {idr_torch.rank}] Synch complete at batch {batch_idx}.")

# If any rank had OOM, skip the batch
if oom_flag.item() > 0:
print(f"[rank {idr_torch.rank}] Skipping synchronized batch {batch_idx} due to OOM...")
skip_count += 1
continue  # Skip optimizer step

# Ensure memory cleanup after epoch
torch.cuda.empty_cache()
return True
Выход:

Код: Выделить всё

            
2024-12-11 17:10:24,414 - INFO - [rank 14] Batch 26 contains samples with names: ['test1', 'test2']
2024-12-11 17:10:24,414 - INFO - [rank 8] Batch 26 contains samples with names: ['test2', 'test3']
2024-12-11 17:10:24,414 - INFO - [rank 15] Batch 26 contains samples with names: ['test4', 'test5']
2024-12-11 17:10:30,923 - INFO - [rank 3] CUDA OOM at batch 26. Skipping batch...
2024-12-11 17:10:30,925 - INFO - [rank 3] Problematic batch samples: ['test6', 'test7']
2024-12-11 17:10:30,932 - INFO - [rank 3] Waiting on OOM-flag synch in batch 26...
2024-12-11 17:10:30,934 - INFO - [rank 3] Synch complete at batch 26.

Я также пробовал добавить дополнительный dist.barrier() перед all_reduce, это приводит к зависанию графического процессора OOM на «Ожидание флага OOM». синхронизировать"

Подробнее здесь: https://stackoverflow.com/questions/792 ... due-to-oom
Реклама
Ответить Пред. темаСлед. тема

Быстрый ответ

Изменение регистра текста: 
Смайлики
:) :( :oops: :roll: :wink: :muza: :clever: :sorry: :angel: :read: *x)
Ещё смайлики…
   
К этому ответу прикреплено по крайней мере одно вложение.

Если вы не хотите добавлять вложения, оставьте поля пустыми.

Максимально разрешённый размер вложения: 15 МБ.

  • Похожие темы
    Ответы
    Просмотры
    Последнее сообщение

Вернуться в «Python»