Trying to backward through the graph a second time with GANs model

I’m trying to setup a simple GANs training loop but am getting the following error:

RuntimeError: Trying to backward through the graph a second time (or directly access saved variables after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved variables after calling backward.

for epoch in range(N_EPOCHS):
    # gets data for the generator
    for i, batch in enumerate(dataloader, 0):

        # passing target images to the Discriminator
        global_disc.zero_grad()
        output_disc = global_disc(batch.to(device))
        error_target = loss(output_disc, torch.ones(output_disc.shape).cuda())
        error_target.backward()

        # apply mask to the images
        batch = apply_mask(batch)

        # passes fake images to the Discriminator
        global_output, local_output = gen(batch.to(device))
        output_disc = global_disc(global_output.detach())
        error_fake = loss(output_disc, torch.zeros(output_disc.shape).to(device))
        error_fake.backward()

        # combines the errors
        error_total = error_target + error_fake
        optimizer_disc.step()

        # updates the generator
        gen.zero_grad()
        error_gen = loss(output_disc, torch.ones(output_disc.shape).to(device))
        error_gen.backward()
        optimizer_gen.step()

        break
    break

As far as I can tell, I have the operations in the right order, I’m zeroing out the gradients, and I’m detaching the output of the generator before it goes into discriminator.

This article was helpful but I’m still running into something I don’t understand.

Any help is appreciated.

Answer

Two important points come to mind:

  1. You should feed your generator with noise, and not the real input:

    global_output, local_output = gen(noise.to(device))
    

Above noise should have the appropriate shape (it is the input of your generator).

  1. In order to optimize the generator, you are required to recompute the discriminator output, because it has already been backpropagated on. Simply add this line to recompute output_disc:

    # updates the generator
    gen.zero_grad()
    output_disc = global_disc(global_output)
    # ...
    

Please refer to this tutorial provided by PyTorch for a full walkthrough.