Рассчитать формулы в PyTorch, используя матрицуPython

Программы на Python
Ответить
Anonymous
 Рассчитать формулы в PyTorch, используя матрицу

Сообщение Anonymous »

У меня есть уравнения:

Код: Выделить всё

$e_{ij} = \frac{X_i W^Q (X_j W^K + A^K_{ij}) }{\sqrt{D_z}}$
$\alpha_{ij} = softmax(e_{ij})$
$z_{i} = \sum_j \alpha_{ij} (X_j W^V + A^V_{ij})$
где размеры:

Код: Выделить всё

X: [B, S, H,D]
each W: [H,D,D]
each A: [S, S, H,D]
как я могу вычислить это с помощью матричных операций?
у меня есть частичное решение

Код: Выделить всё

import torch
import torch.nn.functional as F

B, S, H, D = X.shape
d_z = D  # Assuming d_z is equal to D for simplicity

W_Q = torch.randn(H, D, D)
W_K = torch.randn(H, D, D)
W_V = torch.randn(H, D, D)

a_K = torch.randn(S, S, H, D)
a_V = torch.randn(S, S, H, D)
}
XW_Q = torch.einsum('bshd,hde->bshe', X, W_Q)  # [B, S, H, D] @ [H, D, D] -> [B, S, H, D]
XW_K = torch.einsum('bshd,hde->bshe', X, W_K)  # [B, S, H, D] @ [H, D, D] -> [B, S, H, D]

e_ij_numerator = XW_Q.unsqueeze(2) @ (XW_K.unsqueeze(1) + a_K).transpose(-1, -2)  # [B, S, 1, H, D] @ [B, 1, S, H, D] -> [B, S, S, H, D]
e_ij = e_ij_numerator / torch.sqrt(torch.tensor(d_z, dtype=torch.float32))  # [B, S, S, H, D]

XW_V = torch.einsum('bshd,hde->bshe', X, W_V)  # [B, S, H, D] @ [H, D, D] -> [B, S, H, D]
alpha = F.softmax(e_ij, dim=2)  # [B, S, S, H, D]

z_i = torch.einsum('bshij,bshjd->bshid', alpha, XW_V.unsqueeze(1) + a_V)  # [B, S, S, H, D] @ [B, 1, S, H, D] -> [B, S, S, H, D]
но z должно быть [B, S, H,D]


Подробнее здесь: https://stackoverflow.com/questions/793 ... ing-matrix
Ответить

Быстрый ответ

Изменение регистра текста: 
Смайлики
:) :( :oops: :roll: :wink: :muza: :clever: :sorry: :angel: :read: *x)
Ещё смайлики…
   
К этому ответу прикреплено по крайней мере одно вложение.

Если вы не хотите добавлять вложения, оставьте поля пустыми.

Максимально разрешённый размер вложения: 15 МБ.

Вернуться в «Python»