nn.BCELoss and nn.BCEWithLogitsLoss

August 17, 2023

  • Entropy is a fundamental concept in information theory that quantifies the randomness or uncertainty of an event's outcome. Before delving into its applications and mathematical representations, let's establish a clear understanding of its essence.

  • Entropy reflects the unpredictability of an event, capturing the degree of uncertainty associated with its outcome. As uncertainty increases, so does entropy. This relationship is intuitively evident when considering events of low and high probability. A low-probability event, being unexpected and surprising, carries high information content. Conversely, a high-probability event, being unsurprising, carries low information content.

  • This concept finds practical use in various fields, including machine learning, where cross-entropy plays a crucial role. Cross-entropy measures the difference between two probability distributions for a given set of events. It's commonly utilized as a loss function during the training of classification models.

  • In binary classification, the raw output of a neural network undergoes transformation through the sigmoid function, yielding a probability score denoted as p = σ(z). The true label, represented as t, takes values in {0, 1}. The binary cross-entropy loss, also known as log loss, is a mathematical expression that captures the divergence between predicted probabilities and true labels:

  • L(t, p) = -[t * log(p) + (1 - t) * log(1 - p)]

  • This loss function has distinctive behaviors for different scenarios. When the true label is 1, the loss simplifies to -log(p), penalizing deviations of the predicted probability from 1. Similarly, when the true label is 0, the loss simplifies to -log(1 - p), penalizing deviations of the predicted probability from 0.

  • The choice of using logarithms in the loss function aligns with probabilistic interpretations. For instance, t * log(p) signifies that the loss grows as the predicted probability decreases, which is consistent with the idea that a confident prediction should align with the true label.

  • Graphically, the binary cross-entropy loss demonstrates a sharp increase as predicted probabilities move away from the true values, reflecting the model's increased loss for incorrect high-confidence predictions.

  • In practice, the total cross-entropy loss for a dataset with S samples is the sum of individual loss values across all samples:

  • L(t, p) = -∑[t_i * log(p_i) + (1 - t_i) * log(1 - p_i)] for i = 1 to S

  • lets look at the code for BCELoss in pytorch

    class _Loss(Module):
        reduction: str
    
        def __init__(self, size_average=None, reduce=None, reduction: str = 'mean') -> None:
            super().__init__()
            if size_average is not None or reduce is not None:
                self.reduction: str = _Reduction.legacy_get_string(size_average, reduce)
            else:
                self.reduction = reduction
    
    class _WeightedLoss(_Loss):
        def __init__(self, weight: Optional[Tensor] = None, size_average=None, reduce=None, reduction: str = 'mean') -> None:
            super().__init__(size_average, reduce, reduction)
            self.register_buffer('weight', weight)
            self.weight: Optional[Tensor]
    
    class BCELoss(_WeightedLoss):
        __constants__ = ['reduction']
    
        def __init__(self, weight: Optional[Tensor] = None, size_average=None, reduce=None, reduction: str = 'mean') -> None:
            super().__init__(weight, size_average, reduce, reduction)
    
        def forward(self, input: Tensor, target: Tensor) -> Tensor:
            return F.binary_cross_entropy(input, target, weight=self.weight, reduction=self.reduction)
    
    
    def binary_cross_entropy(
        input: Tensor,
        target: Tensor,
        weight: Optional[Tensor] = None,
        size_average: Optional[bool] = None,
        reduce: Optional[bool] = None,
        reduction: str = "mean",
    ) -> Tensor:
        # checks if the input, target, and weight tensors have overridden __torch_function__ methods.
        # If they do, it uses the handle_torch_function utility to delegate the computation
        # to the overridden method. If not, it proceeds with the native binary cross-entropy
        # computation.
        if has_torch_function_variadic(input, target, weight):
            return handle_torch_function(
                binary_cross_entropy,
                (input, target, weight),
                input,
                target,
                weight=weight,
                size_average=size_average,
                reduce=reduce,
                reduction=reduction,
            )
        if size_average is not None or reduce is not None:
            reduction_enum = _Reduction.legacy_get_enum(size_average, reduce)
        else:
            reduction_enum = _Reduction.get_enum(reduction)
    
        #  checks if the target tensor has the same size as the input tensor
        if target.size() != input.size():
            raise ValueError(
                "Using a target size ({}) that is different to the input size ({}) is deprecated. "
                "Please ensure they have the same size.".format(target.size(), input.size())
            )
    
        # if a weight tensor is provided, it calculates the new size for the weight tensor using
        # the _infer_size function and expands the weight tensor to match the calculated size.
        if weight is not None:
            new_size = _infer_size(target.size(), weight.size())
            weight = weight.expand(new_size)
    
        #  computes the binary cross-entropy loss using the native C++ backend
        return torch._C._nn.binary_cross_entropy(input, target, weight, reduction_enum)
    
  •   if size_average is not None or reduce is not None:
          reduction_enum = _Reduction.legacy_get_enum(size_average, reduce)
      else:
          reduction_enum = _Reduction.get_enum(reduction)
    • "mean" Reduction: This strategy computes the mean (average) of individual losses. It is commonly used when you want to compute the average loss across all samples in the batch or dataset. This can help normalize the loss across different batch sizes.

    • "sum" Reduction: This strategy simply sums up all the individual losses. It is useful when you want to work with the total cumulative loss across all samples.

    • "none" Reduction: With this strategy, no reduction is applied, and you get a tensor of individual losses. This can be useful if you need per-sample loss values for further analysis or custom post-processing.

    • If size_average or reduce is not None, it implies that the user has specified either or both of these reduction options. In this case, the code uses the _Reduction.legacy_get_enum() function to determine the corresponding reduction strategy enum based on the provided size_average and reduce values.

    • If neither size_average nor reduce is specified (both are None), then the code uses _Reduction.get_enum() to determine the reduction strategy enum based on the provided reduction parameter. This is the newer way of specifying the reduction strategy.

  •   def handle_torch_function(
              public_api: Callable, relevant_args: Iterable[Any], *args, **kwargs) -> Any:
          # Check for __torch_function__ methods.
          overloaded_args = _get_overloaded_args(relevant_args)
          # overloaded_args already have unique types.
          types = tuple(map(type, overloaded_args))
    
          # Check for __torch_function__ mode.
          if _is_torch_function_mode_enabled():
              # if we're here, the mode must be set to a TorchFunctionStackMode
              # this unsets it and calls directly into TorchFunctionStackMode's torch function
              with _pop_mode_temporarily() as mode:
                  result = mode.__torch_function__(public_api, types, args, kwargs)
              if result is not NotImplemented:
                  return result
    
          # Call overrides
          for overloaded_arg in overloaded_args:
              # This call needs to become a classmethod call in the future.
              # See https://github.com/pytorch/pytorch/issues/63767
              torch_func_method = overloaded_arg.__torch_function__
              if hasattr(torch_func_method, "__self__") and torch_func_method.__self__ is overloaded_arg and \
                      torch_func_method is not torch._C._disabled_torch_function_impl:
                  warnings.warn("Defining your `__torch_function__ as a plain method is deprecated and "
                                  "will be an error in future, please define it as a classmethod.",
                                  DeprecationWarning)
    
              # Use `public_api` instead of `implementation` so __torch_function__
              # implementations can do equality/identity comparisons.
              result = torch_func_method(public_api, types, args, kwargs)
    
              if result is not NotImplemented:
                  return result
    
          func_name = f'{public_api.__module__}.{public_api.__name__}'
          msg = (
              "no implementation found for '{}' on types that implement "
              '__torch_function__: {}'
          ).format(func_name, [type(arg) for arg in overloaded_args])
          if _is_torch_function_mode_enabled():
              msg += f" nor in mode {_get_current_function_mode()}"
          raise TypeError(msg)
  • # Check for __torch_function__ mode.
        if _is_torch_function_mode_enabled():
            with _pop_mode_temporarily() as mode:
                result = mode.__torch_function__(public_api, types, args, kwargs)
            if result is not NotImplemented:
                return result
    • checks if the torch function mode is enabled using _is_torch_function_mode_enabled(). If it's enabled, it temporarily sets the mode using _pop_mode_temporarily() and then calls the __torch_function__ method of the mode. The public_api, types, args, and kwargs are passed to this method to check if it provides an implementation. If an implementation is found and is not NotImplemented, it is returned.

      • Adding a torch_function implementation to ScalarTensor makes it possible for the above operation to succeed. Let’s re-do our implementation, this time adding a torch_function implementation:

      • HANDLED_FUNCTIONS = {}
        class ScalarTensor(object):
            def __init__(self, N, value):
                self._N = N
                self._value = value
        
            def __repr__(self):
                return "ScalarTensor(N={}, value={})".format(self._N, self._value)
        
            def tensor(self):
                return self._value * torch.eye(self._N)
        
            @classmethod
            def __torch_function__(cls, func, types, args=(), kwargs=None):
                if kwargs is None:
                    kwargs = {}
                if func not in HANDLED_FUNCTIONS or not all(
                    issubclass(t, (torch.Tensor, ScalarTensor))
                    for t in types
                ):
                    return NotImplemented
                return HANDLED_FUNCTIONS[func](*args, **kwargs)
      • The torch_function method takes four arguments: func, a reference to the torch API function that is being overridden, types, the list of types of Tensor-likes that implement torch_function, args, the tuple of arguments passed to the function, and kwargs, the dict of keyword arguments passed to the function. It uses a global dispatch table named HANDLED_FUNCTIONS to store custom implementations. The keys of this dictionary are functions in the torch namespace and the values are implementations for ScalarTensor.

  • Let's look at BCEWithLogit

    • but before we start, let's figure out what's a logit

      • logits are a way to model and transform raw scores into probabilities, allowing us to perform classification tasks more effectively and facilitating optimization during model training.

      • “logits” refers to the raw outputs of a model before they are transformed into probabilities. Specifically, logits are the unnormalized outputs of the last layer of a neural network. These values are often used as inputs to a softmax function, which transforms them into normalized probabilities.

      • A logit is a term used in statistics and machine learning to describe the output of a mathematical function that maps input values to a range that corresponds to probabilities. Logits are often used in the context of logistic regression, neural networks, and other models that deal with binary classification or multi-class classification problems.

      • In a binary classification problem, the goal is to assign an input to one of two classes (e.g., "positive" or "negative," "spam" or "not spam"). Logits help model the underlying linear relationship between input features and the probability of belonging to a particular class.

      • The term "logit" itself refers to the logarithm of the odds ratio. The odds ratio is the ratio of the probability of an event occurring to the probability of it not occurring. By taking the logarithm of the odds ratio, we transform the probabilities into a more linear scale that can be more easily used in mathematical operations.

        • The odds ratio is a concept used in statistics and probability to quantify the likelihood of an event occurring compared to the likelihood of it not occurring.
      • The need for logits arises from the desire to model and compute probabilities in a way that makes optimization and computation more numerically stable. Logits are used to map raw scores (output of a model) to probabilities using a transformation called the sigmoid function (also known as the logistic function).

    • The code is pretty much the same as BCELoss, but the docs say that "This version is more numerically stable than using a plain Sigmoid followed by a BCELoss as, by combining the operations into one layer, we take advantage of the log-sum-exp trick for numerical stability". I need to figure out how that is

      • It combines the calculation of sigmoid activation (logit transformation) and binary cross-entropy loss into a single operation. This loss function is particularly useful when dealing with raw logits (output of a model before activation) rather than probability predictions.

      • BCELoss is used when you have probabilities as predictions. It takes the predicted probabilities and the target labels as inputs and computes the binary cross-entropy loss, which measures the difference between the predicted probabilities and the true labels. The binary cross-entropy loss expects probability values between 0 and 1.

      • BCEWithLogitsLoss is used when you have raw logits (the output of a model before applying any activation function, like sigmoid) as predictions. It combines two steps into one: applying the sigmoid activation (logit transformation) to convert logits into probabilities, and then calculating the binary cross-entropy loss based on the transformed probabilities.

      • so pretty much a sequential of sorts, just depends on the output we have.

        • if predicted probabilities are available, use BCELoss

        • if logits are available, use BCEWithLogitsLoss as it squeeshifies the logits with the sigmoid chain as well