Джакс отслеживает статический аргументPython

Программы на Python
Ответить Пред. темаСлед. тема
Anonymous
 Джакс отслеживает статический аргумент

Сообщение Anonymous »


Я пытаюсь использовать код Jax в ядре Pallas, но по какой-то причине мой код больше не работает.

импортировать функциональные инструменты импортировать Джакс из jax.experimental импортировать Pallas как pl импортировать jax.numpy как jnp импортировать numpy как np из jaxtyping import Array из jax.experimental импорт разреженный ключ = jax.random.PRNGKey(52) другое = jax.random.normal(ключ, (10, 10)) диагноз = jax.random.normal(ключ, (3, 10)) смещения = (-2, 1, 2) defdia_matmul_kernel(diags_ref, offsets,other_ref, o_ref): Diags, Other = Diags_ref[...], Other_ref[...] N = другое.форма[0] out = jnp.zeros((N, N)) печать (смещение) для смещения, Diag в zip(offsets, Diags): начало = jax.lax.max(0, смещение) конец = мин(N, N + смещение) верх = макс (0, -смещение) низ = верх + конец - начало out = out.at[top:bottom, :].add( диагноз[начало:конец, Нет] * другое[начало:конец, :] ) o_ref[...] = выход @functools.partial(jax.jit, static_argnums=(1, )) defdia_matmul(diags: Array, offsets: tuple[int],other:Array) -> Массив: вернуть pl.pallas_call( диаметр_matmul_kernel, out_shape=jax.ShapeDtypeStruct(other.shape,other.dtype) )(диаграммы, смещения, прочее) dia_matmul(diags, смещения, другое) Я понимаю, что печатать данные в JIT-функции Jax не рекомендуется, но когда я печатаю свои смещения, которые должны оставаться статическими из static_argnums=(1,), там написано:

(Tracedwith, Tracedwith, отслеживаетсяwith) Я не понимаю, почему это так, я новичок в Jax и Pallas, поэтому я еще не полностью уверен во всей этой концепции отслеживания. Кроме того, последняя операция цикла for с out не работает, так что если у кого-нибудь тоже есть идеи :D

Большое спасибо!
Реклама
Ответить Пред. темаСлед. тема

Быстрый ответ

Изменение регистра текста: 
Смайлики
:) :( :oops: :roll: :wink: :muza: :clever: :sorry: :angel: :read: *x)
Ещё смайлики…
   
К этому ответу прикреплено по крайней мере одно вложение.

Если вы не хотите добавлять вложения, оставьте поля пустыми.

Максимально разрешённый размер вложения: 15 МБ.

  • Похожие темы
    Ответы
    Просмотры
    Последнее сообщение
  • Джакс отслеживает статический аргумент
    Anonymous » » в форуме Python
    0 Ответы
    34 Просмотры
    Последнее сообщение Anonymous
  • JDK 11; ДЖАКС-WS; Поставщик com.sun.xml.internal.ws.spi.ProviderImpl не найден
    Anonymous » » в форуме JAVA
    0 Ответы
    61 Просмотры
    Последнее сообщение Anonymous
  • Джакс отбрасывает код kd-дерева, занимая невероятно много времени
    Anonymous » » в форуме Python
    0 Ответы
    34 Просмотры
    Последнее сообщение Anonymous
  • Джакс постоянный кэш -разрывы вызывает недостаток?
    Anonymous » » в форуме Python
    0 Ответы
    9 Просмотры
    Последнее сообщение Anonymous
  • Есть ли в Python статический конструктор или статический инициализатор?
    Anonymous » » в форуме Python
    0 Ответы
    34 Просмотры
    Последнее сообщение Anonymous

Вернуться в «Python»