Anonymous
Градиентная ошибка нормы партии, реализованная с нуля
Сообщение
Anonymous » 30 окт 2025, 21:39
Я пытаюсь реализовать пакетную нормализацию с нуля. Вот мой код.
Код: Выделить всё
from functools import partial
import jax
@jax.tree_util.register_pytree_node_class
class MyBN:
def __init__(self, feature_count: int, epsilon: float = 1e-5, momentum: float = 0.1, axis_name: str = 'batch'):
self.scale = jax.numpy.ones((feature_count,))
self.shift = jax.numpy.zeros((feature_count,))
self.running_mean = jax.numpy.zeros((feature_count,))
self.running_var = jax.numpy.ones((feature_count,))
self.momentum = momentum
self.epsilon = epsilon
self.axis_name = axis_name
def tree_flatten(self):
return (self.scale, self.shift), (self.running_mean, self.running_var, self.momentum, self.epsilon, self.axis_name)
@classmethod
def tree_unflatten(cls, aux_data, children):
scale, shift = children
running_mean, running_var, momentum, epsilon, axis_name = aux_data
obj = cls.__new__(cls)
obj.scale = scale
obj.shift = shift
obj.running_mean = running_mean
obj.running_var = running_var
obj.momentum = momentum
obj.epsilon = epsilon
obj.axis_name = axis_name
return obj
def __repr__(self):
return f'{jax.tree_util.tree_leaves(self)}'
def __call__(self, x):
axis_to_manipulate = tuple(range(x.ndim-1))
batch_mean = jax.numpy.mean(x, axis=axis_to_manipulate, keepdims=True)
batch_mean = jax.lax.pmean(batch_mean, axis_name=self.axis_name)
batch_var = jax.numpy.mean((x - batch_mean)**2, axis=axis_to_manipulate, keepdims=True)
batch_var = jax.lax.pmean(batch_var, axis_name=self.axis_name)
batch_var = jax.numpy.maximum(0.0, batch_var)
batch_mean_squeezed = jax.numpy.squeeze(batch_mean, axis_to_manipulate)
batch_var_squeezed = jax.numpy.squeeze(batch_var, axis_to_manipulate)
running_mean_updated = self.momentum * self.running_mean + (1 - self.momentum) * batch_mean_squeezed
running_var_updated = self.momentum * self.running_var + (1. - self.momentum) * batch_var_squeezed
x_normalized = (x - batch_mean) / jax.numpy.sqrt(batch_var + self.epsilon)
scale_with_match_dimensions = jax.numpy.expand_dims(self.scale, axis_to_manipulate)
shift_with_match_dimensions = jax.numpy.expand_dims(self.shift, axis_to_manipulate)
output = x_normalized * scale_with_match_dimensions + shift_with_match_dimensions
return output, (running_mean_updated, running_var_updated)
Я также создал пакетную нормальную функцию для проверки моего кода, приведенного выше.
Код: Выделить всё
feature_count = 5
key_input = jax.random.key(0)
x_bhwc = jax.random.normal(key_input, (20, 30, 40, feature_count))
bn = MyBN(feature_count=feature_count)
del key_input
# manual batch norm
def calculate_batch_norm(x, scale, shift, running_mean, running_var, momentum, epsilon):
batch_mean = jax.numpy.mean(x, axis=(0,1,2), keepdims=True)
batch_var = jax.numpy.var(x, axis=(0,1,2), keepdims=True)
running_mean_updated = momentum * running_mean + (1 - momentum) * jax.numpy.squeeze(batch_mean)
running_var_updated = momentum * running_var + (1. - momentum) * jax.numpy.squeeze(batch_var)
x_normalized = (x - batch_mean) / jax.numpy.sqrt(batch_var + epsilon)
scale_with_match_dimensions = jax.numpy.expand_dims(scale, axis=(0,1,2))
shift_with_match_dimensions = jax.numpy.expand_dims(shift, axis=(0,1,2))
output = x_normalized * scale_with_match_dimensions + shift_with_match_dimensions
return output, (running_mean_updated, running_var_updated)
Тестирование vmap и jit дает тот же или очень близкий результат, но не для grad. Вот мой тест на выпускной.
Код: Выделить всё
# jitted gradient of a loss function that uses the vmapped batch norm
@jax.jit
@jax.grad
def loss_bn(model, x):
normalized_x, _ = jax.vmap(model, out_axes=(0, None), axis_name='batch')(x)
loss = jax.numpy.mean(normalized_x**2)
return loss
@jax.jit
@partial(jax.grad, argnums=(0,1))
def loss_manual_bn(scale, shift, running_mean, running_var, momentum, epsilon, x):
normalized_x, _ = calculate_batch_norm(x, scale, shift, running_mean, running_var, momentum, epsilon)
loss = jax.numpy.mean(normalized_x**2)
return loss
grad_loss_vbn = loss_bn(bn, x_bhwc)
grad_loss_manual_bn = loss_manual_bn(bn.scale, bn.shift, bn.running_mean, bn.running_var, bn.momentum, bn.epsilon, x_bhwc)
print(grad_loss_vbn)
print(grad_loss_manual_bn)
Результат такой.
Код: Выделить всё
[Array([0.39999574, 0.3999958 , 0.39999616, 0.3999952 , 0.3999954 ], dtype=float32), Array([-3.5292942e-09, -3.7252903e-09, -1.1641532e-10, 3.0267984e-09,
6.0535967e-09], dtype=float32)]
(Array([0.3999961 , 0.39999634, 0.39999685, 0.3999954 , 0.39999533], dtype=float32), Array([-3.0895535e-09, 6.0535967e-09, 7.5669959e-10, -4.7730282e-09,
1.1059456e-08], dtype=float32))
Почему оба градиента относительно смена не похожа? Какая часть кода неверна? Я очень ценю любую помощь.
Подробнее здесь:
https://stackoverflow.com/questions/798 ... om-scratch
1761849579
Anonymous
Я пытаюсь реализовать пакетную нормализацию с нуля. Вот мой код. [code]from functools import partial import jax @jax.tree_util.register_pytree_node_class class MyBN: def __init__(self, feature_count: int, epsilon: float = 1e-5, momentum: float = 0.1, axis_name: str = 'batch'): self.scale = jax.numpy.ones((feature_count,)) self.shift = jax.numpy.zeros((feature_count,)) self.running_mean = jax.numpy.zeros((feature_count,)) self.running_var = jax.numpy.ones((feature_count,)) self.momentum = momentum self.epsilon = epsilon self.axis_name = axis_name def tree_flatten(self): return (self.scale, self.shift), (self.running_mean, self.running_var, self.momentum, self.epsilon, self.axis_name) @classmethod def tree_unflatten(cls, aux_data, children): scale, shift = children running_mean, running_var, momentum, epsilon, axis_name = aux_data obj = cls.__new__(cls) obj.scale = scale obj.shift = shift obj.running_mean = running_mean obj.running_var = running_var obj.momentum = momentum obj.epsilon = epsilon obj.axis_name = axis_name return obj def __repr__(self): return f'{jax.tree_util.tree_leaves(self)}' def __call__(self, x): axis_to_manipulate = tuple(range(x.ndim-1)) batch_mean = jax.numpy.mean(x, axis=axis_to_manipulate, keepdims=True) batch_mean = jax.lax.pmean(batch_mean, axis_name=self.axis_name) batch_var = jax.numpy.mean((x - batch_mean)**2, axis=axis_to_manipulate, keepdims=True) batch_var = jax.lax.pmean(batch_var, axis_name=self.axis_name) batch_var = jax.numpy.maximum(0.0, batch_var) batch_mean_squeezed = jax.numpy.squeeze(batch_mean, axis_to_manipulate) batch_var_squeezed = jax.numpy.squeeze(batch_var, axis_to_manipulate) running_mean_updated = self.momentum * self.running_mean + (1 - self.momentum) * batch_mean_squeezed running_var_updated = self.momentum * self.running_var + (1. - self.momentum) * batch_var_squeezed x_normalized = (x - batch_mean) / jax.numpy.sqrt(batch_var + self.epsilon) scale_with_match_dimensions = jax.numpy.expand_dims(self.scale, axis_to_manipulate) shift_with_match_dimensions = jax.numpy.expand_dims(self.shift, axis_to_manipulate) output = x_normalized * scale_with_match_dimensions + shift_with_match_dimensions return output, (running_mean_updated, running_var_updated) [/code] Я также создал пакетную нормальную функцию для проверки моего кода, приведенного выше. [code]feature_count = 5 key_input = jax.random.key(0) x_bhwc = jax.random.normal(key_input, (20, 30, 40, feature_count)) bn = MyBN(feature_count=feature_count) del key_input # manual batch norm def calculate_batch_norm(x, scale, shift, running_mean, running_var, momentum, epsilon): batch_mean = jax.numpy.mean(x, axis=(0,1,2), keepdims=True) batch_var = jax.numpy.var(x, axis=(0,1,2), keepdims=True) running_mean_updated = momentum * running_mean + (1 - momentum) * jax.numpy.squeeze(batch_mean) running_var_updated = momentum * running_var + (1. - momentum) * jax.numpy.squeeze(batch_var) x_normalized = (x - batch_mean) / jax.numpy.sqrt(batch_var + epsilon) scale_with_match_dimensions = jax.numpy.expand_dims(scale, axis=(0,1,2)) shift_with_match_dimensions = jax.numpy.expand_dims(shift, axis=(0,1,2)) output = x_normalized * scale_with_match_dimensions + shift_with_match_dimensions return output, (running_mean_updated, running_var_updated) [/code] Тестирование vmap и jit дает тот же или очень близкий результат, но не для grad. Вот мой тест на выпускной. [code]# jitted gradient of a loss function that uses the vmapped batch norm @jax.jit @jax.grad def loss_bn(model, x): normalized_x, _ = jax.vmap(model, out_axes=(0, None), axis_name='batch')(x) loss = jax.numpy.mean(normalized_x**2) return loss @jax.jit @partial(jax.grad, argnums=(0,1)) def loss_manual_bn(scale, shift, running_mean, running_var, momentum, epsilon, x): normalized_x, _ = calculate_batch_norm(x, scale, shift, running_mean, running_var, momentum, epsilon) loss = jax.numpy.mean(normalized_x**2) return loss grad_loss_vbn = loss_bn(bn, x_bhwc) grad_loss_manual_bn = loss_manual_bn(bn.scale, bn.shift, bn.running_mean, bn.running_var, bn.momentum, bn.epsilon, x_bhwc) print(grad_loss_vbn) print(grad_loss_manual_bn) [/code] Результат такой. [code][Array([0.39999574, 0.3999958 , 0.39999616, 0.3999952 , 0.3999954 ], dtype=float32), Array([-3.5292942e-09, -3.7252903e-09, -1.1641532e-10, 3.0267984e-09, 6.0535967e-09], dtype=float32)] (Array([0.3999961 , 0.39999634, 0.39999685, 0.3999954 , 0.39999533], dtype=float32), Array([-3.0895535e-09, 6.0535967e-09, 7.5669959e-10, -4.7730282e-09, 1.1059456e-08], dtype=float32)) [/code] Почему оба градиента относительно смена не похожа? Какая часть кода неверна? Я очень ценю любую помощь. Подробнее здесь: [url]https://stackoverflow.com/questions/79805112/gradient-error-of-batch-norm-that-is-implemented-from-scratch[/url]