I'm having trouble writing a custom collate_fn function for the PyTorch DataLoader class. I need the custom function because my inputs have different dimensions.
I'm currently trying to write the baseline implementation of the Stanford MURA paper. The dataset has a set of labeled studies. A study may contain more than one image. I created a custom Datasetclass that stacks these multiple images using torch.stack.
The stacked tensor is then provided as input to the model and the list of outputs is averaged to obtain a single output. This implementation works fine with DataLoader when batch_size=1. However, when I try to set the batch_size to 8, as is the case in the original paper, the DataLoader fails since it uses torch.stack to stack the batch and the inputs in my batch have variable dimensions (since each study can have multiple number of images).
In order to fix this, I tried to implement my custom collate_fn function.
def collate_fn(batch):
imgs = [item['images'] for item in batch]
targets = [item['label'] for item in batch]
targets = torch.LongTensor(targets)
return imgs, targets
Then in my training epoch loop, I loop through each batch like this:
for image, label in zip(*batch):
label = label.type(torch.FloatTensor)
# wrap them in Variable
image = Variable(image).cuda()
label = Variable(label).cuda()
# forward
output = model(image)
output = torch.mean(output)
loss = criterion(output, label, phase)
However, this does not give me any improved timings on the epoch and still takes as long as it did with a batch size of only 1. I've also tried setting the batch size to 32 and that does not improve the timings either.
Am I doing something wrong? Is there a better approach to this?