Anonymous
Рассчитать формулы в PyTorch, используя матрицу
Сообщение
Anonymous » 13 янв 2025, 01:54
У меня есть уравнения:
Код: Выделить всё
$e_{ij} = \frac{X_i W^Q (X_j W^K + A^K_{ij}) }{\sqrt{D_z}}$
$\alpha_{ij} = softmax(e_{ij})$
$z_{i} = \sum_j \alpha_{ij} (X_j W^V + A^V_{ij})$
где размеры:
Код: Выделить всё
X: [B, S, H,D]
each W: [H,D,D]
each A: [S, S, H,D]
как я могу вычислить это с помощью матричных операций?
у меня есть частичное решение
Код: Выделить всё
import torch
import torch.nn.functional as F
B, S, H, D = X.shape
d_z = D # Assuming d_z is equal to D for simplicity
W_Q = torch.randn(H, D, D)
W_K = torch.randn(H, D, D)
W_V = torch.randn(H, D, D)
a_K = torch.randn(S, S, H, D)
a_V = torch.randn(S, S, H, D)
}
XW_Q = torch.einsum('bshd,hde->bshe', X, W_Q) # [B, S, H, D] @ [H, D, D] -> [B, S, H, D]
XW_K = torch.einsum('bshd,hde->bshe', X, W_K) # [B, S, H, D] @ [H, D, D] -> [B, S, H, D]
e_ij_numerator = XW_Q.unsqueeze(2) @ (XW_K.unsqueeze(1) + a_K).transpose(-1, -2) # [B, S, 1, H, D] @ [B, 1, S, H, D] -> [B, S, S, H, D]
e_ij = e_ij_numerator / torch.sqrt(torch.tensor(d_z, dtype=torch.float32)) # [B, S, S, H, D]
XW_V = torch.einsum('bshd,hde->bshe', X, W_V) # [B, S, H, D] @ [H, D, D] -> [B, S, H, D]
alpha = F.softmax(e_ij, dim=2) # [B, S, S, H, D]
z_i = torch.einsum('bshij,bshjd->bshid', alpha, XW_V.unsqueeze(1) + a_V) # [B, S, S, H, D] @ [B, 1, S, H, D] -> [B, S, S, H, D]
но z должно быть [B, S, H,D]
Подробнее здесь:
https://stackoverflow.com/questions/793 ... ing-matrix
1736722490
Anonymous
У меня есть уравнения: [code]$e_{ij} = \frac{X_i W^Q (X_j W^K + A^K_{ij}) }{\sqrt{D_z}}$ $\alpha_{ij} = softmax(e_{ij})$ $z_{i} = \sum_j \alpha_{ij} (X_j W^V + A^V_{ij})$ [/code] где размеры: [code]X: [B, S, H,D] each W: [H,D,D] each A: [S, S, H,D] [/code] как я могу вычислить это с помощью матричных операций? у меня есть частичное решение [code]import torch import torch.nn.functional as F B, S, H, D = X.shape d_z = D # Assuming d_z is equal to D for simplicity W_Q = torch.randn(H, D, D) W_K = torch.randn(H, D, D) W_V = torch.randn(H, D, D) a_K = torch.randn(S, S, H, D) a_V = torch.randn(S, S, H, D) } XW_Q = torch.einsum('bshd,hde->bshe', X, W_Q) # [B, S, H, D] @ [H, D, D] -> [B, S, H, D] XW_K = torch.einsum('bshd,hde->bshe', X, W_K) # [B, S, H, D] @ [H, D, D] -> [B, S, H, D] e_ij_numerator = XW_Q.unsqueeze(2) @ (XW_K.unsqueeze(1) + a_K).transpose(-1, -2) # [B, S, 1, H, D] @ [B, 1, S, H, D] -> [B, S, S, H, D] e_ij = e_ij_numerator / torch.sqrt(torch.tensor(d_z, dtype=torch.float32)) # [B, S, S, H, D] XW_V = torch.einsum('bshd,hde->bshe', X, W_V) # [B, S, H, D] @ [H, D, D] -> [B, S, H, D] alpha = F.softmax(e_ij, dim=2) # [B, S, S, H, D] z_i = torch.einsum('bshij,bshjd->bshid', alpha, XW_V.unsqueeze(1) + a_V) # [B, S, S, H, D] @ [B, 1, S, H, D] -> [B, S, S, H, D] [/code] но z должно быть [B, S, H,D] Подробнее здесь: [url]https://stackoverflow.com/questions/79350403/calculate-formulas-on-pytorch-using-matrix[/url]