1

I have followed this tutorial https://www.youtube.com/watch?v=U0s0f995w14 to create a minified version of a transformer architecture but I am confused about the final shape of the output.

Heres the code: https://github.com/aladdinpersson/Machine-Learning-Collection/blob/558557c7989f0b10fee6e8d8f953d7269ae43d4f/ML/Pytorch/more_advanced/transformer_from_scratch/transformer_from_scratch.py#L2

on the final lines (slightly modified):

print(x.shape)
print(trg[:, :-1].shape)
out = model(x, trg[:, :-1])
print(out.shape)

The output shapes don't seem to make sense

torch.Size([2, 9]) #input sentence (num_examples, num_tokens)
torch.Size([2, 7]) #generated sentence so far (num_examples, num_tokens_generated)
torch.Size([2, 7, 10]) # probabilities for next token (num_examples, ???, size_vocab)

The transformer is supposed to predict the next token across 2 training examples (which is why theres a two for the number of examples and a 10 for the size of the vocab), by generating probabilites for each token in the vocab. But I can't make sense of why theres a 7 there. The only explanation I can come up with is that it outputs all predictions simulatenously, but that would require feeding the outputs iteratively through the transformer, but that never happens (see lines 267-270).

So is there a mistake or am I not understanding something correctly? What is that output shape supposed to represent?

Can somebody make sense of this?

1 Answers1

1

7 is the length of the target sequence passed as argument to the model, which is trg[:, :-1], that is, the target sequence except the last token. The last token is removed because it contains either the end-of-sequence token of the longest sentence in the batch or padding tokens of the shorter sequences in the batch, and therefore it is useless.

The output of the decoder is of the same length as its input. The shape of trg[:, :-1] is [2, 7], so the shape of the output is the same.

Note that in the video they are invoking the model in an unusual way, because they are passing a whole target sequence to the model but they are not training it. Normally, the model would be used in one of the following ways:

  • In training mode, receives a full target sequence and its output is used to compute the loss and update the network weights via gradient descent.
  • In inference mode, is used auto-regressively, that is, we decode token by token, incorporating each new predicted token into the input for the next step.

I guess they used the model this way just to illustrate that the model works.

noe
  • 28,203
  • 1
  • 49
  • 83