Дистилляция знаний YOLO (с 11x до 11N) дает более низкую производительность, чем обучение нативногоPython

Программы на Python
Ответить Пред. темаСлед. тема
Anonymous
 Дистилляция знаний YOLO (с 11x до 11N) дает более низкую производительность, чем обучение нативного

Сообщение Anonymous »

Я пытаюсь переоборудовать модель обнаружения YOLO11X в YOLO11N для улучшения скорости вывода, не жертвуя слишком большим количеством производительности обнаружения.
Для этого я только что перегружал некоторые функции в библиотеке ультралитических средств для реализации пользовательских потерь и некоторых обратных вызовов для регистрации. Много потерь происходит от дистилляции), скорости обучения, разминки и оптимизатора, но производительность неизменно хуже < /strong>, чем наивно тонко настроенный yolo11n.
Я использую один набор данных для обучения (назовуте его D) и один для теста (B). B содержит только «новые» изображения, не присутствующие в D. В моем трубопроводе я хорошо настраиваю V11x с D и использую их в качестве учителя на D для V11N, на котором я тестирую b. Используя более умные функции потерь (CWD), он достигает полей лучшей производительности, чем моя дистилляция, но все же намного хуже, чем нативный v11n: < /p>



< /th> < /th> < /th> < /th> < /th> < /th> < /th> < /th> < /th> < /th> < /th> < /th> < /th> /> map < /th>
precision < /th>
remeply < /th>
< /tr>
< /thead>


v11x < /td>
0.543339444. />0.389167 >
0.6385566
0.542057> < /> < /tr>

Native V11n < /td>
>0.40.44444444444444444444444444444444444444444444444444444444440. />0.289082
0.599968
0.468279
< /tr>

my diallidel v11n < /td>
0.364.364.0.3646. />0.2225889 >
0.615199
0.374849
< /tr>

yolo-distiller v11n < /td>
0.399945
0.255973
0.564008
0.401043



Any ideas Почему производительность дистилляции может быть таким плохим? Я не могу придумать причину, по которой это может быть хуже, чем обучение нативного, как учитель, вызывает ученика ошибочно? Я использую большую часть конфигурации Ultralytics по умолчанию и просто изменил некоторые параметры в Yaml: < /p>
epochs: 200
imgsz: 1088
optimizer: AdamW
lr0: 0.001
lrf: 0.01
< /code>
Я сгруппировал все (кроме конфигурации) в одном файле, если необходим контекст, и перевел комментарии на английский: < /p>
import torch
import torch.nn as nn
from ultralytics import YOLO
import ultralytics
from ultralytics.utils.loss import v8DetectionLoss
from ultralytics.models.yolo.detect.train import DetectionTrainer
import yaml
import os
import torch.nn.functional as F
from ultralytics.cfg import get_cfg
import datetime
from torch.utils.tensorboard import SummaryWriter

def distillation_loss(student_raw_outputs_by_scale, teacher_raw_outputs_by_scale, temperature, num_classes, reg_max):
"""
Calculates the distillation loss (KL Divergence) between the raw outputs of the student and teacher.
This version includes distillation for classification and regression (DFL).

Args:
student_raw_outputs_by_scale (list of Tensors): List of raw outputs from the student model per scale.
Each tensor is (B, C, H, W) where C = (4*reg_max + num_classes)
teacher_raw_outputs_by_scale (list of Tensors): List of raw outputs from the teacher model per scale.
Each tensor is (B, C, H, W) where C = (4*reg_max + num_classes)
temperature (float): Temperature for softening the distributions.
num_classes (int): Number of classes.
reg_max (int): reg_max value used in the detection head.
Returns:
torch.Tensor: The total distillation loss (classification + DFL).
"""

total_kd_loss = 0.0
num_scales = len(student_raw_outputs_by_scale) # Theoretically 3 for YOLOv8

log_debug_info(f"Num scales: {num_scales}")

for i in range(num_scales):
student_scale_output = student_raw_outputs_by_scale # (B, C, H, W) = (batch size, total number of output channels,
teacher_scale_output = teacher_raw_outputs_by_scale # feature map height, feature map width)

# Reshape from (B, C, H, W) to (B, H*W, C) to group H*W pixel predictions
# C = (4 * reg_max) + num_classes: 4 coords per box, so 4*reg_max channels with reg_max prob distributions (+ num_classes cls logits)
student_reshaped = student_scale_output.view(student_scale_output.shape[0], student_scale_output.shape[1], -1).permute(0, 2, 1).contiguous()
teacher_reshaped = teacher_scale_output.view(teacher_scale_output.shape[0], teacher_scale_output.shape[1], -1).permute(0, 2, 1).contiguous()

# Extract classification logits
# In YOLOv8, the last num_classes channels/logits are for classification
student_cls_logits = student_reshaped[..., -num_classes:] # (B, H*W, num_classes)
teacher_cls_logits = teacher_reshaped[..., -num_classes:] # (B, H*W, num_classes)

log_debug_info(f"Scale {i} - Student CLS Logits: "
f"min={student_cls_logits.min().item():.4f}, "
f"max={student_cls_logits.max().item():.4f}, "
f"mean={student_cls_logits.mean().item():.4f}, "
f"std={student_cls_logits.std().item():.4f}")
log_debug_info(f"Scale {i} - Teacher CLS Logits: "
f"min={teacher_cls_logits.min().item():.4f}, "
f"max={teacher_cls_logits.max().item():.4f}, "
f"mean={teacher_cls_logits.mean().item():.4f}, "
f"std={teacher_cls_logits.std().item():.4f}")

# Calculate KL Divergence for classification
cls_kd_loss_scale = F.kl_div(
F.log_softmax(student_cls_logits / temperature, dim=-1), # student log probabilities
F.softmax(teacher_cls_logits / temperature, dim=-1), # teacher probabilities
reduction="batchmean" # Mean per batch
) * (temperature ** 2) # Important: multiply by T^2, see google paper

total_kd_loss += cls_kd_loss_scale
log_debug_info(f"KD class loss: {cls_kd_loss_scale.item():.6f}")

# Extract DFL logits
# The first 4*reg_max are for DFL (bbox regression)
student_dfl_logits = student_reshaped[..., :4 * reg_max] # (B, H*W, 4*reg_max)
teacher_dfl_logits = teacher_reshaped[..., :4 * reg_max] # (B, H*W, 4*reg_max)

# Reshape to apply softmax on each DFL distribution for each coordinate (x1, y1, x2, y2)
# (B, H*W, 4, reg_max) with 4 for left, right, top, bottom
student_dfl_reshaped = student_dfl_logits.view(student_dfl_logits.shape[0], student_dfl_logits.shape[1], 4, reg_max)
teacher_dfl_reshaped = teacher_dfl_logits.view(teacher_dfl_logits.shape[0], teacher_dfl_logits.shape[1], 4, reg_max)

log_debug_info(f"Scale {i} - Student DFL Logits: "
f"min={student_dfl_reshaped.min().item():.4f}, "
f"max={student_dfl_reshaped.max().item():.4f}, "
f"mean={student_dfl_reshaped.mean().item():.4f}, "
f"std={student_dfl_reshaped.std().item():.4f}")
log_debug_info(f"Scale {i} - Teacher DFL Logits: "
f"min={teacher_dfl_reshaped.min().item():.4f}, "
f"max={teacher_dfl_reshaped.max().item():.4f}, "
f"mean={teacher_dfl_reshaped.mean().item():.4f}, "
f"std={teacher_dfl_reshaped.std().item():.4}")

# Calculate KL divergence for DFL logits
dfl_kd_loss_scale = F.kl_div(
F.log_softmax(student_dfl_reshaped / temperature, dim=-1), # student DFL log probabilities
F.softmax(teacher_dfl_reshaped / temperature, dim=-1), # teacher DFL probabilities
reduction="batchmean" # Mean per batch
) * (temperature ** 2) # Multiply by T^2

log_debug_info(f"KD dfl loss: {dfl_kd_loss_scale.item():.6f}")

total_kd_loss += dfl_kd_loss_scale

total_kd_loss = total_kd_loss * CORRECTION_GAMMA
log_debug_info(f"Corrected distillation loss: {total_kd_loss}")

return total_kd_loss / num_scales # Average over scales (if > 1)

def on_train_epoch_end_callback(trainer):
"""
Callback called by Ultralytics at the end of each training epoch.
Logs the learning rate and training/validation losses to TensorBoard and the console.
"""
if VERBOSE and writer:
current_lr = trainer.optimizer.param_groups[0]['lr']
current_epoch = trainer.epoch

writer.add_scalar('Learning_Rate', current_lr, current_epoch)
total_train_loss = trainer.tloss.sum().item()
writer.add_scalar('Train/Hard_Loss_Total', total_train_loss, current_epoch)
# Other components in trainer.tloss
for i, name in enumerate(trainer.loss_names):
writer.add_scalar(f'Train/Hard_Loss_{name}', trainer.tloss.item(), current_epoch)

if hasattr(trainer, 'epoch_kd_loss') and trainer.batch_count > 0:
avg_kd_loss = trainer.epoch_kd_loss / trainer.batch_count
avg_combined_loss = trainer.epoch_combined_loss / trainer.batch_count
writer.add_scalar('Train/KD_Loss_Total', avg_kd_loss, current_epoch)
writer.add_scalar('Train/Combined_Loss_Total', avg_combined_loss, current_epoch)
print(f" Train KD Loss (Total): {avg_kd_loss:.4f}")
print(f" Train Combined Loss (Total): {avg_combined_loss:.4f}")

def on_train_epoch_start_callback(trainer):
if hasattr(trainer, 'epoch_kd_loss'):
trainer.epoch_kd_loss = 0.0
trainer.epoch_combined_loss = 0.0
trainer.batch_count = 0

TEACHER = 'v11x_fin_s3r.pt'
TRAINING_CONFIG_YAML = 'distillation_config.yaml'
TEMPERATURE = 10.0
ALPHA = 0.8
CORRECTION_GAMMA = 0.001 # Factor found empirically to scale down the huge KD loss
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
IMG_SIZE= 1088
DEBUG_LOG_FILE = "distillation_debug.log"
VERBOSE = True
TENSORBOARD_LOG_DIR = "runs/distillation_logs"
writer = None

if not VERBOSE:
print("LOGS DISABLED")
else:
writer = SummaryWriter(TENSORBOARD_LOG_DIR)
print(f"TensorBoard logs will be saved to: {TENSORBOARD_LOG_DIR}")

def log_debug_info(message):
"""Writes a debug message to a file."""
if not VERBOSE:
return
timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")
with open(DEBUG_LOG_FILE, "a") as f:
f.write(f"[{timestamp}] {message}\n")

class DistillationTrainer(DetectionTrainer):

def __init__(self, overrides=None, teacher_model_path=None, distillation_temp=5.0, distillation_alpha=0.5):
super().__init__(overrides) # Pass the entire YAML config in overrides
self.teacher_model = YOLO(teacher_model_path).to(self.device)
self.teacher_model.model.eval() # Set teacher model to evaluation mode
self.distillation_temp = distillation_temp
self.distillation_alpha = distillation_alpha

# self.original_ultralytics_criterion is updated in _setup_train
self.original_ultralytics_criterion = None

log_debug_info(f"DistillationTrainer init(): params: path={teacher_model_path}, temp={self.distillation_temp}, alpha={self.distillation_alpha}")
print("Teacher model loaded in DistillationTrainer.")

self.epoch_kd_loss = 0.0
self.epoch_combined_loss = 0.0
self.batch_count = 0 # Counter for averaging

def setup_model(self):
"""
Configures and loads the model.
"""
ckpt = super().setup_model()
log_debug_info("setup_model(): superclass called. self.model initialized.")
return ckpt

def _setup_train(self, world_size):
"""
Build dataloaders and optimizer on correct rank process, and set up the custom loss.
"""
log_debug_info("DistillationTrainer _setup_train(): start")
super()._setup_train(world_size) # calls setup_model, set_model_attributes, etc.
log_debug_info("DistillationTrainer _setup_train(): super()._setup_train finished")

# check on the criterion, which might not be initialized (normally it should be by the super())
if getattr(self.model, "criterion", None) is None:
# if the base model's loss is not yet called, initialize it
self.model.criterion = v8DetectionLoss(self.model)
log_debug_info(f"DistillationTrainer _setup_train(): self.model.criterion initialized to v8DetectionLoss (first pass).")

self.original_ultralytics_criterion = self.model.criterion # need to keep the original criterion
log_debug_info(f"DistillationTrainer _setup_train(): original_ultralytics_criterion saved: {type(self.original_ultralytics_criterion)}")

# Replace the original criterion
self.model.criterion = self._custom_combined_loss_fn
log_debug_info("DistillationTrainer _setup_train(): self.model.criterion redirected to _custom_combined_loss_fn.")

def _custom_combined_loss_fn(self, preds, batch):
"""
Calculates the total combined loss (hard target loss + distillation loss).
Args:
preds (list of Tensors): raw outputs (logits) of the student model per scale.
Each tensor is (B, C, H, W) where C = (4*reg_max + num_classes)
batch (dict): contains img and labels.
Returns:
torch.Tensor: Total combined loss.
torch.Tensor: Loss components for logging.
"""

log_debug_info("_custom_combined_loss_fn(): Starting loss calculation")
# Native hard target loss
hard_loss, hard_loss_items = self.original_ultralytics_criterion(preds, batch)
log_debug_info(f"Custom combined loss fn: Raw hard_loss: {hard_loss}, Type: {type(hard_loss)}, Shape: {hard_loss.shape if isinstance(hard_loss, torch.Tensor) else 'N/A'}")
log_debug_info(f"Custom combined loss fn: Raw hard_loss_items: {hard_loss_items}, Type: {type(hard_loss_items)}, Shape: {hard_loss_items.shape if isinstance(hard_loss_items, torch.Tensor) else 'N/A'}")

# Sum components to get a scalar total hard loss
total_hard_loss = hard_loss.sum()
log_debug_info(f"Custom combined loss fn: Summed hard_loss: {total_hard_loss.item()}")

with torch.no_grad():
# theoretically returns something like return y if self.export else (y, x) (see head.py, class Detect)
teacher_outputs = self.teacher_model.model(batch['img'].to(self.device, non_blocking=True))

# Extract teacher logits
if isinstance(teacher_outputs, tuple) and len(teacher_outputs) == 2:
teacher_raw_outputs_by_scale = teacher_outputs[1]
else:
log_debug_info("Custom combined loss fn: problem retrieving teacher logits")
raise ValueError("Teacher model output in an unexpected format")

student_raw_outputs_by_scale = preds # Raw student logits (list of tensors)

# num_classes + reg_max from the student model head (last layer of self.model.model)
num_classes = self.model.model[-1].nc
reg_max = self.model.model[-1].reg_max

kd_loss = distillation_loss(student_raw_outputs_by_scale, teacher_raw_outputs_by_scale, self.distillation_temp, num_classes, reg_max)
log_debug_info(f"Custom combined loss fn: KD loss: {kd_loss.item()}")

# Combine losses with (1-alpha) for hard loss and alpha for KD loss
combined_loss = (1 - self.distillation_alpha) * total_hard_loss + self.distillation_alpha * kd_loss
log_debug_info(f"Custom combined loss fn: Combined loss: {combined_loss.item()}")

# hard_loss_items = [box_loss, cls_loss, dfl_loss]
if not isinstance(hard_loss_items, torch.Tensor):
# conversion in case, but it's already a tensor according to v8DetectionLoss
hard_loss_items_tensor = torch.tensor(hard_loss_items, device=self.device)
else:
hard_loss_items_tensor = hard_loss_items

# Only return the 3 standard loss components for the ultralytics validator
# kd_loss is already part of the combined loss for backpropagation
log_items = hard_loss_items_tensor * (1 - self.distillation_alpha)
log_debug_info(f"Custom combined loss fn: cooked log items: Shape: {log_items.shape}")

self.epoch_kd_loss += kd_loss.item() # Used for logging
self.epoch_combined_loss += combined_loss.item()
self.batch_count += 1

return combined_loss, log_items # Return total loss and items for logging

log_debug_info("Starting main program")
train_args_from_yaml = get_cfg(TRAINING_CONFIG_YAML)

if os.path.exists(DEBUG_LOG_FILE): # Delete the log file on each restart
os.remove(DEBUG_LOG_FILE)
print("Previous logs deleted")

# update device
if train_args_from_yaml.device is None:
train_args_from_yaml.device = DEVICE

distillation_trainer = DistillationTrainer(
overrides=train_args_from_yaml,
teacher_model_path=TEACHER,
distillation_temp=TEMPERATURE,
distillation_alpha=ALPHA
)

if VERBOSE:
distillation_trainer.add_callback("on_train_epoch_start", on_train_epoch_start_callback)
distillation_trainer.add_callback("on_train_epoch_end", on_train_epoch_end_callback)

distillation_trainer.train()

print(f"Distillation finished")

if VERBOSE and writer:
writer.close()
< /code>
Что вы пытались и чего вы ожидали? < /strong>

-> Другой код дистилляции из Интернета: лучше, но все же плохо.>

Подробнее здесь: https://stackoverflow.com/questions/796 ... n-native-t
Реклама
Ответить Пред. темаСлед. тема

Быстрый ответ

Изменение регистра текста: 
Смайлики
:) :( :oops: :roll: :wink: :muza: :clever: :sorry: :angel: :read: *x)
Ещё смайлики…
   
К этому ответу прикреплено по крайней мере одно вложение.

Если вы не хотите добавлять вложения, оставьте поля пустыми.

Максимально разрешённый размер вложения: 15 МБ.

  • Похожие темы
    Ответы
    Просмотры
    Последнее сообщение

Вернуться в «Python»