Plot the transformed (augmented) images in pytorch

I want to use one of the image augmentation techniques (for example rotation or horizontal flip) and apply it to some images of the CIFAR-10 dataset and plot them in PyTorch.

I know that we can use the following code to augmented images:

from torchvision import models, datasets, transforms
from torchvision.datasets import CIFAR10

data_transforms = transforms.Compose([
        # add augmentations
        transforms.RandomHorizontalFlip(p=0.5),
        # The output of torchvision datasets are PILImage images of range [0, 1].
        # We transform them to Tensors of normalized range [-1, 1]
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])

and then I used the transforms above when I want to load the Cifar10 dataset:

train_set = CIFAR10(
    root='./data/',
    train=True,
    download=True,
    transform=data_transforms['train'])

As far as I know, when this code is used, all CIFAR10 datasets are transformed.

Question

My question is how can I use data transform or augmentation techniques for some images in data sets and plot them? for example 10 images and their augmented images.

Answer

when this code is used, all CIFAR10 datasets are transformed

Actually, the transform pipeline will only be called when images in the dataset are fetched via the __getitem__ function by the user or through a data loader. So at this point in time, train_set doesn’t contain augmented images, they are transformed on the fly.


You will need to construct another dataset without augmentations.

>>> non_augmented = CIFAR10(
...     root='./data/',
...     train=True,
...     download=True)

>>> train_set = CIFAR10(
...     root='./data/',
...     train=True,
...     download=True,
...     transform=data_transforms)

Stack some images together:

>>> imgs = torch.stack((*[non_augmented[i][0] for i in range(10)],
                        *[train_set[i][0] for i in range(10)]))

>>> imgs.shape
torch.Size([20, 3, 32, 32])

Then torchvision.utils.make_grid can be useful to create the desired layout:

>>> grid = torchvision.utils.make_grid(imgs, nrow=10)

There you have it!

>>> transforms.ToPILImage()(grid)

enter image description here