Я пытаюсь использовать код Jax в ядре Pallas, но по какой-то причине мой код больше не работает.
импортировать функциональные инструменты импортировать Джакс из jax.experimental импортировать Pallas как pl импортировать jax.numpy как jnp импортировать numpy как np из jaxtyping import Array из jax.experimental импорт разреженный ключ = jax.random.PRNGKey(52) другое = jax.random.normal(ключ, (10, 10)) диагноз = jax.random.normal(ключ, (3, 10)) смещения = (-2, 1, 2) defdia_matmul_kernel(diags_ref, offsets,other_ref, o_ref): Diags, Other = Diags_ref[...], Other_ref[...] N = другое.форма[0] out = jnp.zeros((N, N)) печать (смещение) для смещения, Diag в zip(offsets, Diags): начало = jax.lax.max(0, смещение) конец = мин(N, N + смещение) верх = макс (0, -смещение) низ = верх + конец - начало out = out.at[top:bottom, :].add( диагноз[начало:конец, Нет] * другое[начало:конец, :] ) o_ref[...] = выход @functools.partial(jax.jit, static_argnums=(1, )) defdia_matmul(diags: Array, offsets: tuple[int],other:Array) -> Массив: вернуть pl.pallas_call( диаметр_matmul_kernel, out_shape=jax.ShapeDtypeStruct(other.shape,other.dtype) )(диаграммы, смещения, прочее) dia_matmul(diags, смещения, другое) Я понимаю, что печатать данные в JIT-функции Jax не рекомендуется, но когда я печатаю свои смещения, которые должны оставаться статическими из static_argnums=(1,), там написано:
(Tracedwith, Tracedwith, отслеживаетсяwith) Я не понимаю, почему это так, я новичок в Jax и Pallas, поэтому я еще не полностью уверен во всей этой концепции отслеживания. Кроме того, последняя операция цикла for с out не работает, так что если у кого-нибудь тоже есть идеи

Большое спасибо!