Как визуализировать векторное поле и деформацию сетки при применении Grid_samplePython

Программы на Python
Ответить Пред. темаСлед. тема
Anonymous
 Как визуализировать векторное поле и деформацию сетки при применении Grid_sample

Сообщение Anonymous »

Я пытаюсь визуализировать деформацию сетки при применении Grid_sample. Например, мое исходное изображение представляет собой простое квадратное изображение с полем смещения, которое деформирует квадратное изображение в круглое. Я хочу визуализировать деформацию сетки с помощью векторного поля того, как перемещается каждый пиксель. Я могу сделать это, создав изображение сетки и применив Grid_sample к изображению сетки. Однако я не уверен, как воспроизвести это и показать векторное поле с помощью LineCollection и quiver с помощью matplotlib. Не могли бы вы помочь мне с этим? Спасибо!

Код: Выделить всё

import numpy as np
import cv2
import torch

from torch.nn.functional import grid_sample

class SimpleReg(torch.nn.Module):
"""
a simple cnn model that take coordinates map as input and output the registration result
such that when apply the grid_sample function, the input image will be transformed to the target image
input: nx 2 x h x w
output: nx 2 x h x w
"""

def __init__(self):
super(SimpleReg, self).__init__()
self.conv = torch.nn.Sequential(
torch.nn.Conv2d(2, 16, 3, 1, 1),
torch.nn.ReLU(),
torch.nn.Conv2d(16, 32, 3, 1, 1),

torch.nn.ReLU(),
torch.nn.Conv2d(32, 64, 3, 1, 1),
torch.nn.ReLU(),
torch.nn.Conv2d(64, 2, 3, 1, 1),
)

def forward(self, x):
return self.conv(x)

def gen_square(image_size: int, fraction: float) -> np.ndarray:
image = np.zeros((image_size, image_size), dtype=np.float32)
square_size = int(image_size * fraction)

start = (image_size - square_size) // 2
end = start + square_size
image[start:end, start:end] = 1.0
return image

def gen_circle(image_size: int, fraction: float) -> np.ndarray:
image = np.zeros((image_size, image_size), dtype=np.float32)
circle_size = int(image_size * fraction)
cv2.circle(image, (image_size // 2, image_size // 2), circle_size // 2, 1.0, -1)
return image

def gen_grid_image(size, n_lines=50, line_width=0.1):
image = np.zeros((size, size), dtype=np.float32)

# Generate horizontal lines
for i in range(n_lines):
y = int(size / (n_lines + 1) * (i + 1))
image[y, :] = 1.0

# Generate vertical lines
for i in range(n_lines):
x = int(size / (n_lines + 1) * (i + 1))
image[:, x] = 1.0

return image

def visualize_deformation_field_with_lines(phi, grid_size=25, vector_scale=1):
"""
Visualize the deformation field by plotting deformed grid lines and overlaying the deformation vector field.
"""

pass

# Create the images
size = 256
fraction = 0.7

input_shape = gen_square(size, fraction)
target_shape = gen_circle(size, fraction)
grid_image = gen_grid_image(size, 50)

# Convert the images to tensors

input_tensor = torch.tensor(input_shape).unsqueeze(0).unsqueeze(0)
target_tensor = torch.tensor(target_shape).unsqueeze(0).unsqueeze(0)
grid_tensor = torch.tensor(grid_image).unsqueeze(0).unsqueeze(0)

# Get domain of the input and target coordinates
input_coords = np.indices((size, size)).astype(np.float32)
input_coords = input_coords / (size * 0.5) - 1.0
input_coords = torch.tensor(input_coords).unsqueeze(0)

# Create the model
model = SimpleReg()

# Define the loss function and optimizer
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# convert everything to cuda
model = model.cuda()
input_tensor = input_tensor.cuda()
target_tensor = target_tensor.cuda()
input_coords = input_coords.cuda()
grid_tensor = grid_tensor.cuda()

# Train the model
for epoch in range(1000):
optimizer.zero_grad()
outputs_grid = model(input_coords)
outputs_grid = input_coords.detach() + outputs_grid
# normalize the outputs to [-1, 1]
outputs_grid = torch.clamp(outputs_grid, -1, 1)
outputs = grid_sample(input_tensor, outputs_grid.permute(0, 2, 3, 1), mode='bilinear', padding_mode='zeros')
vis_grid_tensor = grid_sample(grid_tensor, outputs_grid.permute(0, 2, 3, 1), mode='bilinear', padding_mode='zeros')
loss = criterion(outputs, target_tensor)
loss.backward()
optimizer.step()

if epoch % 10 == 0:
print(f"Epoch {epoch}: Loss: {loss.item()}")

vis = np.concatenate([input_shape,
outputs.squeeze().detach().cpu().numpy(),
target_shape,
vis_grid_tensor.squeeze().detach().cpu().numpy()
], axis=1)

cv2.imshow("Visualization", vis)
cv2.waitKey(0)

Изображение


Подробнее здесь: https://stackoverflow.com/questions/791 ... g-grid-sam
Реклама
Ответить Пред. темаСлед. тема

Быстрый ответ

Изменение регистра текста: 
Смайлики
:) :( :oops: :roll: :wink: :muza: :clever: :sorry: :angel: :read: *x)
Ещё смайлики…
   
К этому ответу прикреплено по крайней мере одно вложение.

Если вы не хотите добавлять вложения, оставьте поля пустыми.

Максимально разрешённый размер вложения: 15 МБ.

  • Похожие темы
    Ответы
    Просмотры
    Последнее сообщение

Вернуться в «Python»