Why call method in Tensorflow is called such a few number of times?

I am writing a wrapper of the Tensorflow Conv1d Layer with the following signature.

And in order to monitor the behavior of the layer, I’ve created a counter in the global scope, initialized to 0.

counter = 0

And then I create following two layer network:

model = tf.keras.Sequential([
    ConvWrapper1D(filters=hidden_dim, kernel_size=5),
    layers.ReLU(),
    ConvWrapper1D(filters=output_dim, kernel_size=5)
])

And run the training procedure for 50 epochs with 128 batches in epoch.

model.fit(x_train, y_train, batch_size=2, epochs=50, verbose=1)

Since the propagation of each batch requires a single forward pass one would expect that the counter stores a value of 50 * 128 * 2 = 12800. However, to my surprise, this counter has a value of 6 after running this line of code.

I have no idea from where does this value come from.

The counting procedure is implement as follows:

class ConvWrapper1D(layers.Conv1D):

def __init__(self, filters, kernel_size, **kwargs):
    super(ConvWrapper1D, self).__init__(
        filters=filters, 
        kernel_size=kernel_size,
        **kwargs
    )
  
def call(self, inputs, **kwargs):
  globals()['counter'] += 1
  # do smth
  return super(ConvWrapper1D, self).call(inputs, **kwargs)

Does Tensorflow perform some optimization under the hood, that skips this step?

I use the latest stable version of Tensorflow '2.6.0'

Answer

The instinct to develop, which TF never describes, is that TF is a compiler. It compiles your python code into some language closer to silicon, but it’s an imperfect compiler – meaning some of the things you’d expect to occur, do not.

Quoting the guide to tracing (which I’d wish they’d just call compiling because that’s what it is):

https://www.tensorflow.org/guide/function#executing_python_side_effects

Side effects, like printing, appending to lists, and mutating globals, can behave unexpectedly inside a Function, sometimes executing twice or not all. They only happen the first time you call a Function with a set of inputs. Afterwards, the traced tf.Graph is reexecuted, without executing the Python code.

net net – to solve your issue – track your count as a tf.variable you create as part of the build method of Layer. See this guide for more details.

https://www.tensorflow.org/tutorials/customization/custom_layers#implementing_custom_layers