How to remove first N layers from a Keras Model?

I would like to remove the first N layers from the pretrained Keras model. For example, an EfficientNetB0, whose first 3 layers are responsible only for preprocessing:

import tensorflow as tf

efinet = tf.keras.applications.EfficientNetB0(weights=None, include_top=True)

print(efinet.layers[:3])
# [<tensorflow.python.keras.engine.input_layer.InputLayer at 0x7fa9a870e4d0>,
# <tensorflow.python.keras.layers.preprocessing.image_preprocessing.Rescaling at 0x7fa9a61343d0>,
# <tensorflow.python.keras.layers.preprocessing.normalization.Normalization at 0x7fa9a60d21d0>]

As M.Innat mentioned, the first layer is an Input Layer, which should be either spared or re-attached. I would like to remove those layers, but simple approach like this throws error:

cut_input_model = return tf.keras.Model(
    inputs=[efinet.layers[3].input], 
    outputs=efinet.outputs
)

This will result in:

ValueError: Graph disconnected: cannot obtain value for tensor KerasTensor(...)

What would be the recommended way to do this?

Answer

The reason for getting the Graph disconnected error is because you don’t specify the Input layer. But that’s not the main issue here. Sometimes removing the intermediate layer from the keras model is not straightforward with Sequential and Functional API.

For sequential, it comparatively should be easy whereas, in a functional model, you need to care about multi-input blocks (e.g multiply, add etc). For example: if you want to remove some intermediate layer in a sequential model, you can easily adapt this solution. But for the functional model (efficientnet), you can’t because of the multi-input internal blocks and you will encounter this error: ValueError: A merged layer should be called on a list of inputs. So that needs a bit more work AFAIK, here is a possible approach to overcome it.


Here I will show a simple workaround for your case, but it’s probably not general and also unsafe in some cases. That based on this approach; using pop method. Why it can be unsafe to use!. Okay, let’s first load the model.

func_model = tf.keras.applications.EfficientNetB0()

for i, l in enumerate(func_model.layers):
    print(l.name, l.output_shape)
    if i == 8: break

input_19 [(None, 224, 224, 3)]
rescaling_13 (None, 224, 224, 3)
normalization_13 (None, 224, 224, 3)
stem_conv_pad (None, 225, 225, 3)
stem_conv (None, 112, 112, 32)
stem_bn (None, 112, 112, 32)
stem_activation (None, 112, 112, 32)
block1a_dwconv (None, 112, 112, 32)
block1a_bn (None, 112, 112, 32)

Next, using .pop method:

func_model._layers.pop(1) # remove rescaling
func_model._layers.pop(1) # remove normalization

for i, l in enumerate(func_model.layers):
    print(l.name, l.output_shape)
    if i == 8: break

input_22 [(None, 224, 224, 3)]
stem_conv_pad (None, 225, 225, 3)
stem_conv (None, 112, 112, 32)
stem_bn (None, 112, 112, 32)
stem_activation (None, 112, 112, 32)
block1a_dwconv (None, 112, 112, 32)
block1a_bn (None, 112, 112, 32)
block1a_activation (None, 112, 112, 32)
block1a_se_squeeze (None, 32)