Индекс класса Pytorch SSDLite выходит за пределы диапазона и ошибка functionnal.cross_entropy torch.nnPython

Программы на Python
Ответить Пред. темаСлед. тема
Anonymous
 Индекс класса Pytorch SSDLite выходит за пределы диапазона и ошибка functionnal.cross_entropy torch.nn

Сообщение Anonymous »

Я пытаюсь точно настроить модель SSDLite320_Mobilenet_V3_Large на пользовательском наборе данных с тремя классами, используя код pytorch для моделей видения/обнаружения. Однако я подозреваю, что существует проблема при создании цели внутри библиотеки torch или в коде, указанном в субгиде (ссылка выше).
Действительно, при достижении функции cross_entropy из файл function.py из torch.nn, я получаю

ошибку CUDA: утверждение на стороне устройства вызвало ошибки ядра CUDA .

Как упоминалось ptrblck на форуме Pytorch по этой проблеме и в документации pytorch, функция принимает цели в диапазоне [0,C), где C — количество классов.
Поэтому я попробовал две вещи:
  • Я установил ignore_index на 3, чтобы гарантировать что не будет значения, выходящего за пределы индекса передано. Код выполнился без ошибок. Однако, посмотрев на предсказанные метки и матрицу путаницы (см. ниже, не беспокойтесь о точности модели, это всего лишь набор отладочных данных с 50 эпохами) после обучения я увидел, что модель не смогла предсказать класс номер 3. Установка ignore_index на 3 не кажется решением, если это приводит к тому, что в прогнозах модели отсутствует класс.
Изображение
  • Я проверил минимальное и максимальное значения мои цели непосредственно в функции cross_entropy прямо перед тем, как они были переданы для возврата torch._C._nn.cross_entropy_loss(input, target, Weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing). Я получил 0 как минимальное и 3 как максимальное значения на всех проходах. Идентификаторы моих классов: [1,2,3]. Я подозреваю, что проблема связана с этим увеличением диапазона. Однако, не будучи знаком с внутренней работой библиотеки torch, я, кажется, не понимаю, откуда взялся этот 0 (или почему 3 не уменьшаются в 2, если сокращение есть) или как target< Создается объект /strong>.
Поскольку эта проблема может возникнуть из-за создания объектов, передаваемых функциям библиотеки Torch, вот некоторый контекст их творение.
Мое модель выглядит следующим образом:
model = torchvision.models.detection.ssdlite320_mobilenet_v3_large(num_classes=num_classes, pretrained_backbone=True, trainable_backbone_layers=0)

И часть моего кода, посвященная созданию загрузчиков данных, выглядит следующим образом:
# Data loading code
print("Loading data")

dataset, num_classes = get_dataset(is_train=True, args=args)
dataset_test, _ = get_dataset(is_train=False, args=args)
dataset_val, _ = get_dataset(is_train=False, args=args)

print("Creating data loaders")
if args.distributed:
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test, shuffle=False)
val_sampler = torch.utils.data.distributed.DistributedSampler(dataset_val, shuffle=False)
else:
train_sampler = torch.utils.data.RandomSampler(dataset)
test_sampler = torch.utils.data.SequentialSampler(dataset_test)
test_val = torch.utils.data.SequentialSampler(dataset_val)

if args.aspect_ratio_group_factor >= 0:
group_ids = create_aspect_ratio_groups(dataset, k=args.aspect_ratio_group_factor)
train_batch_sampler = GroupedBatchSampler(train_sampler, group_ids, args.batch_size)
else:
train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, args.batch_size, drop_last=True)

train_collate_fn = utils.collate_fn
if args.use_copypaste:
if args.data_augmentation != "lsj":
raise RuntimeError("SimpleCopyPaste algorithm currently only supports the 'lsj' data augmentation policies")

train_collate_fn = copypaste_collate_fn

data_loader = torch.utils.data.DataLoader(
dataset, batch_sampler=train_batch_sampler, num_workers=args.workers, collate_fn=train_collate_fn
)

data_loader_test = torch.utils.data.DataLoader(
dataset_test, batch_size=1, sampler=test_sampler, num_workers=args.workers, collate_fn=utils.collate_fn
)

data_loader_val = torch.utils.data.DataLoader(
dataset_val, batch_size=1, sampler=val_sampler, num_workers=args.workers, collate_fn=utils.collate_fn
)


Подробнее здесь: https://stackoverflow.com/questions/782 ... ionnal-cro
Реклама
Ответить Пред. темаСлед. тема

Быстрый ответ

Изменение регистра текста: 
Смайлики
:) :( :oops: :roll: :wink: :muza: :clever: :sorry: :angel: :read: *x)
Ещё смайлики…
   
К этому ответу прикреплено по крайней мере одно вложение.

Если вы не хотите добавлять вложения, оставьте поля пустыми.

Максимально разрешённый размер вложения: 15 МБ.

  • Похожие темы
    Ответы
    Просмотры
    Последнее сообщение

Вернуться в «Python»