Код: Выделить всё
import jax
from functools import partial
from typing import List
def dummy(a: int, b: List[str]):
return a + 1
Код: Выделить всё
j_dummy = jax.jit(dummy, static_argnames=['b'])
j_dummy(2, ['kek'])
ValueError: Non-hashable static arguments are not supported
Итак, я немного потерян здесь: как я должен продолжить в более широкой картине? Должен ли я создавать частичные функции с любыми аргументами, а затем их считать или я должен попытаться сохранить свои аргументы с хранением? Каковы ситуации, когда один подход лучше других? Есть ли недостатки?
Подробнее здесь: https://stackoverflow.com/questions/791 ... le-partial