Код: Выделить всё
from typing import NamedTuple, Optional
import jax
class Node(NamedTuple):
child: Optional['Node']
active: bool
node = Node(child=None, active=True)
child_active = jax.lax.cond(node.child is not None, lambda _: node.child.active, lambda _: False, None)
Может ли кто-нибудь объяснить, почему эта ошибка возникает, когда jax.lax.cond следует оценивать только ветку True? Кроме того, каков правильный способ условного доступа к атрибутам необязательного дочернего узла, подобного этому, с помощью JAX?
Подробнее здесь: https://stackoverflow.com/questions/786 ... rue-branch