Интересная ошибка, вызванная getattrPython

Программы на Python
Ответить Пред. темаСлед. тема
Anonymous
 Интересная ошибка, вызванная getattr

Сообщение Anonymous »

Я пытаюсь одновременно обучить 8 моделей CNN с одинаковыми структурами. После пакетного обучения модели мне нужно синхронизировать веса слоев извлечения признаков в других семи моделях.
Это модель:
class GNet(nn.Module):
def __init__(self, dim_output, dropout=0.5):
super(GNet, self).__init__()
self.out_dim = dim_output
# Load the pretrained AlexNet model
alexnet = models.alexnet(pretrained=True)

self.pre_filtering = nn.Sequential(
alexnet.features[:4]
)

# Set requires_grad to False for all parameters in the pre_filtering network
for param in self.pre_filtering.parameters():
param.requires_grad = False

# construct the feature extractor
# every intermediate feature will be fed to the feature extractor

# res: 25 x 25
self.feat_ex1 = nn.Conv2d(192, 128, kernel_size=3, stride=1)

# res: 25 x 25
self.feat_ex2 = nn.Sequential(
nn.BatchNorm2d(128),
nn.Dropout(p=dropout),
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
)

# res: 25 x 25
self.feat_ex3 = nn.Sequential(
nn.BatchNorm2d(128),
nn.Dropout(p=dropout),
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
)

# res: 13 x 13
self.feat_ex4 = nn.Sequential(
nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(128),
nn.Dropout(p=dropout),
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
)

# res: 13 x 13
self.feat_ex5 = nn.Sequential(
nn.BatchNorm2d(128),
nn.Dropout(p=dropout),
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
)

# res: 13 x 13
self.feat_ex6 = nn.Sequential(
nn.BatchNorm2d(128),
nn.Dropout(p=dropout),
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
)

# res: 13 x 13
self.feat_ex7 = nn.Sequential(
nn.BatchNorm2d(128),
nn.Dropout(p=dropout),
nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1)
)

# define the flexible pooling field of each layer
# use a full convolution layer here to perform flexible pooling
self.fpf13 = nn.Conv2d(in_channels=448, out_channels=448, kernel_size=13, groups=448)
self.fpf25 = nn.Conv2d(in_channels=384, out_channels=384, kernel_size=25, groups=384)
self.linears = {}
for i in range(self.out_dim):
self.linears[f'linear_{i+1}'] = nn.Linear(832, 1)

self.LogTanh = LogTanh()
self.flatten = nn.Flatten()

А это функция синхронизации весов:
def sync_weights(models, current_sub, sync_seqs):
for sub in range(1, 9):
if sub != current_sub:
# Synchronize the specified layers
with torch.no_grad():
for seq_name in sync_seqs:
reference_layer = getattr(models[current_sub], seq_name)[2]
layer = getattr(models[sub], seq_name)[2]
layer.weight.data = reference_layer.weight.data.clone()
if layer.bias is not None:
layer.bias.data = reference_layer.bias.data.clone()

тогда выдается ошибка:
'Conv2d' object is not iterable

что означает, что getattr() возвращает объект Conv2D.
Но если я удалю [2]:
def sync_weights(models, current_sub, sync_seqs):
for sub in range(1, 9):
if sub != current_sub:
# Synchronize the specified layers
with torch.no_grad():
for seq_name in sync_seqs:
reference_layer = getattr(models[current_sub], seq_name)
layer = getattr(models[sub], seq_name)
layer.weight.data = reference_layer.weight.data.clone()
if layer.bias is not None:
layer.bias.data = reference_layer.bias.data.clone()

Я получаю еще одну ошибку:
'Sequential' object has no attribute 'weight'

что означает, что getattr() возвращает Sequential. Но ранее он возвращал объект Conv2D.
Кто-нибудь что-нибудь об этом знает?
Для вашей информации, параметр sync_seqs, передаваемый в sync_weights:
sync_seqs = [
'feat_ex1',
'feat_ex2',
'feat_ex3',
'feat_ex4',
'feat_ex5',
'feat_ex6',
'feat_ex7'
]


Подробнее здесь: https://stackoverflow.com/questions/787 ... by-getattr
Реклама
Ответить Пред. темаСлед. тема

Быстрый ответ

Изменение регистра текста: 
Смайлики
:) :( :oops: :roll: :wink: :muza: :clever: :sorry: :angel: :read: *x)
Ещё смайлики…
   
К этому ответу прикреплено по крайней мере одно вложение.

Если вы не хотите добавлять вложения, оставьте поля пустыми.

Максимально разрешённый размер вложения: 15 МБ.

  • Похожие темы
    Ответы
    Просмотры
    Последнее сообщение
  • Интересная ошибка, вызванная getattr
    Anonymous » » в форуме Python
    0 Ответы
    13 Просмотры
    Последнее сообщение Anonymous
  • Интересная ошибка, вызванная getattr
    Anonymous » » в форуме Python
    0 Ответы
    11 Просмотры
    Последнее сообщение Anonymous
  • Интересная ошибка FileNotFound при запуске pymetamap
    Anonymous » » в форуме Python
    0 Ответы
    28 Просмотры
    Последнее сообщение Anonymous
  • У меня есть интересная ошибка с обнаружением столкновений: скорость игрока удваивается, когда он падает, а столкновение
    Anonymous » » в форуме JAVA
    0 Ответы
    16 Просмотры
    Последнее сообщение Anonymous
  • У меня есть интересная ошибка с обнаружением столкновений: скорость игрока удваивается, когда он падает, а столкновение
    Anonymous » » в форуме JAVA
    0 Ответы
    14 Просмотры
    Последнее сообщение Anonymous

Вернуться в «Python»