I’m trying to setup a simple GANs training loop but am getting the following error:
RuntimeError: Trying to backward through the graph a second time (or directly access saved variables after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved variables after calling backward.
for epoch in range(N_EPOCHS): # gets data for the generator for i, batch in enumerate(dataloader, 0): # passing target images to the Discriminator global_disc.zero_grad() output_disc = global_disc(batch.to(device)) error_target = loss(output_disc, torch.ones(output_disc.shape).cuda()) error_target.backward() # apply mask to the images batch = apply_mask(batch) # passes fake images to the Discriminator global_output, local_output = gen(batch.to(device)) output_disc = global_disc(global_output.detach()) error_fake = loss(output_disc, torch.zeros(output_disc.shape).to(device)) error_fake.backward() # combines the errors error_total = error_target + error_fake optimizer_disc.step() # updates the generator gen.zero_grad() error_gen = loss(output_disc, torch.ones(output_disc.shape).to(device)) error_gen.backward() optimizer_gen.step() break break
As far as I can tell, I have the operations in the right order, I’m zeroing out the gradients, and I’m detaching the output of the generator before it goes into discriminator.
This article was helpful but I’m still running into something I don’t understand.
Any help is appreciated.
Two important points come to mind:
You should feed your generator with noise, and not the real input:
global_output, local_output = gen(noise.to(device))
noise should have the appropriate shape (it is the input of your generator).
In order to optimize the generator, you are required to recompute the discriminator output, because it has already been backpropagated on. Simply add this line to recompute
# updates the generator gen.zero_grad() output_disc = global_disc(global_output) # ...
Please refer to this tutorial provided by PyTorch for a full walkthrough.