tqdm not updating new set_postfix after last iteration

I want to create a tqdm progress bar that similar to tensorflow.keras for Pytorch training. Here’s my requirements :

  1. For each training step, it will show the progress and the train loss
  2. At the last iteration, it will give additional info of the validation loss

I’m following this tutorial https://towardsdatascience.com/training-models-with-a-progress-a-bar-2b664de3e13e and I manage to fulfill the 1st requirement.

The only lacking feature is to give the validation loss after each training.
Here’s my code :

for epoch in range(EPOCH):
    with tqdm(train_dataloader, unit=" batch") as tepoch:
        train_loss = 0
        val_loss = 0
        
        # Training part
        for idx,batch in enumerate(tepoch) :
            tepoch.set_description(f"Epoch {epoch}")
            <do training stuff>
            train_loss += loss.item()
            tepoch.set_postfix({'Train loss': loss.item()})
         train_loss /= (idx+1)

         # Evaluation part
         with torch.no_grad():
            for idx,batch in enumerate(val_dataloader) :
            <do inference stuff>
            val_loss += loss.item()
         val_loss /= (idx+1)

    tepoch.set_postfix({'Train loss': train_loss,"Val loss":val_loss})

This code with give this :

Epoch 0: 100%|██████████| 188/188 [01:22<00:00,  2.28 batch/s, Train loss=0.511]
Epoch 1: 100%|██████████| 188/188 [01:22<00:00,  2.28 batch/s, Train loss=0.298]

But what I want is :

Epoch 0: 100%|██████████| 188/188 [01:22<00:00,  2.28 batch/s, Train loss=0.511, Val loss={number}]
Epoch 1: 100%|██████████| 188/188 [01:22<00:00,  2.28 batch/s, Train loss=0.298, Val loss={number}]

I have seen this SO tqdm update after last iteration but it seems not feasible to my case since the validation loss is calculated after all the training complete.

Answer

Working example

import random
import time
EPOCH = 100
BATCH_SIZE = 10
for epoch in range(EPOCH):
  with tqdm(total=BATCH_SIZE, unit=" batch") as tepoch:
        tepoch.set_description(f"Epoch {epoch+1}")
        train_loss = 0
        val_loss = 0
        
        # Training part
        for idx,batch in enumerate(range(BATCH_SIZE)) :
            tepoch.update(1)
            # do training stuff
            time.sleep(0.5)
            loss = random.choice(range(10))
            train_loss += loss
            tepoch.set_postfix({'Batch': idx+1, 'Train loss (in progress)': loss})

        train_loss /= (idx+1)

        # Evaluation part
        time.sleep(0.5)
        val_loss += random.choice(range(10))

        val_loss /= (idx+1)

        tepoch.set_postfix({'Train loss (final)': train_loss, 'Val loss': val_loss})
        tepoch.close()

Output

Epoch 1: 100% 10/10 [00:11<00:00, 1.18s/ batch, Train loss (final)=4.4, Val loss=0.5]
Epoch 2: 100% 10/10 [00:06<00:00, 1.62 batch/s, Train loss (final)=4.7, Val loss=0.3]
Epoch 3: 80% 8/10 [00:03<00:00, 2.07 batch/s, Batch=7, Train loss (in progress)=9]