Какова связь между весом transposeconv и соответствующими ему заданными группами > 1?Python

Программы на Python
Ответить Пред. темаСлед. тема
Anonymous
 Какова связь между весом transposeconv и соответствующими ему заданными группами > 1?

Сообщение Anonymous »

Я моделирую рабочий процесс функции convtranspose в pytorch, поэтому подумываю установить вес сгруппированного convtranspose соответствующему сгруппированному convtranspose, и я протестировал следующим образом:

Код: Выделить всё

import torch
import torch.nn as nn
class SimulatedTransposeConv2d(nn.Module):
def __init__(self, input_channel, output_channel, k_h, k_w, s_w, p_w,groups, weight, bias) -> None:
super().__init__()
self.conv1 = nn.Conv2d(input_channel, output_channel ,kernel_size = (k_h, k_w), stride = (1, 1), groups=groups)
self.output_channel = output_channel
self.input_channel = input_channel
self.pad_insert = s_w - 1
self.pad_sides = k_w - p_w - 1
with torch.no_grad():
weight = weight.flip([-1, -2]).permute(1,0,2,3).reshape(output_channel, input_channel // groups, k_h, k_w)
self.conv1.weight.copy_(weight)
self.conv1.bias.copy_(bias)
def forward(self, x):
x = self.pad(x, self.pad_insert, self.pad_sides)
return self.conv1(x)
def pad(self, x, pad_insert, pad_sides):
'''It's written in C style'''
b = x.shape[0]
c = x.shape[1]
h = x.shape[2]
w = x.shape[3]
pad_w_size = pad_sides * 2 + pad_insert* (w - 1) + w
res = torch.zeros(pad_w_size * b * h * c)
x = x.flatten()
for k in range(c):
for i in range(h):
for j in range(w):
res[k * h * pad_w_size + i * pad_w_size  + j * (pad_insert + 1)+ pad_sides] = x[k * h * w + i * w + j]
return res.reshape(b ,c ,h, pad_w_size)

if __name__ == "__main__":
k_h = 3
k_w = 3
p_w = 0
s_w = 1
groups = 8
tranpose2d = nn.ConvTranspose2d(in_channels = 8, out_channels = 8, kernel_size = (k_h, k_w), stride = (1, s_w), padding = (k_h - 1, p_w), groups=groups)
simulated_transpose2d = SimulatedTransposeConv2d(8, 8, k_h, k_w, s_w, p_w, groups, tranpose2d.state_dict()['weight'], tranpose2d.state_dict()['bias'])
x = torch.rand((1 ,8 ,8, 8)).type(torch.float32)
print((tranpose2d(x) - simulated_transpose2d(x)).abs().mean())

согласно моему коду, если groups=1 или groups = input_channel, результат будет правильным, но если параметр groups не соответствует описанным выше ситуациям (например, input_channel = 8 и groups = 4), результат будет неверным, я думаю, это возможно из-за того, что соотношение весов между convtranspose и conv неверно. Weight_conv=weight_transpose.flip([-1,-2]).permute(1,0,2,3).reshape(output_channel, input_channel // groups, k_h, k_w), это правильный образ мышления ? Если это правда, то какова связь между ними? Спасибо!

Подробнее здесь: https://stackoverflow.com/questions/792 ... rrespondin
Реклама
Ответить Пред. темаСлед. тема

Быстрый ответ

Изменение регистра текста: 
Смайлики
:) :( :oops: :roll: :wink: :muza: :clever: :sorry: :angel: :read: *x)
Ещё смайлики…
   
К этому ответу прикреплено по крайней мере одно вложение.

Если вы не хотите добавлять вложения, оставьте поля пустыми.

Максимально разрешённый размер вложения: 15 МБ.

  • Похожие темы
    Ответы
    Просмотры
    Последнее сообщение

Вернуться в «Python»