Skip to content


The purpose of metrics is to allow you to monitor and quantify the training process. Therefore, metrics are an essential component in every deep learning training process. For this purpose, we leverage the torchmetrics library. From the torchmetrics homepage:

"TorchMetrics is a collection of 90+ PyTorch metrics implementations and an easy-to-use API to create custom metrics. It offers:

- A standardized interface to increase reproducibility

- Reduces Boilerplate

- Distributed-training compatible

- Rigorously tested

- Automatic accumulation over batches

- Automatic synchronization between multiple devices."

SG is compatible with any module metric implemented by torchmetrics (see complete list here). Apart from the native torchmetrics implementations, SG implements some metrics as torchmetrics.Metric objects as well:


Basic Usage of Implemented Metrics

For coded training scripts (i.e., not using configuration files), the most basic usage is simply passing the metric objects through train_metrics_list and valid_metrics_list:

from super_gradients import Trainer
from import Accuracy, Top5

trainer = Trainer("my_experiment")
train_dataloader = ...
valid_dataloader = ...
model = ...
train_params = {
    "train_metrics_list": [Accuracy(), Top5()],
    "valid_metrics_list": [Accuracy(), Top5()],
    "metric_to_watch": "Accuracy",
    "greater_metric_to_watch_is_better": True,
trainer.train(model=model, training_params=train_params, train_loader=train_dataloader, valid_loader=valid_dataloader)

Now, the metrics progress over the training epochs (and validation) will be displayed and logged in the Tensorboards, and any 3rd party SG Logger (see integration with Weights & Biases and Clearml in repo homepage). Metric results will be lowercase, with the appropriate suffix: train_accuracy, train_top5, valid_accuracy, valid_top5. Also, notice the metric_to_watch set to Accuracy and greater_metric_to_watch_is_better=True, meaning that we will monitor the validation accuracy and save checkpoints according to it. Open any of the tutorial notebooks to see the metrics monitoring in action. For more info on checkpoints and logs, follow our SG checkpoints tutorial.

Equivalently, for training with configuration files, your my_training_hyperparams.yaml would contain:

  - default_train_params
metric_to_watch: Accuracy
greater_metric_to_watch_is_better: True

train_metrics_list:                               # metrics for evaluation
  - Accuracy
  - Top5

valid_metrics_list:                               # metrics for evaluation
  - Accuracy
  - Top5

Using Custom Metrics

Suppose you implemented your own MyAccuracy (more information on how to do so here), for coded training, you can pass an instance of it as done in the previous sub-section. For training with configuration files, first decorate your metric class with SG's @register_metric decorator:

from torchmetrics import Metric
import torch
from super_gradients.common.registry import register_metric

class MyAccuracy(Metric):
    def __init__(self):
        self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum")
        self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")

    def update(self, preds: torch.Tensor, target: torch.Tensor):
        preds, target = self._input_format(preds, target)
        assert preds.shape == target.shape

        self.correct += torch.sum(preds == target) += target.numel()

    def compute(self):
        return self.correct.float() /

Next, use the registered metric in your my_training_hyperparams.yaml by plugging in the registered name, just as if it was any other metric:

  - default_train_params
metric_to_watch: my_accuracy
greater_metric_to_watch_is_better: True

train_metrics_list:                               # metrics for evaluation
  - my_accuracy

valid_metrics_list:                               # metrics for evaluation
  - my_accuracy

Last, in your file, import the newly registered class (even though the class itself is unused, just to trigger the registry):

from omegaconf import DictConfig
import hydra
import pkg_resources
from my_accuracy import MyAccuracy
from super_gradients import Trainer, init_trainer

@hydra.main(config_path=pkg_resources.resource_filename("", ""), version_base="1.2")
def main(cfg: DictConfig) -> None:

def run():

if __name__ == "__main__":