Изменение модели Vision Transformer (ViT) в формате Timm для пользовательской головы в PyTorchPython

Программы на Python
Ответить Пред. темаСлед. тема
Anonymous
 Изменение модели Vision Transformer (ViT) в формате Timm для пользовательской головы в PyTorch

Сообщение Anonymous »

Я работаю с моделью Vision Transformer (ViT), используя PyTorch и библиотеку timm. Моя цель — изменить модель ViT, чтобы заменить заголовок классификации по умолчанию пользовательским заголовком, который принимает среднее значение всех токенов и добавляет новый уровень классификации.
Сводка модели ViT по умолчанию в Тимм заканчивается так:

Код: Выделить всё

       LayerNorm-247             [-1, 197, 768]           1,536
Identity-248                  [-1, 768]               0
Dropout-249                  [-1, 768]               0
Linear-250                 [-1, 1000]         769,000
VisionTransformer-251                 [-1, 1000]               0

Чтобы удалить последние написанные мной слои:

Код: Выделить всё

class VisionTransformerWithoutHead(nn.Module):

def __init__(self, model_name):
super(VisionTransformerWithoutHead, self).__init__()

# Load the ViT model
vit_model = timm.create_model(model_name, pretrained=True)

# Remove the final layers
self.features = nn.Sequential(*list(vit_model.children())[:-1])

def forward(self, x):
# Forward pass through the modified model
output = self.features(x)
return output
Сводка теперь заканчивается:

Код: Выделить всё

       LayerNorm-247             [-1, 196, 768]           1,536
Identity-248             [-1, 196, 768]               0
Dropout-249             [-1, 196, 768]               0
Уменьшилось количество токенов со 197 до 196 и похоже удалился токен класса. Хотелось бы понять, почему это происходит и есть ли способ удалить только последние слои сохраняя при этом токен класса.

Подробнее здесь: https://stackoverflow.com/questions/779 ... in-pytorch
Реклама
Ответить Пред. темаСлед. тема

Быстрый ответ

Изменение регистра текста: 
Смайлики
:) :( :oops: :roll: :wink: :muza: :clever: :sorry: :angel: :read: *x)
Ещё смайлики…
   
К этому ответу прикреплено по крайней мере одно вложение.

Если вы не хотите добавлять вложения, оставьте поля пустыми.

Максимально разрешённый размер вложения: 15 МБ.

  • Похожие темы
    Ответы
    Просмотры
    Последнее сообщение

Вернуться в «Python»