Проблема с запуском RNN и запуском пакета torchsummary.Python

Программы на Python
Ответить Пред. темаСлед. тема
Anonymous
 Проблема с запуском RNN и запуском пакета torchsummary.

Сообщение Anonymous »

Сейчас я тренирую сеть RNN для своего приложения DGPS, измеряя широту, долготу и высоту. Вот архитектура сети:

Код: Выделить всё

# Define our network class by using the nn.module
class ResBlockMLP(nn.Module):
def __init__(self, input_size, output_size):
super(ResBlockMLP, self).__init__()
self.norm1 = nn.LayerNorm(input_size)
self.fc1 = nn.Linear(input_size, input_size//2)

self.norm2 = nn.LayerNorm(input_size//2)
self.fc2 = nn.Linear(input_size//2, output_size)

self.fc3 = nn.Linear(input_size, output_size)

self.act = nn.ELU()

def forward(self, x):
x = self.act(self.norm1(x))
skip = self.fc3(x)

x = self.act(self.norm2(self.fc1(x)))
x = self.fc2(x)

return x + skip

class RNN(nn.Module):
def __init__(self, seq_len, output_size, num_blocks=1, buffer_size=128):
super(RNN, self).__init__()

seq_data_len = seq_len * 2

self.input_mlp = nn.Sequential(nn.Linear(seq_data_len, 4 * seq_data_len),
nn.ELU(),
nn.Linear(4 * seq_data_len, 128),
nn.ELU(),)

self.rnn = nn.Linear(256, 128)

blocks = [ResBlockMLP(128, 128) for _ in range(num_blocks)]
self.res_blocks = nn.Sequential(*blocks)

self.fc_out = nn.Linear(128, output_size)
self.fc_buffer = nn.Linear(128, buffer_size)
self.act = nn.ELU()

def forward(self, input_seq, buffer_in):
input_seq = input_seq.reshape(input_seq.shape[0], -1)
input_vec = self.input_mlp(input_seq)

# Concatenate the previous step buffer
x_cat = torch.cat((buffer_in, input_vec), 1)
x = self.rnn(x_cat)

x  = self.act(self.res_blocks(x))

return self.fc_out(x), torch.tanh(self.fc_buffer(x))
Однако эта строка кода выдает ошибку:

Код: Выделить всё

data_pred, buffer = gps_rnn(seq_block, buffer)
Ошибка:

Код: Выделить всё

File D:\ProgramData\Miniconda_3.9\envs\rnn-sample-py3.9\lib\site-packages\torch\nn\modules\linear.py:114, in Linear.forward(self, input)
113 def forward(self, input: Tensor) -> Tensor:
--> 114     return F.linear(input, self.weight, self.bias)

RuntimeError: mat1 and mat2 shapes cannot be multiplied (32x126 and 1600x6400)
Кто-то посоветовал мне использовать модуль torchsummary, чтобы увидеть, как тензор проходит через вашу сеть. Я получил входную форму (32,14,2), запустив print(seq_block.size()).
Однако у меня есть проблемы с запуском модуля torchsummary:

Код: Выделить всё

from torchsummary import summary
summary(gps_rnn, (32,14,2))
Ошибка: https://pastebin.com/Lt9rZD3y
Также попробовал пакет torchinfo, поскольку он был обновлен, и вместо этого получил следующий результат:< /p>

Код: Выделить всё

from torchinfo import summary
summary(gps_rnn, input_size=(batch_size, 32, 14, 2))
Ошибка: https://pastebin.com/rmuSH0j7
Я также пробовал это решение, чтобы передать два аргумента в сводную функцию, но оно также выдает ошибка: https://pastebin.com/tma4cWyN
Редактировать: проверяя формы блоков, вот что я нашел:

Код: Выделить всё

seq_block: torch.Size([32, 14, 9])
target_seq_block: torch.Size([32, 9])
buffer: torch.Size([32, 128])
input_seq: torch.Size([32, 126])
Как мне решить эту проблему и обучить сеть? Ваша помощь очень ценится.

Подробнее здесь: https://stackoverflow.com/questions/783 ... ry-package
Реклама
Ответить Пред. темаСлед. тема

Быстрый ответ

Изменение регистра текста: 
Смайлики
:) :( :oops: :roll: :wink: :muza: :clever: :sorry: :angel: :read: *x)
Ещё смайлики…
   
К этому ответу прикреплено по крайней мере одно вложение.

Если вы не хотите добавлять вложения, оставьте поля пустыми.

Максимально разрешённый размер вложения: 15 МБ.

  • Похожие темы
    Ответы
    Просмотры
    Последнее сообщение

Вернуться в «Python»