Код: Выделить всё
nn.ModuleList
Я работаю с pytorch ModuleList, как описано ниже,
Код: Выделить всё
decision_modules = nn.ModuleList([nn.Linear(768, 768) for i in range(10)])
Теперь для каждой точки входных данных в мини-пакете из 32 точек данных мы хочу выбрать 4 модуля решений из списка Decision_modules. 4 механизма принятия решений из Decision_engine выбираются с использованием списка индексов, как описано ниже.
У меня есть матрица индексов измерений ind. Матрица ind имеет размерность torch.randint(0,10,(4,4)).
Я хочу найти решение без использования циклов поскольку циклы значительно замедляют выполнение.
Но следующий код выдает ошибку.
Код: Выделить всё
import torch
import torch.nn as nn
linears = nn.ModuleList([nn.Linear(768, 768) for i in range(10)])
ind=torch.randint(0,10,(4,4))
input=torch.rand(32,768)
out=linears[ind](input)
Файл ~\AppData\Local\Programs\Python\Python312\Lib\site -packages\torch\nn\modules\container.py:334, в ModuleList.getitem(self, idx)
332 return self.класс(list(self._modules.values())[idx])
333 else:
--> 334 return self._modules[self._get_abs_string_index(idx )]
Файл ~\AppData\Local\Programs\Python\Python312\Lib\site-packages\torch\nn\modules\container.py:314, в ModuleList._get_abs_string_index(self, idx)
312 def _get_abs_string_index(self, idx) ):
313 """Получить абсолютный индекс списка модулей."""
--> 314 idx = оператор.index(idx)
315 если нет (-len(self)
Любая помощь будет очень полезна.
Подробнее здесь: https://stackoverflow.com/questions/793 ... ules-of-nn