Я изучаю учебник Pytorch SEQ2SEQ: https://pytorch.org/tutorials/Intermedi ... orial.html
У меня есть вопрос о Decoder
. />
class DecoderRNN(nn.Module):
def __init__(self, hidden_size, output_size):
super(DecoderRNN, self).__init__()
self.embedding = nn.Embedding(output_size, hidden_size)
self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True)
self.out = nn.Linear(hidden_size, output_size)
def forward(self, encoder_outputs, encoder_hidden, target_tensor=None):
batch_size = encoder_outputs.size(0)
decoder_input = torch.empty(batch_size, 1, dtype=torch.long, device=device).fill_(SOS_token)
decoder_hidden = encoder_hidden
decoder_outputs = []
for i in range(MAX_LENGTH):
decoder_output, decoder_hidden = self.forward_step(decoder_input, decoder_hidden)
decoder_outputs.append(decoder_output)
if target_tensor is not None:
# Teacher forcing: Feed the target as the next input
decoder_input = target_tensor[:, i].unsqueeze(1) # Teacher forcing
else:
# Without teacher forcing: use its own predictions as the next input
_, topi = decoder_output.topk(1)
decoder_input = topi.squeeze(-1).detach() # detach from history as input
decoder_outputs = torch.cat(decoder_outputs, dim=1)
decoder_outputs = F.log_softmax(decoder_outputs, dim=-1)
return decoder_outputs, decoder_hidden, None # We return `None` for consistency in the training loop
< /code>
Почему «если target_tensor не это»: < /p>
decoder_input = target_tensor[:, i].unsqueeze(1)
< /code>
Но если target_tensor не является: < /p>
_, topi = decoder_output.topk(1)
decoder_input = topi.squeeze(-1).detach()
< /code>
В частности, разве форма decoder_input не отличается в обоих случаях? < /p>
Я чувствую 2D Тенсор, но 1d во втором случае
Спасибо за помощь
Подробнее здесь: https://stackoverflow.com/questions/794 ... al-decoder
Учебный декодер SEQ2SEQ от SEQ2SEQ ⇐ Python
-
- Похожие темы
- Ответы
- Просмотры
- Последнее сообщение