Explorations in Knowledge Distillation

March 25, 2021 · 11 minute read

Chris Zhu

Engineering

TLDR

Knowledge distillation is a common way to train compressed models by transferring the knowledge learned from a large model into a smaller model. Today we’ll be taking a look at using knowledge distillation to train a model that screens for pneumonia in chest x-rays.

What is the point?

Let’s say you’re working in an area where privacy is extremely important, like healthcare. It may be the case that we cannot send patient data to the cloud where all our fancy GPU’s live. This creates the need to train a model that can be downloaded and run on a low power machine.

So… what are we working with?

Let’s use the

Chest X-Ray Images

dataset from Kaggle. The task is to identify pneumonia in chest x-rays.

Pneumonia is a lung infection that causes coughing, fever, chills, and breathing difficulties. It’s caused by an immune response to some kind of infection to the lungs, by virus or bacteria. The ongoing COVID-19 virus can cause pneumonia.

Basically, your lung has air sacs, which is where oxygen and carbon dioxide is exchanged for you to breathe. When these air sacs are infected by virus or bacteria, your body produces an immune response, by inflaming the area with fluid.

Most people can recover from this, but it can cause death in some people due to respiratory failure.

Pneumonia kills many more people in developing countries. While 50,000 people died in the US from pneumonia in 2017, it caused 3 million deaths worldwide.

Access to high quality healthcare is a major factor in the lethality of pneumonia

Limited access to healthcare infrastructure motivates the need for technology to bring down costs and increase efficiency.

Let’s take a look at the data

With this in mind, given a chest x-ray, we would be looking for cloudy regions that indicate fluids in the lungs.

Here is a normal chest x-ray

Here is an chest x-ray of a patient with pneumonia

Not so easy to tell the difference is it?

It’s not that obvious why one scan is healthy while the other is infected. From what I’ve researched on this, doctors look for white clumps around the peripherals of the lungs.

So let’s model it!

The easiest thing we can do here is simply to throw this into a pre-trained convolutional ResNet model and see how far we can get.

We’ll be using PyTorch and PyTorch lightning to build and train the models.

PyTorch Lightning is a library that will let us modularize our code so we can separate the bits that are common in basically all image classification tasks and the bits that are specific to image distillation tasks.

Let’s start by building a generic 

BaseImageClassificationTask

class to take care of all the boring stuff in image classification tasks like configuring optimizers and loading datasets. See the code

here

and dataset loading

here

.

Now, let’s create a simple

ImageClassificationTask

which can consume any PyTorch image classification model, and compute the cross entropy loss. This sets us up to plug in any PyTorch model that can consume an image and output a prediction.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
class ImageClassificationTask(BaseImageClassificationTask):
    def __init__(self, net, train_dataset, test_dataset, val_dataset, classes=10, learning_rate=1e-5):
        super().__init__()
        self.train_dataset = train_dataset
        self.test_dataset = test_dataset
        self.val_dataset = val_dataset
        self.learning_rate = learning_rate
        self.net = net

    def training_step(self, batch, batch_idx):
        # training_step defined the train loop. It is independent of forward
        x, y = batch
        prediction = self.net(x)
        loss = F.cross_entropy(prediction, y)
        self.log('train_loss', loss)
        return loss

Magically (not really), we can now kick off a training loop. PyTorch Lightning will take care of sampling from data loaders, and back propagating the loss.

1
2
3
4
5
6
7
8
9
10
11
12
checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',
    dirpath='./checkpoints/',
    filename='chest-xray-{epoch:02d}-{val_loss:.2f}-{val_acc:.2f}',
    mode='min',
)

trainer = pl.Trainer(max_epochs=40, gpus=1, callbacks=[
    checkpoint_callback,
])

model = ImageClassificationTask(ResNet18(num_classes=2), train_dataset, test_dataset, val_dataset)

Here are the results training with ResNet-18 after 40 epochs:

Final test set accuracy: 91%

How “small” can we make this model?

Remember, the original goal was to build models that can be downloaded and run on low power machines. In this case, let’s build a simple 3 layer CNN as the student model.

We can measure the size of this model in 2 ways:

  1. Model size, which translates to number of parameters

  2. Model speed, which typically translates to number of layers

Size

The ResNet-18 model has 11.7M parameters while the 3 layer CNN has 277,000 parameters.

This is a 97.5% reduction in model parameters.

Speed

CPU inference with ResNet-18 takes 45 ms while the 3 layer CNN takes 3 ms.

This is a 15x speed up in inference speed

.

Do we actually need a teacher?

The first question we should ask is, do we actually need a teacher model? Let’s naively take the our student model and train it with the

ImageClassificationTask

, as we did with the ResNet model.

Here are the results after 40 epochs:

Test set accuracy: 72%

Distillation

Now let’s build our

ImageClassificationDistillationTask

class.

The only meaningful difference between the

ImageClassificationTask

and the

ImageClassificationDistillationTask

is how the final training loss is computed, as well as some hyper-parameters to configure the loss.

1. Starting with a trained teacher network and untrained student network

(We already did this with the ResNet-18 above)

2. Forward pass through the teacher model and get logits

Make sure you put the teacher model into a test mode so we don’t needlessly collect gradients.

3. Compute the final loss as distillation loss + classification loss

4. Backpropagate loss through student model

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
class ImageClassificationDistillationTask(BaseImageClassificationTask):
    def __init__(self, teacher_model, student_model, train_dataset, test_dataset, val_dataset, learning_rate=0.001, temperature=2., alpha=0.5):
        super().__init__()
        self.learning_rate = learning_rate
        self.teacher_model = teacher_model
        self.train_dataset = train_dataset
        self.test_dataset = test_dataset
        self.val_dataset = val_dataset
        self.net = student_model
        self.temperature = temperature
        self.alpha = alpha

    def training_step(self, batch, batch_idx):
        # training_step defined the train loop. It is independent of forward
        x, y = batch
        student_logits = self.net(x)

        student_target_loss = F.cross_entropy(student_logits, y)
        
        with torch.no_grad():
            teacher_logits = self.teacher_model(x)
        distillation_loss = nn.KLDivLoss()(F.log_softmax(student_logits / self.temperature, dim=1),
                             F.softmax(teacher_logits / self.temperature, dim=1))

        loss = (1 - self.alpha) * student_target_loss + self.alpha * distillation_loss
        self.log('train_loss', loss)
        return loss

How does the loss function work?

The loss function is a weighted sum of 2 things:

  • The normal classification loss, referred to as 

    student_target_loss

     in the gist.

  • ⁠The cross entropy loss between student logits and teacher logits, referred to as the 

    distillation_loss

     in the gist. The loss is typically expressed in literature like this:

The first part is the classification loss and the second is the distillation loss

The cross entropy loss between the student and the teacher is the main innovation. Intuitively, this trains the student on the teacher’s uncertainty. This is also commonly referred to as the distillation loss. Intuitively, the purpose of this is to teach the student how the teacher “thinks”. In addition to training the student on the ground truth label, we also train the student on the uncertainty of the label that the teacher learned.

If the teacher outputs a prediction of 51% pneumonia and 49% not pneumonia, we also want the student to be equally uncertain.

An intuitive visualization of distillation loss

This is motivates the need for the two parameters to adjust the behavior of this loss:

  • Alpha: How much weight we put on the student-teacher loss relative to the normal classification loss

  • Temperature: How much we scale the uncertainty of the teacher model

Alpha

The alpha parameter controls the weight that is put on the distillation loss. An alpha of 1 means we only consider the distillation loss while an alpha of 0 means we completely ignore the distillation loss.

Temperature

The temperature is a more interesting parameter which scales how “uncertain” the teacher predictions are.

Here’s an example for a model that outputs 3 classes:

Here is how the predictions scale with various values for temperature to scale the uncertainty of these predictions.

T < 1 makes the model more certain of its predictions

T > 1 makes the model less certain of its predictions

At T = 4, the model is very uncertain compared to the original predictions.

The purpose of the temperature parameter is to control how uncertain the teacher predictions are.

Which hyper-parameters work best?

Here are the final results for the student 3 layer CNN model with different hyper-parameter settings:

Something weird happened at alpha=0.75 temperature=4. Better performance seems to skew to the upper left of this table.

The best performing setting by far was alpha = 0.25, temperature = 1, which achieves

86% on the test set

. This is an improvement from the original 72% when we just trained the student model from scratch, without distillation.

Here are the final results:

In Summary

We were able to train a model that is 97.5% smaller and 15 times faster than ResNet-18 and is about 5% worse than the teacher model.

Hang out with us

Join our community and chat about startups, AI/ML, and product development.

Like what you see? Join the team.

Mage is making AI and ML accessible to product developers. Join us and build beautiful and intuitive devtools.

Want to give us feedback or ask questions?

Please chat with us live by joining our Discord channel or send us an email.