Код: Выделить всё
import jax
from functools import partial
@partial(partial, jax.jit, static_argnums=(2,4))
def init_particles_grid(n_particles: int, n_side: int, L: float, A: float):
return None
Код: Выделить всё
import jax.numpy as jnp
n_side = 4
N = n_side * n_side
L = 1.0
A = 1.0 / (4.0 * jnp.pi**2)
q, x0, v0 = init_particles_grid(N, n_side, L, A)
TypeError: jit() принимает 1 позиционный аргумент, но было передано 5 позиционных аргументов (и 1 аргумент только для ключевых слов).
Что я делаю не так?>
Подробнее здесь: https://stackoverflow.com/questions/798 ... eerror-jit
Мобильная версия