У меня возникли проблемы с пониманием этой строки кода pytorch при чтении исходных кодов mamba_ssm. (исходный код здесь:
https://github.com/state-spaces/mamba/b ... _interface. py#L121).
Код: Выделить всё
# Tensor shape
# b: batch_size, d: d_inner, l: sequence_length, n: d_state
# delta: [b,d,l]
# A: [d,n]
...
deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A))
...
4D-тензор deltaA рассчитывается на основе 3D-тензора delta и 2D-тензора A. Я знаю torch.einsum, но не могу понять, какие операции были выполнены для получения deltaA (объяснение) о том, как тензоры были умножены/сложены/транспонированы и т. д.)?
Я могу понять код ниже, поскольку это умножение матриц.
Однако мне пока сложно полностью понять, что именно произошло в исходном коде. Кажется, что это связано с тензорным внешним произведением, но я не уверен.
Подробнее здесь:
https://stackoverflow.com/questions/790 ... -2d-tensor