JAX LAX.SCAN: Как итерация по слоям и срезам памяти одновременно без динамической индексации в многослойной структуре RNPython

Программы на Python
Ответить Пред. темаСлед. тема
Anonymous
 JAX LAX.SCAN: Как итерация по слоям и срезам памяти одновременно без динамической индексации в многослойной структуре RN

Сообщение Anonymous »

Я пытаюсь внедрить структуру, которая управляет сети RNN с произвольным количеством слоев (все это часть библиотеки, которую я строю на основе JAX/Equinox), проблема в том, что я не могу найти реализацию, которая позволяет мне управлять несколькими слоями модульным образом, дайте мне объяснить:
def __call__(self, x):

def forward(carry, x_t):

def layer_apply(h_t, layer):
return layer(h_t, x_t)

t_memory = lax.scan(layer_apply, carry, self.layers)

return t_memory[0], t_memory[1]

return lax.scan(forward, self.h, x)
< /code>
Этот код является функцией класса, называемого «recurrent_block '(eqx.module), который принимает как параметры, кортеж из слоев RNN и создает Self.h, который является делом Jnp.array All Zero с правильной формой для памяти каждого слоя, так: < /p>
jnp.zeros, один для слоя: (h (0), h (1) ... h (n))
self.layers = tuple of eqx.module, каждый из которых с call , который внедряет обратный пропуск rnn. (seq_len, batch_size, функции)
(как вы можете видеть, все известно во время компиляции)
Теперь: в коде:

lax.scan не принимает xs наличие неоднородных форм (w_h, w_h, и смещение имеет разные формы по своей природе), поэтому я не могу передать слой кортеж как xs для сканирования. /> < /ol>
Я бы предпочел решение использовать lax.scan < /code>, поскольку оно очень оптимизировано для этого типа операции из того, что я знаю, даже несколько вложенных сканов-это нормально, единственное важное-это то, что его можно скомпилировать в XLA без каких-либо странных закупок, и я абсолютно не хочу, чтобы это жесткое решение было таким: < /p>
, и я абсолютно не хочу такого жесткого кодируемого решения: < /p>
, и я абсолютно не хочу, чтобы это жесткое решение было подобно: < /p>
, и я абсолютно не хочу, чтобы это жесткое решение было таким: < /p>
.x = layer[0](h[0], x),
x = layer[1](h[1], x)
...
Потому что это, очевидно, было бы очень неудобно. Дело в том, что, поскольку у меня есть все, что известно во время компиляции, должно быть возможно, чтобы подготовить функцию, но я не могу найти решение. < /P>
Я попробовал несколько подходов: < /strong> < /h1>

Использование индекса для слоев и h < /li>

def __call__(self, x, i):

def forward(carry, x_t):

def layer_apply(h_t, i):
return self.layers(h_t, x_t)

t_memory = lax.scan(layer_apply, carry, i)

return t_memory[0], t_memory[1]

return lax.scan(forward, self.h, x)

Где i - индексный массив, изготовленный с Arange, ошибка этого кода:
**layer_apply
return self.layers(h_t, x_t)
~~~~~~^^^
jax.errors.TracerIntegerConversionError: The __index__() method was called on traced array with shape int32[]
The error occurred while tracing the function layer_apply at test.py:37 for scan. This concrete value was not available in Python because it depends on the value of the argument i.
See https://docs.jax.dev/en/latest/errors.h ... rsionError**
< /code>

Используя развернутые для цикла < /li>
< /ol>
def __call__(self, x):

def forward(carry, x_t):
for layer, h_t in zip(self.layers, carry):
x_t = layer(h_t, x_t)
return carry, x_t

return lax.scan(forward, self.h, x)[1]

Это, вероятно, мой любимый, но, как вы можете видеть, проблема в том, что перенос не обновляется, поэтому он не поднимает ошибку, но это не имеет смысла.
Если вы хотите проверить, вот фиктивный код, который работает как мой:
import jax
import jax.numpy as jnp
import jax.lax as lax
import equinox as eqx

class Rnn(eqx.Module): #layer class
dim : int = eqx.static_field()
w_x : jnp.array
w_h : jnp.array
b : jnp.array

def __init__(self, dim):
self.dim = dim
self.w_x = jnp.ones((dim,dim), dtype=jnp.float32)
self.w_h = jnp.ones((dim,dim), dtype=jnp.float32)
self.b = jnp.zeros((dim,), dtype=jnp.float32)

def __call__(self, h, x): #rnn forward
return [email protected]_x + [email protected]_h + self.b

dim = 2
batch_size = 3
seq_len = 3

x = jnp.ones((seq_len, 3, dim)) #input

layers = (Rnn(dim), Rnn(dim)) #tuple of layers

h = (jnp.zeros((batch_size, dim), dtype=jnp.float32), jnp.zeros((batch_size, dim), dtype=jnp.float32)) #temporal memory

def scan_fn():
return #your solution !

print(scan_fn()) #a general result for the iteration in every rnn layer


Подробнее здесь: https://stackoverflow.com/questions/797 ... sly-withou
Реклама
Ответить Пред. темаСлед. тема

Быстрый ответ

Изменение регистра текста: 
Смайлики
:) :( :oops: :roll: :wink: :muza: :clever: :sorry: :angel: :read: *x)
Ещё смайлики…
   
К этому ответу прикреплено по крайней мере одно вложение.

Если вы не хотите добавлять вложения, оставьте поля пустыми.

Максимально разрешённый размер вложения: 15 МБ.

  • Похожие темы
    Ответы
    Просмотры
    Последнее сообщение
  • JAX LAX.SCAN: Как итерация по слоям и срезам памяти одновременно без динамической индексации в многослойной структуре RN
    Anonymous » » в форуме Python
    0 Ответы
    3 Просмотры
    Последнее сообщение Anonymous
  • Как выбрать между использованием `jax.lax.scan` против` for
    Anonymous » » в форуме Python
    0 Ответы
    4 Просмотры
    Последнее сообщение Anonymous
  • Jax.lax.cond выполняет обе ветки вместо только ветки True
    Anonymous » » в форуме Python
    0 Ответы
    35 Просмотры
    Последнее сообщение Anonymous
  • Jax.lax.cond выполняет обе ветки вместо только ветки True
    Anonymous » » в форуме Python
    0 Ответы
    23 Просмотры
    Последнее сообщение Anonymous
  • Получение столбца от другого осколка с использованием jax.lax.gather ()
    Anonymous » » в форуме Python
    0 Ответы
    4 Просмотры
    Последнее сообщение Anonymous

Вернуться в «Python»