Как правильно извлечь токен CLS из магистрали Keras Hub ViT и уточнить использование препроцессора и набор данных для прPython

Программы на Python
Ответить
Anonymous
 Как правильно извлечь токен CLS из магистрали Keras Hub ViT и уточнить использование препроцессора и набор данных для пр

Сообщение Anonymous »

Я работаю с магистралью Vision Transformer (ViT) из Keras Hub и создаю собственную классификационную головку. Мой код выглядит так:
python

Код: Выделить всё

def get_vit_model(model_variant='vit_base',
input_shape=(256, 256, 3),
num_classes=3,
train_base_model=True):

preset_path = "/home/ahmed/ct_brain_project/models"

back_bone = keras_hub.models.Backbone.from_preset(preset_path)
back_bone.trainable = train_base_model

inputs = layers.Input(shape=input_shape, name='input_layer')
features = back_bone(inputs, training=train_base_model)

# Extract CLS token
cls_token = features[:, 0, :]  # (batch, embed_dim)

x = layers.Dense(128, use_bias=False)(cls_token)

# rest of code of the classification head

model = Model(inputs=inputs, outputs=outputs)
return model
Из скачанного мной конфига (

Код: Выделить всё

vit_base_patch16_224_imagenet
из Kaggle), я вижу:
json
"class_name": "ViTBackbone",

"config": {

"use_class_token": true,

"image_shape": [224, 224, 3],

"patch_size": [16, 16],

"num_layers": 12,

"num_heads": 12,

"hidden_dim": 768,

"mlp_dim": 3072


Итак, мои вопросы:
1- CLS извлечение токена: являются ли функции [:, 0, :] правильным способом извлечения встраивания токена CLS из выходных данных магистральной сети? я просмотрел класс ViTPatchingAndEmbedding и вижу это

Код: Выделить всё

            patch_embeddings = ops.concatenate(
[class_token, patch_embeddings], axis=1
)
Но я не уверен, используется ли этот класс в загруженной мной магистрали или нет.
2- Препроцессор: поскольку я использую keras_hub.models.Backbone.from_preset(...) напрямую, правильно ли я говорю, что ViTImageClassifierPreprocessor не применяется?
3- Набор данных для предварительного обучения: загруженный мной пресет имеет имя vit_base_patch16_224_imagenet. Это предварительно обучено на ImageNet‑1k или ImageNet‑21k? (Я знаю, что у Hugging Face есть google/vit-base-patch16-224, который предварительно обучен на 21 КБ, а затем точно настроен на 1 КБ, но я хочу подтвердить версию Keras Hub/Kaggle.)
Заранее спасибо
Ответить

Быстрый ответ

Изменение регистра текста: 
Смайлики
:) :( :oops: :roll: :wink: :muza: :clever: :sorry: :angel: :read: *x)
Ещё смайлики…
   
К этому ответу прикреплено по крайней мере одно вложение.

Если вы не хотите добавлять вложения, оставьте поля пустыми.

Максимально разрешённый размер вложения: 15 МБ.

Вернуться в «Python»