Dataset with 4D images: expected Byte but found Float

I have some MRI scans that I want to create a custom PyTorch Dataset out of. Each scan is a set of 31 RGB images, so the scans are 4 dimensional (Channels, Depth, Height, Width). The images are .png, and each scan is a folder with 31 images. After loading the scans, I tried passing them through a Conv3D, but I got an error (full traceback at the end):

x = torch.unsqueeze(dataset[0][0], 0)
x.shape  # torch.Size([1, 3, 31, 512, 512])
m = nn.Conv3d(3,12,3)
out = m(x)

RuntimeError: expected scalar type Byte but found Float

How can I solve this error? I think it happens because I load the scans in as a NumPy array of NumPy arrays, but I don’t know how else to do it. How can I load 4D image data into a custom Dataset?

Here’s my custom Dataset class:

import torch
import os
import pandas as pd
from skimage import io
from torch.utils.data import Dataset

class TrainImages(Dataset):
    def __init__(self, csv_file, root_dir, transform=None):
        self.annotations = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, index):
        # The folder containing the images of a scan
        img_path = os.path.join(self.root_dir, str(self.annotations.iloc[index, 0]).zfill(5))
        # Create a tensor out of a numpy array of numpy arrays, where each array is an image in the scan
        image = torch.from_numpy(np.array([np.array(Image.open(os.path.join(str(img_path),"rgb-"+str(i)+".png"))) for i in range(31)]).transpose(3,0,1,2).astype(np.uint8))
        y_label = torch.tensor(int(self.annotations.iloc[index, 1]))
        return (image, y_label)

Full traceback:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-29-f3c4dfbd5496> in <module>
      1 m=nn.Conv3d(3,12,3)
----> 2 out=m(x)

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    725             result = self._slow_forward(*input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(
    729                 _global_forward_hooks.values(),

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/conv.py in forward(self, input)
    571                             self.dilation, self.groups)
    572         return F.conv3d(input, self.weight, self.bias, self.stride,
--> 573                         self.padding, self.dilation, self.groups)
    574 
    575 

RuntimeError: expected scalar type Byte but found Float

Answer

The error message can be confusing, but the problem is that your data has the Byte type, while conv3d expects Float. You need to change from np.uint8 to np.float32 in the __getitem__(...) of your Dataset:

image = torch.from_numpy(np.array([
    np.array(Image.open(os.path.join(str(img_path),"rgb-"+str(i)+".png")))
    for i in range(31)
]).transpose(3, 0, 1, 2).astype(np.float32))  # <<< changed from np.uint8 to float32

or, cast x to Float before passing to the model:

out = m(x.float())

Note that if you use a transform like .ToTensor() later on, this problem will be solved as well.