Skip to content

Knowledge Distillation (KD)

Pre-requisites: Training in SG, Training with Configuration Files

Knowledge distillation is a technique in deep learning that aims to transfer the knowledge of a large, pre-trained neural network model (the "teacher") to a smaller, more computationally efficient model (the "student"). This is accomplished by training the student to mimic the teacher's predictions and the ground-truth labels. The student network can also be designed to have a different architecture from the teacher, making it possible to distill the knowledge of a complex teacher network into a lighter and faster student network for deployment in real-world applications.

The training flow with Knowledge distillation in SG is similar to regular training. For standard training, we used SGs Trainer class - which was in charge of training the model, evaluating test data, making predictions, and saving checkpoints. Equivalently, for knowledge distillation, we use the KDTrainer class which inherits from Trainer.

If for regular training with Trainer, the general flow is:

...
trainer = Trainer("my_experiment")
train_dataloader = ...
valid_dataloader = ...
model = ...

train_params = {...}
trainer.train(model=model, training_params=train_params, train_loader=train_dataloader, valid_loader=valid_dataloader)

Then for training with knowledge distillation, the general flow is:

from super_gradients.training.kd_trainer import KDTrainer
...
kd_trainer = KDTrainer("my_experiment")
train_dataloader = ...
valid_dataloader = ...

student_model = ...
teacher_model = ...

train_params = {...}
kd_trainer.train(student=student_model, teacher=teacher_model, training_params=train_params, train_loader=train_dataloader, valid_loader=valid_dataloader)

Check out our knowledge distillation tutorial notebook to see a practical example.

Knowledge Distillation Training: Key Components

KDModule

The most apparent difference in the training flow using knowledge distillation is that it requires two networks: the "teacher" and the "student". The relation between the two is also configurable - for example, we may decide that the teacher model should preprocess the inputs differently. For that matter, SG introduces a new torch.nn.Module that wraps both the student and the teacher models: KDModule. Upon calling KDTrainer.train(), the teacher and student models are passed along the kd_arch_params to initialize a KDModule instance.

Passing a KDModule instance explicitly to KDTrainer.train() through the model argument instead of student and teacher models is also possible, which gives the users the option to customize KD to their needs.

A high-level example of KD customization:

import torch.nn
from super_gradients.training.kd_trainer import KDTrainer

...


class MyKDModule(KDModule):
    ...

    def forward(self, x: torch.Tensor)->KDOutput:
        intermediate_output_student = self.student.extract_intermediate_output(x, layer_ids=[1, 3, -1])
        intermediate_output_teacher = self.teacher.extract_intermediate(x, layer_ids=[1, 3, -1])
        return KDOutput(student_output=intermediate_output_student, teacher_output=intermediate_output_teacher)


class MyKDLoss(torch.nn.Module):
    ...
    def forward(self, preds: KDOutput, target: torch.Tensor):
        # does something with the intermediate outputs
        ...

kd_trainer = KDTrainer("my_customized_kd_experiment")
train_dataloader = ...
valid_dataloader = ...

student_model = ...
teacher_model = ...

kd_model = MyKDModule(student=student_model, teacher=teacher_model)

train_params = {'loss': MyKDLoss(),
                ...}

kd_trainer.train(model=kd_model, training_params=train_params,
                 train_loader=train_dataloader, valid_loader=valid_dataloader)

KDOutput

KDOutput defines the structure of the output of KDModule and has two self-explanatory attributes: student_output and teacher_output. KDTrainer uses these attributes behind the scenes to perform the usual operations of regular training, such as metrics calculations. This means that when customizing KD, it's essential for the custom KDModule to stick to this output format.

KD Losses

Currently, KDLogitsLoss is currently the only supported loss function in SGs KD losses bank, but more is to come. Note that during KD training, the KDModule outputs (which are of KDOutput instance) are passed to the loss's forward method as predictions.

Knowledge Distillation Training: Checkpoints

Checkpointing during KD training is generally the same as checkpointing without KD. Nevertheless, there are a few differences worth mentioning:

  • ckpt_latest.pth contains the state dict of the entire KDModule.
  • ckpt_best.pth contains the state dict of the student only.
  • When training with EMA, ckpt_best.pths net entry holds the EMA network.

Knowledge Distillation Training with Configuration Files

As done when training without knowledge distillation, to train with configuration files, we call the KDTrainer.train_from_config method, which assumes a specific configuration structure. When training with KD, the same structure and required fields hold, but we introduce a few additions:

  • arch_params are being passed to the KDModule constructor. For example, in our Resnet50 KD training on Imagenet, we handle the difference in preprocessing of the teacher, which expects different normalization by passing the KDModule a normalization adaptor module:

    # super_gradients/recipes/imagenet_resnet50_kd.yaml
    ...
    arch_params:
      teacher_input_adapter:
        _target_: super_gradients.training.utils.kd_trainer_utils.NormalizationAdapter
        mean_original: [0.485, 0.456, 0.406]
        std_original: [0.229, 0.224, 0.225]
        mean_required: [0.5, 0.5, 0.5]
        std_required: [0.5, 0.5, 0.5]
    

    Warning: Remember to distinguish the arch params being passed to the KDModule constructor from the student ones.

  • student_architecture, teacher_architecture,student_arch_params, student_checkpoint_params, teacher_arch_params, and teacher_checkpoint_params play the same role as architecture, arch_params and checkpoint_params for instantiating our model in non-KD training, and are being passed to models.get(...) to instantiate the teacher and the student:

...
student_architecture: resnet50
teacher_architecture: beit_base_patch16_224

student_arch_params:
  num_classes: 1000

teacher_arch_params:
  num_classes: 1000
  image_size: [224, 224]
  patch_size: [16, 16]

teacher_checkpoint_params:
...
  pretrained_weights: imagenet

student_checkpoint_params:
...

Any KD recipe can be launched with our train_from_kd_recipe_example script.