Use PyTorch’s DataLoader with Variable Length Sequences for LSTM/GRU

When I first started using PyTorch to implement recurrent neural networks (RNN), I faced a small issue when I was trying to use DataLoader in conjunction with variable-length sequences. What I specifically wanted to do was to automate the process of distributing training data among multiple graphics cards. Even though there are numerous examples online talking about how to do the actual padding, I couldn’t find any concrete example of using DataLoader in conjunction with padding, and my many-months old question on their forum is still left unanswered!!

The standard way of working with inputs of variable lengths is to pad all the sequences with zeros to make their lengths equal to the length of the largest sequence. This padding is done with the pad_sequence function. PyTorch’s RNN (LSTM, GRU, etc) modules are capable of working with inputs of a padded sequence type and intelligently ignore the zero paddings in the sequence.

If the goal is to train with mini-batches, one needs to pad the sequences in each batch. In other words, given a mini-batch of size N, if the length of the largest sequence is L, one needs to pad every sequence with a length of smaller than L with zeros and make their lengths equal to L. Moreover, it is important that the sequences in the batch are in the descending order.

To do proper padding with DataLoader, we can use the collate_fn argument to specify a class that performs the collation operation, which in our case is zero padding. The following is a minimal example of a collation class that does the padding we need:

Note the importance of batch_first=True in my code above. By default, DataLoader assumes that the first dimension of the data is the batch number. Whereas, PyTorch’s RNN modules, by default, put batch in the second dimension (which I absolutely hate). Fortunately, this behavior can be changed for both the RNN modules and the DataLoader. I personally always prefer to have the batch be the first dimension of the data.

With my code above, DataLoader instance is created as follows:

The last remaining step here is to pass each batch to the RNN module during training/inference. This can be done by using the pack_padded_sequence function as follows:

 

Leave a Reply

Your email address will not be published.