Флэш-внимание дает разные результаты для токенов с одинаковыми вложениями?Python

Программы на Python
Ответить Пред. темаСлед. тема
Anonymous
 Флэш-внимание дает разные результаты для токенов с одинаковыми вложениями?

Сообщение Anonymous »

Я учусь интегрировать Flash Attention в свою модель, чтобы ускорить обучение. Я тестирую функцию, чтобы определить лучший способ ее реализации. Однако я столкнулся с проблемой, из-за которой Flash Attention выдает разные результаты для токенов с идентичными встраиваниями. Я не уверен, совершаю ли я принципиальную ошибку или здесь есть что-то еще.
Вот фрагмент кода, который я использую:

Код: Выделить всё

import torch
from flash_attn.modules.mha import FlashSelfAttention

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
fa_attn = FlashSelfAttention(deterministic=True)
fa_attn.eval()

# Assuming batch_size, seq_len, heads, dim = 1, 4, 1, 4
x = torch.tensor([[0.1, 0.1, 0.1, 0.1],
[0.1, 0.1, 0.1, 0.1],
[0.1, 0.1, 0.1, 0.1],
[0.1, 0.1, 0.1, 0.1]])
q = x.unsqueeze(0).unsqueeze(2)
k = q.clone()
v = q.clone()
qkv = torch.stack([q, k, v], dim=2).half().to(device)
output = fa_attn(qkv)
print(output)
результат:

Код: Выделить всё

tensor([[[[0.1000, 0.1000, 0.1000, 0.1000]],
[[0.0757, 0.0757, 0.0757, 0.0757]],
[[0.1000, 0.1000, 0.1000, 0.1000]],
[[0.0757, 0.0757, 0.0757, 0.0757]]]], device='cuda:0', dtype=torch.float16)
Еще один

Код: Выделить всё

x = torch.tensor([[0.1, 0.1, 0.1, 0.1], [0.1, 0.1, 0.1, 0.1], [0.1, 0.1, 0.1, 0.1]])
q = x.unsqueeze(0).unsqueeze(2)
k = q.clone()
v = q.clone()
qkv = torch.stack([q, k, v], dim=2).half().to(device)
output = fa_attn(qkv)
output
результат:

Код: Выделить всё

tensor([[[[ 0.1000,  0.1000,  0.1000,  0.1000]],

[[-0.5483,  0.5166, -0.5483,  0.5166]],

[[ 0.1000,  0.1000,  0.1000,  0.1000]]]], device='cuda:0',
dtype=torch.float16)
Большое спасибо.

Подробнее здесь: https://stackoverflow.com/questions/791 ... embeddings
Реклама
Ответить Пред. темаСлед. тема

Быстрый ответ

Изменение регистра текста: 
Смайлики
:) :( :oops: :roll: :wink: :muza: :clever: :sorry: :angel: :read: *x)
Ещё смайлики…
   
К этому ответу прикреплено по крайней мере одно вложение.

Если вы не хотите добавлять вложения, оставьте поля пустыми.

Максимально разрешённый размер вложения: 15 МБ.

  • Похожие темы
    Ответы
    Просмотры
    Последнее сообщение
  • Элементар Пейдж показывает флэш -флэш -невозможного контента [закрыто]
    Anonymous » » в форуме Php
    0 Ответы
    6 Просмотры
    Последнее сообщение Anonymous
  • Я застрял, пытаясь установить флэш-внимание (Flash-Attn)
    Anonymous » » в форуме Python
    0 Ответы
    12 Просмотры
    Последнее сообщение Anonymous
  • Я застрял, пытаясь установить флэш-внимание (Flash-Attn)
    Anonymous » » в форуме Python
    0 Ответы
    27 Просмотры
    Последнее сообщение Anonymous
  • Optuna: разные результаты даже с одинаковыми random_state
    Anonymous » » в форуме Python
    0 Ответы
    4 Просмотры
    Последнее сообщение Anonymous
  • Xsd2java генерирует классы с одинаковыми именами, поскольку в xsd есть элементы с одинаковыми именами, вложенные друг в
    Anonymous » » в форуме JAVA
    0 Ответы
    61 Просмотры
    Последнее сообщение Anonymous

Вернуться в «Python»