Shuffle patches in image batch

I am trying to create a transform that shuffles the patches of each image in a batch. I aim to use it in the same manner as the rest of the transformations in torchvision:

trans = transforms.Compose([
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
            ShufflePatches(patch_size=(16,16)) # our new transform
        ])

More specifically, the input is a BxCxHxW tensor. I want to split each image in the batch into non-overlapping patches of size patch_size, shuffle them, and regroup into a single image.

Given the image (of size 224x224):

enter image description here

Using ShufflePatches(patch_size=(112,112)) I would like to produce the output image:

enter image description here

I think the solution has to do with torch.unfold and torch.fold, but didn’t manage to get any further.

Any help would be appreciated!

Answer

Indeed unfold and fold seem appropriate in this case.

import torch
import torch.nn.functional as nnf

class ShufflePatches(object):
  def __init__(self, patch_size):
    self.ps = patch_size

  def __call__(self, x):
    # divide the batch of images into non-overlapping patches
    u = nnf.unfold(x, kernel_size=self.ps, stride=self.ps, padding=0)
    # permute the patches of each image in the batch
    pu = torch.cat([b_[:, torch.randperm(b_.shape[-1])][None,...] for b_ in u], dim=0)
    # fold the permuted patches back together
    f = nnf.fold(pu, x.shape[-2:], kernel_size=ps, stride=ps, padding=0)
    return f

Here’s an example with patch size=16:
enter image description here