Как определить параметр «input_size» torchsummary.summary (model = model.policy, input_size = (int, int, int))?Python

Программы на Python
Ответить Пред. темаСлед. тема
Anonymous
 Как определить параметр «input_size» torchsummary.summary (model = model.policy, input_size = (int, int, int))?

Сообщение Anonymous »

Это моя сеть CNN, напечатанная 'print (model.policy)': < /p>
CnnPolicy(
(actor): Actor(
(features_extractor): CustomCNN(
(cnn): Sequential(
(0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
(1): ReLU()
(2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
(3): ReLU()
(4): Flatten(start_dim=1, end_dim=-1)
)
(linear): Sequential(
(0): Linear(in_features=6, out_features=128, bias=True)
(1): ReLU()
)
)
(mu): Sequential(
(0): Linear(in_features=128, out_features=256, bias=True)
(1): ReLU()
(2): Linear(in_features=256, out_features=128, bias=True)
(3): ReLU()
(4): Linear(in_features=128, out_features=3, bias=True)
(5): Tanh()
)
)
< /code>
Когда я пытаюсь распечатать сетевую архитектуру с помощью TorchSummary.summary (model = model.policy, input_size = (1, 32, 32)). Я получил следующую ошибку:
RuntimeError: формы MAT1 и MAT2 не могут быть умножены (2x50176 и 6x128) < /p>
Я пробовал множество комбинаций «input_size», но все были неправильными.>

Подробнее здесь: https://stackoverflow.com/questions/778 ... -model-pol
Реклама
Ответить Пред. темаСлед. тема

Быстрый ответ

Изменение регистра текста: 
Смайлики
:) :( :oops: :roll: :wink: :muza: :clever: :sorry: :angel: :read: *x)
Ещё смайлики…
   
К этому ответу прикреплено по крайней мере одно вложение.

Если вы не хотите добавлять вложения, оставьте поля пустыми.

Максимально разрешённый размер вложения: 15 МБ.

  • Похожие темы
    Ответы
    Просмотры
    Последнее сообщение

Вернуться в «Python»