How to write a forward hook function for nn.Transformer in pytorch?

I have learnt that forward hook function has the form as hook_fn(m,x,y). m refers to model, x refers to input and y refers to output. I want to write a forward hook function for nn.Transformer.
However there are to input for transformer layer which is src and tgt. For example, >>> out = transformer_model(src, tgt). So how can I differ these inputs?


Your hook will call your callback function with tuples for x and y. As described in the documentation page of torch.nn.Module.register_forward_hook (it does quite explain the type of x and y though).

The input contains only the positional arguments given to the module. Keyword arguments won’t be passed to the hooks and only to the forward. […].

model = nn.Transformer(nhead=16, num_encoder_layers=12)
src = torch.rand(10, 32, 512)
tgt = torch.rand(20, 32, 512)

Define your callback:

def hook(module, x, y):
    print(f'is tuple={isinstance(x, tuple)} - length={len(x)}')      
    src, tgt = x
    print(f'src: {src.shape}')
    print(f'tgt: {tgt.shape}')

Hook to your nn.Module:

>>> model.register_forward_hook(hook)

Do an inference:

>>> out = model(src, tgt)
is tuple=True - length=2
src: torch.Size([10, 32, 512])
tgt: torch.Size([20, 32, 512])