I am trying to implement the attention described in Luong et al. 2015 in PyTorch myself, but I couldn't get it work. Below is my code, I am only interested in the "general" attention case for now. I wonder if I am missing any obvious error. It runs, but doesn't seem to learn.
class AttnDecoderRNN(nn.Module):
def __init__(self, hidden_size, output_size, dropout_p=0.1):
super(AttnDecoderRNN, self).__init__()
self.hidden_size = hidden_size
self.output_size = output_size
self.dropout_p = dropout_p
self.embedding = nn.Embedding(
num_embeddings=self.output_size,
embedding_dim=self.hidden_size
)
self.dropout = nn.Dropout(self.dropout_p)
self.gru = nn.GRU(self.hidden_size, self.hidden_size)
self.attn = nn.Linear(self.hidden_size, self.hidden_size)
# hc: [hidden, context]
self.Whc = nn.Linear(self.hidden_size * 2, self.hidden_size)
# s: softmax
self.Ws = nn.Linear(self.hidden_size, self.output_size)
def forward(self, input, hidden, encoder_outputs):
embedded = self.embedding(input).view(1, 1, -1)
embedded = self.dropout(embedded)
gru_out, hidden = self.gru(embedded, hidden)
# [0] remove the dimension of directions x layers for now
attn_prod = torch.mm(self.attn(hidden)[0], encoder_outputs.t())
attn_weights = F.softmax(attn_prod, dim=1) # eq. 7/8
context = torch.mm(attn_weights, encoder_outputs)
# hc: [hidden: context]
out_hc = F.tanh(self.Whc(torch.cat([hidden[0], context], dim=1)) # eq.5
output = F.log_softmax(self.Ws(out_hc), dim=1) eq. 6
return output, hidden, attn_weights
I have studied the attention implemented in
https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html
and
https://github.com/spro/practical-pytorch/blob/master/seq2seq-translation/seq2seq-translation.ipynb
- The first one isn't the exact attention mechanism I am looking for. A major disadvantage is that its attention depends on the sequence length (
self.attn = nn.Linear(self.hidden_size * 2, self.max_length)), which could be expensive for long sequences. - The second one is more similar to what's described in the paper, but still not the same as there is not
tanh. Besides, it is really slow after updating it to latest version of pytorch (ref). Also I don't know why it takes the last context (ref).