Попытка внедрить CAT-SEG: агрегация затрат на семантическую сегментацию открытого вокалуараPython

Программы на Python
Ответить Пред. темаСлед. тема
Anonymous
 Попытка внедрить CAT-SEG: агрегация затрат на семантическую сегментацию открытого вокалуара

Сообщение Anonymous »

Итак, я пытаюсь реализовать модифицированную версию вышеупомянутой бумаги в ноутбуке Kaggle, используя набор данных ADE-20K для обучения и проверки. После тестирования его на примере ввода я получаю несколько хорошую сегментацию. Однако имена классов показаны как класс 1, класс 2 и т. Д. Вместо их реальных имен. Я не могу исправить это < /p>
В файле набора данных есть текстовые файлы с названием ObjectInfo.txt и SceneCategories.txt. В моем коде я загружаю их так. Поскольку я не могу раскрыть свой обновленный код, я поделюсь своей версией упрощенного кода CAT-SEG, который имеет такую же вышеупомянутую проблему. < /P>
# Enhanced CAT-Seg (Kaggle-Friendly with Full Embedding Guidance + Intermediate Feature Fusion)
!pip install transformers==4.30.0 ftfy timm matplotlib
!pip install 'git+https://github.com/facebookresearch/detectron2.git@v0.6'

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import CLIPModel, CLIPProcessor, CLIPVisionModel
from PIL import Image
import numpy as np
import os
from tqdm import tqdm
import timm
import matplotlib.pyplot as plt
from timm.models.swin_transformer import SwinTransformerBlock
from matplotlib.patches import Patch
from collections import defaultdict

class Config:
clip_model = "openai/clip-vit-base-patch16"
num_classes = 150 # ADE20K
cost_embed_dim = 128
image_size = (224, 224)
batch_size = 4
lr_backbone = 2e-6
lr_head = 2e-4
epochs = 20
num_tta_scales = 3

config = Config()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def load_ade20k_classes(data_root):
"""Load class metadata with robust parsing from the dataset's objectInfo150.txt file."""
classes = []
class_file = os.path.join(data_root, "annotations", "objectInfo150.txt")
try:
with open(class_file, 'r', encoding='utf-8') as f:
# Skip header line (Idx Ratio Train Val Name)
next(f)
for line in f:
parts = line.strip().split('\t')
if len(parts) >= 5:
class_name = parts[4].strip().split(',')[0].strip()
classes.append(class_name)
print(f"Successfully loaded {len(classes)} classes from metadata.")
except Exception as e:
print(f"Error loading classes from metadata: {e}. Using a default list.")
classes = [f"class_{i}" for i in range(150)]
return classes

ADE20K_CLASSES = load_ade20k_classes("/kaggle/input/ade20k-dataset/ADEChallengeData2016")
config.num_classes = len(ADE20K_CLASSES)

class DenseCLIP(nn.Module):
def __init__(self):
super().__init__()
self.clip = CLIPVisionModel.from_pretrained(config.clip_model, output_hidden_states=True)
for name, param in self.clip.named_parameters():
if "vision_model.encoder.layers" in name and ("q_proj" in name or "v_proj" in name):
param.requires_grad = True
else:
param.requires_grad = False
self.proj = nn.Conv2d(768, 512, 1)

def forward(self, x):
outputs = self.clip(pixel_values=x)
last_feat = outputs.last_hidden_state[:, 1:]
B, L, C = last_feat.shape
H = W = int(L**0.5)
img_feat = self.proj(last_feat.permute(0, 2, 1).view(B, -1, H, W))
mid_feat = outputs.hidden_states[6][:, 1:].permute(0, 2, 1).view(B, -1, H, W)
return img_feat, mid_feat

class ClassTransformer(nn.Module):
def __init__(self, dim, num_heads):
super().__init__()
self.attn = nn.TransformerEncoderLayer(d_model=dim, nhead=num_heads, batch_first=True)

def forward(self, x):
B, C, H, W = x.shape
x = x.view(B, C, -1).permute(0, 2, 1)
x = self.attn(x)
x = x.permute(0, 2, 1).view(B, C, H, W)
return x

def get_text_embeddings():
processor = CLIPProcessor.from_pretrained(config.clip_model)
prompt_templates = [
"a photo of a {}",
"an image of a {}"
]

all_embeddings = []
model = CLIPModel.from_pretrained(config.clip_model).to(device)
model.eval()

for class_name in tqdm(ADE20K_CLASSES, desc="Generating text embeddings"):
prompts = [template.format(class_name) for template in prompt_templates]
inputs = processor(text=prompts, return_tensors="pt", padding=True).to(device)
with torch.no_grad():
embeddings = model.get_text_features(**inputs)
all_embeddings.append(embeddings.mean(dim=0))

return torch.stack(all_embeddings)

class CATSeg(nn.Module):
def __init__(self):
super().__init__()
self.clip = DenseCLIP()
self.cost_proj = nn.Conv2d(config.num_classes, config.cost_embed_dim, 1)

self.guidance_v = nn.Linear(512, config.cost_embed_dim)
self.guidance_t = nn.Linear(512, config.cost_embed_dim)

self.spatial_agg = nn.Sequential(
SwinTransformerBlock(dim=config.cost_embed_dim, input_resolution=(14,14), num_heads=4, window_size=7),
SwinTransformerBlock(dim=config.cost_embed_dim, input_resolution=(14,14), num_heads=4, window_size=7, shift_size=3)
)
self.class_agg = nn.Sequential(
ClassTransformer(config.cost_embed_dim, num_heads=4),
ClassTransformer(config.cost_embed_dim, num_heads=4)
)

self.decoder = nn.Sequential(
nn.Conv2d(config.cost_embed_dim + 768, 256, 3, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.Upsample(scale_factor=2, mode='bilinear'),
nn.Conv2d(256, 128, 3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Upsample(scale_factor=2, mode='bilinear'),
nn.Conv2d(128, config.num_classes, 1)
)

def forward_once(self, x, text_embeds):
(img_feats, mid_feats) = self.clip(x)
img_feats = F.normalize(img_feats, dim=1)
text_embeds = F.normalize(text_embeds, dim=1)

cost = torch.einsum('bchw,nc->bnhw', img_feats, text_embeds)
cost_embed = self.cost_proj(cost)

g_img = self.guidance_v(img_feats.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
B, _, H, W = img_feats.shape
g_txt_proj = self.guidance_t(text_embeds)
g_txt_avg = g_txt_proj.mean(dim=0).view(1, -1, 1, 1)
g_txt = g_txt_avg.expand(B, -1, H, W)

x = cost_embed + g_img + g_txt
x = self.spatial_agg(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
x = self.class_agg(x)
x = torch.cat([x, mid_feats], dim=1)
return self.decoder(x)

def forward(self, x, text_embeds, tta=False):
if not tta:
return self.forward_once(x, text_embeds)

H, W = config.image_size
scales = [0.75, 1.0, 1.25][:config.num_tta_scales]
preds = []
for s in scales:
new_size = (int(H * s), int(W * s))
scaled = F.interpolate(x, size=new_size, mode='bilinear', align_corners=False)
scaled = F.interpolate(scaled, size=config.image_size, mode='bilinear', align_corners=False)
out1 = self.forward_once(scaled, text_embeds)

flipped = torch.flip(scaled, dims=[3])
out2 = torch.flip(self.forward_once(flipped, text_embeds), dims=[3])
out_avg = (out1 + out2) / 2.0
preds.append(out_avg)

return torch.mean(torch.stack(preds, dim=0), dim=0)

class ADE20KDataset(Dataset):
def __init__(self, image_dir, annotation_dir):
self.image_dir = image_dir
self.annotation_dir = annotation_dir
self.files = sorted([f for f in os.listdir(image_dir) if f.endswith('.jpg')])

def __len__(self):
return len(self.files)

def __getitem__(self, idx):
img_path = os.path.join(self.image_dir, self.files[idx])
ann_path = os.path.join(self.annotation_dir, self.files[idx].replace('.jpg', '.png'))
image = Image.open(img_path).convert("RGB").resize(config.image_size)
image = np.array(image) / 255.0
image = (image - [0.48145466, 0.4578275, 0.40821073]) / [0.26862954, 0.26130258, 0.27577711]
image = torch.FloatTensor(image).permute(2, 0, 1)
mask = Image.open(ann_path).resize((56, 56), Image.NEAREST)
mask = torch.LongTensor(np.array(mask))
return image, torch.clamp(mask - 1, 0, config.num_classes-1) # Correcting mask to 0-indexed

def calculate_miou(conf_matrix):
iou = torch.diag(conf_matrix) / (conf_matrix.sum(0) + conf_matrix.sum(1) - torch.diag(conf_matrix) + 1e-10)
return iou.nanmean().item()

def train_and_validate():
model = CATSeg().to(device)
text_embeddings = get_text_embeddings()
optimizer = torch.optim.AdamW([
{'params': model.clip.parameters(), 'lr': config.lr_backbone},
{'params': [p for n,p in model.named_parameters() if 'clip' not in n], 'lr': config.lr_head}
])

train_dataset = ADE20KDataset("/kaggle/input/ade20k-dataset/ADEChallengeData2016/images/training", "/kaggle/input/ade20k-dataset/ADEChallengeData2016/annotations/training")
val_dataset = ADE20KDataset("/kaggle/input/ade20k-dataset/ADEChallengeData2016/images/validation", "/kaggle/input/ade20k-dataset/ADEChallengeData2016/annotations/validation")
train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=config.batch_size)

best_miou = 0.0
train_mious, val_mious = [], []

for epoch in range(config.epochs):
model.train()
train_conf_matrix = torch.zeros(config.num_classes, config.num_classes)
for images, masks in tqdm(train_loader, desc=f"Train Epoch {epoch+1}"):
images, masks = images.to(device), masks.to(device)
optimizer.zero_grad()
outputs = model(images, text_embeddings)
loss = F.cross_entropy(outputs, masks)
loss.backward()
optimizer.step()

with torch.no_grad():
preds = outputs.argmax(1).cpu()
for pred, mask in zip(preds, masks.cpu()):
train_conf_matrix += torch.bincount(
mask.flatten() * config.num_classes + pred.flatten(),
minlength=config.num_classes**2
).reshape(config.num_classes, config.num_classes)

model.eval()
val_conf_matrix = torch.zeros(config.num_classes, config.num_classes)
with torch.no_grad():
for images, masks in tqdm(val_loader, desc="Validating"):
outputs = model(images.to(device), text_embeddings)
preds = outputs.argmax(1).cpu()
for pred, mask in zip(preds, masks):
val_conf_matrix += torch.bincount(
mask.flatten() * config.num_classes + pred.flatten(),
minlength=config.num_classes**2
).reshape(config.num_classes, config.num_classes)

current_train_miou = calculate_miou(train_conf_matrix)
current_val_miou = calculate_miou(val_conf_matrix)
train_mious.append(current_train_miou)
val_mious.append(current_val_miou)

print(f"Epoch {epoch+1}: Train mIoU={current_train_miou:.4f}, Val mIoU={current_val_miou:.4f}")

# Save best model
if current_val_miou > best_miou:
best_miou = current_val_miou
torch.save({
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'epoch': epoch,
'train_mious': train_mious,
'val_mious': val_mious,
'best_miou': best_miou,
'text_embeddings': text_embeddings,
'classes': ADE20K_CLASSES
}, 'best_model.pth')
print(f"New best model saved with Val mIoU: {best_miou:.4f}")

return model, train_mious, val_mious

# 🔍 Example Inference + Visualization Block

def visualize_prediction(model_path, sample_idx=0):
"""Visualize predictions with proper class names and legends."""
ckpt = torch.load(model_path)
model = CATSeg().to(device)
model.load_state_dict(ckpt['model_state_dict'])
model.eval()

val_dataset = ADE20KDataset(
"/kaggle/input/ade20k-dataset/ADEChallengeData2016/images/validation",
"/kaggle/input/ade20k-dataset/ADEChallengeData2016/annotations/validation"
)
image_tensor, mask_tensor = val_dataset[sample_idx]

with torch.no_grad():
image_tensor = image_tensor.unsqueeze(0).to(device)
pred_logits = model(image_tensor, ckpt['text_embeddings'].to(device), tta=True)
pred_mask = pred_logits.argmax(dim=1)[0].cpu().numpy()

# Denormalize image for display
mean = np.array([0.48145466, 0.4578275, 0.40821073])
std = np.array([0.26862954, 0.26130258, 0.27577711])
img_display = image_tensor[0].permute(1, 2, 0).cpu().numpy()
img_display = img_display * std + mean
img_display = np.clip(img_display, 0, 1)

# Prepare for visualization
plt.figure(figsize=(18, 6))
cmap = plt.get_cmap('tab20', config.num_classes)

# Plot 1: Original Image
plt.subplot(1, 3, 1)
plt.imshow(img_display)
plt.title("Input Image")
plt.axis("off")

# Plot 2: Ground Truth with Legend
plt.subplot(1, 3, 2)
unique_gt_classes = np.unique(mask_tensor.numpy())
overlay_gt = np.zeros((*mask_tensor.shape, 4))
for class_id in unique_gt_classes:
if 0

Подробнее здесь: https://stackoverflow.com/questions/797 ... ntic-segme
Реклама
Ответить Пред. темаСлед. тема

Быстрый ответ

Изменение регистра текста: 
Смайлики
:) :( :oops: :roll: :wink: :muza: :clever: :sorry: :angel: :read: *x)
Ещё смайлики…
   
К этому ответу прикреплено по крайней мере одно вложение.

Если вы не хотите добавлять вложения, оставьте поля пустыми.

Максимально разрешённый размер вложения: 15 МБ.

  • Похожие темы
    Ответы
    Просмотры
    Последнее сообщение

Вернуться в «Python»