Код: Выделить всё
import jax
import jax.numpy as jnp
import jax.typing as jt
import flax.linen as nn
class MLP(nn.Module):
@nn.compact
def __call__(self, inputs: jt.ArrayLike):
inputs = jnp.array(inputs)
return nn.Dense(4)(inputs)
if __name__ == "__main__":
inputs = jnp.ones((2, 4))
mlp = MLP()
prng_key = jax.random.key(7)
outputs: jax.Array = mlp.apply(mlp.init(prng_key, inputs), inputs)
print(f"type of outputs is {type(outputs)}")
On like outputs: jax.Array = mlp.apply(mlp.init(prng_key, inputs), inputs), Pylance shows this error:
Код: Выделить всё
Type "Any | tuple[Any, FrozenVariableDict | dict[str, Any]]" is not assignable to declared type "Array"
Type "Any | tuple[Any, FrozenVariableDict | dict[str, Any]]" is not assignable to type "Array"
"tuple[Any, FrozenVariableDict | dict[str, Any]]" is not assignable to "Array"
PylancereportAssignmentType
Код: Выделить всё
import jax
import jax.numpy as jnp
x: jax.Array = jnp.zeros((2, 2))
print(f"x's type is {type(x)}")
Подробнее здесь: https://stackoverflow.com/questions/796 ... -correctly