В частности, я пытаюсь реализовать метод раннего завершения лучей, чтобы прекратить добавление сэмплов вдоль любого луча, который уже накопил достаточный коэффициент пропускания (непрозрачность, близкая к 1). Я нашел вдохновение в методе раннего завершения лучей, используемом в реализации Instant-NGP (https://github.com/NVlabs/instant-ngp/b ... ed_nerf.cu).
Проблема в том, что когда я запускаю HashNeRF (100 итераций на графическом процессоре Google COLAB T4) с реализацией раннего завершения, rendering_time на порядок больше, чем значение по умолчанию. ХэшНеРФ (
Код: Выделить всё
rendering_time_early = 170 s
Я выполнил следующие шаги, чтобы включить ускоренный маршинг лучей к HashNeRF
1. Определение метода раннего завершения
Я начал с добавления файла Early_ray_termination.py
Код: Выделить всё
import torch
import torch.nn.functional as F
def apply_early_termination(raw, z_vals, rays_d, transmittance, accumulated_rgb, accumulated_weights, i, early_termination_threshold):
"""
Apply early ray termination logic during raymarching.
Args:
raw: The raw model output at the current sample [num_rays, 4].
z_vals: Sampled depth values along the ray [num_rays, num_samples along the ray].
rays_d: Ray directions [num_rays, 3].
transmittance: Current transmittance value (T) [num_rays, 1].
accumulated_rgb: Accumulated RGB values for rays [num_rays, 3].
accumulated_weights: Accumulated opacity (weights) [num_rays].
i: Current sample index.
early_termination_threshold: Threshold for stopping raymarching early.
Returns:
updated_transmittance, updated_rgb, updated_weights, stop_raymarching (bool)
"""
# # Extract the device from transmittance, as it should always be on the correct device
# device = transmittance.device
# # Initialize sparsity_loss on the same device as transmittance
# sparsity_loss = torch.tensor(0.0, device=device)
#Extract sigma and RGB values from raw model output
sigma = F.relu(raw[..., 3]) #Density
rgb = torch.sigmoid(raw[..., :3]) #Color
#Compute delta between samples
if i + 1 < z_vals.shape[1]:
delta = z_vals[:, i + 1] - z_vals[:, i]
else:
delta = 1e10 # Far plane
delta = delta * torch.norm(rays_d, dim=-1) # Convert to real-world distances
#Compute alpha and weights
alpha = 1. - torch.exp(-sigma * delta) # Alpha for this sample
weights = transmittance * alpha # Weighted alpha
#print(f"Shape of weights:\n {weights.shape}\n")
#Update transmittance, RGB and accumulated weights
transmittance = transmittance * (1. - alpha) # Update transmittance
# Accumulate RGB contributions, summing over the sample dimension
accumulated_rgb += torch.sum(weights[..., None] * rgb, -2)
#print("Shape of accumulated_weights:", accumulated_weights.shape)
#accumulated_rgb += weights[..., None] * rgb # Add weighted RGB
#Accumulate opacity weights (weights shape [1024, 1024] needs to match with accumulated_weights [1024] so we reduce the dimension of weights)
accumulated_weights[:, i] = weights[:, i] # Track weights per sample
#accumulated_weights += torch.sum(weights, dim=-1) #weights
# Check for early termination
stop_raymarching = torch.max(transmittance) < early_termination_threshold
return transmittance, accumulated_rgb, accumulated_weights, stop_raymarching
Параметр sigma представляет (объемную) плотность в точке выборки внутри луча. . Плотность необходима для вычисления непрозрачности данной точки луча. Я использовал ReLU, чтобы убедиться, что включены только положительные значения плотности. Параметр rbg хранит цвет выбранной точки. Я использовал сигмоид, чтобы убедиться, что значения цвета отображаются в диапазоне [0, 1].
Затем расстояние (дельта) между последовательными точками выборки вдоль Луч рассчитывается и измеряется в реальных единицах расстояния (я использовал норму для преобразования нормализованных координат в координаты реального мира). Эта переменная delta используется позже для определения вклада образца в непрозрачность (
Код: Выделить всё
alpha
Затем я вычисляю, сколько света поглощается или рассеивается (непрозрачность) в определенной точке луча, а затем определяю вклад этой точки (веса) в окончательный результат рендеринга.
В последней части я вычисляю, сколько света точка выборки вдоль луча вносит в окончательное изображение (цвет и непрозрачность), обновляю оставшийся свет (коэффициент пропускания ) и суммировать результаты по всем точкам выборки. Рэймарчинг прекращается, когда оставшийся уровень освещенности становится ниже заданного порога.
2. Изменение метода render_rays
Функция render_rays (https://github.com/yashbhalgat/HashNeRF ... un_nerf.py) обрабатывает каждый ray путем выборки точек вдоль него, опроса нейронной сети и накопления результатов (цвет, глубина, непрозрачность). Здесь происходят все вычисления, связанные с выборкой и накоплением лучей. Цель изменения render_rays — применить к нему метод раннего завершения лучей.
Вы увидите, что render — это метод, вызываемый в поезде (также в run_nerf.py) вместо render_rays. render обрабатывает общую картину: делит лучи на пакеты, вызывая render_rays для каждого пакета и комбинируя результаты. Batchify_rays соединяет рендеринг и render_rays, разделяя тензор большого луча из рендеринга на более мелкие фрагменты и вызывая render_rays для каждого фрагмента.
Код: Выделить всё
def render_rays(ray_batch,
network_fn,
network_query_fn,
N_samples,
embed_fn=None,
retraw=False,
lindisp=False,
perturb=0.,
N_importance=0,
network_fine=None,
white_bkgd=False,
raw_noise_std=0.,
verbose=False,
pytest=False,
early_termination_threshold=1e-4,
enable_early_termination=True)
[...]
#Initialize sparsity_loss to ensure it is defined when enable_early_termination=True
sparsity_loss = torch.zeros((N_rays,), device=ray_batch.device) # Match batch size
#sparsity_loss = torch.tensor(0.0, device=ray_batch.device) #Default sparsity loss
if enable_early_termination:
#Initialize accumulators for early termination
transmittance = torch.ones((N_rays, 1), device=ray_batch.device) #Start with T = 1
accumulated_rgb = torch.zeros((N_rays, 3), device=ray_batch.device) #Accumulated color
accumulated_weights = torch.zeros_like(z_vals) # [N_rays, N_samples]
#accumulated_weights = torch.zeros((N_rays,), device=ray_batch.device) #Accumulated weights
#print("Shape of accumulated weights:\n", accumulated_weights.shape)
#Iterate through samples with early termination
for i in range(N_samples):
raw = network_query_fn(pts[:, i:i + 1], viewdirs, network_fn)
transmittance, accumulated_rgb, accumulated_weights, stop_raymarching = apply_early_termination(
raw, z_vals, rays_d, transmittance, accumulated_rgb, accumulated_weights, i, early_termination_threshold
)
if stop_raymarching:
break
weights = torch.where(
torch.sum(accumulated_weights, dim=-1, keepdim=True) > 0,
accumulated_weights / (torch.sum(accumulated_weights, dim=-1, keepdim=True) + 1e-10),
torch.zeros_like(accumulated_weights)
)
#weights = accumulated_weights / (torch.sum(accumulated_weights, dim=-1, keepdim=True) + 1e-10)
#weights=accumulated_weights
rgb_map = accumulated_rgb
depth_map = torch.sum(accumulated_weights * z_vals, -1) / torch.sum(accumulated_weights, -1)#depth_map = torch.sum(accumulated_weights * z_vals, -1)
disp_map = 1./torch.max(1e-10 * torch.ones_like(depth_map), depth_map)#disp_map = 1. / torch.max(1e-10 * torch.ones_like(depth_map), depth_map / accumulated_weights)
#Compute accumulated opacity map
acc_map = torch.sum(accumulated_weights, -1)#acc_map = accumulated_weights
else:
#Original HashNeRF behavior (no early termination)
raw = network_query_fn(pts, viewdirs, network_fn)
rgb_map, disp_map, acc_map, weights, depth_map, sparsity_loss = raw2outputs(raw, z_vals, rays_d, raw_noise_std, white_bkgd, pytest=pytest)
[...]
Я инициализировал аккумуляторы для коэффициента пропускания, цвета (RGB). и веса (непрозрачность) для каждого луча. Затем код перебирает выборки вдоль луча, запрашивая у нейронной сети прогнозы (
Код: Выделить всё
raw
Код: Выделить всё
rgb_map
Код: Выделить всё
depth_map
Код: Выделить всё
disp_map
Код: Выделить всё
acc_map
Любая помощь по оптимизации/изменению кода для достижения цели будет оценена по достоинству.
Подробнее здесь: https://stackoverflow.com/questions/792 ... ementation