RuntimeError: Given groups=1, weight of size [64, 1, 4, 4], expected input[256, 3, 32, 32] to have 1 channels, but got 3 channels instead

Could you help me fix the above error? If I were to load the mnist dataset, there is no error popping up. The error has to do with the dimension of the other datasets, cifar10, fmnist and so on and cannot be run when applied to these sets. Any help appreciated.

# noinspection PyUnresolvedReferences
import os
# imports
# noinspection PyUnresolvedReferences
import pickle
from time import time

from torchvision import datasets, transforms
from torchvision.utils import save_image

import site
site.addsitedir('/content/gw_gan/model') 
from loss import gwnorm_distance, loss_total_variation, loss_procrustes
from model_cnn import Generator, Adversary
from model_cnn import weights_init_generator, weights_init_adversary

# internal imports
from utils import *

# get arguments
args = get_args()

# system preferences
seed = np.random.randint(100)
torch.set_default_dtype(torch.double)
np.random.seed(seed)
torch.manual_seed(seed)

# settings
batch_size = 256
z_dim = 100
lr = 0.0002
ngen = 3
beta = args.beta
lam = 0.5
niter = 10
epsilon = 0.005
num_epochs = args.num_epochs
cuda = args.cuda
channels = args.n_channels
id1 = args.id

model = 'gwgan_{}_eps_{}_tv_{}_procrustes_{}_ngen_{}_channels_{}_{}' 
        .format(args.data, epsilon, lam, beta, ngen, channels, id1)
save_fig_path = 'out_' + model
if not os.path.exists(save_fig_path):
    os.makedirs(save_fig_path)

# data import
dataloader = torch.utils.data.DataLoader(
        datasets.CIFAR10('./data/cifar10', train=True, download=True,
                         transform=transforms.Compose([
                            transforms.ToTensor(),
                            transforms.Normalize((0.5, 0.5, 0.5),
                                                 (0.5, 0.5, 0.5))])),
        batch_size=batch_size, drop_last=True, shuffle=True)

# print example images
save_image(next(iter(dataloader))[0][:25],
           os.path.join(save_fig_path, 'real.pdf'), nrow=5, normalize=True)

# define networks and parameters
generator = Generator(output_dim=channels)
adversary = Adversary(input_dim=channels)

# weight initialisation
generator.apply(weights_init_generator)
adversary.apply(weights_init_adversary)

if cuda:
    generator = generator.cuda()
    adversary = adversary.cuda()

# create optimizer
g_optimizer = torch.optim.Adam(generator.parameters(), lr, betas=(0.5, 0.99))
# zero gradients
generator.zero_grad()

c_optimizer = torch.optim.Adam(adversary.parameters(), lr, betas=(0.5, 0.99))
# zero gradients
adversary.zero_grad()

# sample for plotting
num_test_samples = batch_size
z_ex = torch.randn(num_test_samples, z_dim)
if cuda:
    z_ex = z_ex.cuda()

loss_history = list()
loss_tv = list()
loss_orth = list()
loss_og = 0
is_hist = list()

for epoch in range(num_epochs):
    t0 = time()

    for it, (image, _) in enumerate(dataloader):
        train_c = ((it + 1) % (ngen + 1) == 0)

        x = image.double()
        if cuda:
            x = x.cuda()

        # sample random number z from Z
        z = torch.randn(image.shape[0], z_dim)

        if cuda:
            z = z.cuda()

        if train_c:
            for q in generator.parameters():
                q.requires_grad = False
            for p in adversary.parameters():
                p.requires_grad = True
        else:
            for q in generator.parameters():
                q.requires_grad = True
            for p in adversary.parameters():
                p.requires_grad = False

        # result generator
        g = generator.forward(z)

        # result adversary
        f_x = adversary.forward(x)
        f_g = adversary.forward(g)

        # compute inner distances
        D_g = get_inner_distances(f_g, metric='euclidean', concat=False)
        D_x = get_inner_distances(f_x, metric='euclidean', concat=False)

        # distance matrix normalisation
        D_x_norm = normalise_matrices(D_x)
        D_g_norm = normalise_matrices(D_g)

        # compute normalized gromov-wasserstein distance
        loss, T = gwnorm_distance((D_x, D_x_norm), (D_g, D_g_norm),
                                  epsilon, niter, loss_fun='square_loss',
                                  coupling=True, cuda=cuda)

        if train_c:
            # train adversary
            loss_og = loss_procrustes(f_x, x.view(x.shape[0], -1), cuda)
            loss_to = -loss + beta * loss_og
            loss_to.backward()

            # parameter updates
            c_optimizer.step()
            # zero gradients
            reset_grad(generator, adversary)

        else:
            # train generator
            loss_t = loss_total_variation(g)
            loss_to = loss + lam * loss_t
            loss_to.backward()

            # parameter updates
            g_optimizer.step()
            # zero gradients
            reset_grad(generator, adversary)

    # plotting
    # get generator example
    g_ex = generator.forward(z_ex)
    g_plot = g_ex.cpu().detach()

    # plot result
    save_image(g_plot.data[:25],
               os.path.join(save_fig_path, 'g_%d.pdf' % epoch),
               nrow=5, normalize=True)

    fig1, ax = plt.subplots(1, 3, figsize=(15, 5))
    ax0 = ax[0].imshow(T.cpu().detach().numpy(), cmap='RdBu_r')
    colorbar(ax0)
    ax1 = ax[1].imshow(D_x.cpu().detach().numpy(), cmap='Blues')
    colorbar(ax1)
    ax2 = ax[2].imshow(D_g.cpu().detach().numpy(), cmap='Blues')
    colorbar(ax2)
    ax[0].set_title(r'$T$')
    ax[1].set_title(r'inner distances of $D$')
    ax[2].set_title(r'inner distances of $G$')
    plt.tight_layout(h_pad=1)
    fig1.savefig(os.path.join(save_fig_path, '{}_ccc.pdf'.format(
            str(epoch).zfill(3))), bbox_inches='tight')

    loss_history.append(loss)
    loss_tv.append(loss_t)
    loss_orth.append(loss_og)
    plt.close('all')

# plot loss history
fig2 = plt.figure(figsize=(2.4, 2))
ax2 = fig2.add_subplot(111)
ax2.plot(loss_history, 'k.')
ax2.set_xlabel('Iterations')
ax2.set_ylabel(r'$overline{GW}_epsilon$ Loss')
plt.tight_layout()
plt.grid()
fig2.savefig(save_fig_path + '/loss_history.pdf')

fig3 = plt.figure(figsize=(2.4, 2))
ax3 = fig3.add_subplot(111)
ax3.plot(loss_tv, 'k.')
ax3.set_xlabel('Iterations')
ax3.set_ylabel(r'Total Variation Loss')
plt.tight_layout()
plt.grid()
fig3.savefig(save_fig_path + '/loss_tv.pdf')

fig4 = plt.figure(figsize=(2.4, 2))
ax4 = fig4.add_subplot(111)
ax4.plot(loss_orth, 'k.')
ax4.set_xlabel('Iterations')
ax4.set_ylabel(r'$R_beta(f_omega(X), X)$ Loss')
plt.tight_layout()
plt.grid()
fig4.savefig(save_fig_path + '/loss_orth.pdf')

The error displays:

Traceback (most recent call last):
  File "/content/gw_gan/main_gwgan_cnn.py", line 160, in <module>
    f_x = adversary.forward(x)
  File "/content/gw_gan/model/model_cnn.py", line 62, in forward
    x = self.conv(input)
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/container.py", line 117, in forward
    input = module(input)
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/conv.py", line 423, in forward
    return self._conv_forward(input, self.weight)
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/conv.py", line 420, in _conv_forward
    self.padding, self.dilation, self.groups)
RuntimeError: Given groups=1, weight of size [64, 1, 4, 4], expected input[256, 3, 32, 32] to have 1 channels, but got 3 channels instead

This is for an application of a generative model, where this is a CNN. The reference this is taken from is from main_gwgan_cnn at https://github.com/bunnech/gw_gan. A GAN is proposed to learn from incomparable spaces and produce results.

Answer

You have to set the --n_channels otherwise args.n_chanels will default to 1 as see here. The example given here is for FMNIST which has a single channel. You are running on CIFAR, you should set it to 3 since there are three channels.