Forward hooks are custom functions that get executed right after the forward pass. Among other things, one can use them together with TensorBoard to visualize activations of any layer.
torch.nn.Module, did you ever wonder what the difference between the
forward and the
__call__ methods is?
One can roughly say that
forward + execution of various hooks.
Hooks are custom functions that get executed at specific moments during the forward/backward phase. They allow us to inspect what is going on inside of the network. The specific use cases are:
- Visualizing (this post)
If you are interested to learn more about hooks checkout the official docs.
Forward hook is a function that accepts 3 arguments
module_instance: Instance of the layer your are attaching the hook to
input: tuple of tensors (or other) that we pass as the input to the forward method
output: tensor (or other) that is the output of the the forward method
Once you define it, you need to "register" the hook with your desired layer via the
Once registered, the hook will be executed right after the forward method. You do not have to worry about triggering it manually!
I created a hands-on video tutorial where I explain step by step how to use forward hooks together with TensorBoard. The goal is to visualize activations of any layer of choice (=creating a histogram of its values for a given sample / batch).
The tutorial does not talk about several related (interesting) topics.
- backward hooks
- forward pre hooks
I would encourage the reader to learn more about them:)
Hooks are hidden gems of PyTorch. Specifically, the forward hooks allow you to debug and visualize what is going on inside of your network. This post provided a first look into what they are and how one can use them.
Cover photo https://unsplash.com/@steve_j