Plot the transformed (augmented) images in pytorch Code Answer

Hello Developer, Hope you guys are doing great. Today at Tutorial Guruji Official website, we are sharing the answer of Plot the transformed (augmented) images in pytorch without wasting too much if your time.

The question is published on by Tutorial Guruji team.

I want to use one of the image augmentation techniques (for example rotation or horizontal flip) and apply it to some images of the CIFAR-10 dataset and plot them in PyTorch.

I know that we can use the following code to augmented images:

from torchvision import models, datasets, transforms
from torchvision.datasets import CIFAR10

data_transforms = transforms.Compose([
        # add augmentations
        # The output of torchvision datasets are PILImage images of range [0, 1].
        # We transform them to Tensors of normalized range [-1, 1]
        transforms.Normalize(mean, std)

and then I used the transforms above when I want to load the Cifar10 dataset:

train_set = CIFAR10(

As far as I know, when this code is used, all CIFAR10 datasets are transformed.


My question is how can I use data transform or augmentation techniques for some images in data sets and plot them? for example 10 images and their augmented images.


when this code is used, all CIFAR10 datasets are transformed

Actually, the transform pipeline will only be called when images in the dataset are fetched via the __getitem__ function by the user or through a data loader. So at this point in time, train_set doesn’t contain augmented images, they are transformed on the fly.

You will need to construct another dataset without augmentations.

>>> non_augmented = CIFAR10(
...     root='./data/',
...     train=True,
...     download=True)

>>> train_set = CIFAR10(
...     root='./data/',
...     train=True,
...     download=True,
...     transform=data_transforms)

Stack some images together:

>>> imgs = torch.stack((*[non_augmented[i][0] for i in range(10)],
                        *[train_set[i][0] for i in range(10)]))

>>> imgs.shape
torch.Size([20, 3, 32, 32])

Then torchvision.utils.make_grid can be useful to create the desired layout:

>>> grid = torchvision.utils.make_grid(imgs, nrow=10)

There you have it!

>>> transforms.ToPILImage()(grid)

enter image description here

We are here to answer your question about Plot the transformed (augmented) images in pytorch - If you find the proper solution, please don't forgot to share this with your team members.

Related Posts

Tutorial Guruji