Skip to content

Exponential Moving Average (EMA)

Exponential Moving Average or EMA is a technique used during training to smooth the noise in the training process and improve the generalization of the model. It is a simple technique that can be used with any model and optimizer. Here's a recap how EMA works:

  • At the start of training, the model parameters are copied to the EMA parameters.
  • At each gradient update step, the EMA parameters are updated using the following formula:
    ema_param = ema_param * decay + param * (1 - decay)
    
  • On start of validation epoch the model parameters are replaced by the EMA parameters and reverted back on the end of validation epoch.
  • At the end of training, the model parameters are replaced by the EMA parameters.

To enable use of EMA is SuperGradients one should add following parameters to the training_params:

from super_gradients import Trainer
trainer = Trainer(...)
trainer.train(
    training_params={"ema": True, "ema_params": {"decay": 0.9999, "decay_type": "constant"}, ...}, 
    ...
)

The decay is a hyperparameter that controls the speed of the EMA update. It's value must be in (0,1) range. Larger values of decay will result in slower EMA model update.

It is usually beneficial to have smaller decay values at the start of training and increase it as the training progresses. In SuperGradients we support several types of changing decay value over time:

  • constant: "ema_params": {"decay": 0.9999, "decay_type": "constant"}
  • threshold: "ema_params": {"decay": 0.9999, "decay_type": "threshold"}
  • exp: "ema_params": {"decay": 0.9999, "decay_type": "exp", "beta": 15}

EMA Decay schedules

Adding your own decay schedule

It is possible to bring your own decay schedule in SuperGradients. By subclassing from IDecayFunction one can implement a custom function:

from super_gradients.training.utils.ema_decay_schedules import IDecayFunction, EMA_DECAY_FUNCTIONS

class LinearDecay(IDecayFunction):
    def __init__(self, **kwargs):
        pass

    def __call__(self, decay: float, step: int, total_steps: int) -> float:
      """
      Compute EMA for specific training step following linear scaling rule [0..decay)
        :param decay: The maximum decay value.
        :param step: Current training step. The unit-range training percentage can be obtained by `step / total_steps`.
        :param total_steps:  Total number of training steps.
        :return: Computed decay value for a given step.
      """
      training_progress = step / total_steps
      return decay * training_progress

EMA_DECAY_FUNCTIONS["linear"] = LinearDecay

How EMA weights are saved

When EMA is enabled, saved checkpoints will contain additional ema_net attribute. Weights for EMA model are saved under ema_net. A regular (non-averaged) model weights are saved as net key in checkpoint as usual.

When instantiating a model via models.get, a function will check whether ema_net is present in the checkpoint. In case ema_net is in checkpoint, the model will be initialized using EMA weights; otherwise a model will load initialized from regular weights saved in net.

Knowledge Distillation

EMA is also supported in knowledge distillation. To enable it one should add following parameters to the training_params similar to the above example:

from super_gradients import KDTrainer
trainer = KDTrainer(...)
trainer.train(
    training_params={"ema": True, "ema_params": {"decay": 0.9999, "decay_type": "constant"}, ...}, 
    ...
)