ValueError: tf.function-decorated function tried to create variables on non-first call

I have a model written with Model Subclassing API in tensorflow 2. It contains a custom layer. The issue is, in the custom layer, I need to send the channel number of the input tensor to a Conv2D layer in the runtime. Please, see the code below:

Custom Layer

import tensorflow as tf 

class AuxNet(tf.keras.layers.Layer):
    def __init__(self, ratio=8):
        super(AuxNet, self).__init__()
        self.ratio = ratio
        self.avg = tf.keras.layers.GlobalAveragePooling2D()
        self.max = tf.keras.layers.GlobalMaxPooling2D()

    def call(self, inputs):
        avg = self.avg(inputs)
        max = self.max(inputs)
        avg = tf.keras.layers.Reshape((1, 1, avg.shape[1]))(avg)  
        max = tf.keras.layers.Reshape((1, 1, max.shape[1]))(max)   
        
        # ALERT ---------------------
        input_shape = inputs.get_shape().as_list()
        _, h, w, channels = input_shape
        
        conv1a = tf.keras.layers.Conv2D(channels, kernel_size=1, 
                         strides=1, padding='same',use_bias=True, 
                         activation=tf.nn.relu)(avg)
        
        conv1b = tf.keras.layers.Conv2D(channels, kernel_size=1, strides=1, 
                                    padding='same',use_bias=True, 
                                    activation=tf.nn.relu)(max)
        return tf.nn.sigmoid(conv1a + conv1b)

The whole model

class Net(tf.keras.Model):
    def __init__(self, dim):
        super(Net, self).__init__()
        self.base  = tf.keras.layers.Conv2D(124, 3, 1)
        self.gap   = tf.keras.layers.GlobalAveragePooling2D()
        self.aux   = AuxNet() # init the custom layer
        self.dense  = tf.keras.layers.Dense(128, activation=tf.nn.relu)
        self.out   = tf.keras.layers.Dense(10, activation='softmax')
    
    def call(self, input_tensor, training=False):
        x  = self.base(input_tensor)
        
        # Using custom layer on the input tensor
        aux = self.aux(x)*x

        x = self.gap(aux)
        x = self.dense(x)
        return self.out(x)

As you can see, the AuxNet classes contain the Conv2D layer with the filter size of the channel of its input. And the input is nothing but the input of the model class, the Net. When initializing the custom layer in the model class, I couldn’t set the channel number of its Conv2D layer. So, what I did here, I compute the channel number for this Conv2D in the call method of the AuxNet layer, which I believe is bad practice.

This issue brings the runtime problem. I couldn’t compile the Model class in a graph mode but a force to enable eager mode.

import numpy as np
import tensorflow as tf 
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D, Dropout

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()

# train set / data 
x_train = x_train.astype('float32') / 255

# train set / target 
y_train = tf.keras.utils.to_categorical(y_train, num_classes=10)

model = Net((32, 32, 3))

tf.config.run_functions_eagerly(True) # < ----------------

model.compile(
          loss      = tf.keras.losses.CategoricalCrossentropy(),
          metrics   = tf.keras.metrics.CategoricalAccuracy(),
          optimizer = tf.keras.optimizers.Adam())
# fit 
model.fit(x_train, y_train, batch_size=128, epochs=1)

It works, but very slow training. However, without this, the following error occurs,

ValueError: tf.function-decorated function tried to create variables on non-first call.

Any workaround for no need to enable eager mode? How can I efficiently pass the required argument to this custom layer? In that case, I don’t have to compute channel depth in the call method.

Answer

Solved

Basically, I need to look at how to define built-in layers inside the custom layer. It’s advised that all layers should initialize the __init__ method. But we needed the channel depth of the unknown tensor and based on that value the filters number would be set. However, in the build method we can do that easily.

class AuxNet(tf.keras.layers.Layer):
    def __init__(self, ratio=8):
        super(AuxNet, self).__init__()
        self.ratio = ratio
        self.avg = tf.keras.layers.GlobalAveragePooling2D()
        self.max = tf.keras.layers.GlobalMaxPooling2D()

    def build(self, input_shape):
        self.conv1 = tf.keras.layers.Conv2D(input_shape[-1], 
                               kernel_size=1, strides=1, padding='same',
                               use_bias=True, activation=tf.nn.relu)
        self.conv2 = tf.keras.layers.Conv2D(input_shape[-1], 
                                               kernel_size=1, strides=1, padding='same',
                                               use_bias=True, activation=tf.nn.relu)
        
        super(AuxNet, self).build(input_shape)
        
    def call(self, inputs):
        avg = self.avg(inputs)
        max = self.max(inputs)
        avg = tf.keras.layers.Reshape((1, 1, avg.shape[1]))(avg)  
        max = tf.keras.layers.Reshape((1, 1, max.shape[1]))(max)   
        
        conv1a = self.conv1(avg)
        conv1b = self.conv2(max)
        
        return tf.nn.sigmoid(conv1a + conv1b)