Я пытаюсь переоборудовать модель обнаружения YOLO11X в YOLO11N для улучшения скорости вывода, не жертвуя слишком большим количеством производительности обнаружения.
Для этого я только что перегружал некоторые функции в библиотеке ультралитических средств для реализации пользовательских потерь и некоторых обратных вызовов для регистрации. Много потерь происходит от дистилляции), скорости обучения, разминки и оптимизатора, но производительность неизменно хуже < /strong>, чем наивно тонко настроенный yolo11n.
Я использую один набор данных для обучения (назовуте его D) и один для теста (B). B содержит только «новые» изображения, не присутствующие в D. В моем трубопроводе я хорошо настраиваю V11x с D и использую их в качестве учителя на D для V11N, на котором я тестирую b. Используя более умные функции потерь (CWD), он достигает полей лучшей производительности, чем моя дистилляция, но все же намного хуже, чем нативный v11n: < /p>
< /th> < /th> < /th> < /th> < /th> < /th> < /th> < /th> < /th> < /th> < /th> < /th> < /th> /> map < /th>
precision < /th>
remeply < /th>
< /tr>
< /thead>
v11x < /td>
0.543339444. />0.389167 >
0.6385566
0.542057> < /> < /tr>
Native V11n < /td>
>0.40.44444444444444444444444444444444444444444444444444444444440. />0.289082
0.599968
0.468279
< /tr>
my diallidel v11n < /td>
0.364.364.0.3646. />0.2225889 >
0.615199
0.374849
< /tr>
yolo-distiller v11n < /td>
0.399945
0.255973
0.564008
0.401043
Any ideas Почему производительность дистилляции может быть таким плохим? Я не могу придумать причину, по которой это может быть хуже, чем обучение нативного, как учитель, вызывает ученика ошибочно? Я использую большую часть конфигурации Ultralytics по умолчанию и просто изменил некоторые параметры в Yaml: < /p>
epochs: 200
imgsz: 1088
optimizer: AdamW
lr0: 0.001
lrf: 0.01
< /code>
Я сгруппировал все (кроме конфигурации) в одном файле, если необходим контекст, и перевел комментарии на английский: < /p>
import torch
import torch.nn as nn
from ultralytics import YOLO
import ultralytics
from ultralytics.utils.loss import v8DetectionLoss
from ultralytics.models.yolo.detect.train import DetectionTrainer
import yaml
import os
import torch.nn.functional as F
from ultralytics.cfg import get_cfg
import datetime
from torch.utils.tensorboard import SummaryWriter
def distillation_loss(student_raw_outputs_by_scale, teacher_raw_outputs_by_scale, temperature, num_classes, reg_max):
"""
Calculates the distillation loss (KL Divergence) between the raw outputs of the student and teacher.
This version includes distillation for classification and regression (DFL).
Args:
student_raw_outputs_by_scale (list of Tensors): List of raw outputs from the student model per scale.
Each tensor is (B, C, H, W) where C = (4*reg_max + num_classes)
teacher_raw_outputs_by_scale (list of Tensors): List of raw outputs from the teacher model per scale.
Each tensor is (B, C, H, W) where C = (4*reg_max + num_classes)
temperature (float): Temperature for softening the distributions.
num_classes (int): Number of classes.
reg_max (int): reg_max value used in the detection head.
Returns:
torch.Tensor: The total distillation loss (classification + DFL).
"""
total_kd_loss = 0.0
num_scales = len(student_raw_outputs_by_scale) # Theoretically 3 for YOLOv8
log_debug_info(f"Num scales: {num_scales}")
for i in range(num_scales):
student_scale_output = student_raw_outputs_by_scale # (B, C, H, W) = (batch size, total number of output channels,
teacher_scale_output = teacher_raw_outputs_by_scale # feature map height, feature map width)
# Reshape from (B, C, H, W) to (B, H*W, C) to group H*W pixel predictions
# C = (4 * reg_max) + num_classes: 4 coords per box, so 4*reg_max channels with reg_max prob distributions (+ num_classes cls logits)
student_reshaped = student_scale_output.view(student_scale_output.shape[0], student_scale_output.shape[1], -1).permute(0, 2, 1).contiguous()
teacher_reshaped = teacher_scale_output.view(teacher_scale_output.shape[0], teacher_scale_output.shape[1], -1).permute(0, 2, 1).contiguous()
# Extract classification logits
# In YOLOv8, the last num_classes channels/logits are for classification
student_cls_logits = student_reshaped[..., -num_classes:] # (B, H*W, num_classes)
teacher_cls_logits = teacher_reshaped[..., -num_classes:] # (B, H*W, num_classes)
log_debug_info(f"Scale {i} - Student CLS Logits: "
f"min={student_cls_logits.min().item():.4f}, "
f"max={student_cls_logits.max().item():.4f}, "
f"mean={student_cls_logits.mean().item():.4f}, "
f"std={student_cls_logits.std().item():.4f}")
log_debug_info(f"Scale {i} - Teacher CLS Logits: "
f"min={teacher_cls_logits.min().item():.4f}, "
f"max={teacher_cls_logits.max().item():.4f}, "
f"mean={teacher_cls_logits.mean().item():.4f}, "
f"std={teacher_cls_logits.std().item():.4f}")
# Calculate KL Divergence for classification
cls_kd_loss_scale = F.kl_div(
F.log_softmax(student_cls_logits / temperature, dim=-1), # student log probabilities
F.softmax(teacher_cls_logits / temperature, dim=-1), # teacher probabilities
reduction="batchmean" # Mean per batch
) * (temperature ** 2) # Important: multiply by T^2, see google paper
total_kd_loss += cls_kd_loss_scale
log_debug_info(f"KD class loss: {cls_kd_loss_scale.item():.6f}")
# Extract DFL logits
# The first 4*reg_max are for DFL (bbox regression)
student_dfl_logits = student_reshaped[..., :4 * reg_max] # (B, H*W, 4*reg_max)
teacher_dfl_logits = teacher_reshaped[..., :4 * reg_max] # (B, H*W, 4*reg_max)
# Reshape to apply softmax on each DFL distribution for each coordinate (x1, y1, x2, y2)
# (B, H*W, 4, reg_max) with 4 for left, right, top, bottom
student_dfl_reshaped = student_dfl_logits.view(student_dfl_logits.shape[0], student_dfl_logits.shape[1], 4, reg_max)
teacher_dfl_reshaped = teacher_dfl_logits.view(teacher_dfl_logits.shape[0], teacher_dfl_logits.shape[1], 4, reg_max)
log_debug_info(f"Scale {i} - Student DFL Logits: "
f"min={student_dfl_reshaped.min().item():.4f}, "
f"max={student_dfl_reshaped.max().item():.4f}, "
f"mean={student_dfl_reshaped.mean().item():.4f}, "
f"std={student_dfl_reshaped.std().item():.4f}")
log_debug_info(f"Scale {i} - Teacher DFL Logits: "
f"min={teacher_dfl_reshaped.min().item():.4f}, "
f"max={teacher_dfl_reshaped.max().item():.4f}, "
f"mean={teacher_dfl_reshaped.mean().item():.4f}, "
f"std={teacher_dfl_reshaped.std().item():.4}")
# Calculate KL divergence for DFL logits
dfl_kd_loss_scale = F.kl_div(
F.log_softmax(student_dfl_reshaped / temperature, dim=-1), # student DFL log probabilities
F.softmax(teacher_dfl_reshaped / temperature, dim=-1), # teacher DFL probabilities
reduction="batchmean" # Mean per batch
) * (temperature ** 2) # Multiply by T^2
log_debug_info(f"KD dfl loss: {dfl_kd_loss_scale.item():.6f}")
total_kd_loss += dfl_kd_loss_scale
total_kd_loss = total_kd_loss * CORRECTION_GAMMA
log_debug_info(f"Corrected distillation loss: {total_kd_loss}")
return total_kd_loss / num_scales # Average over scales (if > 1)
def on_train_epoch_end_callback(trainer):
"""
Callback called by Ultralytics at the end of each training epoch.
Logs the learning rate and training/validation losses to TensorBoard and the console.
"""
if VERBOSE and writer:
current_lr = trainer.optimizer.param_groups[0]['lr']
current_epoch = trainer.epoch
writer.add_scalar('Learning_Rate', current_lr, current_epoch)
total_train_loss = trainer.tloss.sum().item()
writer.add_scalar('Train/Hard_Loss_Total', total_train_loss, current_epoch)
# Other components in trainer.tloss
for i, name in enumerate(trainer.loss_names):
writer.add_scalar(f'Train/Hard_Loss_{name}', trainer.tloss.item(), current_epoch)
if hasattr(trainer, 'epoch_kd_loss') and trainer.batch_count > 0:
avg_kd_loss = trainer.epoch_kd_loss / trainer.batch_count
avg_combined_loss = trainer.epoch_combined_loss / trainer.batch_count
writer.add_scalar('Train/KD_Loss_Total', avg_kd_loss, current_epoch)
writer.add_scalar('Train/Combined_Loss_Total', avg_combined_loss, current_epoch)
print(f" Train KD Loss (Total): {avg_kd_loss:.4f}")
print(f" Train Combined Loss (Total): {avg_combined_loss:.4f}")
def on_train_epoch_start_callback(trainer):
if hasattr(trainer, 'epoch_kd_loss'):
trainer.epoch_kd_loss = 0.0
trainer.epoch_combined_loss = 0.0
trainer.batch_count = 0
TEACHER = 'v11x_fin_s3r.pt'
TRAINING_CONFIG_YAML = 'distillation_config.yaml'
TEMPERATURE = 10.0
ALPHA = 0.8
CORRECTION_GAMMA = 0.001 # Factor found empirically to scale down the huge KD loss
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
IMG_SIZE= 1088
DEBUG_LOG_FILE = "distillation_debug.log"
VERBOSE = True
TENSORBOARD_LOG_DIR = "runs/distillation_logs"
writer = None
if not VERBOSE:
print("LOGS DISABLED")
else:
writer = SummaryWriter(TENSORBOARD_LOG_DIR)
print(f"TensorBoard logs will be saved to: {TENSORBOARD_LOG_DIR}")
def log_debug_info(message):
"""Writes a debug message to a file."""
if not VERBOSE:
return
timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")
with open(DEBUG_LOG_FILE, "a") as f:
f.write(f"[{timestamp}] {message}\n")
class DistillationTrainer(DetectionTrainer):
def __init__(self, overrides=None, teacher_model_path=None, distillation_temp=5.0, distillation_alpha=0.5):
super().__init__(overrides) # Pass the entire YAML config in overrides
self.teacher_model = YOLO(teacher_model_path).to(self.device)
self.teacher_model.model.eval() # Set teacher model to evaluation mode
self.distillation_temp = distillation_temp
self.distillation_alpha = distillation_alpha
# self.original_ultralytics_criterion is updated in _setup_train
self.original_ultralytics_criterion = None
log_debug_info(f"DistillationTrainer init(): params: path={teacher_model_path}, temp={self.distillation_temp}, alpha={self.distillation_alpha}")
print("Teacher model loaded in DistillationTrainer.")
self.epoch_kd_loss = 0.0
self.epoch_combined_loss = 0.0
self.batch_count = 0 # Counter for averaging
def setup_model(self):
"""
Configures and loads the model.
"""
ckpt = super().setup_model()
log_debug_info("setup_model(): superclass called. self.model initialized.")
return ckpt
def _setup_train(self, world_size):
"""
Build dataloaders and optimizer on correct rank process, and set up the custom loss.
"""
log_debug_info("DistillationTrainer _setup_train(): start")
super()._setup_train(world_size) # calls setup_model, set_model_attributes, etc.
log_debug_info("DistillationTrainer _setup_train(): super()._setup_train finished")
# check on the criterion, which might not be initialized (normally it should be by the super())
if getattr(self.model, "criterion", None) is None:
# if the base model's loss is not yet called, initialize it
self.model.criterion = v8DetectionLoss(self.model)
log_debug_info(f"DistillationTrainer _setup_train(): self.model.criterion initialized to v8DetectionLoss (first pass).")
self.original_ultralytics_criterion = self.model.criterion # need to keep the original criterion
log_debug_info(f"DistillationTrainer _setup_train(): original_ultralytics_criterion saved: {type(self.original_ultralytics_criterion)}")
# Replace the original criterion
self.model.criterion = self._custom_combined_loss_fn
log_debug_info("DistillationTrainer _setup_train(): self.model.criterion redirected to _custom_combined_loss_fn.")
def _custom_combined_loss_fn(self, preds, batch):
"""
Calculates the total combined loss (hard target loss + distillation loss).
Args:
preds (list of Tensors): raw outputs (logits) of the student model per scale.
Each tensor is (B, C, H, W) where C = (4*reg_max + num_classes)
batch (dict): contains img and labels.
Returns:
torch.Tensor: Total combined loss.
torch.Tensor: Loss components for logging.
"""
log_debug_info("_custom_combined_loss_fn(): Starting loss calculation")
# Native hard target loss
hard_loss, hard_loss_items = self.original_ultralytics_criterion(preds, batch)
log_debug_info(f"Custom combined loss fn: Raw hard_loss: {hard_loss}, Type: {type(hard_loss)}, Shape: {hard_loss.shape if isinstance(hard_loss, torch.Tensor) else 'N/A'}")
log_debug_info(f"Custom combined loss fn: Raw hard_loss_items: {hard_loss_items}, Type: {type(hard_loss_items)}, Shape: {hard_loss_items.shape if isinstance(hard_loss_items, torch.Tensor) else 'N/A'}")
# Sum components to get a scalar total hard loss
total_hard_loss = hard_loss.sum()
log_debug_info(f"Custom combined loss fn: Summed hard_loss: {total_hard_loss.item()}")
with torch.no_grad():
# theoretically returns something like return y if self.export else (y, x) (see head.py, class Detect)
teacher_outputs = self.teacher_model.model(batch['img'].to(self.device, non_blocking=True))
# Extract teacher logits
if isinstance(teacher_outputs, tuple) and len(teacher_outputs) == 2:
teacher_raw_outputs_by_scale = teacher_outputs[1]
else:
log_debug_info("Custom combined loss fn: problem retrieving teacher logits")
raise ValueError("Teacher model output in an unexpected format")
student_raw_outputs_by_scale = preds # Raw student logits (list of tensors)
# num_classes + reg_max from the student model head (last layer of self.model.model)
num_classes = self.model.model[-1].nc
reg_max = self.model.model[-1].reg_max
kd_loss = distillation_loss(student_raw_outputs_by_scale, teacher_raw_outputs_by_scale, self.distillation_temp, num_classes, reg_max)
log_debug_info(f"Custom combined loss fn: KD loss: {kd_loss.item()}")
# Combine losses with (1-alpha) for hard loss and alpha for KD loss
combined_loss = (1 - self.distillation_alpha) * total_hard_loss + self.distillation_alpha * kd_loss
log_debug_info(f"Custom combined loss fn: Combined loss: {combined_loss.item()}")
# hard_loss_items = [box_loss, cls_loss, dfl_loss]
if not isinstance(hard_loss_items, torch.Tensor):
# conversion in case, but it's already a tensor according to v8DetectionLoss
hard_loss_items_tensor = torch.tensor(hard_loss_items, device=self.device)
else:
hard_loss_items_tensor = hard_loss_items
# Only return the 3 standard loss components for the ultralytics validator
# kd_loss is already part of the combined loss for backpropagation
log_items = hard_loss_items_tensor * (1 - self.distillation_alpha)
log_debug_info(f"Custom combined loss fn: cooked log items: Shape: {log_items.shape}")
self.epoch_kd_loss += kd_loss.item() # Used for logging
self.epoch_combined_loss += combined_loss.item()
self.batch_count += 1
return combined_loss, log_items # Return total loss and items for logging
log_debug_info("Starting main program")
train_args_from_yaml = get_cfg(TRAINING_CONFIG_YAML)
if os.path.exists(DEBUG_LOG_FILE): # Delete the log file on each restart
os.remove(DEBUG_LOG_FILE)
print("Previous logs deleted")
# update device
if train_args_from_yaml.device is None:
train_args_from_yaml.device = DEVICE
distillation_trainer = DistillationTrainer(
overrides=train_args_from_yaml,
teacher_model_path=TEACHER,
distillation_temp=TEMPERATURE,
distillation_alpha=ALPHA
)
if VERBOSE:
distillation_trainer.add_callback("on_train_epoch_start", on_train_epoch_start_callback)
distillation_trainer.add_callback("on_train_epoch_end", on_train_epoch_end_callback)
distillation_trainer.train()
print(f"Distillation finished")
if VERBOSE and writer:
writer.close()
< /code>
Что вы пытались и чего вы ожидали? < /strong>
-> Другой код дистилляции из Интернета: лучше, но все же плохо.>
Подробнее здесь: https://stackoverflow.com/questions/796 ... n-native-t
Дистилляция знаний YOLO (с 11x до 11N) дает более низкую производительность, чем обучение нативного ⇐ Python
-
- Похожие темы
- Ответы
- Просмотры
- Последнее сообщение