I have a tensor holding a batch of permutations of the integers 0 to time-1 which e.g. has the shape
[batch,time]
Now I want to invert all these permutations to get a tensor of the same shape.
I know this can be done using tf.math.invert_permutation for a single tensor of shape [time], but that function does not support batched input. It will through an error if the input tensor has more than one dimension.
What can I do to make tf.math.invert_permutation work with batched input?
