Я подключаю модуль cross_attention для более быстрого rcnn,Python

Программы на Python
Ответить Пред. темаСлед. тема
Anonymous
 Я подключаю модуль cross_attention для более быстрого rcnn,

Сообщение Anonymous »

Код: Выделить всё

import torch
import torch.nn as nn
from math import sqrt

class CalculateAttention(nn.Module):
def __init__(self):
super().__init__()

def forward(self, Q, K, V):
attention = torch.matmul(Q, torch.transpose(K, -1, -2))
attention = torch.softmax(attention / sqrt(Q.size(-1)), dim=-1)
attention = torch.matmul(attention,V)
return attention

class Multi_CrossAttention(nn.Module):
"""

"""
def __init__(self, hidden_size, all_head_size, head_num):
super().__init__()
self.hidden_size = hidden_size
self.all_head_size = all_head_size
self.num_heads = head_num
self.h_size = all_head_size // head_num

assert all_head_size % head_num == 0

#  W_q, W_k, W_v (hidden_size, all_head_size)
self.linear_q = nn.Linear(hidden_size, all_head_size, bias=False)
self.linear_k = nn.Linear(1024, all_head_size, bias=False)
self.linear_v = nn.Linear(1024, all_head_size, bias=False)
self.linear_output = nn.Linear(all_head_size, hidden_size)

#  normalization
self.norm = sqrt(all_head_size)

def print(self):
print(self.hidden_size, self.all_head_size)
print(self.linear_k, self.linear_q, self.linear_v)

def forward(self, x, y):

"""

"""
batch_size = x.size(0)
# (B, S, D) -proj-> (B, S, D) -split-> (B, S, H, W) -trans-> (B, H, S, W)

# q_s: [batch_size, num_heads, seq_length, h_size]
print(f"x device is {x.device}")
print(f"self.linear_q device is {self.linear_q.weight.device}")

q_s = self.linear_q(x).view(batch_size, -1, self.num_heads, self.h_size).transpose(1, 2)
print("1")

# k_s: [batch_size, num_heads, seq_length, h_size]
k_s = self.linear_k(y).view(batch_size, -1, self.num_heads, self.h_size).transpose(1, 2)

# v_s: [batch_size, num_heads, seq_length, h_size]
v_s = self.linear_v(y).view(batch_size, -1, self.num_heads, self.h_size).transpose(1, 2)

attention = CalculateAttention()(q_s, k_s, v_s)
# attention : [batch_size , seq_length , num_heads * h_size]
attention = attention.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.h_size)

# output : [batch_size , seq_length , hidden_size]
output = self.linear_output(attention)
print(output.shape)

return output

Выше находится модуль внимания

Код: Выделить всё

 prototype_data = prototype_data.to(self.device)
cross_fearures = OrderedDict()
for key in features.keys():
B, D, W, H = features[key].shape
flatten_features = features[key].reshape(B, D, -1).to(self.device)
print(f"flatten_features device is {flatten_features.device}")
print(f"prototype_data device is {prototype_data.device}")

cross_attention = Multi_CrossAttention(flatten_features.shape[2],  W ** 2, 8)

cross_output = cross_attention(flatten_features, prototype_data)
cross_output = cross_output.reshape(cross_output.shape[0], cross_output.shape[1], W, -1)

cross_fearures[key] = cross_output
features = cross_fearures
Выше я вставил модуль внимания.
Когда я запускаю программу, я получаю следующую ошибку

Код: Выделить всё

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument mat2 in method wrapper_CUDA_mm)
Я обнаружил, что вес этой части (как показано ниже) все еще приходится на процессор.

Код: Выделить всё

self.linear_q = nn.Linear(hidden_size, all_head_size, bias=False)
self.linear_k = nn.Linear(1024, all_head_size, bias=False)
self.linear_v = nn.Linear(1024, all_head_size, bias=False)
self.linear_output = nn.Linear(all_head_size, hidden_size)
Но я отправил всю модель в графический процессор,
Я не знаю, почему это происходит, может кто-нибудь мне помочь, спасибо
Теперь я не знаю, почему происходит вышеописанное, может кто-нибудь мне помочь?

Подробнее здесь: https://stackoverflow.com/questions/791 ... aster-rcnn
Реклама
Ответить Пред. темаСлед. тема

Быстрый ответ

Изменение регистра текста: 
Смайлики
:) :( :oops: :roll: :wink: :muza: :clever: :sorry: :angel: :read: *x)
Ещё смайлики…
   
К этому ответу прикреплено по крайней мере одно вложение.

Если вы не хотите добавлять вложения, оставьте поля пустыми.

Максимально разрешённый размер вложения: 15 МБ.

  • Похожие темы
    Ответы
    Просмотры
    Последнее сообщение

Вернуться в «Python»