I am trying to create a
transform that shuffles the patches of each image in a batch.
I aim to use it in the same manner as the rest of the transformations in
trans = transforms.Compose([ transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), ShufflePatches(patch_size=(16,16)) # our new transform ])
More specifically, the input is a
BxCxHxW tensor. I want to split each image in the batch into non-overlapping patches of size patch_size, shuffle them, and regroup into a single image.
Given the image (of size
ShufflePatches(patch_size=(112,112)) I would like to produce the output image:
I think the solution has to do with
torch.fold, but didn’t manage to get any further.
Any help would be appreciated!
fold seem appropriate in this case.
import torch import torch.nn.functional as nnf class ShufflePatches(object): def __init__(self, patch_size): self.ps = patch_size def __call__(self, x): # divide the batch of images into non-overlapping patches u = nnf.unfold(x, kernel_size=self.ps, stride=self.ps, padding=0) # permute the patches of each image in the batch pu = torch.cat([b_[:, torch.randperm(b_.shape[-1])][None,...] for b_ in u], dim=0) # fold the permuted patches back together f = nnf.fold(pu, x.shape[-2:], kernel_size=ps, stride=ps, padding=0) return f