When I first started using PyTorch to implement recurrent neural networks (RNN), I faced a small issue when I was trying to use DataLoader in conjunction with variable-length sequences. What I specifically wanted to do was to automate the process of distributing training data among multiple graphics cards. Even though there are numerous examples online talking about how to do the actual padding, I couldn’t find any concrete example of using DataLoader in conjunction with padding, and my many-months old question on their forum is still left unanswered!!
The standard way of working with inputs of variable lengths is to pad all the sequences with zeros to make their lengths equal to the length of the largest sequence. This padding is done with the pad_sequence function. PyTorch’s RNN (LSTM, GRU, etc) modules are capable of working with inputs of a padded sequence type and intelligently ignore the zero paddings in the sequence.
If the goal is to train with mini-batches, one needs to pad the sequences in each batch. In other words, given a mini-batch of size
N, if the length of the largest sequence is
L, one needs to pad every sequence with a length of smaller than
L with zeros and make their lengths equal to
L. Moreover, it is important that the sequences in the batch are in the descending order.
To do proper padding with DataLoader, we can use the
collate_fn argument to specify a class that performs the collation operation, which in our case is zero padding. The following is a minimal example of a collation class that does the padding we need:
import numpy as np
def __call__(self, batch):
# Let's assume that each element in "batch" is a tuple (data, label).
# Sort the batch in the descending order
sorted_batch = sorted(batch, key=lambda x: x.shape, reverse=True)
# Get each sequence and pad it
sequences = [x for x in sorted_batch]
sequences_padded = torch.nn.utils.rnn.pad_sequence(sequences, batch_first=True)
# Also need to store the length of each sequence
# This is later needed in order to unpad the sequences
lengths = torch.LongTensor([len(x) for x in sequences])
# Don't forget to grab the labels of the *sorted* batch
labels = torch.LongTensor(map(lambda x: x, sorted_batch))
return sequences_padded, lengths, labels
Note the importance of
batch_first=True in my code above. By default, DataLoader assumes that the first dimension of the data is the batch number. Whereas, PyTorch’s RNN modules, by default, put batch in the second dimension (which I absolutely hate). Fortunately, this behavior can be changed for both the RNN modules and the DataLoader. I personally always prefer to have the batch be the first dimension of the data.
With my code above, DataLoader instance is created as follows:
... more arguments ...,
The last remaining step here is to pass each batch to the RNN module during training/inference. This can be done by using the pack_padded_sequence function as follows:
from torch.nn.utils.rnn import pack_padded_sequence as PACK
self.gru = nn.GRU(10, 20, 2, batch_first=True) # Note that "batch_first" is set to "True"
def forward(self, batch):
x, x_lengths, _ = batch
x_pack = PACK(x, x_lengths, batch_first=True)
output, hidden = self.gru(x_pack)