stochastic gradient optimizer

August 14, 2023

stochastic gradient optimizer

The goal of optimization, in the context of training machine learning models, is to find the set of model parameters that minimizes a specific loss function. This objective is rooted in the fundamental principle of machine learning: we want our models to make accurate predictions or classifications, which often involves minimizing the difference between predicted values and actual values. It's designed to minimize the loss function by iteratively adjusting the model's parameters based on the gradient of the loss with respect to those parameters. [ we use stochastic (probabilistic) gradient (small steps) to move to the lowest point in the loss function, where the difference between the predicted value and the true value is at its minimum. ]

For each mini-batch of training data:

  • Compute the loss on the mini-batch using the current model parameters.
  • Compute the gradient of the loss with respect to the model parameters.
  • Update the model parameters in the opposite direction of the gradient scaled by a learning rate.

The randomness introduced by using mini-batches helps the optimization process escape local minima and find faster convergence.

Momentum is used to incorporate past gradient updates to smooth out the optimization path and accelerate convergence.

class Optimizer(object):
    r"""Base class for all optimizers.

    .. warning::
        Parameters need to be specified as collections that have a deterministic
        ordering that is consistent between runs. Examples of objects that don't
        satisfy those properties are sets and iterators over values of dictionaries.

    Arguments:
        params (iterable): an iterable of :class:`torch.Tensor` s or
            :class:`dict` s. Specifies what Tensors should be optimized.
        defaults: (dict): a dict containing default values of optimization
            options (used when a parameter group doesn't specify them).
    """

    def __init__(self, params, defaults):
        torch._C._log_api_usage_once("python.optimizer")
        self.defaults = defaults

        if isinstance(params, torch.Tensor):
            raise TypeError("params argument given to the optimizer should be "
                            "an iterable of Tensors or dicts, but got " +
                            torch.typename(params))

        self.state = defaultdict(dict)
        self.param_groups = []

        param_groups = list(params)
        if len(param_groups) == 0:
            raise ValueError("optimizer got an empty parameter list")
        if not isinstance(param_groups[0], dict):
            param_groups = [{'params': param_groups}]

        for param_group in param_groups:
            self.add_param_group(param_group)

    def __getstate__(self):
        return {
            'defaults': self.defaults,
            'state': self.state,
            'param_groups': self.param_groups,
        }

    def __setstate__(self, state):
        self.__dict__.update(state)

    def __repr__(self):
        format_string = self.__class__.__name__ + ' ('
        for i, group in enumerate(self.param_groups):
            format_string += '\n'
            format_string += 'Parameter Group {0}\n'.format(i)
            for key in sorted(group.keys()):
                if key != 'params':
                    format_string += '    {0}: {1}\n'.format(key, group[key])
        format_string += ')'
        return format_string

    def state_dict(self):
        r"""Returns the state of the optimizer as a :class:`dict`.

        It contains two entries:

        * state - a dict holding current optimization state. Its content
            differs between optimizer classes.
        * param_groups - a dict containing all parameter groups
        """
        # Save order indices instead of Tensors
        param_mappings = {}
        start_index = 0

        def pack_group(group):
            nonlocal start_index
            packed = {k: v for k, v in group.items() if k != 'params'}
            param_mappings.update({id(p): i for i, p in enumerate(group['params'], start_index)
                                   if id(p) not in param_mappings})
            packed['params'] = [param_mappings[id(p)] for p in group['params']]
            start_index += len(packed['params'])
            return packed
        param_groups = [pack_group(g) for g in self.param_groups]
        # Remap state to use order indices as keys
        packed_state = {(param_mappings[id(k)] if isinstance(k, torch.Tensor) else k): v
                        for k, v in self.state.items()}
        return {
            'state': packed_state,
            'param_groups': param_groups,
        }

    def load_state_dict(self, state_dict):
        r"""Loads the optimizer state.

        Arguments:
            state_dict (dict): optimizer state. Should be an object returned
                from a call to :meth:`state_dict`.
        """
        # deepcopy, to be consistent with module API
        state_dict = deepcopy(state_dict)
        # Validate the state_dict
        groups = self.param_groups
        saved_groups = state_dict['param_groups']

        if len(groups) != len(saved_groups):
            raise ValueError("loaded state dict has a different number of "
                             "parameter groups")
        param_lens = (len(g['params']) for g in groups)
        saved_lens = (len(g['params']) for g in saved_groups)
        if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)):
            raise ValueError("loaded state dict contains a parameter group "
                             "that doesn't match the size of optimizer's group")

        # Update the state
        id_map = {old_id: p for old_id, p in
                  zip(chain.from_iterable((g['params'] for g in saved_groups)),
                      chain.from_iterable((g['params'] for g in groups)))}

        def cast(param, value):
            r"""Make a deep copy of value, casting all tensors to device of param."""
            if isinstance(value, torch.Tensor):
                # Floating-point types are a bit special here. They are the only ones
                # that are assumed to always match the type of params.
                if param.is_floating_point():
                    value = value.to(param.dtype)
                value = value.to(param.device)
                return value
            elif isinstance(value, dict):
                return {k: cast(param, v) for k, v in value.items()}
            elif isinstance(value, container_abcs.Iterable):
                return type(value)(cast(param, v) for v in value)
            else:
                return value

        # Copy state assigned to params (and cast tensors to appropriate types).
        # State that is not assigned to params is copied as is (needed for
        # backward compatibility).
        state = defaultdict(dict)
        for k, v in state_dict['state'].items():
            if k in id_map:
                param = id_map[k]
                state[param] = cast(param, v)
            else:
                state[k] = v

        # Update parameter groups, setting their 'params' value
        def update_group(group, new_group):
            new_group['params'] = group['params']
            return new_group
        param_groups = [
            update_group(g, ng) for g, ng in zip(groups, saved_groups)]
        self.__setstate__({'state': state, 'param_groups': param_groups})

    def zero_grad(self, set_to_none: bool = False):
        r"""Sets the gradients of all optimized :class:`torch.Tensor` s to zero.

        Arguments:
            set_to_none (bool): instead of setting to zero, set the grads to None.
                This is will in general have lower memory footprint, and can modestly improve performance.
                However, it changes certain behaviors. For example:
                1. When the user tries to access a gradient and perform manual ops on it,
                a None attribute or a Tensor full of 0s will behave differently.
                2. If the user requests ``zero_grad(set_to_none=True)`` followed by a backward pass, ``.grad``\ s
                are guaranteed to be None for params that did not receive a gradient.
                3. ``torch.optim`` optimizers have a different behavior if the gradient is 0 or None
                (in one case it does the step with a gradient of 0 and in the other it skips
                the step altogether).
        """
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is not None:
                    if set_to_none:
                        p.grad = None
                    else:
                        if p.grad.grad_fn is not None:
                            p.grad.detach_()
                        else:
                            p.grad.requires_grad_(False)
                        p.grad.zero_()

    def step(self, closure):
        r"""Performs a single optimization step (parameter update).

        Arguments:
            closure (callable): A closure that reevaluates the model and
                returns the loss. Optional for most optimizers.

        .. note::
            Unless otherwise specified, this function should not modify the
            ``.grad`` field of the parameters.
        """
        raise NotImplementedError

    def add_param_group(self, param_group):
        r"""Add a param group to the :class:`Optimizer` s `param_groups`.

        This can be useful when fine tuning a pre-trained network as frozen layers can be made
        trainable and added to the :class:`Optimizer` as training progresses.

        Arguments:
            param_group (dict): Specifies what Tensors should be optimized along with group
            specific optimization options.
        """
        assert isinstance(param_group, dict), "param group must be a dict"

        params = param_group['params']
        if isinstance(params, torch.Tensor):
            param_group['params'] = [params]
        elif isinstance(params, set):
            raise TypeError('optimizer parameters need to be organized in ordered collections, but '
                            'the ordering of tensors in sets will change between runs. Please use a list instead.')
        else:
            param_group['params'] = list(params)

        for param in param_group['params']:
            if not isinstance(param, torch.Tensor):
                raise TypeError("optimizer can only optimize Tensors, "
                                "but one of the params is " + torch.typename(param))
            if not param.is_leaf:
                raise ValueError("can't optimize a non-leaf Tensor")

        for name, default in self.defaults.items():
            if default is required and name not in param_group:
                raise ValueError("parameter group didn't specify a value of required optimization parameter " +
                                 name)
            else:
                param_group.setdefault(name, default)

        params = param_group['params']
        if len(params) != len(set(params)):
            warnings.warn("optimizer contains a parameter group with duplicate parameters; "
                          "in future, this will cause an error; "
                          "see github.com/pytorch/pytorch/issues/40967 for more information", stacklevel=3)

        param_set = set()
        for group in self.param_groups:
            param_set.update(set(group['params']))

        if not param_set.isdisjoint(set(param_group['params'])):
            raise ValueError("some parameters appear in more than one parameter group")

        self.param_groups.append(param_group)

The Optimizer class is a base class for all optimizers in PyTorch. Optimizers are used to update the parameters of a model during training in order to minimize a loss function. Here's a brief overview of the methods and their functionalities:

  • __init__(self, params, defaults): Initializes the optimizer with a list of parameter groups and default optimization options. Each parameter group is a dictionary specifying which tensors to optimize and specific optimization options.

  • __getstate__(self): Returns a dictionary containing the optimizer's state.

  • __setstate__(self, state): Sets the optimizer's state from a given dictionary.

  • __repr__(self): Returns a string representation of the optimizer's configuration.

  • state_dict(self): Returns the state of the optimizer as a dictionary. This state can be saved and loaded to resume training.

  • load_state_dict(self, state_dict): Loads the optimizer's state from a dictionary, allowing you to resume training from a saved state.

  • zero_grad(self, set_to_none=False): Sets the gradients of all optimized tensors to zero. Optionally, you can set the gradients to None for lower memory footprint.

  • step(self, closure): Performs a single optimization step, updating the parameters based on gradients. The specific optimization strategy is implemented in the derived optimizer classes (e.g., SGD, Adam).

  • add_param_group(self, param_group): Adds a parameter group to the optimizer. This can be useful when fine-tuning a pre-trained model by gradually unfreezing layers.

  class SGD(Optimizer):
      def __init__(self, params, lr=required, momentum=0, dampening=0,
                   weight_decay=0, nesterov=False, *, maximize: bool = False, foreach: Optional[bool] = None,
                   differentiable: bool = False):
          if lr is not required and lr < 0.0:
              raise ValueError("Invalid learning rate: {}".format(lr))
          if momentum < 0.0:
              raise ValueError("Invalid momentum value: {}".format(momentum))
          if weight_decay < 0.0:
              raise ValueError("Invalid weight_decay value: {}".format(weight_decay))

          defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
                          weight_decay=weight_decay, nesterov=nesterov,
                          maximize=maximize, foreach=foreach,
                          differentiable=differentiable)
          if nesterov and (momentum <= 0 or dampening != 0):
              raise ValueError("Nesterov momentum requires a momentum and zero dampening")
          super().__init__(params, defaults)

      def __setstate__(self, state):
          super().__setstate__(state)
          for group in self.param_groups:
              group.setdefault('nesterov', False)
              group.setdefault('maximize', False)
              group.setdefault('foreach', None)
              group.setdefault('differentiable', False)

      def _init_group(self, group, params_with_grad, d_p_list, momentum_buffer_list):
          has_sparse_grad = False

          for p in group['params']:
              if p.grad is not None:
                  params_with_grad.append(p)
                  d_p_list.append(p.grad)
                  if p.grad.is_sparse:
                      has_sparse_grad = True

                  state = self.state[p]
                  if 'momentum_buffer' not in state:
                      momentum_buffer_list.append(None)
                  else:
                      momentum_buffer_list.append(state['momentum_buffer'])

          return has_sparse_grad


      @_use_grad_for_differentiable
      def step(self, closure=None):
          """Performs a single optimization step.

          Args:
              closure (Callable, optional): A closure that reevaluates the model
                  and returns the loss.
          """
          loss = None
          if closure is not None:
              with torch.enable_grad():
                  loss = closure()

          for group in self.param_groups:
              params_with_grad = []
              d_p_list = []
              momentum_buffer_list = []

              has_sparse_grad = self._init_group(group, params_with_grad, d_p_list, momentum_buffer_list)

              sgd(params_with_grad,
                  d_p_list,
                  momentum_buffer_list,
                  weight_decay=group['weight_decay'],
                  momentum=group['momentum'],
                  lr=group['lr'],
                  dampening=group['dampening'],
                  nesterov=group['nesterov'],
                  maximize=group['maximize'],
                  has_sparse_grad=has_sparse_grad,
                  foreach=group['foreach'])

              # update momentum_buffers in state
              for p, momentum_buffer in zip(params_with_grad, momentum_buffer_list):
                  state = self.state[p]
                  state['momentum_buffer'] = momentum_buffer

          return loss


  def sgd(params: List[Tensor],
          d_p_list: List[Tensor],
          momentum_buffer_list: List[Optional[Tensor]],
          # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
          # setting this as kwarg for now as functional API is compiled by torch/distributed/optim
          has_sparse_grad: bool = None,
          foreach: Optional[bool] = None,
          *,
          weight_decay: float,
          momentum: float,
          lr: float,
          dampening: float,
          nesterov: bool,
          maximize: bool):
      r"""Functional API that performs SGD algorithm computation.

      See :class:`~torch.optim.SGD` for details.
      """

      if foreach is None:
          # why must we be explicit about an if statement for torch.jit.is_scripting here?
          # because JIT can't handle Optionals nor fancy conditionals when scripting
          if not torch.jit.is_scripting():
              _, foreach = _default_to_fused_or_foreach(params, differentiable=False, use_fused=False)
          else:
              foreach = False

      if foreach and torch.jit.is_scripting():
          raise RuntimeError('torch.jit.script not supported with foreach optimizers')

      if foreach and not torch.jit.is_scripting():
          func = _multi_tensor_sgd
      else:
          func = _single_tensor_sgd

      func(params,
           d_p_list,
           momentum_buffer_list,
           weight_decay=weight_decay,
           momentum=momentum,
           lr=lr,
           dampening=dampening,
           nesterov=nesterov,
           has_sparse_grad=has_sparse_grad,
           maximize=maximize)

SGD class

  • def __init__(self, params, lr=required, momentum=0,dampening=0, weight_decay=0, nesterov=False, maximize: bool = False,foreach: Optional[bool]= None, differentiable: bool = False)

    • Learning Rate (lr): The learning rate determines the step size taken in the direction opposite to the gradient during optimization. A higher learning rate can lead to faster convergence, but it may also cause instability or divergence. A smaller learning rate can ensure stability but might slow down convergence. Finding an appropriate learning rate is crucial for effective optimization.

    • Dampening (dampening): Dampening is used to prevent excessive oscillations in the updates by damping the contribution of the momentum term. It's a coefficient applied to the previous momentum term to reduce its impact. A lower dampening value reduces the impact of the momentum term, leading to smoother updates.

    • Maximize (maximize): This parameter determines whether the optimization problem should be treated as maximization (maximize=True) or minimization (maximize=False). If you're maximizing a function, the gradients need to be reversed before applying updates, which is done by negating the gradients.

    • Foreach (foreach): The foreach parameter is used to decide whether to use the multi-tensor implementation (foreach=True) or the single-tensor implementation (foreach=False) of the optimization algorithm. The multi-tensor implementation can offer better performance in certain cases by parallelizing operations across multiple tensors.

    • Differentiable (differentiable): The differentiable parameter is used internally to determine whether the optimizer should be applied to differentiable tensors. This is important for ensuring that the optimizer operates correctly when dealing with tensors that might have complex or non-standard behavior.

    • Weight Decay:

      • Weight decay is a regularization technique used to prevent overfitting during training. It involves adding a penalty term to the loss function based on the magnitude of the model's weights (parameters). The purpose of weight decay is to encourage the model to have smaller parameter values, which can lead to simpler models and better generalization to unseen data.
      • In the context of the SGD optimizer, weight decay is typically applied during the parameter update step. When weight decay is enabled, the gradient of the loss function is modified by subtracting a fraction of the parameter values themselves. This has the effect of pushing the parameter values towards smaller values, effectively implementing a form of L2 regularization.
      • The weight decay hyperparameter controls the strength of this regularization. A higher weight decay value will result in stronger regularization and smaller parameter values.
    • Momentum:

      • Momentum is a technique used to accelerate the convergence of optimization algorithms. It helps the optimizer to "remember" its previous updates and use that information to make more consistent updates in the current iteration. The idea is to give the optimization process some inertia, allowing it to move more smoothly through the optimization landscape and escape local minima.
      • In the context of the SGD optimizer, momentum is introduced by maintaining a moving average of the gradients of the loss function with respect to the parameters. This moving average is referred to as the "momentum buffer." During each iteration, the momentum term is added to the gradient update, which smooths out the gradient trajectory and helps the optimizer avoid getting stuck in sharp or noisy regions.
      • The momentum hyperparameter controls the influence of the momentum term. A higher momentum value means that the optimizer relies more on the historical gradient information and less on the current gradient, which can help avoid oscillations and speed up convergence.
    • Nesterov momentum modifies the way momentum is applied during the parameter update step to improve convergence speed and stability. Standard momentum accumulates the gradient's moving average and then updates the parameters based on this accumulated gradient. Nesterov momentum, on the other hand, calculates the gradient of the loss function not at the current parameter values but at an intermediate point, which is a "look-ahead" point adjusted by the momentum term. This allows the optimizer to take into account the expected momentum-induced update when calculating the gradient, resulting in more accurate updates and faster convergence.

      • Nesterov momentum has been shown to improve convergence for many optimization problems, especially when dealing with ill-conditioned or noisy gradients. It allows the optimizer to anticipate the next update and adjust its trajectory accordingly. However, it's worth noting that Nesterov momentum might not always lead to faster convergence and may require tuning of the momentum hyperparameter.
    • __setstate__(self, state):

      • Restores the state of the optimizer from a saved state.

      • Sets default values for various optimizer parameters in each parameter group.

      • useful when resuming training from a checkpoint or when loading a pre-trained optimizer.

    • _init_group(self, group, params_with_grad, d_p_list, momentum_buffer_list):

      • It iterates through the parameters in the given group and checks if each parameter has a gradient (p.grad is not None).

      • If a parameter has a gradient, it appends the parameter to params_with_grad, the gradient to d_p_list, and the corresponding momentum buffer to momentum_buffer_list.

      • This method helps organize the parameters and gradients for the specific group, preparing them for the optimization step.

      • The has_sparse_grad boolean indicates whether any gradient is sparse in the group, which can affect the optimization process.

    • step(self, closure=None):

      • Performs a single optimization step.

      • closure is an optional closure that reevaluates the model and returns the loss.

      • If closure is provided, it is executed within a context where gradients are enabled, and the loss is computed.

      • For each parameter group:

        • Initializes lists for parameters with gradients, gradient tensors, and momentum buffers.

        • Calls _init_group to populate these lists and checks if there are sparse gradients.

        • Calls the _use_grad_for_differentiable decorator-wrapped function sgd with the gathered information and group-specific parameters.

        • Updates momentum buffers in the optimizer's state for the parameters with gradients.

          • The momentum term in optimization helps to smooth out the update trajectory and accelerates convergence, especially in noisy or ill-conditioned optimization landscapes. The momentum term is calculated as a weighted sum of past gradients. The momentum_buffer is used to store the momentum term's contribution at each optimization step.

          • During each iteration of the optimization, the gradient of the loss function is computed with respect to the parameters. This gradient is then combined with the stored momentum term from the momentum_buffer to determine the update direction. The magnitude and direction of the parameter update are influenced by both the gradient and the momentum term.

          • If the momentum_buffer were not updated or maintained, the optimizer would lose the momentum effect, and the optimization process might become less effective. The parameter updates would only be driven by the current gradient, ignoring the historical momentum contributions. This could lead to slower convergence, especially in situations where the gradient is noisy or the optimization landscape is complex.

        • Returns the computed loss if a closure is provided.

  def _single_tensor_sgd(params: List[Tensor],
                         d_p_list: List[Tensor],
                         momentum_buffer_list: List[Optional[Tensor]],
                         *,
                         weight_decay: float,
                         momentum: float,
                         lr: float,
                         dampening: float,
                         nesterov: bool,
                         maximize: bool,
                         has_sparse_grad: bool):

      for i, param in enumerate(params):
          d_p = d_p_list[i] if not maximize else -d_p_list[i]

          if weight_decay != 0:
              d_p = d_p.add(param, alpha=weight_decay)

          if momentum != 0:
              buf = momentum_buffer_list[i]

              if buf is None:
                  buf = torch.clone(d_p).detach()
                  momentum_buffer_list[i] = buf
              else:
                  buf.mul_(momentum).add_(d_p, alpha=1 - dampening)

              if nesterov:
                  d_p = d_p.add(buf, alpha=momentum)
              else:
                  d_p = buf

          param.add_(d_p, alpha=-lr)
  • _single_tensor_sgd

    • The gradients are adjusted and momentum buffers are updated to facilitate the convergence of the optimization process.

    • Steps:

      • Loop over the list of Parameters and their corresponding gradients

      • Depending on whether we are maximizing or minimizing (controlled by the maximize parameter), we either keep the gradient as is or negate it. For maximization, we negate the gradient to move in the direction of increasing the objective function.

      • apply weight decay : Weight decay encourages smaller parameter values by penalizing large values.

      • If momentum is non-zero, we update the momentum buffer for the current parameter.

        • If the buffer does not exist (buf is None), we create a copy of the gradient tensor and detach it from the computation graph. This buffer is used to accumulate momentum over iterations.

        • If the buffer exists, we update it using a combination of the current gradient (d_p) and the previous momentum buffer. The dampening term controls how much of the new gradient contributes to the buffer.

      • If Nesterov momentum is enabled (nesterov=True), we modify the gradient d_p by adding the momentum-scaled buffer (buf) to it. Nesterov momentum is a variant of traditional momentum that helps to accelerate convergence.

      • If Nesterov momentum is not used, we assign d_p to be the value of the buffer (buf), which represents the accumulated momentum.

      • Finally, we update the parameter using the modified gradient d_p. The alpha parameter controls the step size, which is scaled by the negative learning rate (-lr).

  def _multi_tensor_sgd(params: List[Tensor],
                        grads: List[Tensor],
                        momentum_buffer_list: List[Optional[Tensor]],
                        *,
                        weight_decay: float,
                        momentum: float,
                        lr: float,
                        dampening: float,
                        nesterov: bool,
                        maximize: bool,
                        has_sparse_grad: bool):

      if len(params) == 0:
          return

      grouped_tensors = _group_tensors_by_device_and_dtype([params, grads, momentum_buffer_list], with_indices=True)
      for device_params, device_grads, device_momentum_buffer_list, indices in grouped_tensors.values():
          device_has_sparse_grad = any(grad.is_sparse for grad in device_grads)

          if maximize:
              device_grads = torch._foreach_neg(tuple(device_grads))  # type: ignore[assignment]

          if weight_decay != 0:
              device_grads = torch._foreach_add(device_grads, device_params, alpha=weight_decay)

          if momentum != 0:
              bufs = []

              all_states_with_momentum_buffer = True
              for i in range(len(device_momentum_buffer_list)):
                  if device_momentum_buffer_list[i] is None:
                      all_states_with_momentum_buffer = False
                      break
                  else:
                      bufs.append(device_momentum_buffer_list[i])

              if all_states_with_momentum_buffer:
                  torch._foreach_mul_(bufs, momentum)
                  torch._foreach_add_(bufs, device_grads, alpha=1 - dampening)
              else:
                  bufs = []
                  for i in range(len(device_momentum_buffer_list)):
                      if device_momentum_buffer_list[i] is None:
                          buf = device_momentum_buffer_list[i] = momentum_buffer_list[indices[i]] = \
                              torch.clone(device_grads[i]).detach()
                      else:
                          buf = device_momentum_buffer_list[i]
                          buf.mul_(momentum).add_(device_grads[i], alpha=1 - dampening)

                      bufs.append(buf)

              if nesterov:
                  torch._foreach_add_(device_grads, bufs, alpha=momentum)
              else:
                  device_grads = bufs

          if not device_has_sparse_grad:
              torch._foreach_add_(device_params, device_grads, alpha=-lr)
          else:
              # foreach APIs don't support sparse
              for i in range(len(device_params)):
                  device_params[i].add_(device_grads[i], alpha=-lr)
  • _group_tensors_by_device_and_dtype: groups tensors based on their device and data type to optimize memory access and computation. It returns separate lists of tensors for each device and data type, along with their indices.

  • Steps:

    • iterate over the grouped tensors for each device and data type.

    • gradient direction is negated if maximize is True to handle maximization tasks.

    • If weight decay is non-zero, the weight decay term is added to the gradients.

    • check if momentum buffers exist for all parameters on the current device. If they do, it updates the buffers and applies the momentum term to the gradients. If some buffers are missing, they are initialized and updated.

    • If the device does not have sparse gradients (device_has_sparse_grad), the parameters are updated using the computed gradients, scaled by the negative learning rate (-lr).

    • If the device has sparse gradients, the parameters are updated in a loop to avoid using the torch._foreach API, which doesn't support sparse gradients.

torch.autograd.backward

  • Imagine you're teaching a robot to paint a picture. You have a canvas, some paint colors, and a brush. You also have a rulebook that tells the robot how to mix colors and apply the brush strokes.

  • Setting the Scene: You start by telling the robot to paint a picture. The robot follows the rulebook step by step, mixing colors and making brush strokes. At the end, it has a beautiful painting on the canvas.

  • Checking the Mistakes: You want the robot to improve, so you show it a picture of a perfect painting and ask, "How far is your painting from this ideal one?" The robot looks at its painting and sees where it went wrong. It figures out which brush strokes and color mixes need adjustment.

  • Learning from Mistakes: Now comes the clever part. The robot analyzes its painting and changes the way it mixed colors and applied brush strokes, so it can make the painting closer to the ideal one. It learns from its mistakes and updates its technique.

  • Adjusting the Rulebook: The robot realizes that some steps in the rulebook were better than others. It modifies the rulebook to make the painting process better next time. It's like refining the instructions to get a better result.

  • In PyTorch, when you call backward on a tensor, you're essentially asking the model to learn from its mistakes by updating the rules it used (neural network parameters). It does this by figuring out how much each parameter contributed to the difference between its prediction and the actual result (gradient), and then tweaking the parameters to make the prediction better next time.

  def backward(
      tensors: _TensorOrTensors,
      grad_tensors: Optional[_TensorOrTensors] = None,
      retain_graph: Optional[bool] = None,
      create_graph: bool = False,
      grad_variables: Optional[_TensorOrTensors] = None,
      inputs: Optional[_TensorOrTensors] = None,
  ) -> None:
      r"""Computes the sum of gradients of given tensors with respect to graph
      leaves.

      The graph is differentiated using the chain rule. If any of ``tensors``
      are non-scalar (i.e. their data has more than one element) and require
      gradient, then the Jacobian-vector product would be computed, in this
      case the function additionally requires specifying ``grad_tensors``.
      It should be a sequence of matching length, that contains the "vector"
      in the Jacobian-vector product, usually the gradient of the differentiated
      function w.r.t. corresponding tensors (``None`` is an acceptable value for
      all tensors that don't need gradient tensors).

      This function accumulates gradients in the leaves - you might need to zero
      ``.grad`` attributes or set them to ``None`` before calling it.
      See :ref:`Default gradient layouts<default-grad-layouts>`
      for details on the memory layout of accumulated gradients.

      .. note::
          Using this method with ``create_graph=True`` will create a reference cycle
          between the parameter and its gradient which can cause a memory leak.
          We recommend using ``autograd.grad`` when creating the graph to avoid this.
          If you have to use this function, make sure to reset the ``.grad`` fields of your
          parameters to ``None`` after use to break the cycle and avoid the leak.

      .. note::

          If you run any forward ops, create ``grad_tensors``, and/or call ``backward``
          in a user-specified CUDA stream context, see
          :ref:`Stream semantics of backward passes<bwd-cuda-stream-semantics>`.

      .. note::

          When ``inputs`` are provided and a given input is not a leaf,
          the current implementation will call its grad_fn (even though it is not strictly needed to get this gradients).
          It is an implementation detail on which the user should not rely.
          See https://github.com/pytorch/pytorch/pull/60521#issuecomment-867061780 for more details.

      Args:
          tensors (Sequence[Tensor] or Tensor): Tensors of which the derivative will be
              computed.
          grad_tensors (Sequence[Tensor or None] or Tensor, optional): The "vector" in
              the Jacobian-vector product, usually gradients w.r.t. each element of
              corresponding tensors. None values can be specified for scalar Tensors or
              ones that don't require grad. If a None value would be acceptable for all
              grad_tensors, then this argument is optional.
          retain_graph (bool, optional): If ``False``, the graph used to compute the grad
              will be freed. Note that in nearly all cases setting this option to ``True``
              is not needed and often can be worked around in a much more efficient
              way. Defaults to the value of ``create_graph``.
          create_graph (bool, optional): If ``True``, graph of the derivative will
              be constructed, allowing to compute higher order derivative products.
              Defaults to ``False``.
          inputs (Sequence[Tensor] or Tensor, optional): Inputs w.r.t. which the gradient
              be will accumulated into ``.grad``. All other Tensors will be ignored. If
              not provided, the gradient is accumulated into all the leaf Tensors that
              were used to compute the attr::tensors.
      """
      if torch._C._are_functorch_transforms_active():
          raise RuntimeError(
              "backward() called inside a functorch transform. This is not "
              "supported, please use functorch.grad or functorch.vjp instead "
              "or call backward() outside of functorch transforms.")

      if grad_variables is not None:
          warnings.warn("'grad_variables' is deprecated. Use 'grad_tensors' instead.")
          if grad_tensors is None:
              grad_tensors = grad_variables
          else:
              raise RuntimeError("'grad_tensors' and 'grad_variables' (deprecated) "
                                 "arguments both passed to backward(). Please only "
                                 "use 'grad_tensors'.")
      if inputs is not None and len(inputs) == 0:
          raise RuntimeError("'inputs' argument to backward() cannot be empty.")

      tensors = (tensors,) if isinstance(tensors, torch.Tensor) else tuple(tensors)
      inputs = (inputs,) if isinstance(inputs, torch.Tensor) else \
          tuple(inputs) if inputs is not None else tuple()

      grad_tensors_ = _tensor_or_tensors_to_tuple(grad_tensors, len(tensors))
      grad_tensors_ = _make_grads(tensors, grad_tensors_, is_grads_batched=False)
      if retain_graph is None:
          retain_graph = create_graph

      # The reason we repeat same the comment below is that
      # some Python versions print out the first line of a multi-line function
      # calls in the traceback and some print out the last line
      Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
          tensors, grad_tensors_, retain_graph, create_graph, inputs,
          allow_unreachable=True, accumulate_grad=True)  # Calls into the C++ engine to run the backward pass
  • The most important part of the code is the call to Variable._execution_engine.run_backward(). This call triggers the actual backward pass computation. Here's what's happening in this step:

    • tensors: These are the tensors with respect to which the gradients are computed. These are typically the model parameters.

    • grad_tensors_: These are the gradient tensors used in the Jacobian-vector product. They represent the gradients of the loss function with respect to some output tensors.

    • retain_graph and create_graph: These control whether the computation graph is retained and whether a graph for higher-order derivatives is constructed, respectively.

    • inputs: These are the tensors with respect to which the gradients are accumulated. Gradients will be propagated through the graph up to these input tensors.

  • the backward pass calculates the gradients of the loss with respect to the model's parameters by applying the chain rule. It starts from the loss, propagates gradients backward through the computation graph, and computes the gradients for each parameter. The actual gradient computations are offloaded to a C++ engine for efficiency.

  • for how backward propogation actually works in pytorch refer: https://discuss.pytorch.org/t/what-does-the-backward-function-do/9944/2