How to optimize PyTorch code ?

Reading time ~8 minutes

Optimizing some deep learning code may seem quite complicated. After all, PyTorch is already super optimized so why (and how) one could improve what is already great ?

For the why, there are many reasons:

Why ?

GPU time is expensive

As data scientist, we are used to telling our bosses or investors that we need more money to improve our models (oh look at that beautiful GPU), what if you could halve your training time for free ?

Less training time is more accuracy

The money argument apart, you can tests more models in the same time! In the end, it means that you are more likely to find a better model in a same amount of time.

Optimizing code is super satisfying

Honestly, there is no other way to put it. Instead of optimizing the accuracy of your model, try to optimize its training time. See how good this feels!

Save the planet

Well, it is actually a pity to think that in the example below (taken from a typical training loop for PyTorch), half of the time is simply wasted. We will not degrade the (validation) loss of the model, yet, we will spend less electricity and time to achieve it!

How ?

Well, we will not touch anything that is inside PyTorch, obviously. I just noted that, as data scientist, we may not be too aware of low hanging fruits in terms of performance in our scripts. Here, the tricks will mostly lie into data loading and transformations.

A concrete case

Without further due, let’s start! I will focus on the most common parts of a PyTorch training script, in the case of image recognition. Here the problem is a transfer learning from a pretrained model (resnet34, because it is fast to execute) to a binary classification problem.

The images are satellite data so they have more channels than a usual RGB image.

We will focus on the image loading function:

def load_and_convert_tiff(file_path):
    image = tiff.imread(file_path)
    R = image[:, :, 1]*255*2
    G = image[:, :, 2]*255*2
    B = image[:, :, 3]*255*2
    rgb_image = np.stack((R, G, B), axis=2).astype(np.uint8)
    rgb_image = Image.fromarray(rgb_image)
    return rgb_image

And the Dataset class:

class Sentinel2Dataset(Dataset):
    def __init__(self, file_paths, labels, transform=None):
        self.file_paths = file_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        img_path = self.file_paths[idx]
        image = load_and_convert_tiff(img_path)
        if self.transform is not None:
            image = self.transform(image)
        label = self.labels[idx]
        return image, label

Along with the composition of transformations:

augment_and_transform = transforms.Compose(
            [transforms.Resize(334),
             transforms.ToTensor(),
             transforms.RandomRotation(degrees=90),
             transforms.RandomVerticalFlip(p=0.5),
             transforms.RandomHorizontalFlip(p=0.5),
             transforms.Normalize(
                 mean=[0.485, 0.456, 0.406],
                 std=[0.229, 0.224, 0.225]),])

We have a usual training loop, similar to what PyTorch suggests to do.

In order to keep the article simple, I will not dive into other parts of the code until it becomes necessary. Let’s just assume that the training loop is properly implemented, and we use tqdm to measure the training time of our model.

These results will be our baseline for what follows:

32it [00:11,  2.76it/s]
EPOCH: 0 | LOSS train: 0.511 | LOSS valid: 0.495
32it [00:11,  2.82it/s]
EPOCH: 1 | LOSS train: 0.389 | LOSS valid: 0.405
32it [00:10,  3.00it/s]
EPOCH: 2 | LOSS train: 0.352 | LOSS valid: 0.359

Where 2.76it/s means 2.76 batches pass per second (the dataset is split into 32 batches). This is the number we will focus on. We will also keep an eye on the training and validation losses to make sure we do not break things.

Get rid of the useless

The part rgb_image = Image.fromarray(rgb_image) is actually useless. Some people use it because transforms.Resize() may behave slightly differently on PIL images than on tensors (depending on the parameters you feed to the transform).

We can simply remove it from the function:

def load_and_convert_tiff(file_path):
    image = tiff.imread(file_path)
    R = image[:, :, 1]*255*2
    G = image[:, :, 2]*255*2
    B = image[:, :, 3]*255*2
    rgb_image = np.stack((R, G, B), axis=2).astype(np.uint8)
    return rgb_image

And now we switch the transforms.ToTensor() and transforms.Resize(...) as transforms does not support numpy arrays.

augment_and_transform = transforms.Compose(
	[transforms.ToTensor(),
	 transforms.Resize(334, antialias=False, interpolation=InterpolationMode.NEAREST_EXACT),
	 transforms.RandomRotation(degrees=90),
	 transforms.RandomVerticalFlip(p=0.5),
	 transforms.RandomHorizontalFlip(p=0.5),
	 transforms.Normalize(
	     mean=[0.485, 0.456, 0.406],
	     std=[0.229, 0.224, 0.225]),])

We can rerun our script and…

32it [00:09,  3.45it/s]
EPOCH: 0 | LOSS train: 0.519 | LOSS valid: 0.506 
32it [00:09,  3.43it/s]
EPOCH: 1 | LOSS train: 0.389 | LOSS valid: 0.407 
32it [00:09,  3.38it/s]
EPOCH: 2 | LOSS train: 0.351 | LOSS valid: 0.357 

TADA ! the speed up is already impressive!

Use your RAM if possible!

Note that this advice will not work if your dataset is too large! Here, we have 1000 training images. This is quite low, my machine has 32GB of RAM so I might as well load them once and for all. Besides, it will save my SSD.

Instead of reading the image from the hard drive each time __getitem__ is called, we can make an array of image which will be stored in memory.

In my case, these images only represents 10% of my RAM.

class Sentinel2Dataset(Dataset):
    def __init__(self, file_paths, labels, transform=None):
        self.file_paths = file_paths
        self.images = []
        for file_path in tqdm(self.file_paths):
            image = load_and_convert_tiff(file_path)
            self.images.append(image)

        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        img_path = self.file_paths[idx]
        image = self.images[idx]
        if self.transform is not None:
            image = self.transform(image)
        label = self.labels[idx]
        return image, label

Note the overhead! Indeed, when creating the class, it takes 2 seconds to load all the images in memory.

100%|███████████████████████████████| 993/993 [00:02<00:00, 418.36it/s]
100%|███████████████████████████████| 249/249 [00:00<00:00, 380.70it/s]

But waow, the speedup in training is totally worth it! We are close to halving our initial training time.

32it [00:06,  5.01it/s]
EPOCH: 0 | LOSS train: 0.519 | LOSS valid: 0.506
32it [00:06,  5.31it/s]
EPOCH: 1 | LOSS train: 0.389 | LOSS valid: 0.407
32it [00:06,  5.30it/s]
EPOCH: 2 | LOSS train: 0.351 | LOSS valid: 0.357

Factor the transformations!

The resize happens all the time for the images. So each pass on the whole training set resizes the same image again and again. Let’s get rid of it. Let’s turn the augment transform:

augment_and_transform = transforms.Compose(
	[transforms.ToTensor(),
	 transforms.Resize(334, antialias=False, interpolation=InterpolationMode.NEAREST_EXACT),
	 transforms.RandomRotation(degrees=90),
	 transforms.RandomVerticalFlip(p=0.5),
	 transforms.RandomHorizontalFlip(p=0.5),
	 transforms.Normalize(
	     mean=[0.485, 0.456, 0.406],
	     std=[0.229, 0.224, 0.225]),])

To:

augment_and_transform = transforms.Compose(
	[transforms.RandomRotation(degrees=90),
	 transforms.RandomVerticalFlip(p=0.5),
	 transforms.RandomHorizontalFlip(p=0.5),
	 transforms.Normalize(
	     mean=[0.485, 0.456, 0.406],
	     std=[0.229, 0.224, 0.225]),])

So that only the data augmentation happens here. Now, when loading the images in memory, let’s perform the common transformations:

class Sentinel2Dataset(Dataset):
    def __init__(self, file_paths, labels, transform=None):

        factored_transform = transforms.Compose(
            [transforms.ToTensor(),
             transforms.Resize(334, antialias=False, interpolation=InterpolationMode.NEAREST_EXACT)])

        self.file_paths = file_paths
        self.images = []
        for file_path in tqdm(self.file_paths):
            image = load_and_convert_tiff(file_path)
            transformed_image = factored_transform(image)
            self.images.append(transformed_image)

        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        img_path = self.file_paths[idx]
        image = self.images[idx]
        if self.transform is not None:
            image = self.transform(image)
        label = self.labels[idx]
        return image, label

And this is it, the training time decreased once more:

32it [00:06,  5.12it/s]
EPOCH: 0 | LOSS train: 0.519 | LOSS valid: 0.506
32it [00:05,  5.45it/s]
EPOCH: 1 | LOSS train: 0.389 | LOSS valid: 0.407
32it [00:05,  5.46it/s]
EPOCH: 2 | LOSS train: 0.351 | LOSS valid: 0.357

Play with deterministic / benchmark

Some of you may be familiar with the deterministic and benchmark flags. I usually see this useful function:

def seed_everything(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

If you inverse the booleans, your results may not be the same at every run (the difference should be low though), but you will gain some extra time.

def seed_everything(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True

As the running output shows:

32it [00:07,  4.39it/s]
EPOCH: 0 | LOSS train: 0.519 | LOSS valid: 0.487
32it [00:05,  5.61it/s]
EPOCH: 1 | LOSS train: 0.392 | LOSS valid: 0.410
32it [00:05,  5.63it/s]
EPOCH: 2 | LOSS train: 0.358 | LOSS valid: 0.400

Note that the variation in the loss is worrysome…

Use (try) gradient accumulation

Gradient accumulation seemed promising to me. I read it on this blog and the heuristic seemed interesting.

Besides, as it reduces the number of gradient updates to the model, I expected to gain some performance (and this should be particulary true on larger models).

The recipe consists in turning the training loop:

def train_one_epoch(epoch_index):
total_loss = 0.

for i, data in tqdm(enumerate(training_loader)):

    inputs, labels = data
    inputs = inputs.to(torch.device(device))
    labels = labels.to(torch.device(device))

    optimizer.zero_grad()
    outputs = model(inputs)

    batch_loss = loss_fn(outputs, labels)
    batch_loss.backward()

    optimizer.step()
    total_loss += batch_loss.item()

return total_loss / (i+1)

In this, where the gradient update is performed evervy accum_iter step.

def train_one_epoch(epoch_index):
    total_loss = 0.
    accum_iter = 4
    
    optimizer.zero_grad()
    for i, data in tqdm(enumerate(training_loader)):
    
        inputs, labels = data
        inputs = inputs.to(torch.device(device))
        labels = labels.to(torch.device(device))
    
        outputs = model(inputs)
    
        batch_loss = loss_fn(outputs, labels)
        batch_loss = batch_loss / accum_iter
        batch_loss.backward()
        total_loss += batch_loss.item()
    
        if ((i + 1) % accum_iter == 0) or (i + 1 == len(training_loader)):
    	optimizer.step()
    	optimizer.zero_grad()
    
    return total_loss / (i+1)

But the decrease in performance is too important (maybe I am doing something wrong ?) for no gain in execution time.

32it [00:06,  5.05it/s]
EPOCH: 0 | LOSS train: 0.156 | LOSS valid: 0.612
32it [00:06,  5.33it/s]
EPOCH: 1 | LOSS train: 0.127 | LOSS valid: 0.532
32it [00:06,  5.30it/s]
EPOCH: 2 | LOSS train: 0.110 | LOSS valid: 0.482

More performance with smaller images!

By reducing the size of the image, we can achieve a massive speedup, just note the 228 instead of 334:

factored_transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Resize(228, antialias=False, interpolation=InterpolationMode.NEAREST_EXACT)])

And we are 4 times faster than our initial benchmark!

32it [00:04,  7.78it/s]
EPOCH: 0 | LOSS train: 0.518 | LOSS valid: 0.500
32it [00:02, 10.76it/s]
EPOCH: 1 | LOSS train: 0.393 | LOSS valid: 0.436
32it [00:02, 10.76it/s]
EPOCH: 2 | LOSS train: 0.363 | LOSS valid: 0.393

However, these kind of optimization changes what we actually are doing and seem to harm the loss of the model…

Conclusion

This is it! We almost halved the training time of our model, without harming its performance :) If we allow ourselves to decrease the model performance, we saw that smaller images are actually a way to go much faster.

I do not have other tricks that can easily be used at the moment. I hope you liked this article, do not hesitate to share to your friends, colleagues and on social media !

Learning more

If you are new to machine learning, Deep learning by Ian Goodfellow, Yoshua Bengio, Aaron Courville is an excellent introduction to the topic. The algorithms and mathematics are presented without any code so it will not be outdated as soon as new breaking change is introduced in the main packages ;) *note that this is a sponsored link.

OCaml List rev_map vs map

If you found this page, you are probably very familiar with OCaml already!So, OCaml has a ````map```` function whose purpose is pretty cl...… Continue reading

Acronyms of deep learning

Published on March 10, 2024

AI with OCaml : the tic tac toe game

Published on September 24, 2023