Градиентная ошибка нормы партии, реализованная с нуляPython

Программы на Python
Ответить
Anonymous
 Градиентная ошибка нормы партии, реализованная с нуля

Сообщение Anonymous »

Я пытаюсь реализовать пакетную нормализацию с нуля. Вот мой код.

Код: Выделить всё

from functools import partial
import jax

@jax.tree_util.register_pytree_node_class
class MyBN:
def __init__(self, feature_count: int, epsilon: float = 1e-5, momentum: float = 0.1, axis_name: str = 'batch'):

self.scale = jax.numpy.ones((feature_count,))
self.shift = jax.numpy.zeros((feature_count,))
self.running_mean = jax.numpy.zeros((feature_count,))
self.running_var = jax.numpy.ones((feature_count,))
self.momentum = momentum
self.epsilon = epsilon

self.axis_name = axis_name

def tree_flatten(self):
return (self.scale, self.shift), (self.running_mean, self.running_var, self.momentum, self.epsilon, self.axis_name)

@classmethod
def tree_unflatten(cls, aux_data, children):
scale, shift = children
running_mean, running_var, momentum, epsilon, axis_name = aux_data
obj = cls.__new__(cls)
obj.scale = scale
obj.shift = shift
obj.running_mean = running_mean
obj.running_var = running_var
obj.momentum = momentum
obj.epsilon = epsilon
obj.axis_name = axis_name
return obj

def __repr__(self):
return f'{jax.tree_util.tree_leaves(self)}'

def __call__(self, x):
axis_to_manipulate = tuple(range(x.ndim-1))
batch_mean = jax.numpy.mean(x, axis=axis_to_manipulate, keepdims=True)
batch_mean = jax.lax.pmean(batch_mean, axis_name=self.axis_name)
batch_var = jax.numpy.mean((x - batch_mean)**2, axis=axis_to_manipulate, keepdims=True)
batch_var = jax.lax.pmean(batch_var, axis_name=self.axis_name)
batch_var = jax.numpy.maximum(0.0, batch_var)

batch_mean_squeezed = jax.numpy.squeeze(batch_mean, axis_to_manipulate)
batch_var_squeezed = jax.numpy.squeeze(batch_var, axis_to_manipulate)
running_mean_updated = self.momentum * self.running_mean + (1 - self.momentum) * batch_mean_squeezed
running_var_updated = self.momentum * self.running_var + (1. - self.momentum) * batch_var_squeezed

x_normalized = (x - batch_mean) / jax.numpy.sqrt(batch_var + self.epsilon)

scale_with_match_dimensions = jax.numpy.expand_dims(self.scale, axis_to_manipulate)
shift_with_match_dimensions = jax.numpy.expand_dims(self.shift, axis_to_manipulate)
output = x_normalized * scale_with_match_dimensions + shift_with_match_dimensions

return output, (running_mean_updated, running_var_updated)
Я также создал пакетную нормальную функцию для проверки моего кода, приведенного выше.

Код: Выделить всё

feature_count = 5
key_input = jax.random.key(0)
x_bhwc = jax.random.normal(key_input, (20, 30, 40, feature_count))
bn = MyBN(feature_count=feature_count)
del key_input

# manual batch norm
def calculate_batch_norm(x, scale, shift, running_mean, running_var, momentum, epsilon):
batch_mean = jax.numpy.mean(x, axis=(0,1,2), keepdims=True)
batch_var = jax.numpy.var(x, axis=(0,1,2), keepdims=True)

running_mean_updated = momentum * running_mean + (1 - momentum) * jax.numpy.squeeze(batch_mean)
running_var_updated = momentum * running_var + (1. - momentum) * jax.numpy.squeeze(batch_var)

x_normalized = (x - batch_mean) / jax.numpy.sqrt(batch_var + epsilon)

scale_with_match_dimensions = jax.numpy.expand_dims(scale, axis=(0,1,2))
shift_with_match_dimensions = jax.numpy.expand_dims(shift, axis=(0,1,2))
output = x_normalized * scale_with_match_dimensions + shift_with_match_dimensions

return output, (running_mean_updated, running_var_updated)
Тестирование vmap и jit дает тот же или очень близкий результат, но не для grad. Вот мой тест на выпускной.

Код: Выделить всё

# jitted gradient of a loss function that uses the vmapped batch norm
@jax.jit
@jax.grad
def loss_bn(model, x):
normalized_x, _ = jax.vmap(model, out_axes=(0, None), axis_name='batch')(x)
loss = jax.numpy.mean(normalized_x**2)
return loss

@jax.jit
@partial(jax.grad, argnums=(0,1))
def loss_manual_bn(scale, shift, running_mean, running_var, momentum, epsilon, x):
normalized_x, _ = calculate_batch_norm(x, scale, shift, running_mean, running_var, momentum, epsilon)
loss = jax.numpy.mean(normalized_x**2)
return loss

grad_loss_vbn = loss_bn(bn, x_bhwc)
grad_loss_manual_bn = loss_manual_bn(bn.scale, bn.shift, bn.running_mean, bn.running_var, bn.momentum, bn.epsilon, x_bhwc)
print(grad_loss_vbn)
print(grad_loss_manual_bn)
Результат такой.

Код: Выделить всё

[Array([0.39999574, 0.3999958 , 0.39999616, 0.3999952 , 0.3999954 ],      dtype=float32), Array([-3.5292942e-09, -3.7252903e-09, -1.1641532e-10,  3.0267984e-09,
6.0535967e-09], dtype=float32)]
(Array([0.3999961 , 0.39999634, 0.39999685, 0.3999954 , 0.39999533],      dtype=float32), Array([-3.0895535e-09,  6.0535967e-09,  7.5669959e-10, -4.7730282e-09,
1.1059456e-08], dtype=float32))
Почему оба градиента относительно смена не похожа? Какая часть кода неверна? Я очень ценю любую помощь.

Подробнее здесь: https://stackoverflow.com/questions/798 ... om-scratch
Ответить

Быстрый ответ

Изменение регистра текста: 
Смайлики
:) :( :oops: :roll: :wink: :muza: :clever: :sorry: :angel: :read: *x)
Ещё смайлики…
   
К этому ответу прикреплено по крайней мере одно вложение.

Если вы не хотите добавлять вложения, оставьте поля пустыми.

Максимально разрешённый размер вложения: 15 МБ.

Вернуться в «Python»