Skip to content

Phase Callbacks

Integrating your own code into an already existing training pipeline can draw much effort on the user's end. To tackle this challenge, a list of callables triggered at specific points of the training code can be passed through phase_calbacks_list inside training_params when calling Trainer.train(...).

SG's super_gradients.training.utils.callbacks module implements some common use cases as callbacks:

ModelConversionCheckCallback
LRCallbackBase
EpochStepWarmupLRCallback
BatchStepLinearWarmupLRCallback
StepLRCallback
ExponentialLRCallback
PolyLRCallback
CosineLRCallback
FunctionLRCallback
LRSchedulerCallback
DetectionVisualizationCallback
BinarySegmentationVisualizationCallback
TrainingStageSwitchCallbackBase
YoloXTrainingStageSwitchCallback

For example, the YoloX's COCO detection training recipe uses YoloXTrainingStageSwitchCallback to turn off augmentations and incorporate L1 loss starting from epoch 285:

super_gradients/recipes/training_hyperparams/coco2017_yolox_train_params.yaml:

max_epochs: 300
...

loss: yolox_loss

...

phase_callbacks:
  - YoloXTrainingStageSwitchCallback:
      next_stage_start_epoch: 285
...

Another example is how we use BinarySegmentationVisualizationCallback to visualize predictions during training in our Segmentation Transfer Learning Notebook:

Integrating Your Code with Callbacks

Integrating your code requires implementing a callback that Trainer would trigger in the proper phases inside SG's training pipeline.

So let's first get familiar with super_gradients.training.utils.callbacks.base_callbacks.Callback class.

It implements the following methods:

    on_training_start(self, context: PhaseContext) -> None
    on_train_loader_start(self, context: PhaseContext) -> None:
    on_train_batch_start(self, context: PhaseContext) -> None:
    on_train_batch_loss_end(self, context: PhaseContext) -> None:
    on_train_batch_backward_end(self, context: PhaseContext) -> None:
    on_train_batch_gradient_step_start(self, context: PhaseContext) -> None:
    on_train_batch_gradient_step_end(self, context: PhaseContext) -> None:
    on_train_batch_end(self, context: PhaseContext) -> None:
    on_train_loader_end(self, context: PhaseContext) -> None:
    on_validation_loader_start(self, context: PhaseContext) -> None:
    on_validation_batch_start(self, context: PhaseContext) -> None:
    on_validation_batch_end(self, context: PhaseContext) -> None:
    on_validation_loader_end(self, context: PhaseContext) -> None:
    on_validation_end_best_epoch(self, context: PhaseContext) -> None:
    on_test_loader_start(self, context: PhaseContext) -> None:
    on_test_batch_start(self, context: PhaseContext) -> None:
    on_test_batch_end(self, context: PhaseContext) -> None:
    on_test_loader_end(self, context: PhaseContext) -> None:
    on_training_end(self, context: PhaseContext) -> None:

Our callback needs to inherit from the above class and override the appropriate methods according to the points at which we would like to trigger it.

To understand which methods we need to override, we need to understand better when are the above methods triggered.

From the class docs, the order of the events is as follows:

    on_training_start(context)                              # called once before training starts, good for setting up the warmup LR

        for epoch in range(epochs):
            on_train_loader_start(context)
                for batch in train_loader:
                    on_train_batch_start(context)
                    on_train_batch_loss_end(context)               # called after loss has been computed
                    on_train_batch_backward_end(context)           # called after .backward() was called
                    on_train_batch_gradient_step_start(context)    # called before the optimizer step about to happen (gradient clipping, logging of gradients)
                    on_train_batch_gradient_step_end(context)      # called after gradient step was done, good place to update LR (for step-based schedulers)
                    on_train_batch_end(context)
            on_train_loader_end(context)

            on_validation_loader_start(context)
                for batch in validation_loader:
                    on_validation_batch_start(context)
                    on_validation_batch_end(context)
            on_validation_loader_end(context)
            on_validation_end_best_epoch(context)

        on_test_start(context)
            for batch in test_loader:
                on_test_batch_start(context)
                on_test_batch_end(context)
        on_test_end(context)

    on_training_end(context)                    # called once after training ends.

As you noticed, all of Callback's methods expect a single argument - a PhaseContext instance. This argument gives access to some variables at the points mentioned above in the code through its attributes. We can discover what variables are exposed by looking at the documentation of the Callback's specific methods we need to override.

For example:

...
    def on_training_start(self, context: PhaseContext) -> None:
        """
        Called once before the start of the first epoch
        At this point, the context argument is guaranteed to have the following attributes:
        - optimizer
        - net
        - checkpoints_dir_path
        - criterion
        - sg_logger
        - train_loader
        - valid_loader
        - training_params
        - checkpoint_params
        - architecture
        - arch_params
        - metric_to_watch
        - device
        - ema_model
        ...
        :return:
        """

Now let's implement our callback. Suppose we would like to implement a simple callback that saves the first batch of images in each epoch for both training and validation in a new folder called "batch_images" under our local checkpoints directory.

Our callback needs to be triggered in 3 places: 1. At the start of training, create a new "batch_images" under our local checkpoints directory. 2. Before passing a train image batch through the network. 3. Before passing a validation image batch through the network.

Therefore, our callback will override Callback's on_training_start, on_train_batch_start, and on_validation_batch_start methods:

from super_gradients.training.utils.callbacks import Callback, PhaseContext
from super_gradients.common.environment.ddp_utils import multi_process_safe
import os
from torchvision.utils import save_image


class SaveFirstBatchCallback(Callback):
    def __init__(self):
        self.outputs_path = None
        self.saved_first_validation_batch = False

    @multi_process_safe
    def on_training_start(self, context: PhaseContext) -> None:
        outputs_path = os.path.join(context.ckpt_dir, "batch_images")
        os.makedirs(outputs_path, exist_ok=True)

    @multi_process_safe
    def on_train_batch_start(self, context: PhaseContext) -> None:
        if context.batch_idx == 0:
            save_image(context.inputs, os.path.join(self.outputs_path, f"first_train_batch_epoch_{context.epoch}.png"))

    @multi_process_safe
    def on_validation_batch_start(self, context: PhaseContext) -> None:
        if context.batch_idx == 0 and not self.saved_first_validation_batch:
            save_image(context.inputs, os.path.join(self.outputs_path, f"first_validation_batch_epoch_{context.epoch}.png"))
            self.saved_first_validation_batch = True

Note the @multi_process_safe decorator, which allows the callback to be triggered precisely once when running distributed training.

For coded training scripts (i.e., not using configuration files), we can pass an instance of the callback through phase_callbacks:

```python ...

... trainer = Trainer("my_experiment") train_dataloader = ... valid_dataloader = ... model = ...

train_params = { ... "loss": "cross_entropy", "criterion_params": {} ... "phase_callbacks": [SaveFirstBatchCallback()], }

trainer.train(training_params=train_params, train_loader=train_dataloader, valid_loader=valid_dataloader)

Otherwise, for training with configuration files, we need to register our new callback by decorating it with the `register_loss` decorator:

```python
from super_gradients.training.utils.callbacks import Callback, PhaseContext
from super_gradients.common.environment.ddp_utils import multi_process_safe
import os
from torchvision.utils import save_image
from super_gradients.common.registry.registry import register_callback

@register_callback()
class SaveFirstBatchCallback(Callback):
    def __init__(self):
        self.outputs_path = None
        self.saved_first_validation_batch = False

    @multi_process_safe
    def on_training_start(self, context: PhaseContext) -> None:
        outputs_path = os.path.join(context.ckpt_dir, "batch_images")
        os.makedirs(outputs_path, exist_ok=True)

    @multi_process_safe
    def on_train_batch_start(self, context: PhaseContext) -> None:
        if context.batch_idx == 0:
            save_image(context.inputs, os.path.join(self.outputs_path, f"first_train_batch_epoch_{context.epoch}.png"))

    @multi_process_safe
    def on_validation_batch_start(self, context: PhaseContext) -> None:
        if context.batch_idx == 0 and not self.saved_first_validation_batch:
            save_image(context.inputs, os.path.join(self.outputs_path, f"first_validation_batch_epoch_{context.epoch}.png"))
            self.saved_first_validation_batch = True

Then, in your my_training_hyperparams.yaml, use SaveFirstBatchCallback in the same way as any other phase callback supported in SG: ```yaml defaults: - default_train_params

max_epochs: 250

... phase_callbacks: - SaveFirstBatchCallback

Last, in your ``my_train_from_recipe_script.py`` file, import the newly registered class (even though the class itself is unused, just to trigger the registry):

```python

  from omegaconf import DictConfig
  import hydra
  import pkg_resources
  from my_callbacks import SaveFirstBatchCallback
  from super_gradients import Trainer, init_trainer


  @hydra.main(config_path=pkg_resources.resource_filename("super_gradients.recipes", ""), version_base="1.2")
  def main(cfg: DictConfig) -> None:
      Trainer.train_from_config(cfg)


  def run():
      init_trainer()
      main()


  if __name__ == "__main__":
      run()