What is the alias to pytorch NN.module in TensorFlow?

I am trying to implement Triplet attention in TensorFlow. One of the question I am facing is what to use in place of NN.module in TensorFlow

class ChannelPool(nn.Module):
    def forward(self, x):
        return torch.cat( (torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1)

What do I put in Place of nn.Module here ?

Answer

In this case, nn.Module is used to create a custom layer. TensorFlow has a tutorial on that, please take a look. In short, one way you could implement it is with tf.keras.layers.Layer, where call is the equivalent of forward in PyTorch:

class ChannelPool(tf.keras.layers.Layer):
    def call(self, inputs):
        return tf.concat((tf.reduce_max(inputs, axis=1, keepdims=True), tf.reduce_mean(inputs, axis=1, keepdims=True)), axis=1)

You can check that they are equivalent like this:

import torch
from torch import nn
import tensorflow as tf
import numpy as np

class PyTorch_ChannelPool(nn.Module):
    def forward(self, x):
        return torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1)

class TensorFlow_ChannelPool(tf.keras.layers.Layer):
    def call(self, inputs):
        return tf.concat((tf.reduce_max(inputs, axis=1, keepdims=True), tf.reduce_mean(inputs, axis=1, keepdims=True)), axis=1)

np.random.seed(2021)
x = np.random.random((1,2,3,4)).astype(np.float32)

a = PyTorch_ChannelPool()
b = TensorFlow_ChannelPool()

pytorch_output = a(torch.from_numpy(x)).numpy()
tensorflow_output = b(x).numpy()

np.all(pytorch_output == tensorflow_output)
# >>> True