Finding the ideal num_workers for Pytorch Dataloaders
One of the biggest bottlenecks in Deep Learning is loading data. having fast drives and access to the data is important, especially if you are trying to saturate a GPU or multiple processors. Pytorch has Dataloaders, which help you manage the task of getting the data into your model. These can be fantastic to use, especially for large datasets as they are very powerful and can handle things such as shuffling of data, batching data, and even memory management. Pytorches Dataloaders also work in parallel, so you can specify a number of "workers", with parameter num_workers, to be loading your data. Figuring out the correct num_workers can be difficult. One thought is you can use the number of CPU cores you have available. In many cases, this works well. Sometimes it's half that number, or one quarter that number. There are a lot of factors such as what else the machine is doing, and the type of data you are working with. The nice thing about Dataloaders is they can be working on loading data while your GPU is processing data. This is one reason why loading data into CPU memory is not a bad idea...........it saves valuable GPU memory and allows your computer to be making use of the CPU and GPU simultaneously.
The best way to go about tackling this is to run a basic test. One thing I can tell you for sure is it is painfully slow to leave num_workers set to default. You should absolutely at least set it to something higher. Using 0 or 1 num_workers can take say 1-2 minutes to load a batch. Having it set correctly can get this down to a few seconds! When you are doing a bunch of interactions in your model this really adds up and can be the primary way you speed up your training.
Here is code I used to benchmark and find my ideal num_workers:
import time
pin_memory = True
print('pin_memory is', pin_memory)
for num_workers in range(0, 20, 1):
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size,
sampler=train_sampler, num_workers=num_workers, pin_memory=pin_memory)
start = time.time()
for epoch in range(1, 5):
for i, data in enumerate(train_loader):
pass
end = time.time()
print("Finish with:{} second, num_workers={}".format(end - start, num_workers))
Here is an example of the output on CIFAR-10 Data:
pin_memory is True
Finish with:38.776511430740356 second, num_workers=0
Finish with:51.23906326293945 second, num_workers=1
Finish with:25.652921199798584 second, num_workers=2
Finish with:18.647804975509644 second, num_workers=3
Finish with:15.828237295150757 second, num_workers=4
Finish with:14.37824296951294 second, num_workers=5
Finish with:13.42700719833374 second, num_workers=6
Finish with:12.521363258361816 second, num_workers=7
Finish with:12.273895263671875 second, num_workers=8
Finish with:12.463818550109863 second, num_workers=9
Finish with:12.434542179107666 second, num_workers=10
Finish with:12.590770721435547 second, num_workers=11
Finish with:12.783716201782227 second, num_workers=12
Finish with:12.741865396499634 second, num_workers=13
Finish with:12.966291427612305 second, num_workers=14
Finish with:13.075541257858276 second, num_workers=15
Finish with:13.420445442199707 second, num_workers=16
Finish with:13.358919143676758 second, num_workers=17
Finish with:13.629449844360352 second, num_workers=18
Finish with:13.735612154006958 second, num_workers=19
Obviously there are a lot of factors that can contribute to the speed in which you load data and this is just one of them. But it is an important one. When you have multiple GPU's it is very important that you can feed them as fast as they can handle, and oftentimes people are falling short of this. You can see from the above output, using at least num_workers=4 is highly beneficial. I have had datasets where setting this parameter much higher was required for a drastic increase. It's always good to check!
Recent Posts
See Allimport org.apache.hadoop.conf.Configuration; import org.apache.hadoop.hbase.HBaseConfiguration; import org.apache.hadoop.hbase.TableName;...
import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.hbase.HBaseConfiguration; import org.apache.hadoop.hbase.TableName;...
コメント