Как предотвратить влияние определенных входов на определенные результаты нейронных сетей в Pytorch?Python

Программы на Python
Ответить Пред. темаСлед. тема
Anonymous
 Как предотвратить влияние определенных входов на определенные результаты нейронных сетей в Pytorch?

Сообщение Anonymous »

У меня есть модель LSTM, которая получает 5 входов для прогнозирования 3 выходов: < /p>
import torch
import torch.nn as nn

class LstmModel(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(CustomLSTMModel, self).__init__()
self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
self.fc = nn.Linear(hidden_size, output_size)

def forward(self, x):
None
< /code>
Я хочу не дать определенным вводу оказать какое -либо влияние на определенный выход. Допустим, первый вход не должен влиять на прогноз второго вывода. Другими словами, второй прогноз не должен быть функцией первого входа.class LstmModel(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(CustomLSTMModel, self).__init__()
self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
self.fc = nn.Linear(hidden_size, output_size)

def forward(self, x):
# Assume x is of shape (batch_size, seq_length, input_size)
# Split inputs
input1, input2, input3, input4, input5 = x.split(1, dim=2)

# Mask inputs for each output
# For output1, exclude input2
input1_for_output1 = torch.cat((input1, input3, input4, input5), dim=2)

# For output2, exclude input3
input2_for_output2 = torch.cat((input1, input2, input4, input5), dim=2)

# For output3, exclude input4
input3_for_output3 = torch.cat((input1, input2, input3, input5), dim=2)

# Process through LSTM
_, (hn, _) = self.lstm(input1_for_output1)
output1 = self.fc(hn[-1])

_, (hn, _) = self.lstm(input2_for_output2)
output2 = self.fc(hn[-1])

_, (hn, _) = self.lstm(input3_for_output3)
output3 = self.fc(hn[-1])

return output1, output2, output3
< /code>
Проблема с этим подходом заключается в том, что для запуска модели требуется как минимум в 3 раза больше (поскольку я запускаю LSTM 3 раза, 1 для каждого вывода). Можно ли делать то, что я хочу добиться более эффективного, с одним пробежком?

Подробнее здесь: https://stackoverflow.com/questions/794 ... etworks-in
Реклама
Ответить Пред. темаСлед. тема

Быстрый ответ

Изменение регистра текста: 
Смайлики
:) :( :oops: :roll: :wink: :muza: :clever: :sorry: :angel: :read: *x)
Ещё смайлики…
   
К этому ответу прикреплено по крайней мере одно вложение.

Если вы не хотите добавлять вложения, оставьте поля пустыми.

Максимально разрешённый размер вложения: 15 МБ.

  • Похожие темы
    Ответы
    Просмотры
    Последнее сообщение

Вернуться в «Python»