torch.nn.module

August 11, 2023

A neural network is made up of interconnected building blocks or layers. These building blocks are called modules in PyTorch. Modules are like Lego blocks that can be assembled to create different architectures. Each module performs a specific computation, such as linear transformation or activation. Modules can be combined and stacked to create complex network architectures. They can also contain learnable parameters like weights and biases. During the forward pass, input data flows through the modules and undergoes transformations. During training, the backward pass computes gradients for the parameters using automatic differentiation. These gradients guide the optimization process. PyTorch provides pre-built modules and allows for the creation of custom modules. The optimization process updates the parameters based on the gradients, improving the network's performance.

class Module:
    r"""Base class for all neural network modules.

    Your models should also subclass this class.

    Modules can also contain other Modules, allowing to nest them in
    a tree structure. You can assign the submodules as regular attributes::

        import torch.nn as nn
        import torch.nn.functional as F

        class Model(nn.Module):
            def __init__(self):
                super().__init__()
                self.conv1 = nn.Conv2d(1, 20, 5)
                self.conv2 = nn.Conv2d(20, 20, 5)

            def forward(self, x):
                x = F.relu(self.conv1(x))
                return F.relu(self.conv2(x))

    Submodules assigned in this way will be registered, and will have their
    parameters converted too when you call :meth:`to`, etc.

    .. note::
        As per the example above, an ``__init__()`` call to the parent class
        must be made before assignment on the child.

    :ivar training: Boolean represents whether this module is in training or
                    evaluation mode.
    :vartype training: bool
    """

    dump_patches: bool = False

    _version: int = 1
    r"""This allows better BC support for :meth:`load_state_dict`. In
    :meth:`state_dict`, the version number will be saved as in the attribute
    `_metadata` of the returned state dict, and thus pickled. `_metadata` is a
    dictionary with keys that follow the naming convention of state dict. See
    ``_load_from_state_dict`` on how to use this information in loading.

    If new parameters/buffers are added/removed from a module, this number shall
    be bumped, and the module's `_load_from_state_dict` method can compare the
    version number and do appropriate changes if the state dict is from before
    the change."""

    training: bool
    _parameters: Dict[str, Optional[Parameter]]
    _buffers: Dict[str, Optional[Tensor]]
    _non_persistent_buffers_set: Set[str]
    _backward_pre_hooks: Dict[int, Callable]
    _backward_hooks: Dict[int, Callable]
    _is_full_backward_hook: Optional[bool]
    _forward_hooks: Dict[int, Callable]
    # Marks whether the corresponding _forward_hooks accept kwargs or not.
    # As JIT does not support Set[int], this dict is used as a set, where all
    # hooks represented in this dict accept kwargs.
    _forward_hooks_with_kwargs: Dict[int, bool]
    _forward_pre_hooks: Dict[int, Callable]
    # Marks whether the corresponding _forward_hooks accept kwargs or not.
    # As JIT does not support Set[int], this dict is used as a set, where all
    # hooks represented in this dict accept kwargs.
    _forward_pre_hooks_with_kwargs: Dict[int, bool]
    _state_dict_hooks: Dict[int, Callable]
    _load_state_dict_pre_hooks: Dict[int, Callable]
    _state_dict_pre_hooks: Dict[int, Callable]
    _load_state_dict_post_hooks: Dict[int, Callable]
    _modules: Dict[str, Optional['Module']]
    call_super_init: bool = False

    def __init__(self, *args, **kwargs) -> None:
        """
        Initializes internal Module state, shared by both nn.Module and ScriptModule.
        """
        torch._C._log_api_usage_once("python.nn_module")

        # Backward compatibility: no args used to be allowed when call_super_init=False
        if self.call_super_init is False and bool(kwargs):
            raise TypeError("{}.__init__() got an unexpected keyword argument '{}'"
                            "".format(type(self).__name__, next(iter(kwargs))))

        if self.call_super_init is False and bool(args):
            raise TypeError("{}.__init__() takes 1 positional argument but {} were"
                            " given".format(type(self).__name__, len(args) + 1))

        """
        Calls super().__setattr__('a', a) instead of the typical self.a = a
        to avoid Module.__setattr__ overhead. Module's __setattr__ has special
        handling for parameters, submodules, and buffers but simply calls into
        super().__setattr__ for all other attributes.
        """
        super().__setattr__('training', True)
        super().__setattr__('_parameters', OrderedDict())
        super().__setattr__('_buffers', OrderedDict())
        super().__setattr__('_non_persistent_buffers_set', set())
        super().__setattr__('_backward_pre_hooks', OrderedDict())
        super().__setattr__('_backward_hooks', OrderedDict())
        super().__setattr__('_is_full_backward_hook', None)
        super().__setattr__('_forward_hooks', OrderedDict())
        super().__setattr__('_forward_hooks_with_kwargs', OrderedDict())
        super().__setattr__('_forward_pre_hooks', OrderedDict())
        super().__setattr__('_forward_pre_hooks_with_kwargs', OrderedDict())
        super().__setattr__('_state_dict_hooks', OrderedDict())
        super().__setattr__('_state_dict_pre_hooks', OrderedDict())
        super().__setattr__('_load_state_dict_pre_hooks', OrderedDict())
        super().__setattr__('_load_state_dict_post_hooks', OrderedDict())
        super().__setattr__('_modules', OrderedDict())

        if self.call_super_init:
            super().__init__(*args, **kwargs)

    forward: Callable[..., Any] = _forward_unimplemented
  • Constructor (init) Method: The init method is used to initialize the internal state of the Module class. This method is called when an instance of a class is created. It initializes various attributes and data structures that are used to manage the behavior of the module, such as parameters, buffers, hooks, etc.

  • Attributes and Data Structures:

    • training: A boolean attribute that represents whether the module is in training or evaluation mode.

    • _parameters: A dictionary that holds the learnable parameters of the module.

    • _buffers: A dictionary that holds persistent buffers used in the module.

    • _non_persistent_buffers_set: A set of names for non-persistent buffers.

    • _backward_pre_hooks: A dictionary to store backward hooks that are executed before gradients are computed during backpropagation.

      • Backward Hooks: In PyTorch, hooks are functions that can be registered to execute when a specific event occurs during the computation of a neural network. Backward hooks are functions that execute during the backward pass (backpropagation) of the network, immediately before the gradients are computed. These hooks provide a means of observing, modifying, or recording information about the gradient computation process.

      • Consider the _backward_pre_hooks as a mechanism that enables one to observe the gradient computation process and intervene immediately before the gradients are computed. It's like to having a checkpoint where one can scrutinize and alter the data that is about to be utilized for parameter updates. This feature can prove advantageous in guaranteeing stable training, identifying problems, and exploring new ideas in neural network optimization.

      • _backward_pre_hooks: The _backward_pre_hooks attribute is specifically designed to store backward hooks that are intended to be executed prior to the computation of gradients. These hooks can serve a variety of purposes:

        • Gradient Modification: Backward hooks can be utilized to modify gradients before they are utilized to update parameters. For instance, one may wish to clip gradients to prevent exploding gradients, normalize gradients, or apply other gradient transformations.

        • Monitoring and Logging: Backward hooks can be employed to monitor gradient values at different layers of the network during training. This can aid in understanding how gradients change as they flow backward through the network and diagnosing potential issues.

        • Debugging and Analysis: Backward hooks provide a powerful tool for debugging neural network behavior. Custom code can be inserted to analyze gradients, activations, or any other relevant information as part of the debugging process.

        • Research and Experimentation: Researchers may use backward hooks to experiment with novel gradient-based techniques or to collect additional data during gradient computation for research purposes.

    • _backward_hooks: A dictionary to store backward hooks that are executed after gradients are computed during backpropagation.

    • _is_full_backward_hook: A boolean flag indicating whether the module has a full backward hook.

      • Consider _is_full_backward_hook as a flag that informs PyTorch whether you are interested in observing or modifying the complete set of gradients during the backward pass. If you are not using complete backward hooks, PyTorch can leverage this information to optimize the backpropagation process for efficiency.

      • A full backward hook is a variant of the backward hook that is executed after all the gradients have been computed for the entire backward pass of the network. It provides a means of capturing and working with the entire set of gradients that are calculated during backpropagation across all layers of the network.

      • The _is_full_backward_hook attribute serves to keep track of whether a full backward hook is registered for the module. This is useful for several reasons:

        • Efficient Execution: If no complete backward hook is registered, PyTorch can optimize the backpropagation process to avoid unnecessary computations related to complete backward hooks. This optimization can lead to faster execution of the backward pass, particularly in cases where complete backward hooks are not required.

        • Resource Management: Complete backward hooks may involve additional memory and computational overhead, particularly when dealing with large networks or complex computations. By tracking whether a complete backward hook is registered, PyTorch can manage resources more efficiently.

        • Conditional Behavior: The presence of a complete backward hook can indicate that certain behaviors need to be triggered during backpropagation. For instance, if a complete backward hook is registered, PyTorch may perform additional checks or actions based on this information.

    • _forward_hooks: A dictionary to store forward hooks that are executed during the forward pass.

      • Think of forward hooks as a way to "peek" into the computations happening within each layer of your neural network during the forward pass. It's like placing observation points to capture the state of the data as it transforms through the network's hidden layers.

      • They allow you to inspect or modify the intermediate outputs (activations) of the network as data passes through each layer.

      • Forward hooks provide a way to analyze how data transformations occur layer by layer and can be used for monitoring, debugging, and research.

    • _forward_hooks_with_kwargs: A dictionary that indicates whether the forward hooks accept keyword arguments.

    • _forward_pre_hooks: A dictionary to store forward pre-hooks that are executed before the forward pass.

    • _forward_pre_hooks_with_kwargs: Similar to _forward_hooks_with_kwargs, but for forward pre-hooks.

    • _state_dict_hooks: A dictionary to store hooks related to state dictionary operations.

      • stores hooks (functions) that are executed when the state_dict method is called on a module.

      • These hooks allow you to customize or modify the state dictionary before it is returned by the state_dict method. This can be useful for adding extra information or transformations to the state dictionary.

      • commonly used to save and load the parameters (weights and biases) and other persistent buffers of a PyTorch model. The state dictionary captures the current state of a model, including its learnable parameters and other internal stateful information.

      • Model Checkpoints: During training, you can save checkpoints of your model's state at various intervals. If training is interrupted or if you want to resume training from a specific point, you can load the model's state from the saved state dictionary.

      • Transfer Learning and Fine-Tuning: When using transfer learning, you might want to load pre-trained weights from another model and fine-tune them for a specific task. The state dictionary allows you to transfer parameters between models.

      • use _state_dict_hooks to append extra metadata to the state dictionary, such as model version information or training statistics. This metadata can be helpful when sharing or resuming training.

    • _state_dict_pre_hooks: A dictionary to store pre-hooks for state dictionary operations.

    • _load_state_dict_pre_hooks: A dictionary to store pre-hooks for loading state dictionaries.

      • use _load_state_dict_pre_hooks to perform tasks like adjusting the size of tensors to match changes in the model's architecture or converting parameters to a specific data type.
    • _load_state_dict_post_hooks: A dictionary to store post-hooks for loading state dictionaries.

    • _modules: A dictionary that holds submodules of the current module.

    • Forward Pass (forward Method): The forward method has been declared, however, it has been left as _forward_unimplemented. The purpose of this method is to be overridden by subclasses in order to define the forward pass of the module. It denotes the computation that the module executes when data is transmitted through it.

  • Check out the methods defined in class Module

    • __init__(self): The constructor to initialize the module.

    • forward(self, input): Defines the forward pass of the module. This is where the computation happens.

    • parameters(self): Returns an iterator over module parameters.

    • named_parameters(self): Returns an iterator over module parameters with names.

    • buffers(self): Returns an iterator over module buffers.

    • named_buffers(self): Returns an iterator over module buffers with names.

    • children(self): Returns an iterator over immediate child modules.

    • named_children(self): Returns an iterator over immediate child modules with names.

    • modules(self): Returns an iterator over all submodules, including the current module.

    • named_modules(self): Returns an iterator over all submodules with names, including the current module.

    • to(self, device=None, dtype=None, non_blocking=False): Moves the module and its parameters/buffers to a specified device and dtype.

    • state_dict(self, destination=None, prefix='', keep_vars=False): Returns a dictionary containing the module's state (parameters and buffers).

    • load_state_dict(self, state_dict, strict=True): Loads the module's state from a state dictionary.

    • train(self, mode=True): Sets the module to training mode.

      • Setting the module to training mode is important because it affects the behavior of certain layers (e.g., dropout, batch normalization) during training.

      • During training, the dropout layer is activated, which randomly eliminates units to prevent overfitting. Consequently, a fraction of the units in the dropout layer are set to zero during forward passes.

      • Batch normalization layers utilize batch statistics, including mean and variance, during the forward pass to normalize input data and stabilize training.

      • Gradients are calculated during the backward pass to update the model's parameters, which is crucial for optimization through techniques like stochastic gradient descent (SGD).

      • Regularization techniques, such as weight decay, are typically employed during training to prevent overfitting.

    • eval(self): Sets the module to evaluation mode.

      • During the forward pass in dropout layers, all units are utilized as the layers are deactivated. This guarantees consistent predictions and eliminates randomness during inference.

      • Batch normalization layers employ running statistics that were previously computed during training to normalize input data based on learned statistics during the forward pass.

      • Gradients are not computed during the backward pass as parameters are not updated based on the loss during inference. The sole focus is on making accurate predictions.

      • Regularization techniques such as weight decay are typically disabled during inference as the primary objective is to generate precise predictions rather than preventing overfitting.

    • The differentiation between train and evaluation modes ensures that the neural network behaves appropriately during different phases of its usage. During training, the network adjusts its parameters based on the data, while during evaluation, it concentrates on producing accurate predictions without introducing randomness or updating parameters.