compare the precision of a single class in the callback instance in tf.keras.callbacks.Callback

I need a help to design my callback, I have the following architecture:

def CNN_exctractor(input_img):
    l2_loss_lambda = 0.01 # the definintion of l2 regaluraiation 

    l2 = None if l2_loss_lambda is None else regularizers.l2(l2_loss_lambda)
    if l2 is not None:
        print('Using L2 regularization - l2_loss_lambda = %.7f' % l2_loss_lambda)

    conv1 = Conv2D(filters=64, kernel_size=(3,3), padding="same", activation="relu", kernel_regularizer=l2)(input_img)
    conv11 = BatchNormalization()(conv1)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv11)

    
    conv10 = Conv2D(filters=64, kernel_size=(3,3), padding="same", activation="relu", kernel_regularizer=l2)(pool2) 
    conv110 = BatchNormalization()(conv10)
    pool21 = MaxPooling2D(pool_size=(2, 2))(conv110)

         
    conv3 = Conv2D(filters=128, kernel_size=(3,3), padding="same", activation="relu", kernel_regularizer=l2)( pool21)#conv21)
    conv31 = BatchNormalization()(conv3)

    conv5 = Conv2D(filters=256, kernel_size=(3,3), padding="same", activation="relu", kernel_regularizer=l2)( conv31)#conv41)
    conv51 = BatchNormalization()(conv5)

    conv511 = Conv2D(filters=256, kernel_size=(3,3), padding="same", activation="relu", kernel_regularizer=l2)( conv51)#conv41)
    conv5111 = BatchNormalization()(conv511)
    #pool3 = MaxPooling2D(pool_size=(2, 2))(conv51)
    return conv5111

    
    
  
def fc1(enco):
    l2_loss_lambda = 0.01

    l2 = None if l2_loss_lambda is None else regularizers.l2(l2_loss_lambda)
    if l2 is not None:
        print('Using L2 regularization - l2_loss_lambda = %.7f' % l2_loss_lambda)

    flat = Flatten()(enco)

    
    den = Dense(256, activation='relu',kernel_regularizer=l2)(flat)#(den_n)#(den_n)
    den_n= BatchNormalization()(den)


    den1 = Dense(128, activation='relu',kernel_regularizer=l2)(den_n)#(den_n)#(den_n)
    den1_n= BatchNormalization()(den1)

    
    out = Dense(2, activation='softmax')(den1_n)
    return out

As you can see, I have two neurons at the output, I am using this simple code for the callback:

class myCallback(tf.keras.callbacks.Callback): 
    def on_epoch_end(self, epoch, logs={}): 
        if((logs.get('val_accuracy')>= 0.92) and (logs.get('accuracy')>= 0.96) and  ):   
          print("nReached %2.2f%% accuracy, so stopping training!!" %(0.96*100))   
          self.model.stop_training = True

I am comparing both training and validation accuracy, what I want to do, is to instead of comparing the whole validation accuracy, I compare the precision of a single class, something like (if it exists)

logs.get('class_1_precision')>= 0.8

Answer

You can pass your validation data to your callback, and then filter it for the specific class. I don’t know how you’ve structured your validation data but here I’m assuming it’s split into two sets (val_x and val_y). Inside the call back you then get the rows containing the class you need (maybe filtering the val_y for the indices of the class you need and then grabbing the same indices of val_x) – I’ve left this bit up to you.

from sklearn.metrics import precision_score


class myCallback(tf.keras.callbacks.Callback):
    def __init__(self, val_x, val_y):
        super(myCallback, self).__init__()
        self.val_x = val_x
        self.val_y = val_y

    def on_epoch_end(self, epoch, logs={}):
        # Filter validation data for your required class
        val_x_class_1 = #filter self.val_x for your class
        val_y_class_1 = #filter self.val_y for your class>

        # Get predictions for the filtered val data
        class1_scores = self.model.predict(val_x_class_1)

        # Get indices of best predictions - you might need to alter this
        y_pred_class1 = tf.argmax(class1_scores, axis=1)
        y_true_class1 = tf.argmax(val_y_class_1, axis=1)
      
        # Calculate precision
        precision_class1 = precision_score(y_true_class1, y_pred_class1)

        # Rest of your code
        <....> 

To pass the validation data to the callback you’ll need to add something like the below to your fit function:

cbs = myCallback(val_x,val_y)

model.fit(...., callbacks=[cbs])