Точная настройка модели с API Trainer | TypeError: объект типа «неэтип» не имеет len ()Python

Программы на Python
Ответить Пред. темаСлед. тема
Anonymous
 Точная настройка модели с API Trainer | TypeError: объект типа «неэтип» не имеет len ()

Сообщение Anonymous »

Я использую API Trainer Trainer Hagging Face. Когда я запускаю Trainer.train (), я получаю следующую ошибку:

plish_strong> video_dataset.pype.py.py-pily.py.

Код: Выделить всё

import os
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
import glob

class VideoFrameDataset(Dataset):
def __init__(self, root_dir, num_frames=16, transform=None):
self.root_dir = root_dir
self.num_frames = num_frames
self.transform = transform
self.video_dirs = []
self.labels = []

for label, subdir in enumerate(['real', 'fake']):
subdir_path = os.path.join(root_dir, subdir)
video_dirs = [
os.path.join(subdir_path, video)
for video in os.listdir(subdir_path)
if os.path.isdir(os.path.join(subdir_path, video))
]
self.video_dirs.extend(video_dirs)
self.labels.extend([label] * len(video_dirs))

def __len__(self):
return len(self.video_dirs)

def __getitem__(self, idx):
video_dir = self.video_dirs[idx]
label = self.labels[idx]

frame_paths = sorted(glob.glob(os.path.join(video_dir, '*.png')))[:self.num_frames]
if len(frame_paths) < self.num_frames:
raise ValueError(f"Video {video_dir} has fewer than {self.num_frames} frames.")

frames = []
for frame_path in frame_paths:
frame = Image.open(frame_path)
if self.transform:
frame = self.transform(frame)
frames.append(frame)

frames_tensor = torch.stack(frames)  # Shape: [num_frames, channels, height, width]
# return frames_tensor, label
# return {
#     "pixel_values": frames_tensor,
#     "labels": label
# }

return {
"pixel_values": frames_tensor,
"labels": torch.tensor(label, dtype=torch.long)
}

# for timesformer
mean=[0.45, 0.45, 0.45]
std=[0.225, 0.225, 0.225]

transform = transforms.Compose([
# transforms.Resize((224, 224)),         # Frames are already resized
transforms.ToTensor(),
transforms.Normalize(
mean=mean,
std=std
)
])

dataset_root = r'/home/oper/Desktop/Dataset_Thumbnail'
train_dataset = VideoFrameDataset(os.path.join(dataset_root, 'train'), num_frames=16, transform=transform)
val_dataset = VideoFrameDataset(os.path.join(dataset_root, 'val'), num_frames=16, transform=transform)
test_dataset = VideoFrameDataset(os.path.join(dataset_root, 'test'), num_frames=16, transform=transform)
main.py:

Код: Выделить всё

import sys, os, json, random
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'

import numpy as np
import torch
import torch.distributed as dist
from utils.save_model import save_pretrained_model
from utils.load_model import load_model
from utils.save_finetuned_model import save_finetuned_model
from utils.load_finetuned_model import load_finetuned_model
from video_dataset import train_dataset, val_dataset
from train_model import train_model

def seed_everything(seed=42):
random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

if __name__ == "__main__":
print("Options:")
print("  1 - Download the pretrained model from transformers library")
print("  2 - Fine-tune the model on custom dataset")

choice = input("Enter your choice (1, 2): ")
if choice not in ["1", "2"]:
print(f"Invalid command: {choice}.  Please enter 1, 2")
sys.exit(1)

if choice == "1":
save_pretrained_model(
model_name="facebook/timesformer-base-finetuned-k600",
save_directory="./model"
)

elif choice == "2":
seed_everything()

base_model = load_model("./model")
if not base_model:
print("Unable to load the model.")
sys.exit()

for param in base_model.parameters():
param.requires_grad = False

for i in [11]:
for param in base_model.timesformer.encoder.layer[i].parameters():
param.requires_grad = True
base_model.timesformer.layernorm.weight.requires_grad = True
base_model.timesformer.layernorm.bias.requires_grad = True
base_model.classifier.weight.requires_grad = True
base_model.classifier.bias.requires_grad = True

trained_model = train_model(base_model, train_dataset, val_dataset)

save_finetuned_model(trained_model, "./weights")
print("Training complete. Model saved.")
else:
if dist.get_rank() == 0:
print(f"Invalid command: {choice}.  Please enter 1, 2")
sys.exit(1)
train_model.py:

Код: Выделить всё

import torch
from torch.utils.data.dataloader import default_collate
import os, sys
import numpy as np
import wandb
from transformers import Trainer, TrainingArguments, TrainerCallback, default_data_collator
from sklearn.metrics import accuracy_score, roc_auc_score, precision_recall_fscore_support

os.environ["WANDB_PROJECT"] = "deepfake-detection"

class MetricsLoggerCallback(TrainerCallback):
def on_epoch_end(self, args, state, control, **kwargs):
train_logs = [x for x in state.log_history if "loss" in x and "eval_loss" not in x]
loss = train_logs[-1]["loss"]
print(f"Epoch {int(state.epoch)}/{args.num_train_epochs} — Train Loss: {loss:.4f}")

def on_evaluate(self, args, state, control, metrics, **kwargs):
print(f"Epoch {int(metrics.get('epoch', state.epoch))} — "
f"Eval Loss: {metrics['eval_loss']:.4f} | "
f"Accuracy: {metrics['eval_accuracy']:.4f} | "
f"Precision: {metrics['eval_precision']:.4f} | "
f"Recall: {metrics['eval_recall']:.4f} | "
f"F1: {metrics['eval_f1']:.4f} | "
f"AUC: {metrics['eval_auc']:.4f}")

def video_collate_fn(batch):
if len(batch) == 0:
return None

for item in batch:
if not isinstance(item["labels"], torch.Tensor):
item["labels"] = torch.tensor(item["labels"], dtype=torch.long)

pixel_values = torch.stack([item["pixel_values"] for item in batch])
labels = torch.stack([item["labels"] for item in batch])

return {
"pixel_values": pixel_values,
"labels": labels
}

def compute_metrics(eval_pred):
logits, labels = eval_pred
preds = np.argmax(logits, axis=-1)
precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='binary')
return {
"accuracy": accuracy_score(labels, preds),
"auc": roc_auc_score(labels, logits[:, 1]),
'f1': f1,
'precision': precision,
'recall': recall
}

def train_model(model, train_dataset, val_dataset, num_epochs=5, warmup_epochs=1):
total_steps = num_epochs * len(train_dataset)
warmup_steps = warmup_epochs * len(train_dataset)

training_args = TrainingArguments(
output_dir="./results",
overwrite_output_dir=True,
per_device_train_batch_size=8,
per_device_eval_batch_size=8,
optim="adamw_torch",
learning_rate=1.5e-5,
weight_decay=0.01,
max_grad_norm=1.0,
gradient_accumulation_steps=1,
label_smoothing_factor=0.1,
num_train_epochs=num_epochs,
lr_scheduler_type="cosine",
warmup_steps=warmup_steps,
evaluation_strategy="epoch",
save_strategy="epoch",
logging_dir='./logs',
logging_strategy="epoch",
load_best_model_at_end=True,
metric_for_best_model="accuracy",
dataloader_drop_last=True,
dataloader_num_workers=0,
disable_tqdm=False,
report_to='wandb',
run_name="TALL-TimeSformer"
)

trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=val_dataset,
data_collator=video_collate_fn,
compute_metrics=compute_metrics,
callbacks=[MetricsLoggerCallback()]
)

trainer.train()

model.save_pretrained('./final_model')
return model
Вопрос
Что может вызвать эту ошибку, и как ее исправить? Ошибка, по -видимому, возникает в классе тренера, когда она пытается вызвать Len () на объекте входов, который каким -то образом не является. Поскольку моя модель отлично работает с теми же данными, когда я теста я тестает их напрямую, я подозреваю, что это проблема с тем, как тренер обрабатывает данные.>

Подробнее здесь: https://stackoverflow.com/questions/795 ... e-nonetype
Реклама
Ответить Пред. темаСлед. тема

Быстрый ответ

Изменение регистра текста: 
Смайлики
:) :( :oops: :roll: :wink: :muza: :clever: :sorry: :angel: :read: *x)
Ещё смайлики…
   
К этому ответу прикреплено по крайней мере одно вложение.

Если вы не хотите добавлять вложения, оставьте поля пустыми.

Максимально разрешённый размер вложения: 15 МБ.

  • Похожие темы
    Ответы
    Просмотры
    Последнее сообщение

Вернуться в «Python»