Cross entropy any which way

loss functions
softmax
nll
cross entropy
Author

Lucas van Walstijn

Published

March 15, 2023

Cross entropy is one of the most commonly used loss functions. In this post, we will have a look at how it works, and compute it in a couple of different ways.

Consider a network that is build for image classification. During the forward pass, images are passed into the network and the network processes the data layer by layer, until evenually some final activations are being returned by the model. These final activations are called “logits” and represent the unnormalized predictions of our model.

Since we generally use mini-batches during training, these logits are of shape [bs, num_classes]

import torch
import torch.nn.functional as F

g = torch.manual_seed(42) # use a generator for reproducability

bs = 32 # batch size of 32
num_classes = 3 # image classification with 3 different classes

logits = torch.randn(size=(bs, num_classes), generator=g) # size: [32,3]

logits[0:4] # show the logits for the first couple of samples
tensor([[ 1.9269,  1.4873,  0.9007],
        [-2.1055,  0.6784, -1.2345],
        [-0.0431, -1.6047, -0.7521],
        [ 1.6487, -0.3925, -1.4036]])

Each row of this tensor represents the unnormalized predictions for each of our samples in the batch. We can normalize these predictions by applying a softmax. The softmax function does two things:

  1. make all our logits positive, by applying the exponential function, wolfram alpha reference
  2. divide each value of the exponentiated logits by the sum over all the classes

This makes sure that we can treat the output of this as probabilities, because:

  1. all individual predictions will be between 0 and 1
  2. the predictions will sum to 1

Specifically:

# Unnormalized predictions for our first sample (3 classes)
logits[0]
tensor([1.9269, 1.4873, 0.9007])
# Exponentiated predictions, making them all positive
exp_logits = logits[0].exp()
exp_logits
tensor([6.8683, 4.4251, 2.4614])
# Turn these values into probabilities by dividing by the sum
probs = exp_logits / exp_logits.sum()

# verify that the sum of the probabilities sum to 1
assert torch.allclose(probs.sum(), torch.tensor(1.))

probs
tensor([0.4993, 0.3217, 0.1789])

So, let’s create a softmax function that does this for a whole batch:

def softmax(logits):
    exp_logits = logits.exp() # shape: [32, 3]
    exp_logits_sum = exp_logits.sum(dim=1, keepdim=True) # shape: [32, 1]
    
    # Note: this get's correctly broadcasted, since the exp_logits_sum will 
    # expand to [32, 3], so each value in exp_logits gets divided by the sum over its row
    probs = exp_logits / exp_logits_sum # shape: [32, 3]
    
    return probs 

probs = softmax(logits)
probs[0:4]
tensor([[0.4993, 0.3217, 0.1789],
        [0.0511, 0.8268, 0.1221],
        [0.5876, 0.1233, 0.2891],
        [0.8495, 0.1103, 0.0401]])

Next, we want to compute the loss for which also need our labels. These labels represent the ground truth class for each of our samples in the batch. Since we have 3 classes they will be between 0 and 3 (e.g. either 0, 1 or 2)

g = torch.manual_seed(42) # use a generator for reproducability

labels = torch.randint(low=0, high=3, size=(32,), generator=g)
labels
tensor([0, 2, 1, 1, 0, 2, 1, 2, 1, 2, 1, 1, 2, 0, 0, 1, 2, 1, 0, 1, 1, 2, 1, 2,
        2, 1, 2, 0, 1, 1, 0, 0])

For classification we use the Negative Log Likelihood loss function, which is defined as such:

\[ \textrm{NLL} = - \sum_{i}{q_i * \log(p_i)} \]

with \(i\) being the index that moves along the classes (3 in our example) and \(q_i\) being the probability that the ground truth label is class \(i\) (this is a somewhat strange formulation, since this probability is either 1 (for the correct class) or 0 (for all the non-correct classes)). Finally, \(p_i\) is the probability that the model associated to class \(i\).

For the very first row of our probs ([0.4993, 0.3217, 0.1789]) and our first label (0) we thus get:

\[\begin{align} \textrm{NLL} &= - ( (1 \cdot \log(0.4993)) + (0 \cdot \log(0.3217)) + (0 \cdot \log(0.1789)) ) \\ \textrm{NLL} &= - ( (1 \cdot \log(0.4993)) ) \\ \textrm{NLL} &= - \log(0.4993) \end{align}\]

From which we see that it’s just the negative log of the probability associated with the ground truth class.

Since this computes only the NLL per sample, we also need a way to combine the NLL across the samples in our batch. We can do this either by summing or averaging, averaging has the advantage that the size of the loss remains the same when we change the batch-size, so let’s use that:

def nll(probs, labels):
    # probs: shape [32, 3]
    # labels: shape [32]
    
    # this plucks out the probability of the ground truth label per sample, 
    # it uses "numpy's integer array indexing":
    # https://numpy.org/doc/stable/user/basics.indexing.html#integer-array-indexing
    probs_ground_truth_class = probs[range(len(labels)), labels] # shape: [32]
    
    nll = -torch.log(probs_ground_truth_class).mean() # shape: []
    return nll
nll(probs, labels)
tensor(1.3465)

Using PyTorch

Instead of using our custom softmax, we can also use the build-in softmax function from PyTorch:

p = F.softmax(logits, dim=1) # dim=1 --> compute the sum across the columns
nll(p, labels)
tensor(1.3465)

Instead of using our custom nll we can also use the build-in version from PyTorch. However, nll_loss expects the log of the softmax (for numerical stability) so instead of softmax we have to use log_softmax:

p = F.log_softmax(logits, dim=1)

# Assert that indeed the log_softmax is just the softmax followed by a log
assert torch.allclose(p, F.softmax(logits, dim=1).log())

torch.nn.functional.nll_loss(p, labels)
tensor(1.3465)

The combination of softmax and nll is called cross entropy, so we can also use PyTorch’s build-in version of that:

F.cross_entropy(logits, labels)
tensor(1.3465)

Instead of the methods in nn.functional, we can also use classes. For that, we first create an instance of the object, and then “call” the instance:

ce = torch.nn.CrossEntropyLoss() # create a CrossEntropyLoss instance
ce(logits, labels) # calling the instance with the arguments returns the cross entropy
tensor(1.3465)

Similarly, we can use classes for the log_softmax and nll_loss functions

ls = torch.nn.LogSoftmax(dim=1)
nll = torch.nn.NLLLoss()

p = ls(logits)
nll(p, labels)
tensor(1.3465)

This is practical, if we want specify custom behavior of the loss function ahead of time of calling the actual loss function. For example, let’s say we want to compute the cross entropy loss based on ‘sums’ instead of ‘averages’. Then when using the method in F we would do:

F.cross_entropy(logits, labels, reduction='sum')
tensor(43.0866)

So whenever we call the loss, we have to specify the additional reduction argument.

Whereas when using the loss classes, we can instantiate the class with that reduction argument, and then call the instance as per usual without passing anything but the logits and the labels:

# instantiate 
ce = torch.nn.CrossEntropyLoss(reduction='sum')

# at some other point in your code, compute the loss as per default
ce(logits, labels)
tensor(43.0866)

This is practical when the loss function is getting called by another object to which we don’t have easy access. So that we can’t easily change the arguments for that call. This is for example the case when using the FastAI Learner class, to which we pass the loss function which then get’s called by the Learner object with the default arguments (logits and labels). By using the classes, we can specify the reduction argument ahead of time and pass that instance to the Learner class.