Я пытаюсь обучить модель OCR для проекта lpr, используя набор данных из более чем 20 тысяч изображений, и это не работает.
Получил эти файлы: train.py:
Все, что я запускаю в обучении, начинается со значения потерь, которое медленно уменьшается, а затем внезапно потери начинают увеличиваться, пока не станут NaN, и все, оно останется NaN. >
Кто-нибудь знает, как это исправить?
Я пытаюсь обучить модель OCR для проекта lpr, используя набор данных из более чем 20 тысяч изображений, и это не работает. Получил эти файлы: [b]train.py[/b]: [code]import os import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader from dataset import get_dataloaders, idx_to_char from model import CRNN import editdistance
# Function to calculate Character Error Rate (CER) def calculate_cer(predictions, ground_truths): total_distance = 0 total_chars = 0 for pred, gt in zip(predictions, ground_truths): distance = editdistance.eval(pred, gt) total_distance += distance total_chars += len(gt) return total_distance / total_chars if total_chars > 0 else 0
# Function to calculate Word Error Rate (WER) def calculate_wer(predictions, ground_truths): total_distance = 0 total_words = 0 for pred, gt in zip(predictions, ground_truths): pred_words = pred.split() gt_words = gt.split() distance = editdistance.eval(pred_words, gt_words) total_distance += distance total_words += len(gt_words) return total_distance / total_words if total_words > 0 else 0
# Function to decode model predictions def decode_predictions(outputs): _, max_indices = outputs.softmax(2).max(2) predictions = [] for seq in max_indices.cpu().numpy(): decoded = [] prev_char = None for idx in seq: if idx != prev_char and idx != 0: # Skip blanks and consecutive repeats decoded.append(idx_to_char[idx]) prev_char = idx predictions.append("".join(decoded)) return predictions
# Function to decode ground truth labels def decode_labels(labels): return ["".join(idx_to_char[idx.item()] for idx in label if idx != 0) for label in labels]
# Function to evaluate the model def evaluate_ocr(model, dataloader, criterion): model.eval() predictions = [] ground_truths = [] total_loss = 0
with torch.no_grad(): for images, labels, label_lens in dataloader: images, labels, label_lens = images.to(device), labels.to(device), label_lens.to(device) outputs = model(images) input_lens = torch.tensor([outputs.size(1)] * labels.size(0), dtype=torch.long).to(device) # Calculate loss loss = criterion(outputs.log_softmax(2).permute(1, 0, 2), labels, input_lens, label_lens) total_loss += loss.item()
avg_loss = total_loss / len(dataloader) cer = calculate_cer(predictions, ground_truths) wer = calculate_wer(predictions, ground_truths) print(f"Validation - Loss: {avg_loss:.4f}, CER: {cer:.4f}, WER: {wer:.4f}") for i, (pred, gt) in enumerate(zip(predictions[:3], ground_truths[:3])): print(f"Sample {i + 1} Prediction: {pred}, Ground Truth: {gt}") return avg_loss, cer, wer
# Load the dataset train_loader, test_loader = get_dataloaders("../data/lpr.csv", "../data/cropped_lps", batch_size=16)
# Initialize CRNN model device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = CRNN(num_classes=36).to(device)
def initialize_weights(module): if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear): nn.init.xavier_uniform_(module.weight) if module.bias is not None: nn.init.zeros_(module.bias)
# Training loop best_cer = float('inf') for epoch in range(50): print(f"\U0001F504 Epoch {epoch+1} started") # ✅ Print before each epoch model.train() total_loss = 0
for batch_idx, (images, labels, label_lens) in enumerate(train_loader): images, labels, label_lens = images.to(device), labels.to(device), label_lens.to(device) optimizer.zero_grad() outputs = model(images)
# Calculate the sequence length for each batch based on the output tensor batch_size = outputs.size(0) sequence_lengths = outputs.size(1) # Assuming this is the sequence length dimension
loss = criterion(outputs.log_softmax(2).permute(1, 0, 2), labels, input_lengths, label_lens) if not torch.isnan(loss): total_loss += loss.item() loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5) optimizer.step() else: print(f"Skipping step {batch_idx}, NaN loss detected")
avg_train_loss = total_loss / len(train_loader) print(f"\u2705 Epoch {epoch+1} finished. Avg Training Loss = {avg_train_loss:.4f}")
# Evaluate model val_loss, cer, wer = evaluate_ocr(model, test_loader, criterion)
# Save the best model based on CER if cer < best_cer: best_cer = cer torch.save(model.state_dict(), "../models/best.pth") print(f"\U0001F4BE Best model saved with CER: {best_cer:.4f}")
# Step the learning rate scheduler scheduler.step(val_loss)
# Save final model torch.save(model.state_dict(), "../models/final.pth") print(f"\u2705 Final model saved as models/final.pth")
[/code] [b]model.py:[/b] [code]import torch.nn as nn import torch.nn.functional as F
class CRNN(nn.Module): def __init__(self, num_classes=37): # Including the CTC blank super(CRNN, self).__init__()
# Fully connected output layer self.fc = nn.Linear(256, num_classes) # Assuming the RNN is not bidirectional
def forward(self, x): x = self.cnn(x) batch_size, channels, height, width = x.shape x = x.permute(0, 3, 1, 2).contiguous() x = x.view(batch_size, width, -1) x, _ = self.rnn(x) x = self.fc(x) return x
[/code] [b]dataset.py:[/b] [code]import os import cv2 import torch import pandas as pd from torch.utils.data import Dataset, DataLoader import torchvision.transforms as transforms from PIL import Image
# Define character set for encoding CHARS = "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" char_to_idx = {c: i + 1 for i, c in enumerate(CHARS)} # Index starts at 1 (0 for blank in CTC) idx_to_char = {i: c for c, i in char_to_idx.items()}
# Transformations for images transform = transforms.Compose([ transforms.Resize((32, 128)), transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)) # Assuming images are in grayscale ])
class LicensePlateDataset(Dataset): def __init__(self, csv_path, img_folder, transform=None): self.data = [] df = pd.read_csv(csv_path) for _, row in df.iterrows(): img_path = os.path.join(img_folder, row['images']) label = row['labels'] if os.path.exists(img_path) and all(c in char_to_idx for c in label): self.data.append((img_path, label)) else: print(f"Skipped {img_path} due to an error.") self.transform = transform
def __len__(self): return len(self.data)
def __getitem__(self, idx): img_path, label = self.data[idx] img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE) img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) img = Image.fromarray(img) if self.transform: img = self.transform(img) label_indices = [char_to_idx[c] for c in label] return img, torch.tensor(label_indices, dtype=torch.long), len(label_indices)
return train_loader, test_loader [/code] В моем наборе данных есть изображения номерных знаков, CSV-файл, построенный следующим образом: [code]images,labels 1.jpg,RZ0047 10.jpg,V95246 100.jpg,6258TU 10000.jpg,B88082 10001.jpg,7065UK 10002.jpg,0195SM 10003.jpg,D01223 10005.jpg,V58318 [/code] ... и разделил его на train_data.txt и test_data.txt, которые выглядят так: [code]C:\Users\tnayd\OneDrive\Desktop\Projects\lpr_final_project\data\cropped_lps\25541.jpg 5182HH 6 C:\Users\tnayd\OneDrive\Desktop\Projects\lpr_final_project\data\cropped_lps\2020-16-07_1326-38_67_350.png LIM898 6 C:\Users\tnayd\OneDrive\Desktop\Projects\lpr_final_project\data\cropped_lps\32641.jpg S81237 6 C:\Users\tnayd\OneDrive\Desktop\Projects\lpr_final_project\data\cropped_lps\15097.jpg N51816 6 C:\Users\tnayd\OneDrive\Desktop\Projects\lpr_final_project\data\cropped_lps\32250.jpg N46628 6 C:\Users\tnayd\OneDrive\Desktop\Projects\lpr_final_project\data\cropped_lps\2020-16-07_1341-16_20_651.png PLB799 6 C:\Users\tnayd\OneDrive\Desktop\Projects\lpr_final_project\data\cropped_lps\6504.jpg C67233 6 C:\Users\tnayd\OneDrive\Desktop\Projects\lpr_final_project\data\cropped_lps\22983.jpg DX5930 6 [/code] Все, что я запускаю в обучении, начинается со значения потерь, которое медленно уменьшается, а затем внезапно потери начинают увеличиваться, пока не станут NaN, и все, оно останется NaN. > Кто-нибудь знает, как это исправить?