How to load a pre-trained PyTorch model?

I’m following this guide on saving and loading checkpoints. However, something is not right. My model would train and the parameters would correctly update during the training phase. However, there seem to be a problem when I load the checkpoints. That is, the parameters are not being updated anymore.

My model:

import torch
import torch.nn as nn
import torch.optim as optim

PATH = 'test.pt'

class model(nn.Module): 
        def __init__(self):
            super(model, self).__init__()
            self.a = torch.nn.Parameter(torch.rand(1, requires_grad=True))
            self.b = torch.nn.Parameter(torch.rand(1, requires_grad=True))
            self.c = torch.nn.Parameter(torch.rand(1, requires_grad=True))
            #print(self.a, self.b, self.c)

        def load(self):
          try:
            checkpoint = torch.load(PATH)  
            print('nloading pre-trained model...')
            self.a = checkpoint['a']
            self.b = checkpoint['b']
            self.c = checkpoint['c']
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            print(self.a, self.b, self.c)
          except: #file doesn't exist yet
            pass

        @property
        def b_opt(self):
            return torch.tanh(self.b)*2

        def train(self):
          print('training...')
          for epoch in range(3):
            print(self.a, self.b, self.c)
            for r in range(5):
              optimizer.zero_grad()
              loss = torch.square(5 * (r > 2) * (3) - model_net.a * torch.sigmoid((r - model_net.b)) * (model_net.c))
              loss.backward(retain_graph=True) #accumulate gradients

            #checkpoint save
            torch.save({
                'model': model_net.state_dict(),
                'a': model_net.a,
                'b': model_net.b,
                'c': model_net.c,
                'optimizer_state_dict': optimizer.state_dict(),
                }, PATH)

            
            optimizer.step() 
          


model_net = model()
optimizer = optim.Adam(model_net.parameters(), lr = 0.1)


print(model_net.a)
print(model_net.b)
print(model_net.c)

This prints

Parameter containing:
tensor([0.4214], requires_grad=True)
Parameter containing:
tensor([0.3862], requires_grad=True)
Parameter containing:
tensor([0.8812], requires_grad=True)

I then run model_net.train() to see that the parameters are being updated and this outputs:

training...
Parameter containing:
tensor([0.9990], requires_grad=True) Parameter containing:
tensor([0.1580], requires_grad=True) Parameter containing:
tensor([0.1517], requires_grad=True)
Parameter containing:
tensor([1.0990], requires_grad=True) Parameter containing:
tensor([0.0580], requires_grad=True) Parameter containing:
tensor([0.2517], requires_grad=True)
Parameter containing:
tensor([1.1974], requires_grad=True) Parameter containing:
tensor([-0.0404], requires_grad=True) Parameter containing:
tensor([0.3518], requires_grad=True)

Running model_net.load() outputs:

loading pre-trained model...
Parameter containing:
tensor([1.1974], requires_grad=True) Parameter containing:
tensor([-0.0404], requires_grad=True) Parameter containing:
tensor([0.3518], requires_grad=True)

And lastly, running model_net.train() again outputs:

training...
Parameter containing:
tensor([1.1974], requires_grad=True) Parameter containing:
tensor([-0.0404], requires_grad=True) Parameter containing:
tensor([0.3518], requires_grad=True)
Parameter containing:
tensor([1.1974], requires_grad=True) Parameter containing:
tensor([-0.0404], requires_grad=True) Parameter containing:
tensor([0.3518], requires_grad=True)
Parameter containing:
tensor([1.1974], requires_grad=True) Parameter containing:
tensor([-0.0404], requires_grad=True) Parameter containing:
tensor([0.3518], requires_grad=True)

Update 1.
Following @jhso suggestion I changed my load to:

def load(self):
  try:
    checkpoint = torch.load(PATH)  
    print('nloading pre-trained model...')
    self.load_state_dict(checkpoint['model'])
    self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    print(self.a, self.b, self.c)
  except: #file doesn't exist yet
    pass

This almost seems to work (the network is training now), but I don’t think the optimizer is loading correctly. That is because it doesn’t go pass the line self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']).
You can see that since it doesn’t print(self.a, self.b, self.c) when I run

model_net.load()

Answer

The way you are loading your data is not the recommended way to load your parameters because you’re overwriting the graph connections (or something along those lines…). You even save the model state_dict, so why not use it!

I changed the load function to:

def load(self):
      try:
        checkpoint = torch.load(PATH)  
        print('nloading pre-trained model...')
        self.load_state_dict(checkpoint['model'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        print(self.a, self.b, self.c)
        self.train()
      except: #file doesn't exist yet
        pass

But note to do this, you have to add your optimizer to your model:

model_net = model()
optimizer = optim.Adam(model_net.parameters(), lr = 0.1)
model_net.optimizer = optimizer

Which then gave the output (running train, load, train):

Parameter containing:
tensor([0.2316], requires_grad=True) Parameter containing:
tensor([0.4561], requires_grad=True) Parameter containing:
tensor([0.8626], requires_grad=True)
Parameter containing:
tensor([0.3316], requires_grad=True) Parameter containing:
tensor([0.3561], requires_grad=True) Parameter containing:
tensor([0.9626], requires_grad=True)
Parameter containing:
tensor([0.4317], requires_grad=True) Parameter containing:
tensor([0.2568], requires_grad=True) Parameter containing:
tensor([1.0620], requires_grad=True)

loading pre-trained model...
Parameter containing:
tensor([0.4317], requires_grad=True) Parameter containing:
tensor([0.2568], requires_grad=True) Parameter containing:
tensor([1.0620], requires_grad=True)
training...
Parameter containing:
tensor([0.4317], requires_grad=True) Parameter containing:
tensor([0.2568], requires_grad=True) Parameter containing:
tensor([1.0620], requires_grad=True)
Parameter containing:
tensor([0.5321], requires_grad=True) Parameter containing:
tensor([0.1577], requires_grad=True) Parameter containing:
tensor([1.1612], requires_grad=True)
Parameter containing:
tensor([0.6328], requires_grad=True) Parameter containing:
tensor([0.0583], requires_grad=True) Parameter containing:
tensor([1.2606], requires_grad=True)