Код: Выделить всё
# Define our network class by using the nn.module
class ResBlockMLP(nn.Module):
def __init__(self, input_size, output_size):
super(ResBlockMLP, self).__init__()
self.norm1 = nn.LayerNorm(input_size)
self.fc1 = nn.Linear(input_size, input_size//2)
self.norm2 = nn.LayerNorm(input_size//2)
self.fc2 = nn.Linear(input_size//2, output_size)
self.fc3 = nn.Linear(input_size, output_size)
self.act = nn.ELU()
def forward(self, x):
x = self.act(self.norm1(x))
skip = self.fc3(x)
x = self.act(self.norm2(self.fc1(x)))
x = self.fc2(x)
return x + skip
class RNN(nn.Module):
def __init__(self, seq_len, output_size, num_blocks=1, buffer_size=128):
super(RNN, self).__init__()
seq_data_len = seq_len * 2
self.input_mlp = nn.Sequential(nn.Linear(seq_data_len, 4 * seq_data_len),
nn.ELU(),
nn.Linear(4 * seq_data_len, 128),
nn.ELU(),)
self.rnn = nn.Linear(256, 128)
blocks = [ResBlockMLP(128, 128) for _ in range(num_blocks)]
self.res_blocks = nn.Sequential(*blocks)
self.fc_out = nn.Linear(128, output_size)
self.fc_buffer = nn.Linear(128, buffer_size)
self.act = nn.ELU()
def forward(self, input_seq, buffer_in):
input_seq = input_seq.reshape(input_seq.shape[0], -1)
input_vec = self.input_mlp(input_seq)
# Concatenate the previous step buffer
x_cat = torch.cat((buffer_in, input_vec), 1)
x = self.rnn(x_cat)
x = self.act(self.res_blocks(x))
return self.fc_out(x), torch.tanh(self.fc_buffer(x))
Код: Выделить всё
data_pred, buffer = gps_rnn(seq_block, buffer)
Код: Выделить всё
File D:\ProgramData\Miniconda_3.9\envs\rnn-sample-py3.9\lib\site-packages\torch\nn\modules\linear.py:114, in Linear.forward(self, input)
113 def forward(self, input: Tensor) -> Tensor:
--> 114 return F.linear(input, self.weight, self.bias)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (32x126 and 1600x6400)
Однако у меня есть проблемы с запуском модуля torchsummary:
Код: Выделить всё
from torchsummary import summary
summary(gps_rnn, (32,14,2))
Также попробовал пакет torchinfo, поскольку он был обновлен, и вместо этого получил следующий результат:< /p>
Код: Выделить всё
from torchinfo import summary
summary(gps_rnn, input_size=(batch_size, 32, 14, 2))
Я также пробовал это решение, чтобы передать два аргумента в сводную функцию, но оно также выдает ошибка: https://pastebin.com/tma4cWyN
Редактировать: проверяя формы блоков, вот что я нашел:
Код: Выделить всё
seq_block: torch.Size([32, 14, 9])
target_seq_block: torch.Size([32, 9])
buffer: torch.Size([32, 128])
input_seq: torch.Size([32, 126])
Подробнее здесь: https://stackoverflow.com/questions/783 ... ry-package