Я пытаюсь внедрить структуру, которая управляет сети RNN с произвольным количеством слоев (все это часть библиотеки, которую я строю на основе JAX/Equinox), проблема в том, что я не могу найти реализацию, которая позволяет мне управлять несколькими слоями модульным образом, дайте мне объяснить:
def __call__(self, x):
def forward(carry, x_t):
def layer_apply(h_t, layer):
return layer(h_t, x_t)
t_memory = lax.scan(layer_apply, carry, self.layers)
return t_memory[0], t_memory[1]
return lax.scan(forward, self.h, x)
< /code>
Этот код является функцией класса, называемого «recurrent_block '(eqx.module), который принимает как параметры, кортеж из слоев RNN и создает Self.h, который является делом Jnp.array All Zero с правильной формой для памяти каждого слоя, так: < /p>
jnp.zeros, один для слоя: (h (0), h (1) ... h (n))
self.layers = tuple of eqx.module, каждый из которых с call , который внедряет обратный пропуск rnn. (seq_len, batch_size, функции)
(как вы можете видеть, все известно во время компиляции)
Теперь: в коде:
lax.scan не принимает xs наличие неоднородных форм (w_h, w_h, и смещение имеет разные формы по своей природе), поэтому я не могу передать слой кортеж как xs для сканирования. /> < /ol>
Я бы предпочел решение использовать lax.scan < /code>, поскольку оно очень оптимизировано для этого типа операции из того, что я знаю, даже несколько вложенных сканов-это нормально, единственное важное-это то, что его можно скомпилировать в XLA без каких-либо странных закупок, и я абсолютно не хочу, чтобы это жесткое решение было таким: < /p>
, и я абсолютно не хочу такого жесткого кодируемого решения: < /p>
, и я абсолютно не хочу, чтобы это жесткое решение было подобно: < /p>
, и я абсолютно не хочу, чтобы это жесткое решение было таким: < /p>
.x = layer[0](h[0], x),
x = layer[1](h[1], x)
...
Потому что это, очевидно, было бы очень неудобно. Дело в том, что, поскольку у меня есть все, что известно во время компиляции, должно быть возможно, чтобы подготовить функцию, но я не могу найти решение. < /P>
Я попробовал несколько подходов: < /strong> < /h1>
Использование индекса для слоев и h < /li>
def __call__(self, x, i):
def forward(carry, x_t):
def layer_apply(h_t, i):
return self.layers(h_t, x_t)
t_memory = lax.scan(layer_apply, carry, i)
return t_memory[0], t_memory[1]
return lax.scan(forward, self.h, x)
Где i - индексный массив, изготовленный с Arange, ошибка этого кода:
**layer_apply
return self.layers(h_t, x_t)
~~~~~~^^^
jax.errors.TracerIntegerConversionError: The __index__() method was called on traced array with shape int32[]
The error occurred while tracing the function layer_apply at test.py:37 for scan. This concrete value was not available in Python because it depends on the value of the argument i.
See https://docs.jax.dev/en/latest/errors.h ... rsionError**
< /code>
Используя развернутые для цикла < /li>
< /ol>
def __call__(self, x):
def forward(carry, x_t):
for layer, h_t in zip(self.layers, carry):
x_t = layer(h_t, x_t)
return carry, x_t
return lax.scan(forward, self.h, x)[1]
Это, вероятно, мой любимый, но, как вы можете видеть, проблема в том, что перенос не обновляется, поэтому он не поднимает ошибку, но это не имеет смысла.
Если вы хотите проверить, вот фиктивный код, который работает как мой:
import jax
import jax.numpy as jnp
import jax.lax as lax
import equinox as eqx
class Rnn(eqx.Module): #layer class
dim : int = eqx.static_field()
w_x : jnp.array
w_h : jnp.array
b : jnp.array
def __init__(self, dim):
self.dim = dim
self.w_x = jnp.ones((dim,dim), dtype=jnp.float32)
self.w_h = jnp.ones((dim,dim), dtype=jnp.float32)
self.b = jnp.zeros((dim,), dtype=jnp.float32)
def __call__(self, h, x): #rnn forward
return [email protected]_x + [email protected]_h + self.b
dim = 2
batch_size = 3
seq_len = 3
x = jnp.ones((seq_len, 3, dim)) #input
layers = (Rnn(dim), Rnn(dim)) #tuple of layers
h = (jnp.zeros((batch_size, dim), dtype=jnp.float32), jnp.zeros((batch_size, dim), dtype=jnp.float32)) #temporal memory
def scan_fn():
return #your solution !
print(scan_fn()) #a general result for the iteration in every rnn layer
Подробнее здесь: https://stackoverflow.com/questions/797 ... sly-withou
JAX LAX.SCAN: Как итерация по слоям и срезам памяти одновременно без динамической индексации в многослойной структуре RN ⇐ Python
-
- Похожие темы
- Ответы
- Просмотры
- Последнее сообщение
-
-
Получение столбца от другого осколка с использованием jax.lax.gather ()
Anonymous » » в форуме Python - 0 Ответы
- 4 Просмотры
-
Последнее сообщение Anonymous
-