Skip to content

Utils

get_builtin_activation_type(activation, **kwargs)

Returns activation class by its name from torch.nn namespace. This function support all modules available from torch.nn and also their lower-case aliases. On top of that, it supports a few aliaes: leaky_relu (LeakyReLU), swish (silu).

act_cls = get_activation_type("LeakyReLU", inplace=True, slope=0.01) act = act_cls()

Parameters:

Name Type Description Default
activation Union[str, None]

Activation function name (E.g. ReLU). If None - return nn.Identity

required

Returns:

Type Description
Type[nn.Module]

Type of the activation function that is ready to be instantiated

Source code in V3_3/src/super_gradients/training/utils/activations_utils.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
def get_builtin_activation_type(activation: Union[str, None], **kwargs) -> Type[nn.Module]:
    """
    Returns activation class by its name from torch.nn namespace. This function support all modules available from
    torch.nn and also their lower-case aliases.
    On top of that, it supports a few aliaes: leaky_relu (LeakyReLU), swish (silu).

    >>> act_cls = get_activation_type("LeakyReLU", inplace=True, slope=0.01)
    >>> act = act_cls()


    :param activation: Activation function name (E.g. ReLU). If None - return nn.Identity
    :param **kwargs  : Extra arguments to pass to constructor during instantiation (E.g. inplace=True)

    :returns         : Type of the activation function that is ready to be instantiated
    """

    if activation is None:
        activation_cls = nn.Identity
    else:
        lowercase_aliases: Dict[str, str] = dict((k.lower(), k) for k in torch.nn.__dict__.keys())

        # Register additional aliases
        lowercase_aliases["leaky_relu"] = "LeakyReLU"  # LeakyRelu in snake_case
        lowercase_aliases["swish"] = "SiLU"  # Swish shich is equivalent to SiLU
        lowercase_aliases["none"] = "Identity"

        if activation in lowercase_aliases:
            activation = lowercase_aliases[activation]

        if activation not in torch.nn.__dict__:
            raise KeyError(f"Requested activation function {activation} is not known")

        activation_cls = torch.nn.__dict__[activation]
        if len(kwargs):
            activation_cls = partial(activation_cls, **kwargs)

    return activation_cls

batch_distance2bbox(points, distance, max_shapes=None)

Decode distance prediction to bounding box for batch.

Parameters:

Name Type Description Default
points Tensor

[B, ..., 2], "xy" format

required
distance Tensor

[B, ..., 4], "ltrb" format

required
max_shapes Optional[Tensor]

[B, 2], "h,w" format, Shape of the image.

None

Returns:

Type Description
Tensor

Tensor: Decoded bboxes, "x1y1x2y2" format.

Source code in V3_3/src/super_gradients/training/utils/bbox_utils.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
def batch_distance2bbox(points: Tensor, distance: Tensor, max_shapes: Optional[Tensor] = None) -> Tensor:
    """Decode distance prediction to bounding box for batch.

    :param points: [B, ..., 2], "xy" format
    :param distance: [B, ..., 4], "ltrb" format
    :param max_shapes: [B, 2], "h,w" format, Shape of the image.
    :return: Tensor: Decoded bboxes, "x1y1x2y2" format.
    """
    lt, rb = torch.split(distance, 2, dim=-1)
    # while tensor add parameters, parameters should be better placed on the second place
    x1y1 = -lt + points
    x2y2 = rb + points
    out_bbox = torch.cat([x1y1, x2y2], dim=-1)
    if max_shapes is not None:
        max_shapes = max_shapes.flip(-1).tile([1, 2])
        delta_dim = out_bbox.ndim - max_shapes.ndim
        for _ in range(delta_dim):
            max_shapes.unsqueeze_(1)
        out_bbox = torch.where(out_bbox < max_shapes, out_bbox, max_shapes)
        out_bbox = torch.where(out_bbox > 0, out_bbox, torch.zeros_like(out_bbox))
    return out_bbox

Callback

Base callback class with all the callback methods. Derived classes may override one or many of the available events to receive callbacks when such events are triggered by the training loop.

The order of the events is as follows:

on_training_start(context) # called once before training starts, good for setting up the warmup LR

for epoch in range(epochs):
    on_train_loader_start(context)
        for batch in train_loader:
            on_train_batch_start(context)
            on_train_batch_loss_end(context)               # called after loss has been computed
            on_train_batch_backward_end(context)           # called after .backward() was called
            on_train_batch_gradient_step_start(context)    # called before the optimizer step about to happen (gradient clipping, logging of gradients)
            on_train_batch_gradient_step_end(context)      # called after gradient step was done, good place to update LR (for step-based schedulers)
            on_train_batch_end(context)
    on_train_loader_end(context)

    on_validation_loader_start(context)
        for batch in validation_loader:
            on_validation_batch_start(context)
            on_validation_batch_end(context)
    on_validation_loader_end(context)
    on_validation_end_best_epoch(context)

on_test_start(context)
    for batch in test_loader:
        on_test_batch_start(context)
        on_test_batch_end(context)
on_test_end(context)

on_average_best_models_validation_start
on_average_best_models_validation_end

on_training_end(context) # called once after training ends.

Correspondence mapping from the old callback API:

on_training_start(context) <-> Phase.PRE_TRAINING for epoch in range(epochs): on_train_loader_start(context) <-> Phase.TRAIN_EPOCH_START for batch in train_loader: on_train_batch_start(context) on_train_batch_loss_end(context) on_train_batch_backward_end(context) <-> Phase.TRAIN_BATCH_END on_train_batch_gradient_step_start(context) on_train_batch_gradient_step_end(context) <-> Phase.TRAIN_BATCH_STEP on_train_batch_end(context) on_train_loader_end(context) <-> Phase.TRAIN_EPOCH_END

on_validation_loader_start(context)
    for batch in validation_loader:
        on_validation_batch_start(context)
        on_validation_batch_end(context)               <-> Phase.VALIDATION_BATCH_END
on_validation_loader_end(context)                      <-> Phase.VALIDATION_EPOCH_END
on_validation_end_best_epoch(context)                  <-> Phase.VALIDATION_END_BEST_EPOCH

on_test_start(context) for batch in test_loader: on_test_batch_start(context) on_test_batch_end(context) <-> Phase.TEST_BATCH_END on_test_end(context) <-> Phase.TEST_END

on_training_end(context) <-> Phase.POST_TRAINING

Source code in V3_3/src/super_gradients/training/utils/callbacks/base_callbacks.py
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
class Callback:
    """
    Base callback class with all the callback methods. Derived classes may override one or many of the available events
    to receive callbacks when such events are triggered by the training loop.

    The order of the events is as follows:

    on_training_start(context)                              # called once before training starts, good for setting up the warmup LR

        for epoch in range(epochs):
            on_train_loader_start(context)
                for batch in train_loader:
                    on_train_batch_start(context)
                    on_train_batch_loss_end(context)               # called after loss has been computed
                    on_train_batch_backward_end(context)           # called after .backward() was called
                    on_train_batch_gradient_step_start(context)    # called before the optimizer step about to happen (gradient clipping, logging of gradients)
                    on_train_batch_gradient_step_end(context)      # called after gradient step was done, good place to update LR (for step-based schedulers)
                    on_train_batch_end(context)
            on_train_loader_end(context)

            on_validation_loader_start(context)
                for batch in validation_loader:
                    on_validation_batch_start(context)
                    on_validation_batch_end(context)
            on_validation_loader_end(context)
            on_validation_end_best_epoch(context)

        on_test_start(context)
            for batch in test_loader:
                on_test_batch_start(context)
                on_test_batch_end(context)
        on_test_end(context)

        on_average_best_models_validation_start
        on_average_best_models_validation_end

    on_training_end(context)                    # called once after training ends.

    Correspondence mapping from the old callback API:

    on_training_start(context)                                 <-> Phase.PRE_TRAINING
    for epoch in range(epochs):
        on_train_loader_start(context)                         <-> Phase.TRAIN_EPOCH_START
            for batch in train_loader:
                on_train_batch_start(context)
                on_train_batch_loss_end(context)
                on_train_batch_backward_end(context)           <-> Phase.TRAIN_BATCH_END
                on_train_batch_gradient_step_start(context)
                on_train_batch_gradient_step_end(context)      <-> Phase.TRAIN_BATCH_STEP
                on_train_batch_end(context)
        on_train_loader_end(context)                           <-> Phase.TRAIN_EPOCH_END

        on_validation_loader_start(context)
            for batch in validation_loader:
                on_validation_batch_start(context)
                on_validation_batch_end(context)               <-> Phase.VALIDATION_BATCH_END
        on_validation_loader_end(context)                      <-> Phase.VALIDATION_EPOCH_END
        on_validation_end_best_epoch(context)                  <-> Phase.VALIDATION_END_BEST_EPOCH

    on_test_start(context)
        for batch in test_loader:
            on_test_batch_start(context)
            on_test_batch_end(context)                         <-> Phase.TEST_BATCH_END
    on_test_end(context)                                       <-> Phase.TEST_END

    on_training_end(context)                                   <-> Phase.POST_TRAINING
    """

    def on_training_start(self, context: PhaseContext) -> None:
        """
        Called once before start of the first epoch
        At this point, the context argument will have the following attributes:
            - optimizer
            - criterion
            - device
            - experiment_name
            - ckpt_dir
            - net
            - sg_logger
            - train_loader
            - valid_loader
            - training_params
            - checkpoint_params
            - arch_params
            - metric_to_watch
            - valid_metrics

        The corresponding Phase enum value for this event is Phase.PRE_TRAINING.
        :param context:
        """
        pass

    def on_train_loader_start(self, context: PhaseContext) -> None:
        """
        Called each epoch at the start of train data loader (before getting the first batch).
        At this point, the context argument will have the following attributes:
            - optimizer
            - criterion
            - device
            - experiment_name
            - ckpt_dir
            - net
            - sg_logger
            - train_loader
            - valid_loader
            - training_params
            - checkpoint_params
            - arch_params
            - metric_to_watch
            - valid_metrics
        The corresponding Phase enum value for this event is Phase.TRAIN_EPOCH_START.
        :param context:
        """
        pass

    def on_train_batch_start(self, context: PhaseContext) -> None:
        """
        Called at each batch after getting batch of data from data loader and moving it to target device.
        This event triggered AFTER Trainer.pre_prediction_callback call (If it was defined).

        At this point, the context argument will have the following attributes:
            - epoch
            - batch_idx
            - optimizer
            - inputs
            - target
            - metrics_compute_fn
            - loss_avg_meter
            - criterion
            - device
            - stop_training
            - experiment_name
            - ckpt_dir
            - net
            - lr_warmup_epochs
            - sg_logger
            - train_loader
            - valid_loader
            - training_params
            - ddp_silent_mode
            - checkpoint_params
            - arch_params
            - metric_to_watch
            - valid_metrics

        :param context:
        """
        pass

    def on_train_batch_loss_end(self, context: PhaseContext) -> None:
        """
        Called after model forward and loss computation has been done.
        At this point, the context argument will have the following attributes:
            - epoch
            - batch_idx
            - optimizer
            - inputs
            - preds
            - target
            - metrics_compute_fn
            - loss_avg_meter
            - loss_log_items
            - criterion
            - device
            - stop_training
            - experiment_name
            - ckpt_dir
            - net
            - lr_warmup_epochs
            - sg_logger
            - train_loader
            - valid_loader
            - training_params
            - ddp_silent_mode
            - checkpoint_params
            - arch_params
            - metric_to_watch
            - valid_metrics
            - loss_logging_items_names
        The corresponding Phase enum value for this event is Phase.TRAIN_BATCH_END.

        :param context:
        """
        pass

    def on_train_batch_backward_end(self, context: PhaseContext) -> None:
        """
        Called after loss.backward() method was called for a given batch
        At this point, the context argument will have the following attributes:
            - epoch
            - batch_idx
            - optimizer
            - inputs
            - preds
            - target
            - metrics_compute_fn
            - loss_avg_meter
            - loss_log_items
            - criterion
            - device
            - stop_training
            - experiment_name
            - ckpt_dir
            - net
            - lr_warmup_epochs
            - sg_logger
            - train_loader
            - valid_loader
            - training_params
            - ddp_silent_mode
            - checkpoint_params
            - arch_params
            - metric_to_watch
            - valid_metrics
            - loss_logging_items_names

        :param context:
        """
        pass

    def on_train_batch_gradient_step_start(self, context: PhaseContext) -> None:
        """
        Called before the graadient step is about to happen.
        Good place to clip gradients (with respect to scaler), log gradients to data ratio, etc.
        At this point, the context argument will have the following attributes:
            - epoch
            - batch_idx
            - optimizer
            - inputs
            - preds
            - target
            - metrics_compute_fn
            - loss_avg_meter
            - loss_log_items
            - criterion
            - device
            - stop_training
            - experiment_name
            - ckpt_dir
            - net
            - lr_warmup_epochs
            - sg_logger
            - train_loader
            - valid_loader
            - training_params
            - ddp_silent_mode
            - checkpoint_params
            - arch_params
            - metric_to_watch
            - valid_metrics
            - loss_logging_items_names

        :param context:
        """
        pass

    def on_train_batch_gradient_step_end(self, context: PhaseContext) -> None:
        """
        Called after gradient step has been performed. Good place to update LR (for step-based schedulers)
        At this point, the context argument will have the following attributes:
            - epoch
            - batch_idx
            - inputs
            - target
            - metrics_compute_fn
            - loss_avg_meter
            - criterion
            - device
            - stop_training
            - net
            - lr_warmup_epochs
            - sg_logger
            - train_loader
            - valid_loader
            - loss_logging_items_names

        The corresponding Phase enum value for this event is Phase.TRAIN_BATCH_STEP.
        :param context:
        """
        pass

    def on_train_batch_end(self, context: PhaseContext) -> None:
        """
        Called after all forward/backward/optimizer steps have been performed for a given batch and there is nothing left to do.
        At this point, the context argument will have the following attributes:
            - epoch
            - batch_idx
            - optimizer
            - inputs
            - preds
            - target
            - metrics_dict
            - metrics_compute_fn
            - loss_avg_meter
            - loss_log_items
            - criterion
            - device
            - stop_training
            - experiment_name
            - ckpt_dir
            - net
            - lr_warmup_epochs
            - sg_logger
            - train_loader
            - valid_loader
            - training_params
            - ddp_silent_mode
            - checkpoint_params
            - arch_params
            - metric_to_watch
            - valid_metrics
            - loss_logging_items_names

        :param context:
        """
        pass

    def on_train_loader_end(self, context: PhaseContext) -> None:
        """
        Called each epoch at the end of train data loader (after processing the last batch).
        At this point, the context argument will have the following attributes:
            - epoch
            - batch_idx
            - optimizer
            - inputs
            - preds
            - target
            - metrics_dict
            - metrics_compute_fn
            - loss_avg_meter
            - loss_log_items
            - criterion
            - device
            - stop_training
            - experiment_name
            - ckpt_dir
            - net
            - lr_warmup_epochs
            - sg_logger
            - train_loader
            - valid_loader
            - training_params
            - ddp_silent_mode
            - checkpoint_params
            - arch_params
            - metric_to_watch
            - valid_metrics
            - loss_logging_items_names

        The corresponding Phase enum value for this event is Phase.TRAIN_EPOCH_END.
        :param context:
        """
        pass

    def on_validation_loader_start(self, context: PhaseContext) -> None:
        """
        Called each epoch at the start of validation data loader (before getting the first batch).
        At this point, the context argument will have the following attributes:
            - epoch
            - batch_idx
            - optimizer
            - inputs
            - preds
            - target
            - metrics_dict
            - metrics_compute_fn
            - loss_avg_meter
            - loss_log_items
            - criterion
            - device
            - stop_training
            - experiment_name
            - ckpt_dir
            - net
            - lr_warmup_epochs
            - sg_logger
            - train_loader
            - valid_loader
            - training_params
            - ddp_silent_mode
            - checkpoint_params
            - arch_params
            - metric_to_watch
            - valid_metrics
            - loss_logging_items_names

        :param context:
        """
        pass

    def on_validation_batch_start(self, context: PhaseContext) -> None:
        """
        Called at each batch after getting batch of data from validation loader and moving it to target device.
        At this point, the context argument will have the following attributes:
            - epoch
            - batch_idx
            - inputs
            - target
            - metrics_compute_fn
            - loss_avg_meter
            - criterion
            - device
            - stop_training
            - net
            - lr_warmup_epochs
            - sg_logger
            - train_loader
            - valid_loader
            - loss_logging_items_names

        :param context:
        """
        pass

    def on_validation_batch_end(self, context: PhaseContext) -> None:
        """
        Called after all forward step / loss / metric computation have been performed for a given batch and there is nothing left to do.
        At this point, the context argument will have the following attributes:
            - epoch
            - batch_idx
            - inputs
            - preds
            - target
            - metrics_compute_fn
            - loss_avg_meter
            - loss_log_items
            - criterion
            - device
            - stop_training
            - net
            - lr_warmup_epochs
            - sg_logger
            - train_loader
            - valid_loader
            - loss_logging_items_names

        The corresponding Phase enum value for this event is Phase.VALIDATION_BATCH_END.
        :param context:
        """
        pass

    def on_validation_loader_end(self, context: PhaseContext) -> None:
        """
        Called each epoch at the end of validation data loader (after processing the last batch).
        At this point, the context argument will have the following attributes:
            - epoch
            - batch_idx
            - optimizer
            - inputs
            - preds
            - target
            - metrics_dict
            - metrics_compute_fn
            - loss_avg_meter
            - loss_log_items
            - criterion
            - device
            - stop_training
            - experiment_name
            - ckpt_dir
            - net
            - lr_warmup_epochs
            - sg_logger
            - train_loader
            - valid_loader
            - training_params
            - ddp_silent_mode
            - checkpoint_params
            - arch_params
            - metric_to_watch
            - valid_metrics
            - loss_logging_items_names

        The corresponding Phase enum value for this event is Phase.VALIDATION_EPOCH_END.
        :param context:
        """
        pass

    def on_validation_end_best_epoch(self, context: PhaseContext) -> None:
        """
        Called each epoch after validation has been performed and the best metric has been achieved.
        At this point, the context argument will have the following attributes:
            - epoch
            - batch_idx
            - optimizer
            - inputs
            - preds
            - target
            - metrics_dict
            - metrics_compute_fn
            - loss_avg_meter
            - loss_log_items
            - criterion
            - device
            - stop_training
            - experiment_name
            - ckpt_dir
            - net
            - lr_warmup_epochs
            - sg_logger
            - train_loader
            - valid_loader
            - training_params
            - ddp_silent_mode
            - checkpoint_params
            - arch_params
            - metric_to_watch
            - valid_metrics
            - loss_logging_items_names

        The corresponding Phase enum value for this event is Phase.VALIDATION_END_BEST_EPOCH.
        :param context:
        """
        pass

    def on_test_loader_start(self, context: PhaseContext) -> None:
        """
        Called once at the start of test data loader (before getting the first batch).
        At this point, the context argument will have the following attributes:
            - epoch
            - batch_idx
            - optimizer
            - inputs
            - preds
            - target
            - metrics_dict
            - metrics_compute_fn
            - loss_avg_meter
            - loss_log_items
            - criterion
            - device
            - stop_training
            - experiment_name
            - ckpt_dir
            - net
            - lr_warmup_epochs
            - sg_logger
            - train_loader
            - valid_loader
            - training_params
            - ddp_silent_mode
            - checkpoint_params
            - arch_params
            - metric_to_watch
            - valid_metrics
            - loss_logging_items_names

        :param context:
        """
        pass

    def on_test_batch_start(self, context: PhaseContext) -> None:
        """
        Called at each batch after getting batch of data from test loader and moving it to target device.
        At this point, the context argument will have the following attributes:
            - epoch
            - batch_idx
            - optimizer
            - inputs
            - preds
            - target
            - metrics_dict
            - metrics_compute_fn
            - loss_avg_meter
            - loss_log_items
            - criterion
            - device
            - stop_training
            - experiment_name
            - ckpt_dir
            - net
            - lr_warmup_epochs
            - sg_logger
            - train_loader
            - valid_loader
            - training_params
            - ddp_silent_mode
            - checkpoint_params
            - arch_params
            - metric_to_watch
            - valid_metrics
            - loss_logging_items_names

        :param context:
        """
        pass

    def on_test_batch_end(self, context: PhaseContext) -> None:
        """
        Called after all forward step have been performed for a given batch and there is nothing left to do.
        At this point, the context argument will have the following attributes:
            - epoch
            - batch_idx
            - optimizer
            - inputs
            - preds
            - target
            - metrics_dict
            - metrics_compute_fn
            - loss_avg_meter
            - loss_log_items
            - criterion
            - device
            - stop_training
            - experiment_name
            - ckpt_dir
            - net
            - lr_warmup_epochs
            - sg_logger
            - train_loader
            - valid_loader
            - training_params
            - ddp_silent_mode
            - checkpoint_params
            - arch_params
            - metric_to_watch
            - valid_metrics
            - loss_logging_items_names

        The corresponding Phase enum value for this event is Phase.TEST_BATCH_END.
        :param context:
        """
        pass

    def on_test_loader_end(self, context: PhaseContext) -> None:
        """
        Called once at the end of test data loader (after processing the last batch).
        At this point, the context argument will have the following attributes:
            - epoch
            - batch_idx
            - optimizer
            - inputs
            - preds
            - target
            - metrics_dict
            - metrics_compute_fn
            - loss_avg_meter
            - loss_log_items
            - criterion
            - device
            - stop_training
            - experiment_name
            - ckpt_dir
            - net
            - lr_warmup_epochs
            - sg_logger
            - train_loader
            - valid_loader
            - training_params
            - ddp_silent_mode
            - checkpoint_params
            - arch_params
            - metric_to_watch
            - valid_metrics
            - loss_logging_items_names

        The corresponding Phase enum value for this event is Phase.TEST_END.
        :param context:
        """
        pass

    def on_average_best_models_validation_start(self, context: PhaseContext) -> None:
        """
        Called once after the test was end before the training loop has finished.
        At this point, the context argument will have the following attributes:
            - epoch
            - batch_idx
            - optimizer
            - inputs
            - preds
            - target
            - metrics_dict
            - metrics_compute_fn
            - loss_avg_meter
            - loss_log_items
            - criterion
            - device
            - stop_training
            - experiment_name
            - ckpt_dir
            - net
            - lr_warmup_epochs
            - sg_logger
            - train_loader
            - valid_loader
            - training_params
            - ddp_silent_mode
            - checkpoint_params
            - arch_params
            - metric_to_watch
            - valid_metrics
            - loss_logging_items_names

        The corresponding Phase enum value for this event is Phase.AVERAGE_BEST_MODELS_VALIDATION_START.
        :param context:
        """
        pass

    def on_average_best_models_validation_end(self, context: PhaseContext) -> None:
        """
        Called once after the average model validation has finished.
        At this point, the context argument will have the following attributes:
            - epoch
            - batch_idx
            - optimizer
            - inputs
            - preds
            - target
            - metrics_dict
            - metrics_compute_fn
            - loss_avg_meter
            - loss_log_items
            - criterion
            - device
            - stop_training
            - experiment_name
            - ckpt_dir
            - net
            - lr_warmup_epochs
            - sg_logger
            - train_loader
            - valid_loader
            - training_params
            - ddp_silent_mode
            - checkpoint_params
            - arch_params
            - metric_to_watch
            - valid_metrics
            - loss_logging_items_names

        The corresponding Phase enum value for this event is Phase.AVERAGE_BEST_MODELS_VALIDATION_START.
        :param context:
        """
        pass

    def on_training_end(self, context: PhaseContext) -> None:
        """
        Called once after the training loop has finished (Due to reaching optimization criterion or because of an error.)
        At this point, the context argument will have the following attributes:
            - epoch
            - batch_idx
            - optimizer
            - inputs
            - preds
            - target
            - metrics_compute_fn
            - loss_avg_meter
            - loss_log_items
            - criterion
            - device
            - stop_training
            - experiment_name
            - ckpt_dir
            - net
            - lr_warmup_epochs
            - sg_logger
            - train_loader
            - valid_loader
            - training_params
            - ddp_silent_mode
            - checkpoint_params
            - arch_params
            - metric_to_watch
            - valid_metrics
            - loss_logging_items_names

        The corresponding Phase enum value for this event is Phase.POST_TRAINING.
        :param context:
        """
        pass

on_average_best_models_validation_end(context)

Called once after the average model validation has finished. At this point, the context argument will have the following attributes: - epoch - batch_idx - optimizer - inputs - preds - target - metrics_dict - metrics_compute_fn - loss_avg_meter - loss_log_items - criterion - device - stop_training - experiment_name - ckpt_dir - net - lr_warmup_epochs - sg_logger - train_loader - valid_loader - training_params - ddp_silent_mode - checkpoint_params - arch_params - metric_to_watch - valid_metrics - loss_logging_items_names

The corresponding Phase enum value for this event is Phase.AVERAGE_BEST_MODELS_VALIDATION_START.

Parameters:

Name Type Description Default
context PhaseContext required
Source code in V3_3/src/super_gradients/training/utils/callbacks/base_callbacks.py
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
def on_average_best_models_validation_end(self, context: PhaseContext) -> None:
    """
    Called once after the average model validation has finished.
    At this point, the context argument will have the following attributes:
        - epoch
        - batch_idx
        - optimizer
        - inputs
        - preds
        - target
        - metrics_dict
        - metrics_compute_fn
        - loss_avg_meter
        - loss_log_items
        - criterion
        - device
        - stop_training
        - experiment_name
        - ckpt_dir
        - net
        - lr_warmup_epochs
        - sg_logger
        - train_loader
        - valid_loader
        - training_params
        - ddp_silent_mode
        - checkpoint_params
        - arch_params
        - metric_to_watch
        - valid_metrics
        - loss_logging_items_names

    The corresponding Phase enum value for this event is Phase.AVERAGE_BEST_MODELS_VALIDATION_START.
    :param context:
    """
    pass

on_average_best_models_validation_start(context)

Called once after the test was end before the training loop has finished. At this point, the context argument will have the following attributes: - epoch - batch_idx - optimizer - inputs - preds - target - metrics_dict - metrics_compute_fn - loss_avg_meter - loss_log_items - criterion - device - stop_training - experiment_name - ckpt_dir - net - lr_warmup_epochs - sg_logger - train_loader - valid_loader - training_params - ddp_silent_mode - checkpoint_params - arch_params - metric_to_watch - valid_metrics - loss_logging_items_names

The corresponding Phase enum value for this event is Phase.AVERAGE_BEST_MODELS_VALIDATION_START.

Parameters:

Name Type Description Default
context PhaseContext required
Source code in V3_3/src/super_gradients/training/utils/callbacks/base_callbacks.py
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
def on_average_best_models_validation_start(self, context: PhaseContext) -> None:
    """
    Called once after the test was end before the training loop has finished.
    At this point, the context argument will have the following attributes:
        - epoch
        - batch_idx
        - optimizer
        - inputs
        - preds
        - target
        - metrics_dict
        - metrics_compute_fn
        - loss_avg_meter
        - loss_log_items
        - criterion
        - device
        - stop_training
        - experiment_name
        - ckpt_dir
        - net
        - lr_warmup_epochs
        - sg_logger
        - train_loader
        - valid_loader
        - training_params
        - ddp_silent_mode
        - checkpoint_params
        - arch_params
        - metric_to_watch
        - valid_metrics
        - loss_logging_items_names

    The corresponding Phase enum value for this event is Phase.AVERAGE_BEST_MODELS_VALIDATION_START.
    :param context:
    """
    pass

on_test_batch_end(context)

Called after all forward step have been performed for a given batch and there is nothing left to do. At this point, the context argument will have the following attributes: - epoch - batch_idx - optimizer - inputs - preds - target - metrics_dict - metrics_compute_fn - loss_avg_meter - loss_log_items - criterion - device - stop_training - experiment_name - ckpt_dir - net - lr_warmup_epochs - sg_logger - train_loader - valid_loader - training_params - ddp_silent_mode - checkpoint_params - arch_params - metric_to_watch - valid_metrics - loss_logging_items_names

The corresponding Phase enum value for this event is Phase.TEST_BATCH_END.

Parameters:

Name Type Description Default
context PhaseContext required
Source code in V3_3/src/super_gradients/training/utils/callbacks/base_callbacks.py
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
def on_test_batch_end(self, context: PhaseContext) -> None:
    """
    Called after all forward step have been performed for a given batch and there is nothing left to do.
    At this point, the context argument will have the following attributes:
        - epoch
        - batch_idx
        - optimizer
        - inputs
        - preds
        - target
        - metrics_dict
        - metrics_compute_fn
        - loss_avg_meter
        - loss_log_items
        - criterion
        - device
        - stop_training
        - experiment_name
        - ckpt_dir
        - net
        - lr_warmup_epochs
        - sg_logger
        - train_loader
        - valid_loader
        - training_params
        - ddp_silent_mode
        - checkpoint_params
        - arch_params
        - metric_to_watch
        - valid_metrics
        - loss_logging_items_names

    The corresponding Phase enum value for this event is Phase.TEST_BATCH_END.
    :param context:
    """
    pass

on_test_batch_start(context)

Called at each batch after getting batch of data from test loader and moving it to target device. At this point, the context argument will have the following attributes: - epoch - batch_idx - optimizer - inputs - preds - target - metrics_dict - metrics_compute_fn - loss_avg_meter - loss_log_items - criterion - device - stop_training - experiment_name - ckpt_dir - net - lr_warmup_epochs - sg_logger - train_loader - valid_loader - training_params - ddp_silent_mode - checkpoint_params - arch_params - metric_to_watch - valid_metrics - loss_logging_items_names

Parameters:

Name Type Description Default
context PhaseContext required
Source code in V3_3/src/super_gradients/training/utils/callbacks/base_callbacks.py
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
def on_test_batch_start(self, context: PhaseContext) -> None:
    """
    Called at each batch after getting batch of data from test loader and moving it to target device.
    At this point, the context argument will have the following attributes:
        - epoch
        - batch_idx
        - optimizer
        - inputs
        - preds
        - target
        - metrics_dict
        - metrics_compute_fn
        - loss_avg_meter
        - loss_log_items
        - criterion
        - device
        - stop_training
        - experiment_name
        - ckpt_dir
        - net
        - lr_warmup_epochs
        - sg_logger
        - train_loader
        - valid_loader
        - training_params
        - ddp_silent_mode
        - checkpoint_params
        - arch_params
        - metric_to_watch
        - valid_metrics
        - loss_logging_items_names

    :param context:
    """
    pass

on_test_loader_end(context)

Called once at the end of test data loader (after processing the last batch). At this point, the context argument will have the following attributes: - epoch - batch_idx - optimizer - inputs - preds - target - metrics_dict - metrics_compute_fn - loss_avg_meter - loss_log_items - criterion - device - stop_training - experiment_name - ckpt_dir - net - lr_warmup_epochs - sg_logger - train_loader - valid_loader - training_params - ddp_silent_mode - checkpoint_params - arch_params - metric_to_watch - valid_metrics - loss_logging_items_names

The corresponding Phase enum value for this event is Phase.TEST_END.

Parameters:

Name Type Description Default
context PhaseContext required
Source code in V3_3/src/super_gradients/training/utils/callbacks/base_callbacks.py
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
def on_test_loader_end(self, context: PhaseContext) -> None:
    """
    Called once at the end of test data loader (after processing the last batch).
    At this point, the context argument will have the following attributes:
        - epoch
        - batch_idx
        - optimizer
        - inputs
        - preds
        - target
        - metrics_dict
        - metrics_compute_fn
        - loss_avg_meter
        - loss_log_items
        - criterion
        - device
        - stop_training
        - experiment_name
        - ckpt_dir
        - net
        - lr_warmup_epochs
        - sg_logger
        - train_loader
        - valid_loader
        - training_params
        - ddp_silent_mode
        - checkpoint_params
        - arch_params
        - metric_to_watch
        - valid_metrics
        - loss_logging_items_names

    The corresponding Phase enum value for this event is Phase.TEST_END.
    :param context:
    """
    pass

on_test_loader_start(context)

Called once at the start of test data loader (before getting the first batch). At this point, the context argument will have the following attributes: - epoch - batch_idx - optimizer - inputs - preds - target - metrics_dict - metrics_compute_fn - loss_avg_meter - loss_log_items - criterion - device - stop_training - experiment_name - ckpt_dir - net - lr_warmup_epochs - sg_logger - train_loader - valid_loader - training_params - ddp_silent_mode - checkpoint_params - arch_params - metric_to_watch - valid_metrics - loss_logging_items_names

Parameters:

Name Type Description Default
context PhaseContext required
Source code in V3_3/src/super_gradients/training/utils/callbacks/base_callbacks.py
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
def on_test_loader_start(self, context: PhaseContext) -> None:
    """
    Called once at the start of test data loader (before getting the first batch).
    At this point, the context argument will have the following attributes:
        - epoch
        - batch_idx
        - optimizer
        - inputs
        - preds
        - target
        - metrics_dict
        - metrics_compute_fn
        - loss_avg_meter
        - loss_log_items
        - criterion
        - device
        - stop_training
        - experiment_name
        - ckpt_dir
        - net
        - lr_warmup_epochs
        - sg_logger
        - train_loader
        - valid_loader
        - training_params
        - ddp_silent_mode
        - checkpoint_params
        - arch_params
        - metric_to_watch
        - valid_metrics
        - loss_logging_items_names

    :param context:
    """
    pass

on_train_batch_backward_end(context)

Called after loss.backward() method was called for a given batch At this point, the context argument will have the following attributes: - epoch - batch_idx - optimizer - inputs - preds - target - metrics_compute_fn - loss_avg_meter - loss_log_items - criterion - device - stop_training - experiment_name - ckpt_dir - net - lr_warmup_epochs - sg_logger - train_loader - valid_loader - training_params - ddp_silent_mode - checkpoint_params - arch_params - metric_to_watch - valid_metrics - loss_logging_items_names

Parameters:

Name Type Description Default
context PhaseContext required
Source code in V3_3/src/super_gradients/training/utils/callbacks/base_callbacks.py
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
def on_train_batch_backward_end(self, context: PhaseContext) -> None:
    """
    Called after loss.backward() method was called for a given batch
    At this point, the context argument will have the following attributes:
        - epoch
        - batch_idx
        - optimizer
        - inputs
        - preds
        - target
        - metrics_compute_fn
        - loss_avg_meter
        - loss_log_items
        - criterion
        - device
        - stop_training
        - experiment_name
        - ckpt_dir
        - net
        - lr_warmup_epochs
        - sg_logger
        - train_loader
        - valid_loader
        - training_params
        - ddp_silent_mode
        - checkpoint_params
        - arch_params
        - metric_to_watch
        - valid_metrics
        - loss_logging_items_names

    :param context:
    """
    pass

on_train_batch_end(context)

Called after all forward/backward/optimizer steps have been performed for a given batch and there is nothing left to do. At this point, the context argument will have the following attributes: - epoch - batch_idx - optimizer - inputs - preds - target - metrics_dict - metrics_compute_fn - loss_avg_meter - loss_log_items - criterion - device - stop_training - experiment_name - ckpt_dir - net - lr_warmup_epochs - sg_logger - train_loader - valid_loader - training_params - ddp_silent_mode - checkpoint_params - arch_params - metric_to_watch - valid_metrics - loss_logging_items_names

Parameters:

Name Type Description Default
context PhaseContext required
Source code in V3_3/src/super_gradients/training/utils/callbacks/base_callbacks.py
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
def on_train_batch_end(self, context: PhaseContext) -> None:
    """
    Called after all forward/backward/optimizer steps have been performed for a given batch and there is nothing left to do.
    At this point, the context argument will have the following attributes:
        - epoch
        - batch_idx
        - optimizer
        - inputs
        - preds
        - target
        - metrics_dict
        - metrics_compute_fn
        - loss_avg_meter
        - loss_log_items
        - criterion
        - device
        - stop_training
        - experiment_name
        - ckpt_dir
        - net
        - lr_warmup_epochs
        - sg_logger
        - train_loader
        - valid_loader
        - training_params
        - ddp_silent_mode
        - checkpoint_params
        - arch_params
        - metric_to_watch
        - valid_metrics
        - loss_logging_items_names

    :param context:
    """
    pass

on_train_batch_gradient_step_end(context)

Called after gradient step has been performed. Good place to update LR (for step-based schedulers) At this point, the context argument will have the following attributes: - epoch - batch_idx - inputs - target - metrics_compute_fn - loss_avg_meter - criterion - device - stop_training - net - lr_warmup_epochs - sg_logger - train_loader - valid_loader - loss_logging_items_names

The corresponding Phase enum value for this event is Phase.TRAIN_BATCH_STEP.

Parameters:

Name Type Description Default
context PhaseContext required
Source code in V3_3/src/super_gradients/training/utils/callbacks/base_callbacks.py
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
def on_train_batch_gradient_step_end(self, context: PhaseContext) -> None:
    """
    Called after gradient step has been performed. Good place to update LR (for step-based schedulers)
    At this point, the context argument will have the following attributes:
        - epoch
        - batch_idx
        - inputs
        - target
        - metrics_compute_fn
        - loss_avg_meter
        - criterion
        - device
        - stop_training
        - net
        - lr_warmup_epochs
        - sg_logger
        - train_loader
        - valid_loader
        - loss_logging_items_names

    The corresponding Phase enum value for this event is Phase.TRAIN_BATCH_STEP.
    :param context:
    """
    pass

on_train_batch_gradient_step_start(context)

Called before the graadient step is about to happen. Good place to clip gradients (with respect to scaler), log gradients to data ratio, etc. At this point, the context argument will have the following attributes: - epoch - batch_idx - optimizer - inputs - preds - target - metrics_compute_fn - loss_avg_meter - loss_log_items - criterion - device - stop_training - experiment_name - ckpt_dir - net - lr_warmup_epochs - sg_logger - train_loader - valid_loader - training_params - ddp_silent_mode - checkpoint_params - arch_params - metric_to_watch - valid_metrics - loss_logging_items_names

Parameters:

Name Type Description Default
context PhaseContext required
Source code in V3_3/src/super_gradients/training/utils/callbacks/base_callbacks.py
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
def on_train_batch_gradient_step_start(self, context: PhaseContext) -> None:
    """
    Called before the graadient step is about to happen.
    Good place to clip gradients (with respect to scaler), log gradients to data ratio, etc.
    At this point, the context argument will have the following attributes:
        - epoch
        - batch_idx
        - optimizer
        - inputs
        - preds
        - target
        - metrics_compute_fn
        - loss_avg_meter
        - loss_log_items
        - criterion
        - device
        - stop_training
        - experiment_name
        - ckpt_dir
        - net
        - lr_warmup_epochs
        - sg_logger
        - train_loader
        - valid_loader
        - training_params
        - ddp_silent_mode
        - checkpoint_params
        - arch_params
        - metric_to_watch
        - valid_metrics
        - loss_logging_items_names

    :param context:
    """
    pass

on_train_batch_loss_end(context)

Called after model forward and loss computation has been done. At this point, the context argument will have the following attributes: - epoch - batch_idx - optimizer - inputs - preds - target - metrics_compute_fn - loss_avg_meter - loss_log_items - criterion - device - stop_training - experiment_name - ckpt_dir - net - lr_warmup_epochs - sg_logger - train_loader - valid_loader - training_params - ddp_silent_mode - checkpoint_params - arch_params - metric_to_watch - valid_metrics - loss_logging_items_names The corresponding Phase enum value for this event is Phase.TRAIN_BATCH_END.

Parameters:

Name Type Description Default
context PhaseContext required
Source code in V3_3/src/super_gradients/training/utils/callbacks/base_callbacks.py
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
def on_train_batch_loss_end(self, context: PhaseContext) -> None:
    """
    Called after model forward and loss computation has been done.
    At this point, the context argument will have the following attributes:
        - epoch
        - batch_idx
        - optimizer
        - inputs
        - preds
        - target
        - metrics_compute_fn
        - loss_avg_meter
        - loss_log_items
        - criterion
        - device
        - stop_training
        - experiment_name
        - ckpt_dir
        - net
        - lr_warmup_epochs
        - sg_logger
        - train_loader
        - valid_loader
        - training_params
        - ddp_silent_mode
        - checkpoint_params
        - arch_params
        - metric_to_watch
        - valid_metrics
        - loss_logging_items_names
    The corresponding Phase enum value for this event is Phase.TRAIN_BATCH_END.

    :param context:
    """
    pass

on_train_batch_start(context)

Called at each batch after getting batch of data from data loader and moving it to target device. This event triggered AFTER Trainer.pre_prediction_callback call (If it was defined).

At this point, the context argument will have the following attributes: - epoch - batch_idx - optimizer - inputs - target - metrics_compute_fn - loss_avg_meter - criterion - device - stop_training - experiment_name - ckpt_dir - net - lr_warmup_epochs - sg_logger - train_loader - valid_loader - training_params - ddp_silent_mode - checkpoint_params - arch_params - metric_to_watch - valid_metrics

Parameters:

Name Type Description Default
context PhaseContext required
Source code in V3_3/src/super_gradients/training/utils/callbacks/base_callbacks.py
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
def on_train_batch_start(self, context: PhaseContext) -> None:
    """
    Called at each batch after getting batch of data from data loader and moving it to target device.
    This event triggered AFTER Trainer.pre_prediction_callback call (If it was defined).

    At this point, the context argument will have the following attributes:
        - epoch
        - batch_idx
        - optimizer
        - inputs
        - target
        - metrics_compute_fn
        - loss_avg_meter
        - criterion
        - device
        - stop_training
        - experiment_name
        - ckpt_dir
        - net
        - lr_warmup_epochs
        - sg_logger
        - train_loader
        - valid_loader
        - training_params
        - ddp_silent_mode
        - checkpoint_params
        - arch_params
        - metric_to_watch
        - valid_metrics

    :param context:
    """
    pass

on_train_loader_end(context)

Called each epoch at the end of train data loader (after processing the last batch). At this point, the context argument will have the following attributes: - epoch - batch_idx - optimizer - inputs - preds - target - metrics_dict - metrics_compute_fn - loss_avg_meter - loss_log_items - criterion - device - stop_training - experiment_name - ckpt_dir - net - lr_warmup_epochs - sg_logger - train_loader - valid_loader - training_params - ddp_silent_mode - checkpoint_params - arch_params - metric_to_watch - valid_metrics - loss_logging_items_names

The corresponding Phase enum value for this event is Phase.TRAIN_EPOCH_END.

Parameters:

Name Type Description Default
context PhaseContext required
Source code in V3_3/src/super_gradients/training/utils/callbacks/base_callbacks.py
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
def on_train_loader_end(self, context: PhaseContext) -> None:
    """
    Called each epoch at the end of train data loader (after processing the last batch).
    At this point, the context argument will have the following attributes:
        - epoch
        - batch_idx
        - optimizer
        - inputs
        - preds
        - target
        - metrics_dict
        - metrics_compute_fn
        - loss_avg_meter
        - loss_log_items
        - criterion
        - device
        - stop_training
        - experiment_name
        - ckpt_dir
        - net
        - lr_warmup_epochs
        - sg_logger
        - train_loader
        - valid_loader
        - training_params
        - ddp_silent_mode
        - checkpoint_params
        - arch_params
        - metric_to_watch
        - valid_metrics
        - loss_logging_items_names

    The corresponding Phase enum value for this event is Phase.TRAIN_EPOCH_END.
    :param context:
    """
    pass

on_train_loader_start(context)

Called each epoch at the start of train data loader (before getting the first batch). At this point, the context argument will have the following attributes: - optimizer - criterion - device - experiment_name - ckpt_dir - net - sg_logger - train_loader - valid_loader - training_params - checkpoint_params - arch_params - metric_to_watch - valid_metrics The corresponding Phase enum value for this event is Phase.TRAIN_EPOCH_START.

Parameters:

Name Type Description Default
context PhaseContext required
Source code in V3_3/src/super_gradients/training/utils/callbacks/base_callbacks.py
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
def on_train_loader_start(self, context: PhaseContext) -> None:
    """
    Called each epoch at the start of train data loader (before getting the first batch).
    At this point, the context argument will have the following attributes:
        - optimizer
        - criterion
        - device
        - experiment_name
        - ckpt_dir
        - net
        - sg_logger
        - train_loader
        - valid_loader
        - training_params
        - checkpoint_params
        - arch_params
        - metric_to_watch
        - valid_metrics
    The corresponding Phase enum value for this event is Phase.TRAIN_EPOCH_START.
    :param context:
    """
    pass

on_training_end(context)

Called once after the training loop has finished (Due to reaching optimization criterion or because of an error.) At this point, the context argument will have the following attributes: - epoch - batch_idx - optimizer - inputs - preds - target - metrics_compute_fn - loss_avg_meter - loss_log_items - criterion - device - stop_training - experiment_name - ckpt_dir - net - lr_warmup_epochs - sg_logger - train_loader - valid_loader - training_params - ddp_silent_mode - checkpoint_params - arch_params - metric_to_watch - valid_metrics - loss_logging_items_names

The corresponding Phase enum value for this event is Phase.POST_TRAINING.

Parameters:

Name Type Description Default
context PhaseContext required
Source code in V3_3/src/super_gradients/training/utils/callbacks/base_callbacks.py
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
def on_training_end(self, context: PhaseContext) -> None:
    """
    Called once after the training loop has finished (Due to reaching optimization criterion or because of an error.)
    At this point, the context argument will have the following attributes:
        - epoch
        - batch_idx
        - optimizer
        - inputs
        - preds
        - target
        - metrics_compute_fn
        - loss_avg_meter
        - loss_log_items
        - criterion
        - device
        - stop_training
        - experiment_name
        - ckpt_dir
        - net
        - lr_warmup_epochs
        - sg_logger
        - train_loader
        - valid_loader
        - training_params
        - ddp_silent_mode
        - checkpoint_params
        - arch_params
        - metric_to_watch
        - valid_metrics
        - loss_logging_items_names

    The corresponding Phase enum value for this event is Phase.POST_TRAINING.
    :param context:
    """
    pass

on_training_start(context)

Called once before start of the first epoch At this point, the context argument will have the following attributes: - optimizer - criterion - device - experiment_name - ckpt_dir - net - sg_logger - train_loader - valid_loader - training_params - checkpoint_params - arch_params - metric_to_watch - valid_metrics

The corresponding Phase enum value for this event is Phase.PRE_TRAINING.

Parameters:

Name Type Description Default
context PhaseContext required
Source code in V3_3/src/super_gradients/training/utils/callbacks/base_callbacks.py
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
def on_training_start(self, context: PhaseContext) -> None:
    """
    Called once before start of the first epoch
    At this point, the context argument will have the following attributes:
        - optimizer
        - criterion
        - device
        - experiment_name
        - ckpt_dir
        - net
        - sg_logger
        - train_loader
        - valid_loader
        - training_params
        - checkpoint_params
        - arch_params
        - metric_to_watch
        - valid_metrics

    The corresponding Phase enum value for this event is Phase.PRE_TRAINING.
    :param context:
    """
    pass

on_validation_batch_end(context)

Called after all forward step / loss / metric computation have been performed for a given batch and there is nothing left to do. At this point, the context argument will have the following attributes: - epoch - batch_idx - inputs - preds - target - metrics_compute_fn - loss_avg_meter - loss_log_items - criterion - device - stop_training - net - lr_warmup_epochs - sg_logger - train_loader - valid_loader - loss_logging_items_names

The corresponding Phase enum value for this event is Phase.VALIDATION_BATCH_END.

Parameters:

Name Type Description Default
context PhaseContext required
Source code in V3_3/src/super_gradients/training/utils/callbacks/base_callbacks.py
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
def on_validation_batch_end(self, context: PhaseContext) -> None:
    """
    Called after all forward step / loss / metric computation have been performed for a given batch and there is nothing left to do.
    At this point, the context argument will have the following attributes:
        - epoch
        - batch_idx
        - inputs
        - preds
        - target
        - metrics_compute_fn
        - loss_avg_meter
        - loss_log_items
        - criterion
        - device
        - stop_training
        - net
        - lr_warmup_epochs
        - sg_logger
        - train_loader
        - valid_loader
        - loss_logging_items_names

    The corresponding Phase enum value for this event is Phase.VALIDATION_BATCH_END.
    :param context:
    """
    pass

on_validation_batch_start(context)

Called at each batch after getting batch of data from validation loader and moving it to target device. At this point, the context argument will have the following attributes: - epoch - batch_idx - inputs - target - metrics_compute_fn - loss_avg_meter - criterion - device - stop_training - net - lr_warmup_epochs - sg_logger - train_loader - valid_loader - loss_logging_items_names

Parameters:

Name Type Description Default
context PhaseContext required
Source code in V3_3/src/super_gradients/training/utils/callbacks/base_callbacks.py
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
def on_validation_batch_start(self, context: PhaseContext) -> None:
    """
    Called at each batch after getting batch of data from validation loader and moving it to target device.
    At this point, the context argument will have the following attributes:
        - epoch
        - batch_idx
        - inputs
        - target
        - metrics_compute_fn
        - loss_avg_meter
        - criterion
        - device
        - stop_training
        - net
        - lr_warmup_epochs
        - sg_logger
        - train_loader
        - valid_loader
        - loss_logging_items_names

    :param context:
    """
    pass

on_validation_end_best_epoch(context)

Called each epoch after validation has been performed and the best metric has been achieved. At this point, the context argument will have the following attributes: - epoch - batch_idx - optimizer - inputs - preds - target - metrics_dict - metrics_compute_fn - loss_avg_meter - loss_log_items - criterion - device - stop_training - experiment_name - ckpt_dir - net - lr_warmup_epochs - sg_logger - train_loader - valid_loader - training_params - ddp_silent_mode - checkpoint_params - arch_params - metric_to_watch - valid_metrics - loss_logging_items_names

The corresponding Phase enum value for this event is Phase.VALIDATION_END_BEST_EPOCH.

Parameters:

Name Type Description Default
context PhaseContext required
Source code in V3_3/src/super_gradients/training/utils/callbacks/base_callbacks.py
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
def on_validation_end_best_epoch(self, context: PhaseContext) -> None:
    """
    Called each epoch after validation has been performed and the best metric has been achieved.
    At this point, the context argument will have the following attributes:
        - epoch
        - batch_idx
        - optimizer
        - inputs
        - preds
        - target
        - metrics_dict
        - metrics_compute_fn
        - loss_avg_meter
        - loss_log_items
        - criterion
        - device
        - stop_training
        - experiment_name
        - ckpt_dir
        - net
        - lr_warmup_epochs
        - sg_logger
        - train_loader
        - valid_loader
        - training_params
        - ddp_silent_mode
        - checkpoint_params
        - arch_params
        - metric_to_watch
        - valid_metrics
        - loss_logging_items_names

    The corresponding Phase enum value for this event is Phase.VALIDATION_END_BEST_EPOCH.
    :param context:
    """
    pass

on_validation_loader_end(context)

Called each epoch at the end of validation data loader (after processing the last batch). At this point, the context argument will have the following attributes: - epoch - batch_idx - optimizer - inputs - preds - target - metrics_dict - metrics_compute_fn - loss_avg_meter - loss_log_items - criterion - device - stop_training - experiment_name - ckpt_dir - net - lr_warmup_epochs - sg_logger - train_loader - valid_loader - training_params - ddp_silent_mode - checkpoint_params - arch_params - metric_to_watch - valid_metrics - loss_logging_items_names

The corresponding Phase enum value for this event is Phase.VALIDATION_EPOCH_END.

Parameters:

Name Type Description Default
context PhaseContext required
Source code in V3_3/src/super_gradients/training/utils/callbacks/base_callbacks.py
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
def on_validation_loader_end(self, context: PhaseContext) -> None:
    """
    Called each epoch at the end of validation data loader (after processing the last batch).
    At this point, the context argument will have the following attributes:
        - epoch
        - batch_idx
        - optimizer
        - inputs
        - preds
        - target
        - metrics_dict
        - metrics_compute_fn
        - loss_avg_meter
        - loss_log_items
        - criterion
        - device
        - stop_training
        - experiment_name
        - ckpt_dir
        - net
        - lr_warmup_epochs
        - sg_logger
        - train_loader
        - valid_loader
        - training_params
        - ddp_silent_mode
        - checkpoint_params
        - arch_params
        - metric_to_watch
        - valid_metrics
        - loss_logging_items_names

    The corresponding Phase enum value for this event is Phase.VALIDATION_EPOCH_END.
    :param context:
    """
    pass

on_validation_loader_start(context)

Called each epoch at the start of validation data loader (before getting the first batch). At this point, the context argument will have the following attributes: - epoch - batch_idx - optimizer - inputs - preds - target - metrics_dict - metrics_compute_fn - loss_avg_meter - loss_log_items - criterion - device - stop_training - experiment_name - ckpt_dir - net - lr_warmup_epochs - sg_logger - train_loader - valid_loader - training_params - ddp_silent_mode - checkpoint_params - arch_params - metric_to_watch - valid_metrics - loss_logging_items_names

Parameters:

Name Type Description Default
context PhaseContext required
Source code in V3_3/src/super_gradients/training/utils/callbacks/base_callbacks.py
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
def on_validation_loader_start(self, context: PhaseContext) -> None:
    """
    Called each epoch at the start of validation data loader (before getting the first batch).
    At this point, the context argument will have the following attributes:
        - epoch
        - batch_idx
        - optimizer
        - inputs
        - preds
        - target
        - metrics_dict
        - metrics_compute_fn
        - loss_avg_meter
        - loss_log_items
        - criterion
        - device
        - stop_training
        - experiment_name
        - ckpt_dir
        - net
        - lr_warmup_epochs
        - sg_logger
        - train_loader
        - valid_loader
        - training_params
        - ddp_silent_mode
        - checkpoint_params
        - arch_params
        - metric_to_watch
        - valid_metrics
        - loss_logging_items_names

    :param context:
    """
    pass

CallbackHandler

Bases: Callback

Runs all callbacks

Parameters:

Name Type Description Default
callbacks List[Callback]

Callbacks to be run.

required
Source code in V3_3/src/super_gradients/training/utils/callbacks/base_callbacks.py
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
class CallbackHandler(Callback):
    """
    Runs all callbacks

    :param callbacks: Callbacks to be run.
    """

    def __init__(self, callbacks: List[Callback]):
        # TODO: Add reordering of callbacks to make sure that they are called in the right order
        # For instance, two callbacks may be dependent on each other, so the first one should be called first
        # Example: Gradient Clipping & Gradient Logging callback. We first need to clip the gradients, and then log them
        # So if user added them in wrong order we can guarantee their order would be correct.
        # We can achieve this by adding a property to the callback to the callback indicating it's priority:
        # Forward   = 0
        # Loss      = 100
        # Backward  = 200
        # Metrics   = 300
        # Scheduler = 400
        # Logging   = 500
        # So ordering callbacks by their order would ensure than we first run all Forward-related callbacks (for a given event),
        # Than backward, and only then - logging.
        self.callbacks = callbacks

    def on_training_start(self, context: PhaseContext) -> None:
        for callback in self.callbacks:
            callback.on_training_start(context)

    def on_train_loader_start(self, context: PhaseContext) -> None:
        for callback in self.callbacks:
            callback.on_train_loader_start(context)

    def on_train_batch_start(self, context: PhaseContext) -> None:
        for callback in self.callbacks:
            callback.on_train_batch_start(context)

    def on_train_batch_loss_end(self, context: PhaseContext) -> None:
        for callback in self.callbacks:
            callback.on_train_batch_loss_end(context)

    def on_train_batch_backward_end(self, context: PhaseContext) -> None:
        for callback in self.callbacks:
            callback.on_train_batch_backward_end(context)

    def on_train_batch_gradient_step_start(self, context: PhaseContext) -> None:
        for callback in self.callbacks:
            callback.on_train_batch_gradient_step_start(context)

    def on_train_batch_gradient_step_end(self, context: PhaseContext) -> None:
        for callback in self.callbacks:
            callback.on_train_batch_gradient_step_end(context)

    def on_train_batch_end(self, context: PhaseContext) -> None:
        for callback in self.callbacks:
            callback.on_train_batch_end(context)

    def on_validation_loader_start(self, context: PhaseContext) -> None:
        for callback in self.callbacks:
            callback.on_validation_loader_start(context)

    def on_validation_batch_start(self, context: PhaseContext) -> None:
        for callback in self.callbacks:
            callback.on_validation_batch_start(context)

    def on_validation_batch_end(self, context: PhaseContext) -> None:
        for callback in self.callbacks:
            callback.on_validation_batch_end(context)

    def on_validation_loader_end(self, context: PhaseContext) -> None:
        for callback in self.callbacks:
            callback.on_validation_loader_end(context)

    def on_train_loader_end(self, context: PhaseContext) -> None:
        for callback in self.callbacks:
            callback.on_train_loader_end(context)

    def on_training_end(self, context: PhaseContext) -> None:
        for callback in self.callbacks:
            callback.on_training_end(context)

    def on_validation_end_best_epoch(self, context: PhaseContext) -> None:
        for callback in self.callbacks:
            callback.on_validation_end_best_epoch(context)

    def on_test_loader_start(self, context: PhaseContext) -> None:
        for callback in self.callbacks:
            callback.on_test_loader_start(context)

    def on_test_batch_start(self, context: PhaseContext) -> None:
        for callback in self.callbacks:
            callback.on_test_batch_start(context)

    def on_test_batch_end(self, context: PhaseContext) -> None:
        for callback in self.callbacks:
            callback.on_test_batch_end(context)

    def on_test_loader_end(self, context: PhaseContext) -> None:
        for callback in self.callbacks:
            callback.on_test_loader_end(context)

    def on_average_best_models_validation_start(self, context: PhaseContext) -> None:
        for callback in self.callbacks:
            callback.on_average_best_models_validation_start(context)

    def on_average_best_models_validation_end(self, context: PhaseContext) -> None:
        for callback in self.callbacks:
            callback.on_average_best_models_validation_end(context)

PhaseCallback

Bases: Callback

Kept here to keep backward compatibility with old code. New callbacks should use Callback class instead. This callback supports receiving only a subset of events defined in Phase enum:

PRE_TRAINING = "PRE_TRAINING" TRAIN_EPOCH_START = "TRAIN_EPOCH_START" TRAIN_BATCH_END = "TRAIN_BATCH_END" TRAIN_BATCH_STEP = "TRAIN_BATCH_STEP" TRAIN_EPOCH_END = "TRAIN_EPOCH_END"

VALIDATION_BATCH_END = "VALIDATION_BATCH_END" VALIDATION_EPOCH_END = "VALIDATION_EPOCH_END" VALIDATION_END_BEST_EPOCH = "VALIDATION_END_BEST_EPOCH"

TEST_BATCH_END = "TEST_BATCH_END" TEST_END = "TEST_END" AVERAGE_BEST_MODELS_VALIDATION_START = "AVERAGE_BEST_MODELS_VALIDATION_START" AVERAGE_BEST_MODELS_VALIDATION_END = "AVERAGE_BEST_MODELS_VALIDATION_END" POST_TRAINING = "POST_TRAINING"

Source code in V3_3/src/super_gradients/training/utils/callbacks/base_callbacks.py
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
class PhaseCallback(Callback):
    """
    Kept here to keep backward compatibility with old code. New callbacks should use Callback class instead.
    This callback supports receiving only a subset of events defined in Phase enum:

    PRE_TRAINING = "PRE_TRAINING"
    TRAIN_EPOCH_START = "TRAIN_EPOCH_START"
    TRAIN_BATCH_END = "TRAIN_BATCH_END"
    TRAIN_BATCH_STEP = "TRAIN_BATCH_STEP"
    TRAIN_EPOCH_END = "TRAIN_EPOCH_END"

    VALIDATION_BATCH_END = "VALIDATION_BATCH_END"
    VALIDATION_EPOCH_END = "VALIDATION_EPOCH_END"
    VALIDATION_END_BEST_EPOCH = "VALIDATION_END_BEST_EPOCH"

    TEST_BATCH_END = "TEST_BATCH_END"
    TEST_END = "TEST_END"
    AVERAGE_BEST_MODELS_VALIDATION_START = "AVERAGE_BEST_MODELS_VALIDATION_START"
    AVERAGE_BEST_MODELS_VALIDATION_END = "AVERAGE_BEST_MODELS_VALIDATION_END"
    POST_TRAINING = "POST_TRAINING"
    """

    def __init__(self, phase: Phase):
        self.phase = phase

    def __call__(self, *args, **kwargs):
        raise NotImplementedError

    def __repr__(self) -> str:
        return self.__class__.__name__

    def on_training_start(self, context: PhaseContext) -> None:
        if self.phase == Phase.PRE_TRAINING:
            self(context)

    def on_train_loader_start(self, context: PhaseContext) -> None:
        if self.phase == Phase.TRAIN_EPOCH_START:
            self(context)

    def on_train_batch_loss_end(self, context: PhaseContext) -> None:
        if self.phase == Phase.TRAIN_BATCH_END:
            self(context)

    def on_train_batch_gradient_step_end(self, context: PhaseContext) -> None:
        if self.phase == Phase.TRAIN_BATCH_STEP:
            self(context)

    def on_train_loader_end(self, context: PhaseContext) -> None:
        if self.phase == Phase.TRAIN_EPOCH_END:
            self(context)

    def on_validation_batch_end(self, context: PhaseContext) -> None:
        if self.phase == Phase.VALIDATION_BATCH_END:
            self(context)

    def on_validation_loader_end(self, context: PhaseContext) -> None:
        if self.phase == Phase.VALIDATION_EPOCH_END:
            self(context)

    def on_validation_end_best_epoch(self, context: PhaseContext) -> None:
        if self.phase == Phase.VALIDATION_END_BEST_EPOCH:
            self(context)

    def on_test_batch_end(self, context: PhaseContext) -> None:
        if self.phase == Phase.TEST_BATCH_END:
            self(context)

    def on_test_loader_end(self, context: PhaseContext) -> None:
        if self.phase == Phase.TEST_END:
            self(context)

    def on_average_best_models_validation_start(self, context: PhaseContext) -> None:
        if self.phase == Phase.AVERAGE_BEST_MODELS_VALIDATION_START:
            self(context)

    def on_average_best_models_validation_end(self, context: PhaseContext) -> None:
        if self.phase == Phase.AVERAGE_BEST_MODELS_VALIDATION_END:
            self(context)

    def on_training_end(self, context: PhaseContext) -> None:
        if self.phase == Phase.POST_TRAINING:
            self(context)

PhaseContext

Represents the input for phase callbacks, and is constantly updated after callback calls.

Source code in V3_3/src/super_gradients/training/utils/callbacks/base_callbacks.py
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
class PhaseContext:
    """
    Represents the input for phase callbacks, and is constantly updated after callback calls.

    """

    def __init__(
        self,
        epoch: Optional[int] = None,
        batch_idx: Optional[int] = None,
        optimizer: Optional[torch.optim.Optimizer] = None,
        metrics_dict=None,
        inputs: Optional[torch.Tensor] = None,
        preds: Optional[torch.Tensor] = None,
        target: Optional[torch.Tensor] = None,
        metrics_compute_fn: Optional[MetricCollection] = None,
        loss_avg_meter: Optional["AverageMeter"] = None,  # noqa: ignore
        loss_log_items: Optional[torch.Tensor] = None,
        criterion: Optional[_Loss] = None,
        device: Optional[str] = None,
        experiment_name: Optional[str] = None,
        ckpt_dir: Optional[str] = None,
        net: Optional["SgModule"] = None,  # noqa: ignore
        lr_warmup_epochs: Optional[int] = None,
        sg_logger: Optional["BaseSGLogger"] = None,  # noqa: ignore
        train_loader: Optional[DataLoader] = None,
        valid_loader: Optional[DataLoader] = None,
        test_loader: Optional[DataLoader] = None,
        training_params: Optional["TrainingParams"] = None,  # noqa: ignore
        ddp_silent_mode: Optional[bool] = None,
        checkpoint_params: Optional["HpmStruct"] = None,  # noqa: ignore
        architecture: Optional = None,
        arch_params: Optional["HpmStruct"] = None,  # noqa: ignore
        metric_to_watch: Optional[str] = None,
        valid_metrics: Optional[MetricCollection] = None,  # noqa: ignore
        ema_model: Optional["SgModule"] = None,  # noqa: ignore
        loss_logging_items_names: Optional[List[str]] = None,
        additional_batch_items: Optional[Any] = None,
    ):
        self.epoch = epoch
        self.batch_idx = batch_idx
        self.optimizer = optimizer
        self.inputs = inputs
        self.preds = preds
        self.target = target
        self.metrics_dict = metrics_dict
        self.metrics_compute_fn = metrics_compute_fn
        self.loss_avg_meter = loss_avg_meter
        self.loss_log_items = loss_log_items
        self.criterion = criterion
        self.device = device
        self.stop_training = False
        self.experiment_name = experiment_name
        self.ckpt_dir = ckpt_dir
        self.net = net
        self.lr_warmup_epochs = lr_warmup_epochs
        self.sg_logger = sg_logger
        self.train_loader = train_loader
        self.valid_loader = valid_loader
        self.test_loader = test_loader
        self.training_params = training_params
        self.ddp_silent_mode = ddp_silent_mode
        self.checkpoint_params = checkpoint_params
        self.architecture = architecture
        self.arch_params = arch_params
        self.metric_to_watch = metric_to_watch
        self.valid_metrics = valid_metrics
        self.ema_model = ema_model
        self.loss_logging_items_names = loss_logging_items_names
        self.additional_batch_items = additional_batch_items

    def update_context(self, **kwargs):
        for attr, attr_val in kwargs.items():
            setattr(self, attr, attr_val)

BinarySegmentationVisualizationCallback

Bases: PhaseCallback

A callback that adds a visualization of a batch of segmentation predictions to context.sg_logger

Parameters:

Name Type Description Default
phase Phase

When to trigger the callback.

required
freq int

Frequency (in epochs) to perform this callback.

required
batch_idx int

Batch index to perform visualization for.

0
last_img_idx_in_batch int

Last image index to add to log. (default=-1, will take entire batch).

-1
Source code in V3_3/src/super_gradients/training/utils/callbacks/callbacks.py
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
class BinarySegmentationVisualizationCallback(PhaseCallback):
    """
    A callback that adds a visualization of a batch of segmentation predictions to context.sg_logger

    :param phase:                   When to trigger the callback.
    :param freq:                    Frequency (in epochs) to perform this callback.
    :param batch_idx:               Batch index to perform visualization for.
    :param last_img_idx_in_batch:   Last image index to add to log. (default=-1, will take entire batch).
    """

    def __init__(self, phase: Phase, freq: int, batch_idx: int = 0, last_img_idx_in_batch: int = -1):
        super(BinarySegmentationVisualizationCallback, self).__init__(phase)
        self.freq = freq
        self.batch_idx = batch_idx
        self.last_img_idx_in_batch = last_img_idx_in_batch

    def __call__(self, context: PhaseContext):
        if context.epoch % self.freq == 0 and context.batch_idx == self.batch_idx:
            if isinstance(context.preds, tuple):
                preds = context.preds[0].clone()
            else:
                preds = context.preds.clone()
            batch_imgs = BinarySegmentationVisualization.visualize_batch(context.inputs, preds, context.target, self.batch_idx)
            batch_imgs = [cv2.cvtColor(image, cv2.COLOR_BGR2RGB) for image in batch_imgs]
            batch_imgs = np.stack(batch_imgs)
            tag = "batch_" + str(self.batch_idx) + "_images"
            context.sg_logger.add_images(tag=tag, images=batch_imgs[: self.last_img_idx_in_batch], global_step=context.epoch, data_format="NHWC")

CosineLRScheduler

Bases: LRCallbackBase

Hard coded step Cosine anealing learning rate scheduling.

Source code in V3_3/src/super_gradients/training/utils/callbacks/callbacks.py
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
@register_lr_scheduler(LRSchedulers.COSINE, deprecated_name="cosine")
class CosineLRScheduler(LRCallbackBase):
    """
    Hard coded step Cosine anealing learning rate scheduling.
    """

    def __init__(self, max_epochs, cosine_final_lr_ratio, **kwargs):
        super().__init__(Phase.TRAIN_BATCH_STEP, **kwargs)
        self.max_epochs = max_epochs
        self.cosine_final_lr_ratio = cosine_final_lr_ratio

    def perform_scheduling(self, context):
        effective_epoch = context.epoch - self.training_params.lr_warmup_epochs
        effective_max_epochs = self.max_epochs - self.training_params.lr_warmup_epochs - self.training_params.lr_cooldown_epochs
        current_iter = max(0, self.train_loader_len * effective_epoch + context.batch_idx - self.training_params.lr_warmup_steps)
        max_iter = self.train_loader_len * effective_max_epochs - self.training_params.lr_warmup_steps

        lr = self.compute_learning_rate(current_iter, max_iter, self.initial_lr, self.cosine_final_lr_ratio)
        self.lr = float(lr)
        self.update_lr(context.optimizer, context.epoch, context.batch_idx)

    def is_lr_scheduling_enabled(self, context):
        # Account of per-step warmup
        if self.training_params.lr_warmup_steps > 0:
            current_step = self.train_loader_len * context.epoch + context.batch_idx
            return current_step >= self.training_params.lr_warmup_steps

        post_warmup_epochs = self.training_params.max_epochs - self.training_params.lr_cooldown_epochs
        return self.training_params.lr_warmup_epochs <= context.epoch < post_warmup_epochs

    @classmethod
    def compute_learning_rate(cls, step: Union[float, np.ndarray], total_steps: float, initial_lr: float, final_lr_ratio: float):
        # the cosine starts from initial_lr and reaches initial_lr * cosine_final_lr_ratio in last epoch

        lr = 0.5 * initial_lr * (1.0 + np.cos(step / (total_steps + 1) * math.pi))
        return lr * (1 - final_lr_ratio) + (initial_lr * final_lr_ratio)

DeciLabUploadCallback

Bases: PhaseCallback

Post-training callback for uploading and optimizing a model.

Parameters:

Name Type Description Default
model_meta_data

Model's meta-data object. Type: ModelMetadata

required
optimization_request_form

Optimization request form object. Type: OptimizationRequestForm

required
ckpt_name str

Checkpoint filename, inside the checkpoint directory.

'ckpt_best.pth'
Source code in V3_3/src/super_gradients/training/utils/callbacks/callbacks.py
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
@register_callback(Callbacks.DECI_LAB_UPLOAD)
class DeciLabUploadCallback(PhaseCallback):
    """
    Post-training callback for uploading and optimizing a model.

    :param model_meta_data:             Model's meta-data object. Type: ModelMetadata
    :param optimization_request_form:   Optimization request form object. Type: OptimizationRequestForm
    :param ckpt_name:                   Checkpoint filename, inside the checkpoint directory.
    """

    def __init__(
        self,
        model_name: str,
        input_dimensions: Sequence[int],
        target_hardware_types: "Optional[List[str]]" = None,
        target_batch_size: "Optional[int]" = None,
        target_quantization_level: "Optional[str]" = None,
        ckpt_name: str = "ckpt_best.pth",
        **kwargs,
    ):
        super().__init__(phase=Phase.POST_TRAINING)
        self.input_dimensions = input_dimensions
        self.model_name = model_name
        self.target_hardware_types = target_hardware_types
        self.target_batch_size = target_batch_size
        self.target_quantization_level = target_quantization_level
        self.ckpt_name = ckpt_name
        self.platform_client = DeciClient()

    @staticmethod
    def log_optimization_failed():
        logger.info("We couldn't finish your model optimization. Visit https://console.deci.ai for details")

    def upload_model(self, model):
        """
        This function will upload the trained model to the Deci Lab

        :param model: The resulting model from the training process
        """
        self.platform_client.upload_model(
            model=model,
            name=self.model_name,
            input_dimensions=self.input_dimensions,
            target_hardware_types=self.target_hardware_types,
            target_batch_size=self.target_batch_size,
            target_quantization_level=self.target_quantization_level,
        )

    def get_optimization_status(self, optimized_model_name: str):
        """
        This function will do fetch the optimized version of the trained model and check on its benchmark status.
        The status will be checked against the server every 30 seconds and the process will timeout after 30 minutes
        or log about the successful optimization - whichever happens first.

        :param optimized_model_name: Optimized model name

        :return: Whether or not the optimized model has been benchmarked
        """

        def handler(_signum, _frame):
            logger.error("Process timed out. Visit https://console.deci.ai for details")
            return False

        signal.signal(signal.SIGALRM, handler)
        signal.alarm(1800)

        finished = False
        while not finished:
            if self.platform_client.is_model_benchmarking(name=optimized_model_name):
                time.sleep(30)
            else:
                finished = True

        signal.alarm(0)
        return True

    def __call__(self, context: PhaseContext) -> None:
        """
        This function will attempt to upload the trained model and schedule an optimization for it.

        :param context: Training phase context
        """
        try:
            model = copy.deepcopy(unwrap_model(context.net))
            model_state_dict_path = os.path.join(context.ckpt_dir, self.ckpt_name)
            model_state_dict = torch.load(model_state_dict_path)["net"]
            model.load_state_dict(state_dict=model_state_dict)

            model = model.cpu()
            if hasattr(model, "prep_model_for_conversion"):
                model.prep_model_for_conversion(input_size=self.input_dimensions)

            self.upload_model(model=model)
            model_name = self.model_name
            logger.info(f"Successfully added {model_name} to the model repository")

            optimized_model_name = f"{model_name}_1_1"
            logger.info("We'll wait for the scheduled optimization to finish. Please don't close this window")
            success = self.get_optimization_status(optimized_model_name=optimized_model_name)
            if success:
                logger.info("Successfully finished your model optimization. Visit https://console.deci.ai for details")
            else:
                DeciLabUploadCallback.log_optimization_failed()
        except Exception as ex:
            DeciLabUploadCallback.log_optimization_failed()
            logger.error(ex)

__call__(context)

This function will attempt to upload the trained model and schedule an optimization for it.

Parameters:

Name Type Description Default
context PhaseContext

Training phase context

required
Source code in V3_3/src/super_gradients/training/utils/callbacks/callbacks.py
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
def __call__(self, context: PhaseContext) -> None:
    """
    This function will attempt to upload the trained model and schedule an optimization for it.

    :param context: Training phase context
    """
    try:
        model = copy.deepcopy(unwrap_model(context.net))
        model_state_dict_path = os.path.join(context.ckpt_dir, self.ckpt_name)
        model_state_dict = torch.load(model_state_dict_path)["net"]
        model.load_state_dict(state_dict=model_state_dict)

        model = model.cpu()
        if hasattr(model, "prep_model_for_conversion"):
            model.prep_model_for_conversion(input_size=self.input_dimensions)

        self.upload_model(model=model)
        model_name = self.model_name
        logger.info(f"Successfully added {model_name} to the model repository")

        optimized_model_name = f"{model_name}_1_1"
        logger.info("We'll wait for the scheduled optimization to finish. Please don't close this window")
        success = self.get_optimization_status(optimized_model_name=optimized_model_name)
        if success:
            logger.info("Successfully finished your model optimization. Visit https://console.deci.ai for details")
        else:
            DeciLabUploadCallback.log_optimization_failed()
    except Exception as ex:
        DeciLabUploadCallback.log_optimization_failed()
        logger.error(ex)

get_optimization_status(optimized_model_name)

This function will do fetch the optimized version of the trained model and check on its benchmark status. The status will be checked against the server every 30 seconds and the process will timeout after 30 minutes or log about the successful optimization - whichever happens first.

Parameters:

Name Type Description Default
optimized_model_name str

Optimized model name

required

Returns:

Type Description

Whether or not the optimized model has been benchmarked

Source code in V3_3/src/super_gradients/training/utils/callbacks/callbacks.py
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
def get_optimization_status(self, optimized_model_name: str):
    """
    This function will do fetch the optimized version of the trained model and check on its benchmark status.
    The status will be checked against the server every 30 seconds and the process will timeout after 30 minutes
    or log about the successful optimization - whichever happens first.

    :param optimized_model_name: Optimized model name

    :return: Whether or not the optimized model has been benchmarked
    """

    def handler(_signum, _frame):
        logger.error("Process timed out. Visit https://console.deci.ai for details")
        return False

    signal.signal(signal.SIGALRM, handler)
    signal.alarm(1800)

    finished = False
    while not finished:
        if self.platform_client.is_model_benchmarking(name=optimized_model_name):
            time.sleep(30)
        else:
            finished = True

    signal.alarm(0)
    return True

upload_model(model)

This function will upload the trained model to the Deci Lab

Parameters:

Name Type Description Default
model

The resulting model from the training process

required
Source code in V3_3/src/super_gradients/training/utils/callbacks/callbacks.py
156
157
158
159
160
161
162
163
164
165
166
167
168
169
def upload_model(self, model):
    """
    This function will upload the trained model to the Deci Lab

    :param model: The resulting model from the training process
    """
    self.platform_client.upload_model(
        model=model,
        name=self.model_name,
        input_dimensions=self.input_dimensions,
        target_hardware_types=self.target_hardware_types,
        target_batch_size=self.target_batch_size,
        target_quantization_level=self.target_quantization_level,
    )

DetectionVisualizationCallback

Bases: PhaseCallback

A callback that adds a visualization of a batch of detection predictions to context.sg_logger

Parameters:

Name Type Description Default
phase Phase

When to trigger the callback.

required
freq int

Frequency (in epochs) to perform this callback.

required
batch_idx int

Batch index to perform visualization for.

0
classes list

Class list of the dataset.

required
last_img_idx_in_batch int

Last image index to add to log. (default=-1, will take entire batch).

-1
Source code in V3_3/src/super_gradients/training/utils/callbacks/callbacks.py
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
@register_callback(Callbacks.DETECTION_VISUALIZATION_CALLBACK)
class DetectionVisualizationCallback(PhaseCallback):
    """
    A callback that adds a visualization of a batch of detection predictions to context.sg_logger

    :param phase:                   When to trigger the callback.
    :param freq:                    Frequency (in epochs) to perform this callback.
    :param batch_idx:               Batch index to perform visualization for.
    :param classes:                 Class list of the dataset.
    :param last_img_idx_in_batch:   Last image index to add to log. (default=-1, will take entire batch).
    """

    def __init__(
        self,
        phase: Phase,
        freq: int,
        post_prediction_callback: DetectionPostPredictionCallback,
        classes: list,
        batch_idx: int = 0,
        last_img_idx_in_batch: int = -1,
    ):
        super(DetectionVisualizationCallback, self).__init__(phase)
        self.freq = freq
        self.post_prediction_callback = post_prediction_callback
        self.batch_idx = batch_idx
        self.classes = classes
        self.last_img_idx_in_batch = last_img_idx_in_batch

    def __call__(self, context: PhaseContext):
        if context.epoch % self.freq == 0 and context.batch_idx == self.batch_idx:
            # SOME CALCULATIONS ARE IN-PLACE IN NMS, SO CLONE THE PREDICTIONS
            preds = (context.preds[0].clone(), None)
            preds = self.post_prediction_callback(preds)
            batch_imgs = DetectionVisualization.visualize_batch(context.inputs, preds, context.target, self.batch_idx, self.classes)
            batch_imgs = [cv2.cvtColor(image, cv2.COLOR_BGR2RGB) for image in batch_imgs]
            batch_imgs = np.stack(batch_imgs)
            tag = "batch_" + str(self.batch_idx) + "_images"
            context.sg_logger.add_images(tag=tag, images=batch_imgs[: self.last_img_idx_in_batch], global_step=context.epoch, data_format="NHWC")

ExponentialLRScheduler

Bases: LRCallbackBase

Exponential decay learning rate scheduling. Decays the learning rate by lr_decay_factor every epoch.

Source code in V3_3/src/super_gradients/training/utils/callbacks/callbacks.py
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
@register_lr_scheduler(LRSchedulers.EXP, deprecated_name="exp")
class ExponentialLRScheduler(LRCallbackBase):
    """
    Exponential decay learning rate scheduling. Decays the learning rate by `lr_decay_factor` every epoch.
    """

    def __init__(self, lr_decay_factor: float, **kwargs):
        super().__init__(phase=Phase.TRAIN_BATCH_STEP, **kwargs)
        self.lr_decay_factor = lr_decay_factor

    def perform_scheduling(self, context):
        effective_epoch = context.epoch - self.training_params.lr_warmup_epochs
        current_iter = self.train_loader_len * effective_epoch + context.batch_idx
        self.lr = self.initial_lr * self.lr_decay_factor ** (current_iter / self.train_loader_len)
        self.update_lr(context.optimizer, context.epoch, context.batch_idx)

    def is_lr_scheduling_enabled(self, context):
        post_warmup_epochs = self.training_params.max_epochs - self.training_params.lr_cooldown_epochs
        return self.training_params.lr_warmup_epochs <= context.epoch < post_warmup_epochs

ExtremeBatchCaseVisualizationCallback

Bases: Callback, ABC

ExtremeBatchCaseVisualizationCallback

A base class for visualizing worst/best validation batches in an epoch according to some metric or loss value, with Full DDP support.

Images are saved with training_hyperparams.sg_logger.

Parameters:

Name Type Description Default
metric Optional[Metric]

Metric, will be the metric which is monitored.

None
metric_component_name Optional[str]

In case metric returns multiple values (as Mapping), the value at metric.compute()[metric_component_name] will be the one monitored.

None
loss_to_monitor Optional[str]

str, loss_to_monitor corresponfing to the 'criterion' passed through training_params in Trainer.train(...). Monitoring loss follows the same logic as metric_to_watch in Trainer.train(..), when watching the loss and should be: if hasattr(criterion, "component_names") and criterion.forward(..) returns a tuple: "/". If a single item is returned rather then a tuple: . When there is no such attributesand criterion.forward(..) returns a tuple: "/"Loss_"

None
max bool

bool, Whether to take the batch corresponding to the max value of the metric/loss or the minimum (default=False).

False
freq int

int, epoch frequency to perform all of the above (default=1). Inheritors should implement process_extreme_batch which returns an image, as np.ndarray (uint8) with shape BHWC.

1
Source code in V3_3/src/super_gradients/training/utils/callbacks/callbacks.py
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
class ExtremeBatchCaseVisualizationCallback(Callback, ABC):
    """
    ExtremeBatchCaseVisualizationCallback

    A base class for visualizing worst/best validation batches in an epoch
     according to some metric or loss value, with Full DDP support.

    Images are saved with training_hyperparams.sg_logger.

    :param metric: Metric, will be the metric which is monitored.

    :param metric_component_name: In case metric returns multiple values (as Mapping),
     the value at metric.compute()[metric_component_name] will be the one monitored.

    :param loss_to_monitor: str, loss_to_monitor corresponfing to the 'criterion' passed through training_params in Trainer.train(...).
     Monitoring loss follows the same logic as metric_to_watch in Trainer.train(..), when watching the loss and should be:

        if hasattr(criterion, "component_names") and criterion.forward(..) returns a tuple:
            <LOSS_CLASS.__name__>"/"<COMPONENT_NAME>.

        If a single item is returned rather then a tuple:
            <LOSS_CLASS.__name__>.

        When there is no such attributesand criterion.forward(..) returns a tuple:
            <LOSS_CLASS.__name__>"/"Loss_"<IDX>

    :param max: bool, Whether to take the batch corresponding to the max value of the metric/loss or
     the minimum (default=False).

    :param freq: int, epoch frequency to perform all of the above (default=1).

     Inheritors should implement process_extreme_batch which returns an image, as np.ndarray (uint8) with shape BHWC.
    """

    @resolve_param("metric", MetricsFactory())
    def __init__(
        self,
        metric: Optional[Metric] = None,
        metric_component_name: Optional[str] = None,
        loss_to_monitor: Optional[str] = None,
        max: bool = False,
        freq: int = 1,
        enable_on_train_loader: bool = False,
        enable_on_valid_loader: bool = True,
        max_images: int = -1,
    ):
        """
        :param metric: Metric, will be the metric which is monitored.

        :param metric_component_name: In case metric returns multiple values (as Mapping),
         the value at metric.compute()[metric_component_name] will be the one monitored.

        :param loss_to_monitor: str, loss_to_monitor corresponding to the 'criterion' passed through training_params in Trainer.train(...).
         Monitoring loss follows the same logic as metric_to_watch in Trainer.train(..), when watching the loss and should be:

        if hasattr(criterion, "component_names") and criterion.forward(..) returns a tuple:
            <LOSS_CLASS.__name__>"/"<COMPONENT_NAME>.

        If a single item is returned rather then a tuple:
            <LOSS_CLASS.__name__>.

        When there is no such attributes and criterion.forward(..) returns a tuple:
            <LOSS_CLASS.__name__>"/"Loss_"<IDX>

        :param max:                    bool, Whether to take the batch corresponding to the max value of the metric/loss or
        the minimum (default=False).

        :param freq:                   int, epoch frequency to perform all of the above (default=1).
        :param enable_on_train_loader: Controls whether to enable this callback on the train loader. Default is False.
        :param enable_on_valid_loader: Controls whether to enable this callback on the valid loader. Default is True.
        :param max_images:             Maximum images to save. If -1, save all images.
        """
        super(ExtremeBatchCaseVisualizationCallback, self).__init__()

        if (metric and loss_to_monitor) or (metric is None and loss_to_monitor is None):
            raise RuntimeError("Must pass exactly one of: loss, metric != None")

        self._set_tag_attr(loss_to_monitor, max, metric, metric_component_name)
        self.metric = metric
        if self.metric:
            self.metric = MetricCollection(self.metric)
            self.metric.to(device_config.device)

        self.metric_component_name = metric_component_name

        self.loss_to_monitor = loss_to_monitor
        self.max = max
        self.freq = freq

        self.extreme_score = None
        self.extreme_batch = None
        self.extreme_preds = None
        self.extreme_targets = None
        self.extreme_additional_batch_items = None

        self._first_call = True
        self._idx_loss_tuple = None

        self.enable_on_train_loader = enable_on_train_loader
        self.enable_on_valid_loader = enable_on_valid_loader
        self.max_images = max_images

    def _set_tag_attr(self, loss_to_monitor, max, metric, metric_component_name):
        if metric_component_name:
            monitored_val_name = metric_component_name
        elif metric:
            monitored_val_name = metric.__class__.__name__
        else:
            monitored_val_name = loss_to_monitor
        self._tag = f"max_{monitored_val_name}_batch" if max else f"min_{monitored_val_name}_batch"

    @abstractmethod
    def process_extreme_batch(self) -> np.ndarray:
        """
        This method is called right before adding the images to the in  SGLoggger (inside the on_validation_loader_end call).
         It should process self.extreme_batch, self.extreme_preds and self.extreme_targets and output the images, as np.ndarrray.
         Output should be of shape N,H,W,3 and uint8.
        :return: images to save, np.ndarray
        """
        raise NotImplementedError

    def on_train_loader_start(self, context: PhaseContext) -> None:
        self._reset()

    def on_train_batch_end(self, context: PhaseContext) -> None:
        if self.enable_on_train_loader and context.epoch % self.freq == 0:
            self._on_batch_end(context)

    def on_train_loader_end(self, context: PhaseContext) -> None:
        if self.enable_on_train_loader and context.epoch % self.freq == 0:
            self._gather_extreme_batch_images_and_log(context, "train")
            self._reset()

    def on_validation_loader_start(self, context: PhaseContext) -> None:
        self._reset()

    def on_validation_batch_end(self, context: PhaseContext) -> None:
        if self.enable_on_valid_loader and context.epoch % self.freq == 0:
            self._on_batch_end(context)

    def on_validation_loader_end(self, context: PhaseContext) -> None:
        if self.enable_on_valid_loader and context.epoch % self.freq == 0:
            self._gather_extreme_batch_images_and_log(context, "valid")
            self._reset()

    def _gather_extreme_batch_images_and_log(self, context, loader_name: str):
        images_to_save = self.process_extreme_batch()
        images_to_save = maybe_all_gather_np_images(images_to_save)
        if self.max_images > 0:
            images_to_save = images_to_save[: self.max_images]
        if not context.ddp_silent_mode:
            context.sg_logger.add_images(tag=f"{loader_name}/{self._tag}", images=images_to_save, global_step=context.epoch, data_format="NHWC")

    def _on_batch_end(self, context: PhaseContext) -> None:
        if self.metric is not None:
            self.metric.update(**context.__dict__)
            score = self.metric.compute()
            if self.metric_component_name is not None:
                if not isinstance(score, Mapping) or (isinstance(score, Mapping) and self.metric_component_name not in score.keys()):
                    raise RuntimeError(
                        f"metric_component_name: {self.metric_component_name} is not a component of the monitored metric: {self.metric.__class__.__name__}"
                    )
                score = score[self.metric_component_name]
            elif len(score) > 1:
                raise RuntimeError(f"returned multiple values from {self.metric} but no metric_component_name has been passed to __init__.")
            else:
                score = score.pop(list(score.keys())[0])
            self.metric.reset()

        else:

            # FOR LOSS VALUES, GET THE RIGHT COMPONENT, DERIVE IT ON THE FIRST PASS
            loss_tuple = context.loss_log_items
            if self._first_call:
                self._init_loss_attributes(context)
            score = loss_tuple[self._idx_loss_tuple].detach().cpu().item()

            # IN CONTRARY TO METRICS - LOSS VALUES NEED TO BE REDUCES IN DDP
            device = infer_model_device(context.net)
            score = torch.tensor(score, device=device)
            score = maybe_all_reduce_tensor_average(score)

        if self._is_more_extreme(score):
            self.extreme_score = tensor_container_to_device(score, device="cpu", detach=True, non_blocking=False)
            self.extreme_batch = tensor_container_to_device(context.inputs, device="cpu", detach=True, non_blocking=False)
            self.extreme_preds = tensor_container_to_device(context.preds, device="cpu", detach=True, non_blocking=False)
            self.extreme_targets = tensor_container_to_device(context.target, device="cpu", detach=True, non_blocking=False)
            self.extreme_additional_batch_items = tensor_container_to_device(context.additional_batch_items, device="cpu", detach=True, non_blocking=False)

    def _init_loss_attributes(self, context: PhaseContext):
        if self.loss_to_monitor not in context.loss_logging_items_names:
            raise ValueError(f"{self.loss_to_monitor} not a loss or loss component.")
        self._idx_loss_tuple = context.loss_logging_items_names.index(self.loss_to_monitor)
        self._first_call = False

    def _reset(self):
        self.extreme_score = None
        self.extreme_batch = None
        self.extreme_preds = None
        self.extreme_targets = None
        self.extreme_additional_batch_items = None
        if self.metric is not None:
            self.metric.reset()

    def _is_more_extreme(self, score: Union[float, torch.Tensor]) -> bool:
        """
        Checks whether computed score is the more extreme than the current extreme score.
        If the current score is None (first call), returns True.
        :param score: A newly computed score.
        :return:      True if score is more extreme than the current extreme score, False otherwise.
        """
        # A score can be Nan/Inf (rare but possible event when training diverges).
        # In such case the both < and > operators would return False according to IEEE 754.
        # As a consequence, self.extreme_inputs / self.extreme_outputs would not be updated
        # and that would crash at the attempt to visualize batch.
        if self.extreme_score is None:
            return True

        if self.max:
            return self.extreme_score < score
        else:
            return self.extreme_score > score

__init__(metric=None, metric_component_name=None, loss_to_monitor=None, max=False, freq=1, enable_on_train_loader=False, enable_on_valid_loader=True, max_images=-1)

Parameters:

Name Type Description Default
metric Optional[Metric]

Metric, will be the metric which is monitored.

None
metric_component_name Optional[str]

In case metric returns multiple values (as Mapping), the value at metric.compute()[metric_component_name] will be the one monitored.

None
loss_to_monitor Optional[str]

str, loss_to_monitor corresponding to the 'criterion' passed through training_params in Trainer.train(...). Monitoring loss follows the same logic as metric_to_watch in Trainer.train(..), when watching the loss and should be: if hasattr(criterion, "component_names") and criterion.forward(..) returns a tuple: "/". If a single item is returned rather then a tuple: . When there is no such attributes and criterion.forward(..) returns a tuple: "/"Loss_"

None
max bool

bool, Whether to take the batch corresponding to the max value of the metric/loss or the minimum (default=False).

False
freq int

int, epoch frequency to perform all of the above (default=1).

1
enable_on_train_loader bool

Controls whether to enable this callback on the train loader. Default is False.

False
enable_on_valid_loader bool

Controls whether to enable this callback on the valid loader. Default is True.

True
max_images int

Maximum images to save. If -1, save all images.

-1
Source code in V3_3/src/super_gradients/training/utils/callbacks/callbacks.py
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
@resolve_param("metric", MetricsFactory())
def __init__(
    self,
    metric: Optional[Metric] = None,
    metric_component_name: Optional[str] = None,
    loss_to_monitor: Optional[str] = None,
    max: bool = False,
    freq: int = 1,
    enable_on_train_loader: bool = False,
    enable_on_valid_loader: bool = True,
    max_images: int = -1,
):
    """
    :param metric: Metric, will be the metric which is monitored.

    :param metric_component_name: In case metric returns multiple values (as Mapping),
     the value at metric.compute()[metric_component_name] will be the one monitored.

    :param loss_to_monitor: str, loss_to_monitor corresponding to the 'criterion' passed through training_params in Trainer.train(...).
     Monitoring loss follows the same logic as metric_to_watch in Trainer.train(..), when watching the loss and should be:

    if hasattr(criterion, "component_names") and criterion.forward(..) returns a tuple:
        <LOSS_CLASS.__name__>"/"<COMPONENT_NAME>.

    If a single item is returned rather then a tuple:
        <LOSS_CLASS.__name__>.

    When there is no such attributes and criterion.forward(..) returns a tuple:
        <LOSS_CLASS.__name__>"/"Loss_"<IDX>

    :param max:                    bool, Whether to take the batch corresponding to the max value of the metric/loss or
    the minimum (default=False).

    :param freq:                   int, epoch frequency to perform all of the above (default=1).
    :param enable_on_train_loader: Controls whether to enable this callback on the train loader. Default is False.
    :param enable_on_valid_loader: Controls whether to enable this callback on the valid loader. Default is True.
    :param max_images:             Maximum images to save. If -1, save all images.
    """
    super(ExtremeBatchCaseVisualizationCallback, self).__init__()

    if (metric and loss_to_monitor) or (metric is None and loss_to_monitor is None):
        raise RuntimeError("Must pass exactly one of: loss, metric != None")

    self._set_tag_attr(loss_to_monitor, max, metric, metric_component_name)
    self.metric = metric
    if self.metric:
        self.metric = MetricCollection(self.metric)
        self.metric.to(device_config.device)

    self.metric_component_name = metric_component_name

    self.loss_to_monitor = loss_to_monitor
    self.max = max
    self.freq = freq

    self.extreme_score = None
    self.extreme_batch = None
    self.extreme_preds = None
    self.extreme_targets = None
    self.extreme_additional_batch_items = None

    self._first_call = True
    self._idx_loss_tuple = None

    self.enable_on_train_loader = enable_on_train_loader
    self.enable_on_valid_loader = enable_on_valid_loader
    self.max_images = max_images

process_extreme_batch() abstractmethod

This method is called right before adding the images to the in SGLoggger (inside the on_validation_loader_end call). It should process self.extreme_batch, self.extreme_preds and self.extreme_targets and output the images, as np.ndarrray. Output should be of shape N,H,W,3 and uint8.

Returns:

Type Description
np.ndarray

images to save, np.ndarray

Source code in V3_3/src/super_gradients/training/utils/callbacks/callbacks.py
1152
1153
1154
1155
1156
1157
1158
1159
1160
@abstractmethod
def process_extreme_batch(self) -> np.ndarray:
    """
    This method is called right before adding the images to the in  SGLoggger (inside the on_validation_loader_end call).
     It should process self.extreme_batch, self.extreme_preds and self.extreme_targets and output the images, as np.ndarrray.
     Output should be of shape N,H,W,3 and uint8.
    :return: images to save, np.ndarray
    """
    raise NotImplementedError

ExtremeBatchDetectionVisualizationCallback

Bases: ExtremeBatchCaseVisualizationCallback

ExtremeBatchSegVisualizationCallback

Visualizes worst/best batch in an epoch for Object detection. For clarity, the batch is saved twice in the SG Logger, once with the model's predictions and once with ground truth targets.

Assumptions on bbox dormats: - After applying post_prediction_callback on context.preds, the predictions are a list/Tensor s.t: predictions[i] is a tensor of shape nx6 - (x1, y1, x2, y2, confidence, class) where x and y are in pixel units.

  • context.targets is a tensor of shape (total_num_targets, 6), in LABEL_CXCYWH format: (index, label, cx, cy, w, h).

Example usage in Yaml config:

training_hyperparams:
  phase_callbacks:
    - ExtremeBatchDetectionVisualizationCallback:
        metric:
          DetectionMetrics_050:
            score_thres: 0.1
            top_k_predictions: 300
            num_cls: ${num_classes}
            normalize_targets: True
            post_prediction_callback:
              _target_: super_gradients.training.models.detection_models.pp_yolo_e.PPYoloEPostPredictionCallback
              score_threshold: 0.01
              nms_top_k: 1000
              max_predictions: 300
              nms_threshold: 0.7
        metric_component_name: 'mAP@0.50'
        post_prediction_callback:
          _target_: super_gradients.training.models.detection_models.pp_yolo_e.PPYoloEPostPredictionCallback
          score_threshold: 0.25
          nms_top_k: 1000
          max_predictions: 300
          nms_threshold: 0.7
        normalize_targets: True

Parameters:

Name Type Description Default
metric Optional[Metric]

Metric, will be the metric which is monitored.

None
metric_component_name Optional[str]

In case metric returns multiple values (as Mapping), the value at metric.compute()[metric_component_name] will be the one monitored.

None
loss_to_monitor Optional[str]

str, loss_to_monitor corresponding to the 'criterion' passed through training_params in Trainer.train(...). Monitoring loss follows the same logic as metric_to_watch in Trainer.train(..), when watching the loss and should be: if hasattr(criterion, "component_names") and criterion.forward(..) returns a tuple: "/". If a single item is returned rather then a tuple: . When there is no such attributes and criterion.forward(..) returns a tuple: "/"Loss_"

None
max bool

bool, Whether to take the batch corresponding to the max value of the metric/loss or the minimum (default=False).

False
freq int

int, epoch frequency to perform all of the above (default=1).

1
classes Optional[List[str]]

List[str], a list of class names corresponding to the class indices for display. When None, will try to fetch this through a "classes" attribute of the valdiation dataset. If such attribute does not exist an error will be raised (default=None).

None
normalize_targets bool

bool, whether to scale the target bboxes. If the bboxes returned by the validation data loader are in pixel values range, this needs to be set to True (default=False)

False
enable_on_train_loader bool

Controls whether to enable this callback on the train loader. Default is False.

False
enable_on_valid_loader bool

Controls whether to enable this callback on the valid loader. Default is True.

True
max_images int

Maximum images to save. If -1, save all images.

-1
Source code in V3_3/src/super_gradients/training/utils/callbacks/callbacks.py
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
@register_callback("ExtremeBatchDetectionVisualizationCallback")
class ExtremeBatchDetectionVisualizationCallback(ExtremeBatchCaseVisualizationCallback):
    """
    ExtremeBatchSegVisualizationCallback

    Visualizes worst/best batch in an epoch for Object detection.
    For clarity, the batch is saved twice in the SG Logger, once with the model's predictions and once with
     ground truth targets.

    Assumptions on bbox dormats:
     - After applying post_prediction_callback on context.preds, the predictions are a list/Tensor s.t:
        predictions[i] is a tensor of shape nx6 - (x1, y1, x2, y2, confidence, class) where x and y are in pixel units.

     - context.targets is a tensor of shape (total_num_targets, 6), in LABEL_CXCYWH format:  (index, label, cx, cy, w, h).



    Example usage in Yaml config:

        training_hyperparams:
          phase_callbacks:
            - ExtremeBatchDetectionVisualizationCallback:
                metric:
                  DetectionMetrics_050:
                    score_thres: 0.1
                    top_k_predictions: 300
                    num_cls: ${num_classes}
                    normalize_targets: True
                    post_prediction_callback:
                      _target_: super_gradients.training.models.detection_models.pp_yolo_e.PPYoloEPostPredictionCallback
                      score_threshold: 0.01
                      nms_top_k: 1000
                      max_predictions: 300
                      nms_threshold: 0.7
                metric_component_name: 'mAP@0.50'
                post_prediction_callback:
                  _target_: super_gradients.training.models.detection_models.pp_yolo_e.PPYoloEPostPredictionCallback
                  score_threshold: 0.25
                  nms_top_k: 1000
                  max_predictions: 300
                  nms_threshold: 0.7
                normalize_targets: True

    :param metric: Metric, will be the metric which is monitored.

    :param metric_component_name: In case metric returns multiple values (as Mapping),
     the value at metric.compute()[metric_component_name] will be the one monitored.

    :param loss_to_monitor: str, loss_to_monitor corresponding to the 'criterion' passed through training_params in Trainer.train(...).
     Monitoring loss follows the same logic as metric_to_watch in Trainer.train(..), when watching the loss and should be:

        if hasattr(criterion, "component_names") and criterion.forward(..) returns a tuple:
            <LOSS_CLASS.__name__>"/"<COMPONENT_NAME>.

        If a single item is returned rather then a tuple:
            <LOSS_CLASS.__name__>.

        When there is no such attributes and criterion.forward(..) returns a tuple:
            <LOSS_CLASS.__name__>"/"Loss_"<IDX>

    :param max:                    bool, Whether to take the batch corresponding to the max value of the metric/loss or
    the minimum (default=False).

    :param freq:                   int, epoch frequency to perform all of the above (default=1).

    :param classes:                List[str], a list of class names corresponding to the class indices for display.
    When None, will try to fetch this through a "classes" attribute of the valdiation dataset. If such attribute does
    not exist an error will be raised (default=None).

    :param normalize_targets:      bool, whether to scale the target bboxes. If the bboxes returned by the validation data loader
     are in pixel values range, this needs to be set to True (default=False)

    :param enable_on_train_loader: Controls whether to enable this callback on the train loader. Default is False.
    :param enable_on_valid_loader: Controls whether to enable this callback on the valid loader. Default is True.
    :param max_images:             Maximum images to save. If -1, save all images.
    """

    def __init__(
        self,
        post_prediction_callback: DetectionPostPredictionCallback,
        metric: Optional[Metric] = None,
        metric_component_name: Optional[str] = None,
        loss_to_monitor: Optional[str] = None,
        max: bool = False,
        freq: int = 1,
        classes: Optional[List[str]] = None,
        normalize_targets: bool = False,
        enable_on_train_loader: bool = False,
        enable_on_valid_loader: bool = True,
        max_images: int = -1,
    ):
        super(ExtremeBatchDetectionVisualizationCallback, self).__init__(
            metric=metric,
            metric_component_name=metric_component_name,
            loss_to_monitor=loss_to_monitor,
            max=max,
            freq=freq,
            enable_on_valid_loader=enable_on_valid_loader,
            enable_on_train_loader=enable_on_train_loader,
            max_images=max_images,
        )
        self.post_prediction_callback = post_prediction_callback
        if classes is None:
            logger.info(
                "No classes have been passed to ExtremeBatchDetectionVisualizationCallback. "
                "Will try to fetch them through context.valid_loader.dataset classes attribute if it exists."
            )
        self.classes = classes
        self.normalize_targets = normalize_targets

    @staticmethod
    def universal_undo_preprocessing_fn(inputs: torch.Tensor) -> np.ndarray:
        """
        A universal reversing of preprocessing to be passed to DetectionVisualization.visualize_batch's undo_preprocessing_func kwarg.
        This function scales input tensor to 0..255 range, and cast it to uint8 dtype.

        :param inputs: Input 4D tensor of images in BCHW format with unknown normalization.
        :return:       Numpy 4D tensor of images in BHWC format, normalized to 0..255 range (uint8).
        """
        inputs -= inputs.min()
        inputs /= inputs.max() + 1e-8
        inputs *= 255
        inputs = inputs.to(torch.uint8)
        inputs = inputs.cpu().numpy()
        inputs = inputs[:, ::-1, :, :].transpose(0, 2, 3, 1)
        inputs = np.ascontiguousarray(inputs, dtype=np.uint8)
        return inputs

    def process_extreme_batch(self) -> np.ndarray:
        """
        Processes the extreme batch, and returns list of images for visualization.
        Default implementations stacks GT and prediction overlays horisontally.

        :return: np.ndarray A 4D tensor of BHWC shape with visualizations of the extreme batch.
        """
        inputs = self.extreme_batch
        preds = self.post_prediction_callback(self.extreme_preds, self.extreme_batch.device)
        targets = self.extreme_targets.clone()
        if self.normalize_targets:
            target_bboxes = targets[:, 2:]
            target_bboxes = cxcywh2xyxy(target_bboxes)
            _, _, height, width = inputs.shape
            target_bboxes[:, [0, 2]] /= width
            target_bboxes[:, [1, 3]] /= height
            target_bboxes = xyxy2cxcywh(target_bboxes)
            targets[:, 2:] = target_bboxes

        images_to_save_preds = DetectionVisualization.visualize_batch(
            inputs, preds, targets, "extreme_batch_preds", self.classes, gt_alpha=0.0, undo_preprocessing_func=self.universal_undo_preprocessing_fn
        )
        images_to_save_preds = np.stack(images_to_save_preds)

        images_to_save_gt = DetectionVisualization.visualize_batch(
            inputs, None, targets, "extreme_batch_gt", self.classes, gt_alpha=1.0, undo_preprocessing_func=self.universal_undo_preprocessing_fn
        )
        images_to_save_gt = np.stack(images_to_save_gt)

        # Stack the predictions and GT images together
        return np.concatenate([images_to_save_gt, images_to_save_preds], axis=2)

    def on_validation_loader_start(self, context: PhaseContext) -> None:
        if self.classes is None:
            if hasattr(context.valid_loader.dataset, "classes"):
                self.classes = context.valid_loader.dataset.classes
            else:
                raise RuntimeError("Couldn't fetch classes from valid_loader, please pass classes explicitly")
        super().on_validation_loader_start(context)

process_extreme_batch()

Processes the extreme batch, and returns list of images for visualization. Default implementations stacks GT and prediction overlays horisontally.

Returns:

Type Description
np.ndarray

np.ndarray A 4D tensor of BHWC shape with visualizations of the extreme batch.

Source code in V3_3/src/super_gradients/training/utils/callbacks/callbacks.py
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
def process_extreme_batch(self) -> np.ndarray:
    """
    Processes the extreme batch, and returns list of images for visualization.
    Default implementations stacks GT and prediction overlays horisontally.

    :return: np.ndarray A 4D tensor of BHWC shape with visualizations of the extreme batch.
    """
    inputs = self.extreme_batch
    preds = self.post_prediction_callback(self.extreme_preds, self.extreme_batch.device)
    targets = self.extreme_targets.clone()
    if self.normalize_targets:
        target_bboxes = targets[:, 2:]
        target_bboxes = cxcywh2xyxy(target_bboxes)
        _, _, height, width = inputs.shape
        target_bboxes[:, [0, 2]] /= width
        target_bboxes[:, [1, 3]] /= height
        target_bboxes = xyxy2cxcywh(target_bboxes)
        targets[:, 2:] = target_bboxes

    images_to_save_preds = DetectionVisualization.visualize_batch(
        inputs, preds, targets, "extreme_batch_preds", self.classes, gt_alpha=0.0, undo_preprocessing_func=self.universal_undo_preprocessing_fn
    )
    images_to_save_preds = np.stack(images_to_save_preds)

    images_to_save_gt = DetectionVisualization.visualize_batch(
        inputs, None, targets, "extreme_batch_gt", self.classes, gt_alpha=1.0, undo_preprocessing_func=self.universal_undo_preprocessing_fn
    )
    images_to_save_gt = np.stack(images_to_save_gt)

    # Stack the predictions and GT images together
    return np.concatenate([images_to_save_gt, images_to_save_preds], axis=2)

universal_undo_preprocessing_fn(inputs) staticmethod

A universal reversing of preprocessing to be passed to DetectionVisualization.visualize_batch's undo_preprocessing_func kwarg. This function scales input tensor to 0..255 range, and cast it to uint8 dtype.

Parameters:

Name Type Description Default
inputs torch.Tensor

Input 4D tensor of images in BCHW format with unknown normalization.

required

Returns:

Type Description
np.ndarray

Numpy 4D tensor of images in BHWC format, normalized to 0..255 range (uint8).

Source code in V3_3/src/super_gradients/training/utils/callbacks/callbacks.py
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
@staticmethod
def universal_undo_preprocessing_fn(inputs: torch.Tensor) -> np.ndarray:
    """
    A universal reversing of preprocessing to be passed to DetectionVisualization.visualize_batch's undo_preprocessing_func kwarg.
    This function scales input tensor to 0..255 range, and cast it to uint8 dtype.

    :param inputs: Input 4D tensor of images in BCHW format with unknown normalization.
    :return:       Numpy 4D tensor of images in BHWC format, normalized to 0..255 range (uint8).
    """
    inputs -= inputs.min()
    inputs /= inputs.max() + 1e-8
    inputs *= 255
    inputs = inputs.to(torch.uint8)
    inputs = inputs.cpu().numpy()
    inputs = inputs[:, ::-1, :, :].transpose(0, 2, 3, 1)
    inputs = np.ascontiguousarray(inputs, dtype=np.uint8)
    return inputs

ExtremeBatchSegVisualizationCallback

Bases: ExtremeBatchCaseVisualizationCallback

ExtremeBatchSegVisualizationCallback

Visualizes worst/best batch in an epoch, for segmentation. Assumes context.preds in validation is a score tensor of shape BCHW, or a tuple whose first item is one.

True predictions will be marked with green, false ones with red.

Example usage in training_params definition:

training_hyperparams ={
  ...
  "phase_callbacks":
    [ExtremeBatchSegVisualizationCallback(
        metric=IoU(20, ignore_idx=19)
        max=False
        ignore_idx=19),
    ExtremeBatchSegVisualizationCallback(
        loss_to_monitor="CrossEntropyLoss"
        max=True
        ignore_idx=19)]
        ...}

Example usage in Yaml config:

training_hyperparams:
  phase_callbacks:
    - ExtremeBatchSegVisualizationCallback:
        loss_to_monitor: DiceCEEdgeLoss/aux_loss0
        ignore_idx: 19

Parameters:

Name Type Description Default
metric Optional[Metric]

Metric, will be the metric which is monitored.

None
metric_component_name Optional[str]

In case metric returns multiple values (as Mapping), the value at metric.compute()[metric_component_name] will be the one monitored.

None
loss_to_monitor Optional[str]

str, loss_to_monitor corresponding to the 'criterion' passed through training_params in Trainer.train(...). Monitoring loss follows the same logic as metric_to_watch in Trainer.train(..), when watching the loss and should be: if hasattr(criterion, "component_names") and criterion.forward(..) returns a tuple: "/". If a single item is returned rather then a tuple: . When there is no such attributes and criterion.forward(..) returns a tuple: "/"Loss_"

None
max bool

bool, Whether to take the batch corresponding to the max value of the metric/loss or the minimum (default=False).

False
freq int

int, epoch frequency to perform all of the above (default=1).

1
enable_on_train_loader bool

Controls whether to enable this callback on the train loader. Default is False.

False
enable_on_valid_loader bool

Controls whether to enable this callback on the valid loader. Default is True.

True
max_images int

Maximum images to save. If -1, save all images.

-1
Source code in V3_3/src/super_gradients/training/utils/callbacks/callbacks.py
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
@register_callback("ExtremeBatchSegVisualizationCallback")
class ExtremeBatchSegVisualizationCallback(ExtremeBatchCaseVisualizationCallback):
    """
    ExtremeBatchSegVisualizationCallback

    Visualizes worst/best batch in an epoch, for segmentation.
    Assumes context.preds in validation is a score tensor of shape BCHW, or a tuple whose first item is one.

    True predictions will be marked with green, false ones with red.

    Example usage in training_params definition:

        training_hyperparams ={
          ...
          "phase_callbacks":
            [ExtremeBatchSegVisualizationCallback(
                metric=IoU(20, ignore_idx=19)
                max=False
                ignore_idx=19),
            ExtremeBatchSegVisualizationCallback(
                loss_to_monitor="CrossEntropyLoss"
                max=True
                ignore_idx=19)]
                ...}

    Example usage in Yaml config:

        training_hyperparams:
          phase_callbacks:
            - ExtremeBatchSegVisualizationCallback:
                loss_to_monitor: DiceCEEdgeLoss/aux_loss0
                ignore_idx: 19

    :param metric: Metric, will be the metric which is monitored.

    :param metric_component_name: In case metric returns multiple values (as Mapping),
     the value at metric.compute()[metric_component_name] will be the one monitored.

    :param loss_to_monitor: str, loss_to_monitor corresponding to the 'criterion' passed through training_params in Trainer.train(...).
     Monitoring loss follows the same logic as metric_to_watch in Trainer.train(..), when watching the loss and should be:

        if hasattr(criterion, "component_names") and criterion.forward(..) returns a tuple:
            <LOSS_CLASS.__name__>"/"<COMPONENT_NAME>.

        If a single item is returned rather then a tuple:
            <LOSS_CLASS.__name__>.

        When there is no such attributes and criterion.forward(..) returns a tuple:
            <LOSS_CLASS.__name__>"/"Loss_"<IDX>

    :param max:                    bool, Whether to take the batch corresponding to the max value of the metric/loss or
    the minimum (default=False).

    :param freq:                   int, epoch frequency to perform all of the above (default=1).

    :param enable_on_train_loader: Controls whether to enable this callback on the train loader. Default is False.
    :param enable_on_valid_loader: Controls whether to enable this callback on the valid loader. Default is True.
    :param max_images:             Maximum images to save. If -1, save all images.
    """

    def __init__(
        self,
        metric: Optional[Metric] = None,
        metric_component_name: Optional[str] = None,
        loss_to_monitor: Optional[str] = None,
        max: bool = False,
        freq: int = 1,
        ignore_idx: int = -1,
        enable_on_train_loader: bool = False,
        enable_on_valid_loader: bool = True,
        max_images: int = -1,
    ):
        super(ExtremeBatchSegVisualizationCallback, self).__init__(
            metric=metric,
            metric_component_name=metric_component_name,
            loss_to_monitor=loss_to_monitor,
            max=max,
            freq=freq,
            enable_on_valid_loader=enable_on_valid_loader,
            enable_on_train_loader=enable_on_train_loader,
            max_images=max_images,
        )
        self.ignore_idx = ignore_idx

    @torch.no_grad()
    def process_extreme_batch(self) -> np.ndarray:
        inputs = self.extreme_batch
        inputs -= inputs.min()
        inputs /= inputs.max()
        inputs *= 255
        inputs = inputs.to(torch.uint8)
        preds = self.extreme_preds
        if isinstance(preds, tuple):
            preds = preds[0]
        preds = preds.argmax(1)
        p_mask = preds == self.extreme_targets
        n_mask = preds != self.extreme_targets
        p_mask[self.extreme_targets == self.ignore_idx] = False
        n_mask[self.extreme_targets == self.ignore_idx] = False
        overlay = torch.cat([p_mask.unsqueeze(1), n_mask.unsqueeze(1)], 1)
        colors = ["green", "red"]
        images_to_save = []
        for i in range(len(inputs)):
            image = draw_segmentation_masks(inputs[i].cpu(), overlay[i].cpu(), colors=colors, alpha=0.4).numpy()
            image = np.transpose(image, (1, 2, 0))
            images_to_save.append(image)
        images_to_save = np.stack(images_to_save)
        return images_to_save

FunctionLRScheduler

Bases: LRCallbackBase

Hard coded rate scheduling for user defined lr scheduling function.

Source code in V3_3/src/super_gradients/training/utils/callbacks/callbacks.py
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
@register_lr_scheduler(LRSchedulers.FUNCTION, deprecated_name="function")
class FunctionLRScheduler(LRCallbackBase):
    """
    Hard coded rate scheduling for user defined lr scheduling function.
    """

    @deprecated(deprecated_since="3.2.0", removed_from="3.5.0", reason="This callback is deprecated and will be removed in future versions.")
    def __init__(self, max_epochs, lr_schedule_function, **kwargs):
        super().__init__(Phase.TRAIN_BATCH_STEP, **kwargs)
        assert callable(lr_schedule_function), "self.lr_function must be callable"
        self.lr_schedule_function = lr_schedule_function
        self.max_epochs = max_epochs

    def is_lr_scheduling_enabled(self, context):
        post_warmup_epochs = self.training_params.max_epochs - self.training_params.lr_cooldown_epochs
        return self.training_params.lr_warmup_epochs <= context.epoch < post_warmup_epochs

    def perform_scheduling(self, context):
        effective_epoch = context.epoch - self.training_params.lr_warmup_epochs
        effective_max_epochs = self.max_epochs - self.training_params.lr_warmup_epochs - self.training_params.lr_cooldown_epochs
        self.lr = self.lr_schedule_function(
            initial_lr=self.initial_lr,
            epoch=effective_epoch,
            iter=context.batch_idx,
            max_epoch=effective_max_epochs,
            iters_per_epoch=self.train_loader_len,
        )
        self.update_lr(context.optimizer, context.epoch, context.batch_idx)

IllegalLRSchedulerMetric

Bases: Exception

Exception raised illegal combination of training parameters.

Parameters:

Name Type Description Default
metric_name str

Name of the metric that is not supported.

required
metrics_dict dict

Dictionary of metrics that are supported.

required
Source code in V3_3/src/super_gradients/training/utils/callbacks/callbacks.py
563
564
565
566
567
568
569
570
571
572
class IllegalLRSchedulerMetric(Exception):
    """Exception raised illegal combination of training parameters.

    :param metric_name: Name of the metric that is not supported.
    :param metrics_dict: Dictionary of metrics that are supported.
    """

    def __init__(self, metric_name: str, metrics_dict: dict):
        self.message = "Illegal metric name: " + metric_name + ". Expected one of metics_dics keys: " + str(metrics_dict.keys())
        super().__init__(self.message)

LRCallbackBase

Bases: PhaseCallback

Base class for hard coded learning rate scheduling regimes, implemented as callbacks.

Source code in V3_3/src/super_gradients/training/utils/callbacks/callbacks.py
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
@register_callback(Callbacks.LR_CALLBACK_BASE)
class LRCallbackBase(PhaseCallback):
    """
    Base class for hard coded learning rate scheduling regimes, implemented as callbacks.
    """

    def __init__(self, phase, initial_lr, update_param_groups, train_loader_len, net, training_params, **kwargs):
        super(LRCallbackBase, self).__init__(phase)
        self.initial_lr = initial_lr
        self.lr = initial_lr
        self.update_param_groups = update_param_groups
        self.train_loader_len = train_loader_len
        self.net = net
        self.training_params = training_params

    def __call__(self, context: PhaseContext, **kwargs):
        if self.is_lr_scheduling_enabled(context):
            self.perform_scheduling(context)

    def is_lr_scheduling_enabled(self, context: PhaseContext):
        """
        Predicate that controls whether to perform lr scheduling based on values in context.

        :param context: PhaseContext: current phase's context.
        :return: bool, whether to apply lr scheduling or not.
        """
        raise NotImplementedError

    def perform_scheduling(self, context: PhaseContext):
        """
        Performs lr scheduling based on values in context.

        :param context: PhaseContext: current phase's context.
        """
        raise NotImplementedError

    def update_lr(self, optimizer, epoch, batch_idx=None):
        if self.update_param_groups:
            param_groups = unwrap_model(self.net).update_param_groups(
                optimizer.param_groups, self.lr, epoch, batch_idx, self.training_params, self.train_loader_len
            )
            optimizer.param_groups = param_groups
        else:
            # UPDATE THE OPTIMIZERS PARAMETER
            for param_group in optimizer.param_groups:
                param_group["lr"] = self.lr

is_lr_scheduling_enabled(context)

Predicate that controls whether to perform lr scheduling based on values in context.

Parameters:

Name Type Description Default
context PhaseContext

PhaseContext: current phase's context.

required

Returns:

Type Description

bool, whether to apply lr scheduling or not.

Source code in V3_3/src/super_gradients/training/utils/callbacks/callbacks.py
250
251
252
253
254
255
256
257
def is_lr_scheduling_enabled(self, context: PhaseContext):
    """
    Predicate that controls whether to perform lr scheduling based on values in context.

    :param context: PhaseContext: current phase's context.
    :return: bool, whether to apply lr scheduling or not.
    """
    raise NotImplementedError

perform_scheduling(context)

Performs lr scheduling based on values in context.

Parameters:

Name Type Description Default
context PhaseContext

PhaseContext: current phase's context.

required
Source code in V3_3/src/super_gradients/training/utils/callbacks/callbacks.py
259
260
261
262
263
264
265
def perform_scheduling(self, context: PhaseContext):
    """
    Performs lr scheduling based on values in context.

    :param context: PhaseContext: current phase's context.
    """
    raise NotImplementedError

LRSchedulerCallback

Bases: PhaseCallback

Learning rate scheduler callback.

When passing call a metrics_dict, with a key=self.metric_name, the value of that metric will monitored for ReduceLROnPlateau (i.e step(metrics_dict[self.metric_name]).

Parameters:

Name Type Description Default
scheduler torch.optim.lr_scheduler._LRScheduler

Learning rate scheduler to be called step() with.

required
metric_name str

Metric name for ReduceLROnPlateau learning rate scheduler.

None
phase Phase

Phase of when to trigger it.

required
Source code in V3_3/src/super_gradients/training/utils/callbacks/callbacks.py
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
@register_callback(Callbacks.LR_SCHEDULER)
class LRSchedulerCallback(PhaseCallback):
    """
    Learning rate scheduler callback.

    When passing __call__ a metrics_dict, with a key=self.metric_name, the value of that metric will monitored
         for ReduceLROnPlateau (i.e step(metrics_dict[self.metric_name]).

    :param scheduler:       Learning rate scheduler to be called step() with.
    :param metric_name:     Metric name for ReduceLROnPlateau learning rate scheduler.
    :param phase:           Phase of when to trigger it.
    """

    def __init__(self, scheduler: torch.optim.lr_scheduler._LRScheduler, phase: Phase, metric_name: str = None):
        super(LRSchedulerCallback, self).__init__(phase)
        self.scheduler = scheduler
        self.metric_name = metric_name

    def __call__(self, context: PhaseContext):
        if context.lr_warmup_epochs <= context.epoch:
            if self.metric_name and self.metric_name in context.metrics_dict.keys():
                self.scheduler.step(context.metrics_dict[self.metric_name])
            elif self.metric_name is None:
                self.scheduler.step()
            else:
                raise IllegalLRSchedulerMetric(self.metric_name, context.metrics_dict)

    def __repr__(self):
        return "LRSchedulerCallback: " + repr(self.scheduler)

LinearBatchLRWarmup

Bases: Callback

LR scheduling callback for linear step warmup on each batch step. LR climbs from warmup_initial_lr with to initial lr.

Source code in V3_3/src/super_gradients/training/utils/callbacks/callbacks.py
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
@register_lr_warmup(LRWarmups.LINEAR_BATCH_STEP, deprecated_name="linear_batch_step")
class LinearBatchLRWarmup(Callback):
    """
    LR scheduling callback for linear step warmup on each batch step.
    LR climbs from warmup_initial_lr with to initial lr.
    """

    def __init__(
        self,
        warmup_initial_lr: float,
        initial_lr: float,
        train_loader_len: int,
        update_param_groups: bool,
        lr_warmup_steps: int,
        training_params,
        net,
        **kwargs,
    ):
        """

        :param warmup_initial_lr: Starting learning rate
        :param initial_lr: Target learning rate after warmup
        :param train_loader_len: Length of train data loader
        :param lr_warmup_steps: Optional. If passed, will use fixed number of warmup steps to warmup LR. Default is None.
        :param kwargs:
        """

        super().__init__()

        if lr_warmup_steps > train_loader_len:
            logger.warning(
                f"Number of warmup steps ({lr_warmup_steps}) is greater than number of steps in epoch ({train_loader_len}). "
                f"Warmup steps will be capped to number of steps in epoch to avoid interfering with any pre-epoch LR schedulers."
            )

        lr_warmup_steps = min(lr_warmup_steps, train_loader_len)
        learning_rates = np.linspace(start=warmup_initial_lr, stop=initial_lr, num=lr_warmup_steps, endpoint=True)

        self.lr = initial_lr
        self.initial_lr = initial_lr
        self.update_param_groups = update_param_groups
        self.training_params = training_params
        self.net = net
        self.learning_rates = learning_rates
        self.train_loader_len = train_loader_len
        self.lr_warmup_steps = lr_warmup_steps

    def on_train_batch_start(self, context: PhaseContext) -> None:
        global_training_step = context.batch_idx + context.epoch * self.train_loader_len
        if global_training_step < self.lr_warmup_steps:
            self.lr = float(self.learning_rates[global_training_step])
            self.update_lr(context.optimizer, context.epoch, context.batch_idx)

    def update_lr(self, optimizer, epoch, batch_idx=None):
        """
        Same as in LRCallbackBase
        :param optimizer:
        :param epoch:
        :param batch_idx:
        :return:
        """
        if self.update_param_groups:
            param_groups = unwrap_model(self.net).update_param_groups(
                optimizer.param_groups, self.lr, epoch, batch_idx, self.training_params, self.train_loader_len
            )
            optimizer.param_groups = param_groups
        else:
            # UPDATE THE OPTIMIZERS PARAMETER
            for param_group in optimizer.param_groups:
                param_group["lr"] = self.lr

__init__(warmup_initial_lr, initial_lr, train_loader_len, update_param_groups, lr_warmup_steps, training_params, net, **kwargs)

Parameters:

Name Type Description Default
warmup_initial_lr float

Starting learning rate

required
initial_lr float

Target learning rate after warmup

required
train_loader_len int

Length of train data loader

required
lr_warmup_steps int

Optional. If passed, will use fixed number of warmup steps to warmup LR. Default is None.

required
kwargs {}
Source code in V3_3/src/super_gradients/training/utils/callbacks/callbacks.py
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
def __init__(
    self,
    warmup_initial_lr: float,
    initial_lr: float,
    train_loader_len: int,
    update_param_groups: bool,
    lr_warmup_steps: int,
    training_params,
    net,
    **kwargs,
):
    """

    :param warmup_initial_lr: Starting learning rate
    :param initial_lr: Target learning rate after warmup
    :param train_loader_len: Length of train data loader
    :param lr_warmup_steps: Optional. If passed, will use fixed number of warmup steps to warmup LR. Default is None.
    :param kwargs:
    """

    super().__init__()

    if lr_warmup_steps > train_loader_len:
        logger.warning(
            f"Number of warmup steps ({lr_warmup_steps}) is greater than number of steps in epoch ({train_loader_len}). "
            f"Warmup steps will be capped to number of steps in epoch to avoid interfering with any pre-epoch LR schedulers."
        )

    lr_warmup_steps = min(lr_warmup_steps, train_loader_len)
    learning_rates = np.linspace(start=warmup_initial_lr, stop=initial_lr, num=lr_warmup_steps, endpoint=True)

    self.lr = initial_lr
    self.initial_lr = initial_lr
    self.update_param_groups = update_param_groups
    self.training_params = training_params
    self.net = net
    self.learning_rates = learning_rates
    self.train_loader_len = train_loader_len
    self.lr_warmup_steps = lr_warmup_steps

update_lr(optimizer, epoch, batch_idx=None)

Same as in LRCallbackBase

Parameters:

Name Type Description Default
optimizer required
epoch required
batch_idx None

Returns:

Type Description
Source code in V3_3/src/super_gradients/training/utils/callbacks/callbacks.py
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
def update_lr(self, optimizer, epoch, batch_idx=None):
    """
    Same as in LRCallbackBase
    :param optimizer:
    :param epoch:
    :param batch_idx:
    :return:
    """
    if self.update_param_groups:
        param_groups = unwrap_model(self.net).update_param_groups(
            optimizer.param_groups, self.lr, epoch, batch_idx, self.training_params, self.train_loader_len
        )
        optimizer.param_groups = param_groups
    else:
        # UPDATE THE OPTIMIZERS PARAMETER
        for param_group in optimizer.param_groups:
            param_group["lr"] = self.lr

LinearEpochLRWarmup

Bases: LRCallbackBase

LR scheduling callback for linear step warmup. This scheduler uses a whole epoch as single step. LR climbs from warmup_initial_lr with even steps to initial lr. When warmup_initial_lr is None - LR climb starts from initial_lr/(1+warmup_epochs).

Source code in V3_3/src/super_gradients/training/utils/callbacks/callbacks.py
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
@register_lr_warmup(LRWarmups.LINEAR_EPOCH_STEP, deprecated_name="linear_epoch_step")
class LinearEpochLRWarmup(LRCallbackBase):
    """
    LR scheduling callback for linear step warmup. This scheduler uses a whole epoch as single step.
    LR climbs from warmup_initial_lr with even steps to initial lr. When warmup_initial_lr is None - LR climb starts from
     initial_lr/(1+warmup_epochs).

    """

    def __init__(self, **kwargs):
        super().__init__(Phase.TRAIN_EPOCH_START, **kwargs)
        self.warmup_initial_lr = self.training_params.warmup_initial_lr or self.initial_lr / (self.training_params.lr_warmup_epochs + 1)
        self.warmup_step_size = (
            (self.initial_lr - self.warmup_initial_lr) / self.training_params.lr_warmup_epochs if self.training_params.lr_warmup_epochs > 0 else 0
        )

    def perform_scheduling(self, context):
        self.lr = self.warmup_initial_lr + context.epoch * self.warmup_step_size
        self.update_lr(context.optimizer, context.epoch, None)

    def is_lr_scheduling_enabled(self, context):
        return self.training_params.lr_warmup_epochs > 0 and self.training_params.lr_warmup_epochs >= context.epoch

ModelConversionCheckCallback

Bases: PhaseCallback

Pre-training callback that verifies model conversion to onnx given specified conversion parameters.

The model is converted, then inference is applied with onnx runtime.

Use this callback with the same args as DeciPlatformCallback to prevent conversion fails at the end of training.

Parameters:

Name Type Description Default
model_name str

Model's name

required
input_dimensions Sequence[int]

Model's input dimensions

required
primary_batch_size int

Model's primary batch size

required
opset_version

(default=11)

required
do_constant_folding

(default=True)

required
dynamic_axes

(default={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}})

required
input_names

(default=["input"])

required
output_names

(default=["output"])

required
rtol

(default=1e-03)

required
atol

(default=1e-05)

required
Source code in V3_3/src/super_gradients/training/utils/callbacks/callbacks.py
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
@register_callback(Callbacks.MODEL_CONVERSION_CHECK)
class ModelConversionCheckCallback(PhaseCallback):
    """
    Pre-training callback that verifies model conversion to onnx given specified conversion parameters.

    The model is converted, then inference is applied with onnx runtime.

    Use this callback with the same args as DeciPlatformCallback to prevent conversion fails at the end of training.

    :param model_name:              Model's name
    :param input_dimensions:        Model's input dimensions
    :param primary_batch_size:      Model's primary batch size
    :param opset_version:           (default=11)
    :param do_constant_folding:     (default=True)
    :param dynamic_axes:            (default={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}})
    :param input_names:             (default=["input"])
    :param output_names:            (default=["output"])
    :param rtol:                    (default=1e-03)
    :param atol:                    (default=1e-05)
    """

    def __init__(self, model_name: str, input_dimensions: Sequence[int], primary_batch_size: int, **kwargs):
        super(ModelConversionCheckCallback, self).__init__(phase=Phase.PRE_TRAINING)
        self.model_name = model_name
        self.input_dimensions = input_dimensions
        self.primary_batch_size = primary_batch_size

        self.opset_version = kwargs.get("opset_version", 10)
        self.do_constant_folding = kwargs.get("do_constant_folding", None) if kwargs.get("do_constant_folding", None) else True
        self.input_names = kwargs.get("input_names") or ["input"]
        self.output_names = kwargs.get("output_names") or ["output"]
        self.dynamic_axes = kwargs.get("dynamic_axes") or {"input": {0: "batch_size"}, "output": {0: "batch_size"}}

        self.rtol = kwargs.get("rtol", 1e-03)
        self.atol = kwargs.get("atol", 1e-05)

    def __call__(self, context: PhaseContext):
        model = copy.deepcopy(unwrap_model(context.net))
        model = model.cpu()
        model.eval()  # Put model into eval mode

        if hasattr(model, "prep_model_for_conversion"):
            model.prep_model_for_conversion(input_size=self.input_dimensions)

        x = torch.randn(self.primary_batch_size, *self.input_dimensions, requires_grad=False)

        tmp_model_path = os.path.join(context.ckpt_dir, self.model_name + "_tmp.onnx")

        with torch.no_grad():
            torch_out = model(x)

        torch.onnx.export(
            model,  # Model being run
            x,  # Model input (or a tuple for multiple inputs)
            tmp_model_path,  # Where to save the model (can be a file or file-like object)
            export_params=True,  # Store the trained parameter weights inside the model file
            opset_version=self.opset_version,
            do_constant_folding=self.do_constant_folding,
            input_names=self.input_names,
            output_names=self.output_names,
            dynamic_axes=self.dynamic_axes,
        )

        onnx_model = onnx.load(tmp_model_path)
        onnx.checker.check_model(onnx_model)

        ort_session = onnxruntime.InferenceSession(tmp_model_path, providers=["CUDAExecutionProvider", "CPUExecutionProvider"])

        # compute ONNX Runtime output prediction
        ort_inputs = {ort_session.get_inputs()[0].name: x.cpu().numpy()}
        ort_outs = ort_session.run(None, ort_inputs)

        # TODO: Ideally we don't want to check this but have the certainty of just calling torch_out.cpu()
        if isinstance(torch_out, List) or isinstance(torch_out, tuple):
            torch_out = torch_out[0]
        # compare ONNX Runtime and PyTorch results
        np.testing.assert_allclose(torch_out.cpu().numpy(), ort_outs[0], rtol=self.rtol, atol=self.atol)

        os.remove(tmp_model_path)

        logger.info("Exported model has been tested with ONNXRuntime, and the result looks good!")

PhaseContextTestCallback

Bases: PhaseCallback

A callback that saves the phase context the for testing.

Source code in V3_3/src/super_gradients/training/utils/callbacks/callbacks.py
628
629
630
631
632
633
634
635
636
637
638
class PhaseContextTestCallback(PhaseCallback):
    """
    A callback that saves the phase context the for testing.
    """

    def __init__(self, phase: Phase):
        super(PhaseContextTestCallback, self).__init__(phase)
        self.context = None

    def __call__(self, context: PhaseContext):
        self.context = context

PolyLRScheduler

Bases: LRCallbackBase

Hard coded polynomial decay learning rate scheduling (i.e at specific milestones).

Source code in V3_3/src/super_gradients/training/utils/callbacks/callbacks.py
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
@register_lr_scheduler(LRSchedulers.POLY, deprecated_name="poly")
class PolyLRScheduler(LRCallbackBase):
    """
    Hard coded polynomial decay learning rate scheduling (i.e at specific milestones).
    """

    def __init__(self, max_epochs, **kwargs):
        super().__init__(Phase.TRAIN_BATCH_STEP, **kwargs)
        self.max_epochs = max_epochs

    def perform_scheduling(self, context):
        effective_epoch = context.epoch - self.training_params.lr_warmup_epochs
        effective_max_epochs = self.max_epochs - self.training_params.lr_warmup_epochs - self.training_params.lr_cooldown_epochs
        current_iter = (self.train_loader_len * effective_epoch + context.batch_idx) / self.training_params.batch_accumulate
        max_iter = self.train_loader_len * effective_max_epochs / self.training_params.batch_accumulate
        self.lr = self.initial_lr * pow((1.0 - (current_iter / max_iter)), 0.9)
        self.update_lr(context.optimizer, context.epoch, context.batch_idx)

    def is_lr_scheduling_enabled(self, context):
        post_warmup_epochs = self.training_params.max_epochs - self.training_params.lr_cooldown_epochs
        return self.training_params.lr_warmup_epochs <= context.epoch < post_warmup_epochs

RoboflowResultCallback

Bases: Callback

Append the training results to a csv file. Be aware that this does not fully overwrite the existing file, just appends.

Source code in V3_3/src/super_gradients/training/utils/callbacks/callbacks.py
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
@register_callback(Callbacks.ROBOFLOW_RESULT_CALLBACK)
class RoboflowResultCallback(Callback):
    """Append the training results to a csv file. Be aware that this does not fully overwrite the existing file, just appends."""

    def __init__(self, dataset_name: str, output_path: Optional[str] = None):
        """
        :param dataset_name:    Name of the dataset that was used to train the model.
        :param output_path:     Full path to the output csv file. By default, save at 'checkpoint_dir/results.csv'
        """
        self.dataset_name = dataset_name
        self.output_path = output_path or os.path.join(get_project_checkpoints_dir_path(), "results.csv")

        if self.output_path is None:
            raise ValueError("Output path must be specified")

        super(RoboflowResultCallback, self).__init__()

    @multi_process_safe
    def on_training_end(self, context: PhaseContext):
        with open(self.output_path, mode="a", newline="") as csv_file:
            writer = csv.writer(csv_file)

            mAP = context.metrics_dict["mAP@0.50:0.95"].item()
            writer.writerow([self.dataset_name, mAP])

__init__(dataset_name, output_path=None)

Parameters:

Name Type Description Default
dataset_name str

Name of the dataset that was used to train the model.

required
output_path Optional[str]

Full path to the output csv file. By default, save at 'checkpoint_dir/results.csv'

None
Source code in V3_3/src/super_gradients/training/utils/callbacks/callbacks.py
763
764
765
766
767
768
769
770
771
772
773
774
def __init__(self, dataset_name: str, output_path: Optional[str] = None):
    """
    :param dataset_name:    Name of the dataset that was used to train the model.
    :param output_path:     Full path to the output csv file. By default, save at 'checkpoint_dir/results.csv'
    """
    self.dataset_name = dataset_name
    self.output_path = output_path or os.path.join(get_project_checkpoints_dir_path(), "results.csv")

    if self.output_path is None:
        raise ValueError("Output path must be specified")

    super(RoboflowResultCallback, self).__init__()

SlidingWindowValidationCallback

Bases: Callback

Performing single-scale sliding window during inference at the last epoch on the validation set and on the average model.

Source code in V3_3/src/super_gradients/training/utils/callbacks/callbacks.py
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
@register_callback(Callbacks.SLIDING_WINDOW_VALIDATION)
class SlidingWindowValidationCallback(Callback):
    """
    Performing single-scale sliding window during inference at the last epoch on the validation set and on the average model.
    """

    def __init__(self, transforms_for_sliding_window) -> None:
        self.transforms_for_sliding_window = transforms_for_sliding_window
        self.valid_loader_transforms = []
        self.test_loader_transforms = []

    def on_validation_loader_start(self, context: PhaseContext) -> None:
        if context.training_params.max_epochs - 1 == context.epoch:
            unwrap_model(context.net).enable_sliding_window_validation()
            self.valid_loader_transforms = context.valid_loader.dataset.transforms.transforms
            context.valid_loader.dataset.transforms.transforms = self.transforms_for_sliding_window
            iter(context.valid_loader)

    def on_validation_loader_end(self, context: PhaseContext) -> None:
        if context.training_params.max_epochs - 1 == context.epoch:
            unwrap_model(context.net).disable_sliding_window_validation()

    def on_average_best_models_validation_start(self, context: PhaseContext) -> None:
        if context.training_params.max_epochs - 1 == context.epoch and context.training_params.average_best_models:
            unwrap_model(context.net).enable_sliding_window_validation()
            context.valid_loader.dataset.transforms.transforms = self.transforms_for_sliding_window
            iter(context.valid_loader)

    def on_average_best_models_validation_end(self, context: PhaseContext) -> None:
        if context.training_params.max_epochs == context.epoch and context.training_params.average_best_models:
            unwrap_model(context.net).disable_sliding_window_validation()
            context.valid_loader.dataset.transforms.transforms = self.valid_loader_transforms
            iter(context.valid_loader)

    def on_test_loader_start(self, context: PhaseContext) -> None:
        unwrap_model(context.net).enable_sliding_window_validation()
        self.test_loader_transforms = context.test_loader.dataset.transforms.transforms
        context.test_loader.dataset.transforms.transforms = self.transforms_for_sliding_window
        iter(context.test_loader)

    def on_test_loader_end(self, context: PhaseContext) -> None:
        unwrap_model(context.net).disable_sliding_window_validation()
        context.test_loader.dataset.transforms.transforms = self.test_loader_transforms
        iter(context.test_loader)

StepLRScheduler

Bases: LRCallbackBase

Hard coded step learning rate scheduling (i.e at specific milestones).

Source code in V3_3/src/super_gradients/training/utils/callbacks/callbacks.py
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
@register_lr_scheduler(LRSchedulers.STEP, deprecated_name="step")
class StepLRScheduler(LRCallbackBase):
    """
    Hard coded step learning rate scheduling (i.e at specific milestones).
    """

    def __init__(self, lr_updates, lr_decay_factor, step_lr_update_freq=None, **kwargs):
        super().__init__(Phase.TRAIN_EPOCH_END, **kwargs)
        if step_lr_update_freq and len(lr_updates):
            raise ValueError("Only one of [lr_updates, step_lr_update_freq] should be passed to StepLRScheduler constructor")

        if step_lr_update_freq:
            max_epochs = self.training_params.max_epochs - self.training_params.lr_cooldown_epochs
            warmup_epochs = self.training_params.lr_warmup_epochs
            lr_updates = [
                int(np.ceil(step_lr_update_freq * x)) for x in range(1, max_epochs) if warmup_epochs <= int(np.ceil(step_lr_update_freq * x)) < max_epochs
            ]
        elif self.training_params.lr_cooldown_epochs > 0:
            logger.warning("Specific lr_updates were passed along with cooldown_epochs > 0," " cooldown will have no effect.")
        self.lr_updates = lr_updates
        self.lr_decay_factor = lr_decay_factor

    def perform_scheduling(self, context):
        num_updates_passed = [x for x in self.lr_updates if x <= context.epoch]
        self.lr = self.initial_lr * self.lr_decay_factor ** len(num_updates_passed)
        self.update_lr(context.optimizer, context.epoch, None)

    def is_lr_scheduling_enabled(self, context):
        return self.training_params.lr_warmup_epochs <= context.epoch

TestLRCallback

Bases: PhaseCallback

Phase callback that collects the learning rates in lr_placeholder at the end of each epoch (used for testing). In the case of multiple parameter groups (i.e multiple learning rates) the learning rate is collected from the first one. The phase is VALIDATION_EPOCH_END to ensure all lr updates have been performed before calling this callback.

Source code in V3_3/src/super_gradients/training/utils/callbacks/callbacks.py
785
786
787
788
789
790
791
792
793
794
795
796
797
class TestLRCallback(PhaseCallback):
    """
    Phase callback that collects the learning rates in lr_placeholder at the end of each epoch (used for testing). In
     the case of multiple parameter groups (i.e multiple learning rates) the learning rate is collected from the first
     one. The phase is VALIDATION_EPOCH_END to ensure all lr updates have been performed before calling this callback.
    """

    def __init__(self, lr_placeholder):
        super(TestLRCallback, self).__init__(Phase.VALIDATION_EPOCH_END)
        self.lr_placeholder = lr_placeholder

    def __call__(self, context: PhaseContext):
        self.lr_placeholder.append(context.optimizer.param_groups[0]["lr"])

TrainingStageSwitchCallbackBase

Bases: PhaseCallback

TrainingStageSwitchCallback

A phase callback that is called at a specific epoch (epoch start) to support multi-stage training. It does so by manipulating the objects inside the context.

Parameters:

Name Type Description Default
next_stage_start_epoch int

Epoch idx to apply the stage change.

required
Source code in V3_3/src/super_gradients/training/utils/callbacks/callbacks.py
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
class TrainingStageSwitchCallbackBase(PhaseCallback):
    """
    TrainingStageSwitchCallback

    A phase callback that is called at a specific epoch (epoch start) to support multi-stage training.
    It does so by manipulating the objects inside the context.

    :param next_stage_start_epoch: Epoch idx to apply the stage change.
    """

    def __init__(self, next_stage_start_epoch: int):
        super(TrainingStageSwitchCallbackBase, self).__init__(phase=Phase.TRAIN_EPOCH_START)
        self.next_stage_start_epoch = next_stage_start_epoch

    def __call__(self, context: PhaseContext):
        if context.epoch == self.next_stage_start_epoch:
            self.apply_stage_change(context)

    def apply_stage_change(self, context: PhaseContext):
        """
        This method is called when the callback is fired on the next_stage_start_epoch,
         and holds the stage change logic that should be applied to the context's objects.

        :param context: PhaseContext, context of current phase
        """
        raise NotImplementedError

apply_stage_change(context)

This method is called when the callback is fired on the next_stage_start_epoch, and holds the stage change logic that should be applied to the context's objects.

Parameters:

Name Type Description Default
context PhaseContext

PhaseContext, context of current phase

required
Source code in V3_3/src/super_gradients/training/utils/callbacks/callbacks.py
728
729
730
731
732
733
734
735
def apply_stage_change(self, context: PhaseContext):
    """
    This method is called when the callback is fired on the next_stage_start_epoch,
     and holds the stage change logic that should be applied to the context's objects.

    :param context: PhaseContext, context of current phase
    """
    raise NotImplementedError

YoloXTrainingStageSwitchCallback

Bases: TrainingStageSwitchCallbackBase

YoloXTrainingStageSwitchCallback

Training stage switch for YoloX training. Disables mosaic, and manipulates YoloX loss to use L1.

Source code in V3_3/src/super_gradients/training/utils/callbacks/callbacks.py
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
@register_callback(Callbacks.YOLOX_TRAINING_STAGE_SWITCH)
class YoloXTrainingStageSwitchCallback(TrainingStageSwitchCallbackBase):
    """
    YoloXTrainingStageSwitchCallback

    Training stage switch for YoloX training.
    Disables mosaic, and manipulates YoloX loss to use L1.

    """

    def __init__(self, next_stage_start_epoch: int = 285):
        super(YoloXTrainingStageSwitchCallback, self).__init__(next_stage_start_epoch=next_stage_start_epoch)

    def apply_stage_change(self, context: PhaseContext):
        for transform in context.train_loader.dataset.transforms:
            if hasattr(transform, "close"):
                transform.close()
        iter(context.train_loader)
        context.criterion.use_l1 = True

create_lr_scheduler_callback(lr_mode, train_loader, net, training_params, update_param_groups, optimizer)

Creates the phase callback in charge of LR scheduling, to be used by Trainer.

Parameters:

Name Type Description Default
lr_mode Union[str, Mapping]

Union[str, Mapping], When str: Learning rate scheduling policy, one of ['StepLRScheduler','PolyLRScheduler','CosineLRScheduler','FunctionLRScheduler']. 'StepLRScheduler' refers to constant updates at epoch numbers passed through lr_updates. Each update decays the learning rate by lr_decay_factor. 'CosineLRScheduler' refers to the Cosine Anealing policy as mentioned in https://arxiv.org/abs/1608.03983. The final learning rate ratio is controlled by cosine_final_lr_ratio training parameter. 'PolyLRScheduler' refers to the polynomial decrease: in each epoch iteration self.lr = self.initial_lr * pow((1.0 - (current_iter / max_iter)), 0.9) 'FunctionLRScheduler' refers to a user-defined learning rate scheduling function, that is passed through lr_schedule_function. When Mapping, refers to a torch.optim.lr_scheduler.LRScheduler, following the below API: lr_mode = {LR_SCHEDULER_CLASS_NAME: {*LR_SCHEDULER_KWARGS, "phase": XXX, "metric_name": XXX) Where "phase" (of Phase type) controls when to call torch.optim.lr_scheduler._LRScheduler.step(). For instance, in order to: - Update LR on each batch: Use phase: Phase.TRAIN_BATCH_END - Update LR after each epoch: Use phase: Phase.TRAIN_EPOCH_END The "metric_name" refers to the metric to watch (See docs for "metric_to_watch" in train(...) https://docs.deci.ai/super-gradients/docstring/training/sg_trainer.html) when using ReduceLROnPlateau. In any other case this kwarg is ignored. *LR_SCHEDULER_KWARGS are simply passed to the torch scheduler's __init_.

required
train_loader DataLoader

DataLoader, the Trainer.train_loader used for training.

required
net torch.nn.Module

torch.nn.Module, the Trainer.net used for training.

required
training_params Mapping

Mapping, Trainer.training_params.

required
update_param_groups bool

bool, Whether the Trainer.net has a specific way of updaitng its parameter group.

required
optimizer torch.optim.Optimizer

The optimizer used for training. Will be passed to the LR callback's init (or the torch scheduler's init, depending on the lr_mode value as described above).

required

Returns:

Type Description
PhaseCallback

a PhaseCallback instance to be used by Trainer for LR scheduling.

Source code in V3_3/src/super_gradients/training/utils/callbacks/callbacks.py
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
def create_lr_scheduler_callback(
    lr_mode: Union[str, Mapping],
    train_loader: DataLoader,
    net: torch.nn.Module,
    training_params: Mapping,
    update_param_groups: bool,
    optimizer: torch.optim.Optimizer,
) -> PhaseCallback:
    """
    Creates the phase callback in charge of LR scheduling, to be used by Trainer.

    :param lr_mode: Union[str, Mapping],

                    When str:

                    Learning rate scheduling policy, one of ['StepLRScheduler','PolyLRScheduler','CosineLRScheduler','FunctionLRScheduler'].

                    'StepLRScheduler' refers to constant updates at epoch numbers passed through `lr_updates`.
                        Each update decays the learning rate by `lr_decay_factor`.

                    'CosineLRScheduler' refers to the Cosine Anealing policy as mentioned in https://arxiv.org/abs/1608.03983.
                      The final learning rate ratio is controlled by `cosine_final_lr_ratio` training parameter.

                    'PolyLRScheduler' refers to the polynomial decrease:
                        in each epoch iteration `self.lr = self.initial_lr * pow((1.0 - (current_iter / max_iter)), 0.9)`

                    'FunctionLRScheduler' refers to a user-defined learning rate scheduling function, that is passed through `lr_schedule_function`.



                    When Mapping, refers to a torch.optim.lr_scheduler._LRScheduler, following the below API:

                        lr_mode = {LR_SCHEDULER_CLASS_NAME: {**LR_SCHEDULER_KWARGS, "phase": XXX, "metric_name": XXX)

                        Where "phase" (of Phase type) controls when to call torch.optim.lr_scheduler._LRScheduler.step().

                        For instance, in order to:
                        - Update LR on each batch: Use phase: Phase.TRAIN_BATCH_END
                        - Update LR after each epoch: Use phase: Phase.TRAIN_EPOCH_END

                        The "metric_name" refers to the metric to watch (See docs for "metric_to_watch" in train(...)
                         https://docs.deci.ai/super-gradients/docstring/training/sg_trainer.html) when using
                          ReduceLROnPlateau. In any other case this kwarg is ignored.

                        **LR_SCHEDULER_KWARGS are simply passed to the torch scheduler's __init__.




    :param train_loader: DataLoader, the Trainer.train_loader used for training.

    :param net: torch.nn.Module, the Trainer.net used for training.

    :param training_params: Mapping, Trainer.training_params.

    :param update_param_groups:bool,  Whether the Trainer.net has a specific way of updaitng its parameter group.

    :param optimizer: The optimizer used for training. Will be passed to the LR callback's __init__
     (or the torch scheduler's init, depending on the lr_mode value as described above).

    :return: a PhaseCallback instance to be used by Trainer for LR scheduling.
    """

    if isinstance(lr_mode, str) and lr_mode in LR_SCHEDULERS_CLS_DICT:
        sg_lr_callback_cls = LR_SCHEDULERS_CLS_DICT[lr_mode]
        sg_lr_callback = sg_lr_callback_cls(
            train_loader_len=len(train_loader),
            net=net,
            training_params=training_params,
            update_param_groups=update_param_groups,
            **training_params.to_dict(),
        )
    elif isinstance(lr_mode, Mapping) and list(lr_mode.keys())[0] in TORCH_LR_SCHEDULERS:
        if update_param_groups:
            logger.warning(
                "The network's way of updataing (i.e update_param_groups) is not supported with native " "torch lr schedulers and will have no effect."
            )
        lr_scheduler_name = list(lr_mode.keys())[0]
        torch_scheduler_params = {k: v for k, v in lr_mode[lr_scheduler_name].items() if k != "phase" and k != "metric_name"}
        torch_scheduler_params["optimizer"] = optimizer
        torch_scheduler = TORCH_LR_SCHEDULERS[lr_scheduler_name](**torch_scheduler_params)
        if get_param(lr_mode[lr_scheduler_name], "phase") is None:
            raise ValueError("Phase is required argument when working with torch schedulers.")

        if lr_scheduler_name == "ReduceLROnPlateau" and get_param(lr_mode[lr_scheduler_name], "metric_name") is None:
            raise ValueError("metric_name is required argument when working with ReduceLROnPlateau schedulers.")

        sg_lr_callback = LRSchedulerCallback(
            scheduler=torch_scheduler, phase=lr_mode[lr_scheduler_name]["phase"], metric_name=get_param(lr_mode[lr_scheduler_name], "metric_name")
        )
    else:
        raise ValueError(f"Unknown lr_mode: {lr_mode}")

    return sg_lr_callback

ExtremeBatchPoseEstimationVisualizationCallback

Bases: ExtremeBatchCaseVisualizationCallback

ExtremeBatchPoseEstimationVisualizationCallback

Visualizes worst/best batch in an epoch for pose estimation task. This class visualize horizontally-stacked GT and predicted poses. It requires a key 'gt_samples' (List[PoseEstimationSample]) to be present in additional_batch_items dictionary.

Supported models: YoloNASPose Supported datasets: COCOPoseEstimationDataset, CrowdPoseEstimationDataset, AnimalPoseEstimationDataset

Example usage in Yaml config:

training_hyperparams:
  phase_callbacks:
      - ExtremeBatchPoseEstimationVisualizationCallback:
          keypoint_colors: ${dataset_params.keypoint_colors}
          edge_colors: ${dataset_params.edge_colors}
          edge_links: ${dataset_params.edge_links}
          loss_to_monitor: YoloNASPoseLoss/loss
          max: True
          freq: 1
          max_images: 16
          enable_on_train_loader: True
          enable_on_valid_loader: True
          post_prediction_callback:
            _target_: super_gradients.training.models.pose_estimation_models.yolo_nas_pose.YoloNASPosePostPredictionCallback
            pose_confidence_threshold: 0.01
            nms_iou_threshold: 0.7
            pre_nms_max_predictions: 300
            post_nms_max_predictions: 30

Parameters:

Name Type Description Default
metric Optional[Metric]

Metric, will be the metric which is monitored.

None
metric_component_name Optional[str]

In case metric returns multiple values (as Mapping), the value at metric.compute()[metric_component_name] will be the one monitored.

None
loss_to_monitor Optional[str]

str, loss_to_monitor corresponding to the 'criterion' passed through training_params in Trainer.train(...). Monitoring loss follows the same logic as metric_to_watch in Trainer.train(..), when watching the loss and should be: if hasattr(criterion, "component_names") and criterion.forward(..) returns a tuple: "/". If a single item is returned rather then a tuple: . When there is no such attributes and criterion.forward(..) returns a tuple: "/"Loss_"

None
max bool

bool, Whether to take the batch corresponding to the max value of the metric/loss or the minimum (default=False).

False
freq int

int, epoch frequency to perform all of the above (default=1).

1
classes

List[str], a list of class names corresponding to the class indices for display. When None, will try to fetch this through a "classes" attribute of the valdiation dataset. If such attribute does not exist an error will be raised (default=None).

required
normalize_targets

bool, whether to scale the target bboxes. If the bboxes returned by the validation data loader are in pixel values range, this needs to be set to True (default=False)

required
Source code in V3_3/src/super_gradients/training/utils/callbacks/extreme_batch_pose_visualization_callback.py
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
@register_callback("ExtremeBatchPoseEstimationVisualizationCallback")
class ExtremeBatchPoseEstimationVisualizationCallback(ExtremeBatchCaseVisualizationCallback):
    """
    ExtremeBatchPoseEstimationVisualizationCallback

    Visualizes worst/best batch in an epoch for pose estimation task.
    This class visualize horizontally-stacked GT and predicted poses.
    It requires a key 'gt_samples' (List[PoseEstimationSample]) to be present in additional_batch_items dictionary.

    Supported models: YoloNASPose
    Supported datasets: COCOPoseEstimationDataset, CrowdPoseEstimationDataset, AnimalPoseEstimationDataset

    Example usage in Yaml config:

        training_hyperparams:
          phase_callbacks:
              - ExtremeBatchPoseEstimationVisualizationCallback:
                  keypoint_colors: ${dataset_params.keypoint_colors}
                  edge_colors: ${dataset_params.edge_colors}
                  edge_links: ${dataset_params.edge_links}
                  loss_to_monitor: YoloNASPoseLoss/loss
                  max: True
                  freq: 1
                  max_images: 16
                  enable_on_train_loader: True
                  enable_on_valid_loader: True
                  post_prediction_callback:
                    _target_: super_gradients.training.models.pose_estimation_models.yolo_nas_pose.YoloNASPosePostPredictionCallback
                    pose_confidence_threshold: 0.01
                    nms_iou_threshold: 0.7
                    pre_nms_max_predictions: 300
                    post_nms_max_predictions: 30

    :param metric: Metric, will be the metric which is monitored.

    :param metric_component_name: In case metric returns multiple values (as Mapping),
     the value at metric.compute()[metric_component_name] will be the one monitored.

    :param loss_to_monitor: str, loss_to_monitor corresponding to the 'criterion' passed through training_params in Trainer.train(...).
     Monitoring loss follows the same logic as metric_to_watch in Trainer.train(..), when watching the loss and should be:

        if hasattr(criterion, "component_names") and criterion.forward(..) returns a tuple:
            <LOSS_CLASS.__name__>"/"<COMPONENT_NAME>.

        If a single item is returned rather then a tuple:
            <LOSS_CLASS.__name__>.

        When there is no such attributes and criterion.forward(..) returns a tuple:
            <LOSS_CLASS.__name__>"/"Loss_"<IDX>

    :param max: bool, Whether to take the batch corresponding to the max value of the metric/loss or
     the minimum (default=False).

    :param freq: int, epoch frequency to perform all of the above (default=1).

    :param classes: List[str], a list of class names corresponding to the class indices for display.
     When None, will try to fetch this through a "classes" attribute of the valdiation dataset. If such attribute does
      not exist an error will be raised (default=None).

    :param normalize_targets: bool, whether to scale the target bboxes. If the bboxes returned by the validation data loader
     are in pixel values range, this needs to be set to True (default=False)

    """

    def __init__(
        self,
        post_prediction_callback: Callable,
        keypoint_colors: List[Tuple[int, int, int]],
        edge_colors: List[Tuple[int, int, int]],
        edge_links: List[Tuple[int, int]],
        metric: Optional[Metric] = None,
        metric_component_name: Optional[str] = None,
        loss_to_monitor: Optional[str] = None,
        max: bool = False,
        freq: int = 1,
        max_images: Optional[int] = None,
        enable_on_train_loader: bool = False,
        enable_on_valid_loader: bool = True,
    ):
        super().__init__(
            metric=metric,
            metric_component_name=metric_component_name,
            loss_to_monitor=loss_to_monitor,
            max=max,
            freq=freq,
            enable_on_train_loader=enable_on_train_loader,
            enable_on_valid_loader=enable_on_valid_loader,
        )
        self.post_prediction_callback = post_prediction_callback
        self.keypoint_colors = OmegaConf.to_container(keypoint_colors)
        self.edge_colors = OmegaConf.to_container(edge_colors)
        self.edge_links = OmegaConf.to_container(edge_links)
        self.max_images = max_images

    @classmethod
    def universal_undo_preprocessing_fn(cls, inputs: torch.Tensor) -> np.ndarray:
        """
        A universal reversing of preprocessing to be passed to DetectionVisualization.visualize_batch's undo_preprocessing_func kwarg.
        :param inputs:
        :return:
        """
        inputs = inputs - inputs.min()
        inputs /= inputs.max()
        inputs *= 255
        inputs = inputs.to(torch.uint8)
        inputs = inputs.cpu().numpy()
        inputs = inputs[:, ::-1, :, :].transpose(0, 2, 3, 1)
        inputs = np.ascontiguousarray(inputs, dtype=np.uint8)
        return inputs

    @classmethod
    def _visualize_batch(
        cls,
        image_tensor: np.ndarray,
        keypoints: List[Union[np.ndarray, Tensor]],
        bboxes: List[Union[None, np.ndarray, Tensor]],
        scores: Optional[List[Union[None, np.ndarray, Tensor]]],
        is_crowd: Optional[List[Union[None, np.ndarray, Tensor]]],
        keypoint_colors: List[Tuple[int, int, int]],
        edge_colors: List[Tuple[int, int, int]],
        edge_links: List[Tuple[int, int]],
        show_keypoint_confidence: bool,
    ) -> List[np.ndarray]:
        """
        Generate list of samples visualization of a batch of images with keypoints and bounding boxes.

        :param image_tensor:             Images batch of [Batch Size, 3, H, W] shape with values in [0, 255] range.
                                         The images should be scaled to [0, 255] range and converted to uint8 type beforehead.
        :param keypoints:                Keypoints in XY format. Shape [Num Instances, Num Joints, 2]. Can be None.
        :param bboxes:                   Bounding boxes in XYXY format. Shape [Num Instances, 4]. Can be None.
        :param scores:                   Keypoint scores. Shape [Num Instances, Num Joints]. Can be None.
        :param is_crowd:                 Whether each sample is crowd or not. Shape [Num Instances]. Can be None.
        :param keypoint_colors:          Keypoint colors. Shape [Num Joints, 3]
        :param edge_colors:              Edge colors between joints. Shape [Num Links, 3]
        :param edge_links:               Edge links between joints. Shape [Num Links, 2]
        :param show_keypoint_confidence: Whether to show confidence for each keypoint. Requires `scores` to be not None.
        :return:                         List of visualization images.
        """

        out_images = []
        for i in range(image_tensor.shape[0]):
            keypoints_i = keypoints[i]
            bboxes_i = bboxes[i]
            scores_i = scores[i] if scores is not None else None
            is_crowd_i = is_crowd[i] if is_crowd is not None else None

            if torch.is_tensor(keypoints_i):
                keypoints_i = keypoints_i.detach().cpu().numpy()
            if torch.is_tensor(bboxes_i):
                bboxes_i = bboxes_i.detach().cpu().numpy()
            if torch.is_tensor(scores_i):
                scores_i = scores_i.detach().cpu().numpy()
            if torch.is_tensor(is_crowd_i):
                is_crowd_i = is_crowd_i.detach().cpu().numpy()

            res_image = image_tensor[i]
            res_image = PoseVisualization.draw_poses(
                image=res_image,
                poses=keypoints_i,
                boxes=bboxes_i,
                scores=scores_i,
                is_crowd=is_crowd_i,
                show_keypoint_confidence=show_keypoint_confidence,
                edge_links=edge_links,
                edge_colors=edge_colors,
                keypoint_colors=keypoint_colors,
                keypoint_confidence_threshold=0.01,
            )

            out_images.append(res_image)

        return out_images

    @torch.no_grad()
    def process_extreme_batch(self) -> np.ndarray:
        """
        Processes the extreme batch, and returns batche of images for visualization - predictions and GT poses stacked horizontally.

        :return: np.ndarray - the visualization of predictions and GT
        """
        if "gt_samples" not in self.extreme_additional_batch_items:
            raise RuntimeError(
                "ExtremeBatchPoseEstimationVisualizationCallback requires 'gt_samples' to be present in additional_batch_items."
                "Currently only YoloNASPose model is supported. Old DEKR recipe is not supported at the moment."
            )

        inputs = self.universal_undo_preprocessing_fn(self.extreme_batch)
        gt_samples: List[PoseEstimationSample] = self.extreme_additional_batch_items["gt_samples"]
        predictions: List[PoseEstimationPredictions] = self.post_prediction_callback(self.extreme_preds)

        images_to_save_preds = self._visualize_batch(
            image_tensor=inputs,
            keypoints=[p.poses for p in predictions],
            bboxes=[p.bboxes_xyxy for p in predictions],
            scores=[p.scores for p in predictions],
            is_crowd=None,
            edge_links=self.edge_links,
            edge_colors=self.edge_colors,
            keypoint_colors=self.keypoint_colors,
            show_keypoint_confidence=True,
        )
        images_to_save_preds = np.stack(images_to_save_preds)

        images_to_save_gt = self._visualize_batch(
            image_tensor=inputs,
            keypoints=[gt.joints for gt in gt_samples],
            bboxes=[xywh_to_xyxy(gt.bboxes_xywh, image_shape=None) if gt.bboxes_xywh is not None else None for gt in gt_samples],
            scores=None,
            is_crowd=[gt.is_crowd for gt in gt_samples],
            edge_links=self.edge_links,
            edge_colors=self.edge_colors,
            keypoint_colors=self.keypoint_colors,
            show_keypoint_confidence=False,
        )
        images_to_save_gt = np.stack(images_to_save_gt)

        # Stack the predictions and GT images together
        return np.concatenate([images_to_save_gt, images_to_save_preds], axis=2)

process_extreme_batch()

Processes the extreme batch, and returns batche of images for visualization - predictions and GT poses stacked horizontally.

Returns:

Type Description
np.ndarray

np.ndarray - the visualization of predictions and GT

Source code in V3_3/src/super_gradients/training/utils/callbacks/extreme_batch_pose_visualization_callback.py
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
@torch.no_grad()
def process_extreme_batch(self) -> np.ndarray:
    """
    Processes the extreme batch, and returns batche of images for visualization - predictions and GT poses stacked horizontally.

    :return: np.ndarray - the visualization of predictions and GT
    """
    if "gt_samples" not in self.extreme_additional_batch_items:
        raise RuntimeError(
            "ExtremeBatchPoseEstimationVisualizationCallback requires 'gt_samples' to be present in additional_batch_items."
            "Currently only YoloNASPose model is supported. Old DEKR recipe is not supported at the moment."
        )

    inputs = self.universal_undo_preprocessing_fn(self.extreme_batch)
    gt_samples: List[PoseEstimationSample] = self.extreme_additional_batch_items["gt_samples"]
    predictions: List[PoseEstimationPredictions] = self.post_prediction_callback(self.extreme_preds)

    images_to_save_preds = self._visualize_batch(
        image_tensor=inputs,
        keypoints=[p.poses for p in predictions],
        bboxes=[p.bboxes_xyxy for p in predictions],
        scores=[p.scores for p in predictions],
        is_crowd=None,
        edge_links=self.edge_links,
        edge_colors=self.edge_colors,
        keypoint_colors=self.keypoint_colors,
        show_keypoint_confidence=True,
    )
    images_to_save_preds = np.stack(images_to_save_preds)

    images_to_save_gt = self._visualize_batch(
        image_tensor=inputs,
        keypoints=[gt.joints for gt in gt_samples],
        bboxes=[xywh_to_xyxy(gt.bboxes_xywh, image_shape=None) if gt.bboxes_xywh is not None else None for gt in gt_samples],
        scores=None,
        is_crowd=[gt.is_crowd for gt in gt_samples],
        edge_links=self.edge_links,
        edge_colors=self.edge_colors,
        keypoint_colors=self.keypoint_colors,
        show_keypoint_confidence=False,
    )
    images_to_save_gt = np.stack(images_to_save_gt)

    # Stack the predictions and GT images together
    return np.concatenate([images_to_save_gt, images_to_save_preds], axis=2)

universal_undo_preprocessing_fn(inputs) classmethod

A universal reversing of preprocessing to be passed to DetectionVisualization.visualize_batch's undo_preprocessing_func kwarg.

Parameters:

Name Type Description Default
inputs torch.Tensor required

Returns:

Type Description
np.ndarray
Source code in V3_3/src/super_gradients/training/utils/callbacks/extreme_batch_pose_visualization_callback.py
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
@classmethod
def universal_undo_preprocessing_fn(cls, inputs: torch.Tensor) -> np.ndarray:
    """
    A universal reversing of preprocessing to be passed to DetectionVisualization.visualize_batch's undo_preprocessing_func kwarg.
    :param inputs:
    :return:
    """
    inputs = inputs - inputs.min()
    inputs /= inputs.max()
    inputs *= 255
    inputs = inputs.to(torch.uint8)
    inputs = inputs.cpu().numpy()
    inputs = inputs[:, ::-1, :, :].transpose(0, 2, 3, 1)
    inputs = np.ascontiguousarray(inputs, dtype=np.uint8)
    return inputs

PPYoloETrainingStageSwitchCallback

Bases: TrainingStageSwitchCallbackBase

PPYoloETrainingStageSwitchCallback

Training stage switch for PPYolo training. It changes static bbox assigner to a task aligned assigned after certain number of epochs passed

Source code in V3_3/src/super_gradients/training/utils/callbacks/ppyoloe_switch_callback.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
@register_callback(Callbacks.PPYOLOE_TRAINING_STAGE_SWITCH)
class PPYoloETrainingStageSwitchCallback(TrainingStageSwitchCallbackBase):
    """
    PPYoloETrainingStageSwitchCallback

    Training stage switch for PPYolo training.
    It changes static bbox assigner to a task aligned assigned after certain number of epochs passed

    """

    def __init__(
        self,
        static_assigner_end_epoch: int = 30,
    ):
        super().__init__(next_stage_start_epoch=static_assigner_end_epoch)

    def apply_stage_change(self, context: PhaseContext):
        from super_gradients.training.losses import PPYoloELoss

        if not isinstance(context.criterion, PPYoloELoss):
            raise RuntimeError(
                f"A criterion must be an instance of PPYoloELoss when using PPYoloETrainingStageSwitchCallback. " f"Got criterion {repr(context.criterion)}"
            )
        context.criterion.use_static_assigner = False

DefaultCheckpointSolver

Implements the default behavior from adaptive_load_state_dict. If the model state dict and checkpoint state dict has no 1:1 matching by name, then default solver uses simple ordered matching. It assumes that order of layers in the checkpoint is the same as in the model and iterates over them simultaneously. If shape of the source and recipient tensors are different, solver raises an error.

Source code in V3_3/src/super_gradients/training/utils/checkpoint_utils.py
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
class DefaultCheckpointSolver:
    """
    Implements the default behavior from adaptive_load_state_dict.
    If the model state dict and checkpoint state dict has no 1:1 matching by name,
    then default solver uses simple ordered matching.
    It assumes that order of layers in the checkpoint is the same as in the model and
    iterates over them simultaneously.
    If shape of the source and recipient tensors are different, solver raises an error.
    """

    def __call__(self, model_state_dict: Mapping[str, Tensor], checkpoint_state_dict: Mapping[str, Tensor]) -> Mapping[str, Tensor]:
        """
        Map checkpoint state_dict to model state_dict.

        :param model_state_dict: (Mapping[str, Tensor]) A checkpoint state dict
        :param checkpoint_state_dict: (Mapping[str, Tensor]) A model state dict
        :return: (Mapping[str, Tensor]) New checkpoint state dict with keys/values converted to match model state_dict
        """
        new_ckpt_dict = {}
        for (ckpt_key, ckpt_val), (model_key, model_val) in zip(checkpoint_state_dict.items(), model_state_dict.items()):

            if ckpt_val.shape != model_val.shape:
                raise ValueError(f"ckpt layer {ckpt_key} with shape {ckpt_val.shape} does not match {model_key}" f" with shape {model_val.shape} in the model")
            new_ckpt_dict[model_key] = ckpt_val
        return new_ckpt_dict

__call__(model_state_dict, checkpoint_state_dict)

Map checkpoint state_dict to model state_dict.

Parameters:

Name Type Description Default
model_state_dict Mapping[str, Tensor]

(Mapping[str, Tensor]) A checkpoint state dict

required
checkpoint_state_dict Mapping[str, Tensor]

(Mapping[str, Tensor]) A model state dict

required

Returns:

Type Description
Mapping[str, Tensor]

(Mapping[str, Tensor]) New checkpoint state dict with keys/values converted to match model state_dict

Source code in V3_3/src/super_gradients/training/utils/checkpoint_utils.py
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
def __call__(self, model_state_dict: Mapping[str, Tensor], checkpoint_state_dict: Mapping[str, Tensor]) -> Mapping[str, Tensor]:
    """
    Map checkpoint state_dict to model state_dict.

    :param model_state_dict: (Mapping[str, Tensor]) A checkpoint state dict
    :param checkpoint_state_dict: (Mapping[str, Tensor]) A model state dict
    :return: (Mapping[str, Tensor]) New checkpoint state dict with keys/values converted to match model state_dict
    """
    new_ckpt_dict = {}
    for (ckpt_key, ckpt_val), (model_key, model_val) in zip(checkpoint_state_dict.items(), model_state_dict.items()):

        if ckpt_val.shape != model_val.shape:
            raise ValueError(f"ckpt layer {ckpt_key} with shape {ckpt_val.shape} does not match {model_key}" f" with shape {model_val.shape} in the model")
        new_ckpt_dict[model_key] = ckpt_val
    return new_ckpt_dict

MissingPretrainedWeightsException

Bases: Exception

Exception raised by unsupported pretrianed model.

Parameters:

Name Type Description Default
desc

explanation of the error

required
Source code in V3_3/src/super_gradients/training/utils/checkpoint_utils.py
1529
1530
1531
1532
1533
1534
1535
1536
1537
class MissingPretrainedWeightsException(Exception):
    """Exception raised by unsupported pretrianed model.

    :param desc: explanation of the error
    """

    def __init__(self, desc):
        self.message = "Missing pretrained wights: " + desc
        super().__init__(self.message)

YoloXCheckpointSolver

Implementation of checkpoint solver for old YoloX model checkpoints.

Source code in V3_3/src/super_gradients/training/utils/checkpoint_utils.py
 206
 207
 208
 209
 210
 211
 212
 213
 214
 215
 216
 217
 218
 219
 220
 221
 222
 223
 224
 225
 226
 227
 228
 229
 230
 231
 232
 233
 234
 235
 236
 237
 238
 239
 240
 241
 242
 243
 244
 245
 246
 247
 248
 249
 250
 251
 252
 253
 254
 255
 256
 257
 258
 259
 260
 261
 262
 263
 264
 265
 266
 267
 268
 269
 270
 271
 272
 273
 274
 275
 276
 277
 278
 279
 280
 281
 282
 283
 284
 285
 286
 287
 288
 289
 290
 291
 292
 293
 294
 295
 296
 297
 298
 299
 300
 301
 302
 303
 304
 305
 306
 307
 308
 309
 310
 311
 312
 313
 314
 315
 316
 317
 318
 319
 320
 321
 322
 323
 324
 325
 326
 327
 328
 329
 330
 331
 332
 333
 334
 335
 336
 337
 338
 339
 340
 341
 342
 343
 344
 345
 346
 347
 348
 349
 350
 351
 352
 353
 354
 355
 356
 357
 358
 359
 360
 361
 362
 363
 364
 365
 366
 367
 368
 369
 370
 371
 372
 373
 374
 375
 376
 377
 378
 379
 380
 381
 382
 383
 384
 385
 386
 387
 388
 389
 390
 391
 392
 393
 394
 395
 396
 397
 398
 399
 400
 401
 402
 403
 404
 405
 406
 407
 408
 409
 410
 411
 412
 413
 414
 415
 416
 417
 418
 419
 420
 421
 422
 423
 424
 425
 426
 427
 428
 429
 430
 431
 432
 433
 434
 435
 436
 437
 438
 439
 440
 441
 442
 443
 444
 445
 446
 447
 448
 449
 450
 451
 452
 453
 454
 455
 456
 457
 458
 459
 460
 461
 462
 463
 464
 465
 466
 467
 468
 469
 470
 471
 472
 473
 474
 475
 476
 477
 478
 479
 480
 481
 482
 483
 484
 485
 486
 487
 488
 489
 490
 491
 492
 493
 494
 495
 496
 497
 498
 499
 500
 501
 502
 503
 504
 505
 506
 507
 508
 509
 510
 511
 512
 513
 514
 515
 516
 517
 518
 519
 520
 521
 522
 523
 524
 525
 526
 527
 528
 529
 530
 531
 532
 533
 534
 535
 536
 537
 538
 539
 540
 541
 542
 543
 544
 545
 546
 547
 548
 549
 550
 551
 552
 553
 554
 555
 556
 557
 558
 559
 560
 561
 562
 563
 564
 565
 566
 567
 568
 569
 570
 571
 572
 573
 574
 575
 576
 577
 578
 579
 580
 581
 582
 583
 584
 585
 586
 587
 588
 589
 590
 591
 592
 593
 594
 595
 596
 597
 598
 599
 600
 601
 602
 603
 604
 605
 606
 607
 608
 609
 610
 611
 612
 613
 614
 615
 616
 617
 618
 619
 620
 621
 622
 623
 624
 625
 626
 627
 628
 629
 630
 631
 632
 633
 634
 635
 636
 637
 638
 639
 640
 641
 642
 643
 644
 645
 646
 647
 648
 649
 650
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
class YoloXCheckpointSolver:
    """
    Implementation of checkpoint solver for old YoloX model checkpoints.
    """

    @classmethod
    def generate_mapping_table(cls) -> Mapping[str, str]:
        """
        Helper method to generate mapping table between olx YoloX checkpoints and the current YoloX layer names.
        :return: A mapping dictionary {checkpoint_key: model_key}
        """
        from super_gradients.common.object_names import Models
        from super_gradients.training import models

        all_mapping_keys = {}
        model_names = [Models.YOLOX_N, Models.YOLOX_T, Models.YOLOX_S, Models.YOLOX_M, Models.YOLOX_L]

        for model_name in model_names:
            model_url = MODEL_URLS[model_name + "_coco"]
            state_dict = load_state_dict_from_url(model_url, progress=True, map_location="cpu")

            model = models.get(model_name, num_classes=80)
            model_state_dict = model.state_dict()
            checkpoint_state_dict = maybe_remove_module_prefix(state_dict["net"])
            new_sd = {
                k: v
                for k, v in checkpoint_state_dict.items()
                if k not in {"stride", "_head.anchors._anchors", "_head.anchors._anchor_grid", "_head.anchors._stride", "_head._modules_list.14.stride"}
            }

            for (model_key, model_value), (checkpoint_key, checkpoint_value) in zip(model_state_dict.items(), new_sd.items()):
                if model_value.size() == checkpoint_value.size() and model_key.split(".")[-1] == checkpoint_key.split(".")[-1]:
                    if checkpoint_key in all_mapping_keys:
                        assert all_mapping_keys[checkpoint_key] == model_key
                    all_mapping_keys[checkpoint_key] = model_key
                else:
                    raise RuntimeError(
                        "Detected mismatch between model and checkpoint state dict keys."
                        f"Model key {model_key} of shape {model_value.size()} does not "
                        f"match checkpoint key {checkpoint_key} of shape {checkpoint_value.size()}"
                    )

        return all_mapping_keys

    def __init__(self):
        # The layers_rename_table below is a result of a manual mapping between the checkpoint keys and the model keys.
        # It was code-generated using YoloXCheckpointSolver.generate_mapping_table() method and tested for
        # correctness with:
        # tests.unit_tests.yolox_unit_test.TestYOLOX.test_yolo_x_checkpoint_solver.
        # tests.unit_tests.test_predict.TestModelPredict.test_detection_models

        self.layers_rename_table = {
            "_backbone._modules_list.0.conv.bn.bias": "_backbone._modules_list.0.bn.bias",
            "_backbone._modules_list.0.conv.bn.num_batches_tracked": "_backbone._modules_list.0.bn.num_batches_tracked",
            "_backbone._modules_list.0.conv.bn.running_mean": "_backbone._modules_list.0.bn.running_mean",
            "_backbone._modules_list.0.conv.bn.running_var": "_backbone._modules_list.0.bn.running_var",
            "_backbone._modules_list.0.conv.bn.weight": "_backbone._modules_list.0.bn.weight",
            "_backbone._modules_list.1.bn.bias": "_backbone._modules_list.1.bn.bias",
            "_backbone._modules_list.1.bn.num_batches_tracked": "_backbone._modules_list.1.bn.num_batches_tracked",
            "_backbone._modules_list.1.bn.running_mean": "_backbone._modules_list.1.bn.running_mean",
            "_backbone._modules_list.1.bn.running_var": "_backbone._modules_list.1.bn.running_var",
            "_backbone._modules_list.1.bn.weight": "_backbone._modules_list.1.bn.weight",
            "_backbone._modules_list.1.conv.bn.bias": "_backbone._modules_list.1.conv.bn.bias",
            "_backbone._modules_list.1.conv.bn.num_batches_tracked": "_backbone._modules_list.1.conv.bn.num_batches_tracked",
            "_backbone._modules_list.1.conv.bn.running_mean": "_backbone._modules_list.1.conv.bn.running_mean",
            "_backbone._modules_list.1.conv.bn.running_var": "_backbone._modules_list.1.conv.bn.running_var",
            "_backbone._modules_list.1.conv.bn.weight": "_backbone._modules_list.1.conv.bn.weight",
            "_backbone._modules_list.1.conv.conv.weight": "_backbone._modules_list.1.conv.conv.weight",
            "_backbone._modules_list.1.conv.weight": "_backbone._modules_list.1.conv.weight",
            "_backbone._modules_list.1.dconv.bn.bias": "_backbone._modules_list.1.dconv.bn.bias",
            "_backbone._modules_list.1.dconv.bn.num_batches_tracked": "_backbone._modules_list.1.dconv.bn.num_batches_tracked",
            "_backbone._modules_list.1.dconv.bn.running_mean": "_backbone._modules_list.1.dconv.bn.running_mean",
            "_backbone._modules_list.1.dconv.bn.running_var": "_backbone._modules_list.1.dconv.bn.running_var",
            "_backbone._modules_list.1.dconv.bn.weight": "_backbone._modules_list.1.dconv.bn.weight",
            "_backbone._modules_list.1.dconv.conv.weight": "_backbone._modules_list.1.dconv.conv.weight",
            "_backbone._modules_list.2.cv1.bn.bias": "_backbone._modules_list.2.conv1.bn.bias",
            "_backbone._modules_list.2.cv1.bn.num_batches_tracked": "_backbone._modules_list.2.conv1.bn.num_batches_tracked",
            "_backbone._modules_list.2.cv1.bn.running_mean": "_backbone._modules_list.2.conv1.bn.running_mean",
            "_backbone._modules_list.2.cv1.bn.running_var": "_backbone._modules_list.2.conv1.bn.running_var",
            "_backbone._modules_list.2.cv1.bn.weight": "_backbone._modules_list.2.conv1.bn.weight",
            "_backbone._modules_list.2.cv1.conv.weight": "_backbone._modules_list.2.conv1.conv.weight",
            "_backbone._modules_list.2.cv2.bn.bias": "_backbone._modules_list.2.conv2.bn.bias",
            "_backbone._modules_list.2.cv2.bn.num_batches_tracked": "_backbone._modules_list.2.conv2.bn.num_batches_tracked",
            "_backbone._modules_list.2.cv2.bn.running_mean": "_backbone._modules_list.2.conv2.bn.running_mean",
            "_backbone._modules_list.2.cv2.bn.running_var": "_backbone._modules_list.2.conv2.bn.running_var",
            "_backbone._modules_list.2.cv2.bn.weight": "_backbone._modules_list.2.conv2.bn.weight",
            "_backbone._modules_list.2.cv2.conv.weight": "_backbone._modules_list.2.conv2.conv.weight",
            "_backbone._modules_list.2.cv3.bn.bias": "_backbone._modules_list.2.conv3.bn.bias",
            "_backbone._modules_list.2.cv3.bn.num_batches_tracked": "_backbone._modules_list.2.conv3.bn.num_batches_tracked",
            "_backbone._modules_list.2.cv3.bn.running_mean": "_backbone._modules_list.2.conv3.bn.running_mean",
            "_backbone._modules_list.2.cv3.bn.running_var": "_backbone._modules_list.2.conv3.bn.running_var",
            "_backbone._modules_list.2.cv3.bn.weight": "_backbone._modules_list.2.conv3.bn.weight",
            "_backbone._modules_list.2.cv3.conv.weight": "_backbone._modules_list.2.conv3.conv.weight",
            "_backbone._modules_list.2.m.0.cv1.bn.bias": "_backbone._modules_list.2.bottlenecks.0.cv1.bn.bias",
            "_backbone._modules_list.2.m.0.cv1.bn.num_batches_tracked": "_backbone._modules_list.2.bottlenecks.0.cv1.bn.num_batches_tracked",
            "_backbone._modules_list.2.m.0.cv1.bn.running_mean": "_backbone._modules_list.2.bottlenecks.0.cv1.bn.running_mean",
            "_backbone._modules_list.2.m.0.cv1.bn.running_var": "_backbone._modules_list.2.bottlenecks.0.cv1.bn.running_var",
            "_backbone._modules_list.2.m.0.cv1.bn.weight": "_backbone._modules_list.2.bottlenecks.0.cv1.bn.weight",
            "_backbone._modules_list.2.m.0.cv1.conv.weight": "_backbone._modules_list.2.bottlenecks.0.cv1.conv.weight",
            "_backbone._modules_list.2.m.0.cv2.bn.bias": "_backbone._modules_list.2.bottlenecks.0.cv2.bn.bias",
            "_backbone._modules_list.2.m.0.cv2.bn.num_batches_tracked": "_backbone._modules_list.2.bottlenecks.0.cv2.bn.num_batches_tracked",
            "_backbone._modules_list.2.m.0.cv2.bn.running_mean": "_backbone._modules_list.2.bottlenecks.0.cv2.bn.running_mean",
            "_backbone._modules_list.2.m.0.cv2.bn.running_var": "_backbone._modules_list.2.bottlenecks.0.cv2.bn.running_var",
            "_backbone._modules_list.2.m.0.cv2.bn.weight": "_backbone._modules_list.2.bottlenecks.0.cv2.bn.weight",
            "_backbone._modules_list.2.m.0.cv2.conv.bn.bias": "_backbone._modules_list.2.bottlenecks.0.cv2.conv.bn.bias",
            "_backbone._modules_list.2.m.0.cv2.conv.bn.num_batches_tracked": "_backbone._modules_list.2.bottlenecks.0.cv2.conv.bn.num_batches_tracked",
            "_backbone._modules_list.2.m.0.cv2.conv.bn.running_mean": "_backbone._modules_list.2.bottlenecks.0.cv2.conv.bn.running_mean",
            "_backbone._modules_list.2.m.0.cv2.conv.bn.running_var": "_backbone._modules_list.2.bottlenecks.0.cv2.conv.bn.running_var",
            "_backbone._modules_list.2.m.0.cv2.conv.bn.weight": "_backbone._modules_list.2.bottlenecks.0.cv2.conv.bn.weight",
            "_backbone._modules_list.2.m.0.cv2.conv.conv.weight": "_backbone._modules_list.2.bottlenecks.0.cv2.conv.conv.weight",
            "_backbone._modules_list.2.m.0.cv2.conv.weight": "_backbone._modules_list.2.bottlenecks.0.cv2.conv.weight",
            "_backbone._modules_list.2.m.0.cv2.dconv.bn.bias": "_backbone._modules_list.2.bottlenecks.0.cv2.dconv.bn.bias",
            "_backbone._modules_list.2.m.0.cv2.dconv.bn.num_batches_tracked": "_backbone._modules_list.2.bottlenecks.0.cv2.dconv.bn.num_batches_tracked",
            "_backbone._modules_list.2.m.0.cv2.dconv.bn.running_mean": "_backbone._modules_list.2.bottlenecks.0.cv2.dconv.bn.running_mean",
            "_backbone._modules_list.2.m.0.cv2.dconv.bn.running_var": "_backbone._modules_list.2.bottlenecks.0.cv2.dconv.bn.running_var",
            "_backbone._modules_list.2.m.0.cv2.dconv.bn.weight": "_backbone._modules_list.2.bottlenecks.0.cv2.dconv.bn.weight",
            "_backbone._modules_list.2.m.0.cv2.dconv.conv.weight": "_backbone._modules_list.2.bottlenecks.0.cv2.dconv.conv.weight",
            "_backbone._modules_list.2.m.1.cv1.bn.bias": "_backbone._modules_list.2.bottlenecks.1.cv1.bn.bias",
            "_backbone._modules_list.2.m.1.cv1.bn.num_batches_tracked": "_backbone._modules_list.2.bottlenecks.1.cv1.bn.num_batches_tracked",
            "_backbone._modules_list.2.m.1.cv1.bn.running_mean": "_backbone._modules_list.2.bottlenecks.1.cv1.bn.running_mean",
            "_backbone._modules_list.2.m.1.cv1.bn.running_var": "_backbone._modules_list.2.bottlenecks.1.cv1.bn.running_var",
            "_backbone._modules_list.2.m.1.cv1.bn.weight": "_backbone._modules_list.2.bottlenecks.1.cv1.bn.weight",
            "_backbone._modules_list.2.m.1.cv1.conv.weight": "_backbone._modules_list.2.bottlenecks.1.cv1.conv.weight",
            "_backbone._modules_list.2.m.1.cv2.bn.bias": "_backbone._modules_list.2.bottlenecks.1.cv2.bn.bias",
            "_backbone._modules_list.2.m.1.cv2.bn.num_batches_tracked": "_backbone._modules_list.2.bottlenecks.1.cv2.bn.num_batches_tracked",
            "_backbone._modules_list.2.m.1.cv2.bn.running_mean": "_backbone._modules_list.2.bottlenecks.1.cv2.bn.running_mean",
            "_backbone._modules_list.2.m.1.cv2.bn.running_var": "_backbone._modules_list.2.bottlenecks.1.cv2.bn.running_var",
            "_backbone._modules_list.2.m.1.cv2.bn.weight": "_backbone._modules_list.2.bottlenecks.1.cv2.bn.weight",
            "_backbone._modules_list.2.m.1.cv2.conv.weight": "_backbone._modules_list.2.bottlenecks.1.cv2.conv.weight",
            "_backbone._modules_list.2.m.2.cv1.bn.bias": "_backbone._modules_list.2.bottlenecks.2.cv1.bn.bias",
            "_backbone._modules_list.2.m.2.cv1.bn.num_batches_tracked": "_backbone._modules_list.2.bottlenecks.2.cv1.bn.num_batches_tracked",
            "_backbone._modules_list.2.m.2.cv1.bn.running_mean": "_backbone._modules_list.2.bottlenecks.2.cv1.bn.running_mean",
            "_backbone._modules_list.2.m.2.cv1.bn.running_var": "_backbone._modules_list.2.bottlenecks.2.cv1.bn.running_var",
            "_backbone._modules_list.2.m.2.cv1.bn.weight": "_backbone._modules_list.2.bottlenecks.2.cv1.bn.weight",
            "_backbone._modules_list.2.m.2.cv1.conv.weight": "_backbone._modules_list.2.bottlenecks.2.cv1.conv.weight",
            "_backbone._modules_list.2.m.2.cv2.bn.bias": "_backbone._modules_list.2.bottlenecks.2.cv2.bn.bias",
            "_backbone._modules_list.2.m.2.cv2.bn.num_batches_tracked": "_backbone._modules_list.2.bottlenecks.2.cv2.bn.num_batches_tracked",
            "_backbone._modules_list.2.m.2.cv2.bn.running_mean": "_backbone._modules_list.2.bottlenecks.2.cv2.bn.running_mean",
            "_backbone._modules_list.2.m.2.cv2.bn.running_var": "_backbone._modules_list.2.bottlenecks.2.cv2.bn.running_var",
            "_backbone._modules_list.2.m.2.cv2.bn.weight": "_backbone._modules_list.2.bottlenecks.2.cv2.bn.weight",
            "_backbone._modules_list.2.m.2.cv2.conv.weight": "_backbone._modules_list.2.bottlenecks.2.cv2.conv.weight",
            "_backbone._modules_list.3.bn.bias": "_backbone._modules_list.3.bn.bias",
            "_backbone._modules_list.3.bn.num_batches_tracked": "_backbone._modules_list.3.bn.num_batches_tracked",
            "_backbone._modules_list.3.bn.running_mean": "_backbone._modules_list.3.bn.running_mean",
            "_backbone._modules_list.3.bn.running_var": "_backbone._modules_list.3.bn.running_var",
            "_backbone._modules_list.3.bn.weight": "_backbone._modules_list.3.bn.weight",
            "_backbone._modules_list.3.conv.bn.bias": "_backbone._modules_list.3.conv.bn.bias",
            "_backbone._modules_list.3.conv.bn.num_batches_tracked": "_backbone._modules_list.3.conv.bn.num_batches_tracked",
            "_backbone._modules_list.3.conv.bn.running_mean": "_backbone._modules_list.3.conv.bn.running_mean",
            "_backbone._modules_list.3.conv.bn.running_var": "_backbone._modules_list.3.conv.bn.running_var",
            "_backbone._modules_list.3.conv.bn.weight": "_backbone._modules_list.3.conv.bn.weight",
            "_backbone._modules_list.3.conv.conv.weight": "_backbone._modules_list.3.conv.conv.weight",
            "_backbone._modules_list.3.conv.weight": "_backbone._modules_list.3.conv.weight",
            "_backbone._modules_list.3.dconv.bn.bias": "_backbone._modules_list.3.dconv.bn.bias",
            "_backbone._modules_list.3.dconv.bn.num_batches_tracked": "_backbone._modules_list.3.dconv.bn.num_batches_tracked",
            "_backbone._modules_list.3.dconv.bn.running_mean": "_backbone._modules_list.3.dconv.bn.running_mean",
            "_backbone._modules_list.3.dconv.bn.running_var": "_backbone._modules_list.3.dconv.bn.running_var",
            "_backbone._modules_list.3.dconv.bn.weight": "_backbone._modules_list.3.dconv.bn.weight",
            "_backbone._modules_list.3.dconv.conv.weight": "_backbone._modules_list.3.dconv.conv.weight",
            "_backbone._modules_list.4.cv1.bn.bias": "_backbone._modules_list.4.conv1.bn.bias",
            "_backbone._modules_list.4.cv1.bn.num_batches_tracked": "_backbone._modules_list.4.conv1.bn.num_batches_tracked",
            "_backbone._modules_list.4.cv1.bn.running_mean": "_backbone._modules_list.4.conv1.bn.running_mean",
            "_backbone._modules_list.4.cv1.bn.running_var": "_backbone._modules_list.4.conv1.bn.running_var",
            "_backbone._modules_list.4.cv1.bn.weight": "_backbone._modules_list.4.conv1.bn.weight",
            "_backbone._modules_list.4.cv1.conv.weight": "_backbone._modules_list.4.conv1.conv.weight",
            "_backbone._modules_list.4.cv2.bn.bias": "_backbone._modules_list.4.conv2.bn.bias",
            "_backbone._modules_list.4.cv2.bn.num_batches_tracked": "_backbone._modules_list.4.conv2.bn.num_batches_tracked",
            "_backbone._modules_list.4.cv2.bn.running_mean": "_backbone._modules_list.4.conv2.bn.running_mean",
            "_backbone._modules_list.4.cv2.bn.running_var": "_backbone._modules_list.4.conv2.bn.running_var",
            "_backbone._modules_list.4.cv2.bn.weight": "_backbone._modules_list.4.conv2.bn.weight",
            "_backbone._modules_list.4.cv2.conv.weight": "_backbone._modules_list.4.conv2.conv.weight",
            "_backbone._modules_list.4.cv3.bn.bias": "_backbone._modules_list.4.conv3.bn.bias",
            "_backbone._modules_list.4.cv3.bn.num_batches_tracked": "_backbone._modules_list.4.conv3.bn.num_batches_tracked",
            "_backbone._modules_list.4.cv3.bn.running_mean": "_backbone._modules_list.4.conv3.bn.running_mean",
            "_backbone._modules_list.4.cv3.bn.running_var": "_backbone._modules_list.4.conv3.bn.running_var",
            "_backbone._modules_list.4.cv3.bn.weight": "_backbone._modules_list.4.conv3.bn.weight",
            "_backbone._modules_list.4.cv3.conv.weight": "_backbone._modules_list.4.conv3.conv.weight",
            "_backbone._modules_list.4.m.0.cv1.bn.bias": "_backbone._modules_list.4.bottlenecks.0.cv1.bn.bias",
            "_backbone._modules_list.4.m.0.cv1.bn.num_batches_tracked": "_backbone._modules_list.4.bottlenecks.0.cv1.bn.num_batches_tracked",
            "_backbone._modules_list.4.m.0.cv1.bn.running_mean": "_backbone._modules_list.4.bottlenecks.0.cv1.bn.running_mean",
            "_backbone._modules_list.4.m.0.cv1.bn.running_var": "_backbone._modules_list.4.bottlenecks.0.cv1.bn.running_var",
            "_backbone._modules_list.4.m.0.cv1.bn.weight": "_backbone._modules_list.4.bottlenecks.0.cv1.bn.weight",
            "_backbone._modules_list.4.m.0.cv1.conv.weight": "_backbone._modules_list.4.bottlenecks.0.cv1.conv.weight",
            "_backbone._modules_list.4.m.0.cv2.bn.bias": "_backbone._modules_list.4.bottlenecks.0.cv2.bn.bias",
            "_backbone._modules_list.4.m.0.cv2.bn.num_batches_tracked": "_backbone._modules_list.4.bottlenecks.0.cv2.bn.num_batches_tracked",
            "_backbone._modules_list.4.m.0.cv2.bn.running_mean": "_backbone._modules_list.4.bottlenecks.0.cv2.bn.running_mean",
            "_backbone._modules_list.4.m.0.cv2.bn.running_var": "_backbone._modules_list.4.bottlenecks.0.cv2.bn.running_var",
            "_backbone._modules_list.4.m.0.cv2.bn.weight": "_backbone._modules_list.4.bottlenecks.0.cv2.bn.weight",
            "_backbone._modules_list.4.m.0.cv2.conv.bn.bias": "_backbone._modules_list.4.bottlenecks.0.cv2.conv.bn.bias",
            "_backbone._modules_list.4.m.0.cv2.conv.bn.num_batches_tracked": "_backbone._modules_list.4.bottlenecks.0.cv2.conv.bn.num_batches_tracked",
            "_backbone._modules_list.4.m.0.cv2.conv.bn.running_mean": "_backbone._modules_list.4.bottlenecks.0.cv2.conv.bn.running_mean",
            "_backbone._modules_list.4.m.0.cv2.conv.bn.running_var": "_backbone._modules_list.4.bottlenecks.0.cv2.conv.bn.running_var",
            "_backbone._modules_list.4.m.0.cv2.conv.bn.weight": "_backbone._modules_list.4.bottlenecks.0.cv2.conv.bn.weight",
            "_backbone._modules_list.4.m.0.cv2.conv.conv.weight": "_backbone._modules_list.4.bottlenecks.0.cv2.conv.conv.weight",
            "_backbone._modules_list.4.m.0.cv2.conv.weight": "_backbone._modules_list.4.bottlenecks.0.cv2.conv.weight",
            "_backbone._modules_list.4.m.0.cv2.dconv.bn.bias": "_backbone._modules_list.4.bottlenecks.0.cv2.dconv.bn.bias",
            "_backbone._modules_list.4.m.0.cv2.dconv.bn.num_batches_tracked": "_backbone._modules_list.4.bottlenecks.0.cv2.dconv.bn.num_batches_tracked",
            "_backbone._modules_list.4.m.0.cv2.dconv.bn.running_mean": "_backbone._modules_list.4.bottlenecks.0.cv2.dconv.bn.running_mean",
            "_backbone._modules_list.4.m.0.cv2.dconv.bn.running_var": "_backbone._modules_list.4.bottlenecks.0.cv2.dconv.bn.running_var",
            "_backbone._modules_list.4.m.0.cv2.dconv.bn.weight": "_backbone._modules_list.4.bottlenecks.0.cv2.dconv.bn.weight",
            "_backbone._modules_list.4.m.0.cv2.dconv.conv.weight": "_backbone._modules_list.4.bottlenecks.0.cv2.dconv.conv.weight",
            "_backbone._modules_list.4.m.1.cv1.bn.bias": "_backbone._modules_list.4.bottlenecks.1.cv1.bn.bias",
            "_backbone._modules_list.4.m.1.cv1.bn.num_batches_tracked": "_backbone._modules_list.4.bottlenecks.1.cv1.bn.num_batches_tracked",
            "_backbone._modules_list.4.m.1.cv1.bn.running_mean": "_backbone._modules_list.4.bottlenecks.1.cv1.bn.running_mean",
            "_backbone._modules_list.4.m.1.cv1.bn.running_var": "_backbone._modules_list.4.bottlenecks.1.cv1.bn.running_var",
            "_backbone._modules_list.4.m.1.cv1.bn.weight": "_backbone._modules_list.4.bottlenecks.1.cv1.bn.weight",
            "_backbone._modules_list.4.m.1.cv1.conv.weight": "_backbone._modules_list.4.bottlenecks.1.cv1.conv.weight",
            "_backbone._modules_list.4.m.1.cv2.bn.bias": "_backbone._modules_list.4.bottlenecks.1.cv2.bn.bias",
            "_backbone._modules_list.4.m.1.cv2.bn.num_batches_tracked": "_backbone._modules_list.4.bottlenecks.1.cv2.bn.num_batches_tracked",
            "_backbone._modules_list.4.m.1.cv2.bn.running_mean": "_backbone._modules_list.4.bottlenecks.1.cv2.bn.running_mean",
            "_backbone._modules_list.4.m.1.cv2.bn.running_var": "_backbone._modules_list.4.bottlenecks.1.cv2.bn.running_var",
            "_backbone._modules_list.4.m.1.cv2.bn.weight": "_backbone._modules_list.4.bottlenecks.1.cv2.bn.weight",
            "_backbone._modules_list.4.m.1.cv2.conv.bn.bias": "_backbone._modules_list.4.bottlenecks.1.cv2.conv.bn.bias",
            "_backbone._modules_list.4.m.1.cv2.conv.bn.num_batches_tracked": "_backbone._modules_list.4.bottlenecks.1.cv2.conv.bn.num_batches_tracked",
            "_backbone._modules_list.4.m.1.cv2.conv.bn.running_mean": "_backbone._modules_list.4.bottlenecks.1.cv2.conv.bn.running_mean",
            "_backbone._modules_list.4.m.1.cv2.conv.bn.running_var": "_backbone._modules_list.4.bottlenecks.1.cv2.conv.bn.running_var",
            "_backbone._modules_list.4.m.1.cv2.conv.bn.weight": "_backbone._modules_list.4.bottlenecks.1.cv2.conv.bn.weight",
            "_backbone._modules_list.4.m.1.cv2.conv.conv.weight": "_backbone._modules_list.4.bottlenecks.1.cv2.conv.conv.weight",
            "_backbone._modules_list.4.m.1.cv2.conv.weight": "_backbone._modules_list.4.bottlenecks.1.cv2.conv.weight",
            "_backbone._modules_list.4.m.1.cv2.dconv.bn.bias": "_backbone._modules_list.4.bottlenecks.1.cv2.dconv.bn.bias",
            "_backbone._modules_list.4.m.1.cv2.dconv.bn.num_batches_tracked": "_backbone._modules_list.4.bottlenecks.1.cv2.dconv.bn.num_batches_tracked",
            "_backbone._modules_list.4.m.1.cv2.dconv.bn.running_mean": "_backbone._modules_list.4.bottlenecks.1.cv2.dconv.bn.running_mean",
            "_backbone._modules_list.4.m.1.cv2.dconv.bn.running_var": "_backbone._modules_list.4.bottlenecks.1.cv2.dconv.bn.running_var",
            "_backbone._modules_list.4.m.1.cv2.dconv.bn.weight": "_backbone._modules_list.4.bottlenecks.1.cv2.dconv.bn.weight",
            "_backbone._modules_list.4.m.1.cv2.dconv.conv.weight": "_backbone._modules_list.4.bottlenecks.1.cv2.dconv.conv.weight",
            "_backbone._modules_list.4.m.2.cv1.bn.bias": "_backbone._modules_list.4.bottlenecks.2.cv1.bn.bias",
            "_backbone._modules_list.4.m.2.cv1.bn.num_batches_tracked": "_backbone._modules_list.4.bottlenecks.2.cv1.bn.num_batches_tracked",
            "_backbone._modules_list.4.m.2.cv1.bn.running_mean": "_backbone._modules_list.4.bottlenecks.2.cv1.bn.running_mean",
            "_backbone._modules_list.4.m.2.cv1.bn.running_var": "_backbone._modules_list.4.bottlenecks.2.cv1.bn.running_var",
            "_backbone._modules_list.4.m.2.cv1.bn.weight": "_backbone._modules_list.4.bottlenecks.2.cv1.bn.weight",
            "_backbone._modules_list.4.m.2.cv1.conv.weight": "_backbone._modules_list.4.bottlenecks.2.cv1.conv.weight",
            "_backbone._modules_list.4.m.2.cv2.bn.bias": "_backbone._modules_list.4.bottlenecks.2.cv2.bn.bias",
            "_backbone._modules_list.4.m.2.cv2.bn.num_batches_tracked": "_backbone._modules_list.4.bottlenecks.2.cv2.bn.num_batches_tracked",
            "_backbone._modules_list.4.m.2.cv2.bn.running_mean": "_backbone._modules_list.4.bottlenecks.2.cv2.bn.running_mean",
            "_backbone._modules_list.4.m.2.cv2.bn.running_var": "_backbone._modules_list.4.bottlenecks.2.cv2.bn.running_var",
            "_backbone._modules_list.4.m.2.cv2.bn.weight": "_backbone._modules_list.4.bottlenecks.2.cv2.bn.weight",
            "_backbone._modules_list.4.m.2.cv2.conv.bn.bias": "_backbone._modules_list.4.bottlenecks.2.cv2.conv.bn.bias",
            "_backbone._modules_list.4.m.2.cv2.conv.bn.num_batches_tracked": "_backbone._modules_list.4.bottlenecks.2.cv2.conv.bn.num_batches_tracked",
            "_backbone._modules_list.4.m.2.cv2.conv.bn.running_mean": "_backbone._modules_list.4.bottlenecks.2.cv2.conv.bn.running_mean",
            "_backbone._modules_list.4.m.2.cv2.conv.bn.running_var": "_backbone._modules_list.4.bottlenecks.2.cv2.conv.bn.running_var",
            "_backbone._modules_list.4.m.2.cv2.conv.bn.weight": "_backbone._modules_list.4.bottlenecks.2.cv2.conv.bn.weight",
            "_backbone._modules_list.4.m.2.cv2.conv.conv.weight": "_backbone._modules_list.4.bottlenecks.2.cv2.conv.conv.weight",
            "_backbone._modules_list.4.m.2.cv2.conv.weight": "_backbone._modules_list.4.bottlenecks.2.cv2.conv.weight",
            "_backbone._modules_list.4.m.2.cv2.dconv.bn.bias": "_backbone._modules_list.4.bottlenecks.2.cv2.dconv.bn.bias",
            "_backbone._modules_list.4.m.2.cv2.dconv.bn.num_batches_tracked": "_backbone._modules_list.4.bottlenecks.2.cv2.dconv.bn.num_batches_tracked",
            "_backbone._modules_list.4.m.2.cv2.dconv.bn.running_mean": "_backbone._modules_list.4.bottlenecks.2.cv2.dconv.bn.running_mean",
            "_backbone._modules_list.4.m.2.cv2.dconv.bn.running_var": "_backbone._modules_list.4.bottlenecks.2.cv2.dconv.bn.running_var",
            "_backbone._modules_list.4.m.2.cv2.dconv.bn.weight": "_backbone._modules_list.4.bottlenecks.2.cv2.dconv.bn.weight",
            "_backbone._modules_list.4.m.2.cv2.dconv.conv.weight": "_backbone._modules_list.4.bottlenecks.2.cv2.dconv.conv.weight",
            "_backbone._modules_list.4.m.3.cv1.bn.bias": "_backbone._modules_list.4.bottlenecks.3.cv1.bn.bias",
            "_backbone._modules_list.4.m.3.cv1.bn.num_batches_tracked": "_backbone._modules_list.4.bottlenecks.3.cv1.bn.num_batches_tracked",
            "_backbone._modules_list.4.m.3.cv1.bn.running_mean": "_backbone._modules_list.4.bottlenecks.3.cv1.bn.running_mean",
            "_backbone._modules_list.4.m.3.cv1.bn.running_var": "_backbone._modules_list.4.bottlenecks.3.cv1.bn.running_var",
            "_backbone._modules_list.4.m.3.cv1.bn.weight": "_backbone._modules_list.4.bottlenecks.3.cv1.bn.weight",
            "_backbone._modules_list.4.m.3.cv1.conv.weight": "_backbone._modules_list.4.bottlenecks.3.cv1.conv.weight",
            "_backbone._modules_list.4.m.3.cv2.bn.bias": "_backbone._modules_list.4.bottlenecks.3.cv2.bn.bias",
            "_backbone._modules_list.4.m.3.cv2.bn.num_batches_tracked": "_backbone._modules_list.4.bottlenecks.3.cv2.bn.num_batches_tracked",
            "_backbone._modules_list.4.m.3.cv2.bn.running_mean": "_backbone._modules_list.4.bottlenecks.3.cv2.bn.running_mean",
            "_backbone._modules_list.4.m.3.cv2.bn.running_var": "_backbone._modules_list.4.bottlenecks.3.cv2.bn.running_var",
            "_backbone._modules_list.4.m.3.cv2.bn.weight": "_backbone._modules_list.4.bottlenecks.3.cv2.bn.weight",
            "_backbone._modules_list.4.m.3.cv2.conv.weight": "_backbone._modules_list.4.bottlenecks.3.cv2.conv.weight",
            "_backbone._modules_list.4.m.4.cv1.bn.bias": "_backbone._modules_list.4.bottlenecks.4.cv1.bn.bias",
            "_backbone._modules_list.4.m.4.cv1.bn.num_batches_tracked": "_backbone._modules_list.4.bottlenecks.4.cv1.bn.num_batches_tracked",
            "_backbone._modules_list.4.m.4.cv1.bn.running_mean": "_backbone._modules_list.4.bottlenecks.4.cv1.bn.running_mean",
            "_backbone._modules_list.4.m.4.cv1.bn.running_var": "_backbone._modules_list.4.bottlenecks.4.cv1.bn.running_var",
            "_backbone._modules_list.4.m.4.cv1.bn.weight": "_backbone._modules_list.4.bottlenecks.4.cv1.bn.weight",
            "_backbone._modules_list.4.m.4.cv1.conv.weight": "_backbone._modules_list.4.bottlenecks.4.cv1.conv.weight",
            "_backbone._modules_list.4.m.4.cv2.bn.bias": "_backbone._modules_list.4.bottlenecks.4.cv2.bn.bias",
            "_backbone._modules_list.4.m.4.cv2.bn.num_batches_tracked": "_backbone._modules_list.4.bottlenecks.4.cv2.bn.num_batches_tracked",
            "_backbone._modules_list.4.m.4.cv2.bn.running_mean": "_backbone._modules_list.4.bottlenecks.4.cv2.bn.running_mean",
            "_backbone._modules_list.4.m.4.cv2.bn.running_var": "_backbone._modules_list.4.bottlenecks.4.cv2.bn.running_var",
            "_backbone._modules_list.4.m.4.cv2.bn.weight": "_backbone._modules_list.4.bottlenecks.4.cv2.bn.weight",
            "_backbone._modules_list.4.m.4.cv2.conv.weight": "_backbone._modules_list.4.bottlenecks.4.cv2.conv.weight",
            "_backbone._modules_list.4.m.5.cv1.bn.bias": "_backbone._modules_list.4.bottlenecks.5.cv1.bn.bias",
            "_backbone._modules_list.4.m.5.cv1.bn.num_batches_tracked": "_backbone._modules_list.4.bottlenecks.5.cv1.bn.num_batches_tracked",
            "_backbone._modules_list.4.m.5.cv1.bn.running_mean": "_backbone._modules_list.4.bottlenecks.5.cv1.bn.running_mean",
            "_backbone._modules_list.4.m.5.cv1.bn.running_var": "_backbone._modules_list.4.bottlenecks.5.cv1.bn.running_var",
            "_backbone._modules_list.4.m.5.cv1.bn.weight": "_backbone._modules_list.4.bottlenecks.5.cv1.bn.weight",
            "_backbone._modules_list.4.m.5.cv1.conv.weight": "_backbone._modules_list.4.bottlenecks.5.cv1.conv.weight",
            "_backbone._modules_list.4.m.5.cv2.bn.bias": "_backbone._modules_list.4.bottlenecks.5.cv2.bn.bias",
            "_backbone._modules_list.4.m.5.cv2.bn.num_batches_tracked": "_backbone._modules_list.4.bottlenecks.5.cv2.bn.num_batches_tracked",
            "_backbone._modules_list.4.m.5.cv2.bn.running_mean": "_backbone._modules_list.4.bottlenecks.5.cv2.bn.running_mean",
            "_backbone._modules_list.4.m.5.cv2.bn.running_var": "_backbone._modules_list.4.bottlenecks.5.cv2.bn.running_var",
            "_backbone._modules_list.4.m.5.cv2.bn.weight": "_backbone._modules_list.4.bottlenecks.5.cv2.bn.weight",
            "_backbone._modules_list.4.m.5.cv2.conv.weight": "_backbone._modules_list.4.bottlenecks.5.cv2.conv.weight",
            "_backbone._modules_list.4.m.6.cv1.bn.bias": "_backbone._modules_list.4.bottlenecks.6.cv1.bn.bias",
            "_backbone._modules_list.4.m.6.cv1.bn.num_batches_tracked": "_backbone._modules_list.4.bottlenecks.6.cv1.bn.num_batches_tracked",
            "_backbone._modules_list.4.m.6.cv1.bn.running_mean": "_backbone._modules_list.4.bottlenecks.6.cv1.bn.running_mean",
            "_backbone._modules_list.4.m.6.cv1.bn.running_var": "_backbone._modules_list.4.bottlenecks.6.cv1.bn.running_var",
            "_backbone._modules_list.4.m.6.cv1.bn.weight": "_backbone._modules_list.4.bottlenecks.6.cv1.bn.weight",
            "_backbone._modules_list.4.m.6.cv1.conv.weight": "_backbone._modules_list.4.bottlenecks.6.cv1.conv.weight",
            "_backbone._modules_list.4.m.6.cv2.bn.bias": "_backbone._modules_list.4.bottlenecks.6.cv2.bn.bias",
            "_backbone._modules_list.4.m.6.cv2.bn.num_batches_tracked": "_backbone._modules_list.4.bottlenecks.6.cv2.bn.num_batches_tracked",
            "_backbone._modules_list.4.m.6.cv2.bn.running_mean": "_backbone._modules_list.4.bottlenecks.6.cv2.bn.running_mean",
            "_backbone._modules_list.4.m.6.cv2.bn.running_var": "_backbone._modules_list.4.bottlenecks.6.cv2.bn.running_var",
            "_backbone._modules_list.4.m.6.cv2.bn.weight": "_backbone._modules_list.4.bottlenecks.6.cv2.bn.weight",
            "_backbone._modules_list.4.m.6.cv2.conv.weight": "_backbone._modules_list.4.bottlenecks.6.cv2.conv.weight",
            "_backbone._modules_list.4.m.7.cv1.bn.bias": "_backbone._modules_list.4.bottlenecks.7.cv1.bn.bias",
            "_backbone._modules_list.4.m.7.cv1.bn.num_batches_tracked": "_backbone._modules_list.4.bottlenecks.7.cv1.bn.num_batches_tracked",
            "_backbone._modules_list.4.m.7.cv1.bn.running_mean": "_backbone._modules_list.4.bottlenecks.7.cv1.bn.running_mean",
            "_backbone._modules_list.4.m.7.cv1.bn.running_var": "_backbone._modules_list.4.bottlenecks.7.cv1.bn.running_var",
            "_backbone._modules_list.4.m.7.cv1.bn.weight": "_backbone._modules_list.4.bottlenecks.7.cv1.bn.weight",
            "_backbone._modules_list.4.m.7.cv1.conv.weight": "_backbone._modules_list.4.bottlenecks.7.cv1.conv.weight",
            "_backbone._modules_list.4.m.7.cv2.bn.bias": "_backbone._modules_list.4.bottlenecks.7.cv2.bn.bias",
            "_backbone._modules_list.4.m.7.cv2.bn.num_batches_tracked": "_backbone._modules_list.4.bottlenecks.7.cv2.bn.num_batches_tracked",
            "_backbone._modules_list.4.m.7.cv2.bn.running_mean": "_backbone._modules_list.4.bottlenecks.7.cv2.bn.running_mean",
            "_backbone._modules_list.4.m.7.cv2.bn.running_var": "_backbone._modules_list.4.bottlenecks.7.cv2.bn.running_var",
            "_backbone._modules_list.4.m.7.cv2.bn.weight": "_backbone._modules_list.4.bottlenecks.7.cv2.bn.weight",
            "_backbone._modules_list.4.m.7.cv2.conv.weight": "_backbone._modules_list.4.bottlenecks.7.cv2.conv.weight",
            "_backbone._modules_list.4.m.8.cv1.bn.bias": "_backbone._modules_list.4.bottlenecks.8.cv1.bn.bias",
            "_backbone._modules_list.4.m.8.cv1.bn.num_batches_tracked": "_backbone._modules_list.4.bottlenecks.8.cv1.bn.num_batches_tracked",
            "_backbone._modules_list.4.m.8.cv1.bn.running_mean": "_backbone._modules_list.4.bottlenecks.8.cv1.bn.running_mean",
            "_backbone._modules_list.4.m.8.cv1.bn.running_var": "_backbone._modules_list.4.bottlenecks.8.cv1.bn.running_var",
            "_backbone._modules_list.4.m.8.cv1.bn.weight": "_backbone._modules_list.4.bottlenecks.8.cv1.bn.weight",
            "_backbone._modules_list.4.m.8.cv1.conv.weight": "_backbone._modules_list.4.bottlenecks.8.cv1.conv.weight",
            "_backbone._modules_list.4.m.8.cv2.bn.bias": "_backbone._modules_list.4.bottlenecks.8.cv2.bn.bias",
            "_backbone._modules_list.4.m.8.cv2.bn.num_batches_tracked": "_backbone._modules_list.4.bottlenecks.8.cv2.bn.num_batches_tracked",
            "_backbone._modules_list.4.m.8.cv2.bn.running_mean": "_backbone._modules_list.4.bottlenecks.8.cv2.bn.running_mean",
            "_backbone._modules_list.4.m.8.cv2.bn.running_var": "_backbone._modules_list.4.bottlenecks.8.cv2.bn.running_var",
            "_backbone._modules_list.4.m.8.cv2.bn.weight": "_backbone._modules_list.4.bottlenecks.8.cv2.bn.weight",
            "_backbone._modules_list.4.m.8.cv2.conv.weight": "_backbone._modules_list.4.bottlenecks.8.cv2.conv.weight",
            "_backbone._modules_list.5.bn.bias": "_backbone._modules_list.5.bn.bias",
            "_backbone._modules_list.5.bn.num_batches_tracked": "_backbone._modules_list.5.bn.num_batches_tracked",
            "_backbone._modules_list.5.bn.running_mean": "_backbone._modules_list.5.bn.running_mean",
            "_backbone._modules_list.5.bn.running_var": "_backbone._modules_list.5.bn.running_var",
            "_backbone._modules_list.5.bn.weight": "_backbone._modules_list.5.bn.weight",
            "_backbone._modules_list.5.conv.bn.bias": "_backbone._modules_list.5.conv.bn.bias",
            "_backbone._modules_list.5.conv.bn.num_batches_tracked": "_backbone._modules_list.5.conv.bn.num_batches_tracked",
            "_backbone._modules_list.5.conv.bn.running_mean": "_backbone._modules_list.5.conv.bn.running_mean",
            "_backbone._modules_list.5.conv.bn.running_var": "_backbone._modules_list.5.conv.bn.running_var",
            "_backbone._modules_list.5.conv.bn.weight": "_backbone._modules_list.5.conv.bn.weight",
            "_backbone._modules_list.5.conv.conv.weight": "_backbone._modules_list.5.conv.conv.weight",
            "_backbone._modules_list.5.conv.weight": "_backbone._modules_list.5.conv.weight",
            "_backbone._modules_list.5.dconv.bn.bias": "_backbone._modules_list.5.dconv.bn.bias",
            "_backbone._modules_list.5.dconv.bn.num_batches_tracked": "_backbone._modules_list.5.dconv.bn.num_batches_tracked",
            "_backbone._modules_list.5.dconv.bn.running_mean": "_backbone._modules_list.5.dconv.bn.running_mean",
            "_backbone._modules_list.5.dconv.bn.running_var": "_backbone._modules_list.5.dconv.bn.running_var",
            "_backbone._modules_list.5.dconv.bn.weight": "_backbone._modules_list.5.dconv.bn.weight",
            "_backbone._modules_list.5.dconv.conv.weight": "_backbone._modules_list.5.dconv.conv.weight",
            "_backbone._modules_list.6.cv1.bn.bias": "_backbone._modules_list.6.conv1.bn.bias",
            "_backbone._modules_list.6.cv1.bn.num_batches_tracked": "_backbone._modules_list.6.conv1.bn.num_batches_tracked",
            "_backbone._modules_list.6.cv1.bn.running_mean": "_backbone._modules_list.6.conv1.bn.running_mean",
            "_backbone._modules_list.6.cv1.bn.running_var": "_backbone._modules_list.6.conv1.bn.running_var",
            "_backbone._modules_list.6.cv1.bn.weight": "_backbone._modules_list.6.conv1.bn.weight",
            "_backbone._modules_list.6.cv1.conv.weight": "_backbone._modules_list.6.conv1.conv.weight",
            "_backbone._modules_list.6.cv2.bn.bias": "_backbone._modules_list.6.conv2.bn.bias",
            "_backbone._modules_list.6.cv2.bn.num_batches_tracked": "_backbone._modules_list.6.conv2.bn.num_batches_tracked",
            "_backbone._modules_list.6.cv2.bn.running_mean": "_backbone._modules_list.6.conv2.bn.running_mean",
            "_backbone._modules_list.6.cv2.bn.running_var": "_backbone._modules_list.6.conv2.bn.running_var",
            "_backbone._modules_list.6.cv2.bn.weight": "_backbone._modules_list.6.conv2.bn.weight",
            "_backbone._modules_list.6.cv2.conv.weight": "_backbone._modules_list.6.conv2.conv.weight",
            "_backbone._modules_list.6.cv3.bn.bias": "_backbone._modules_list.6.conv3.bn.bias",
            "_backbone._modules_list.6.cv3.bn.num_batches_tracked": "_backbone._modules_list.6.conv3.bn.num_batches_tracked",
            "_backbone._modules_list.6.cv3.bn.running_mean": "_backbone._modules_list.6.conv3.bn.running_mean",
            "_backbone._modules_list.6.cv3.bn.running_var": "_backbone._modules_list.6.conv3.bn.running_var",
            "_backbone._modules_list.6.cv3.bn.weight": "_backbone._modules_list.6.conv3.bn.weight",
            "_backbone._modules_list.6.cv3.conv.weight": "_backbone._modules_list.6.conv3.conv.weight",
            "_backbone._modules_list.6.m.0.cv1.bn.bias": "_backbone._modules_list.6.bottlenecks.0.cv1.bn.bias",
            "_backbone._modules_list.6.m.0.cv1.bn.num_batches_tracked": "_backbone._modules_list.6.bottlenecks.0.cv1.bn.num_batches_tracked",
            "_backbone._modules_list.6.m.0.cv1.bn.running_mean": "_backbone._modules_list.6.bottlenecks.0.cv1.bn.running_mean",
            "_backbone._modules_list.6.m.0.cv1.bn.running_var": "_backbone._modules_list.6.bottlenecks.0.cv1.bn.running_var",
            "_backbone._modules_list.6.m.0.cv1.bn.weight": "_backbone._modules_list.6.bottlenecks.0.cv1.bn.weight",
            "_backbone._modules_list.6.m.0.cv1.conv.weight": "_backbone._modules_list.6.bottlenecks.0.cv1.conv.weight",
            "_backbone._modules_list.6.m.0.cv2.bn.bias": "_backbone._modules_list.6.bottlenecks.0.cv2.bn.bias",
            "_backbone._modules_list.6.m.0.cv2.bn.num_batches_tracked": "_backbone._modules_list.6.bottlenecks.0.cv2.bn.num_batches_tracked",
            "_backbone._modules_list.6.m.0.cv2.bn.running_mean": "_backbone._modules_list.6.bottlenecks.0.cv2.bn.running_mean",
            "_backbone._modules_list.6.m.0.cv2.bn.running_var": "_backbone._modules_list.6.bottlenecks.0.cv2.bn.running_var",
            "_backbone._modules_list.6.m.0.cv2.bn.weight": "_backbone._modules_list.6.bottlenecks.0.cv2.bn.weight",
            "_backbone._modules_list.6.m.0.cv2.conv.bn.bias": "_backbone._modules_list.6.bottlenecks.0.cv2.conv.bn.bias",
            "_backbone._modules_list.6.m.0.cv2.conv.bn.num_batches_tracked": "_backbone._modules_list.6.bottlenecks.0.cv2.conv.bn.num_batches_tracked",
            "_backbone._modules_list.6.m.0.cv2.conv.bn.running_mean": "_backbone._modules_list.6.bottlenecks.0.cv2.conv.bn.running_mean",
            "_backbone._modules_list.6.m.0.cv2.conv.bn.running_var": "_backbone._modules_list.6.bottlenecks.0.cv2.conv.bn.running_var",
            "_backbone._modules_list.6.m.0.cv2.conv.bn.weight": "_backbone._modules_list.6.bottlenecks.0.cv2.conv.bn.weight",
            "_backbone._modules_list.6.m.0.cv2.conv.conv.weight": "_backbone._modules_list.6.bottlenecks.0.cv2.conv.conv.weight",
            "_backbone._modules_list.6.m.0.cv2.conv.weight": "_backbone._modules_list.6.bottlenecks.0.cv2.conv.weight",
            "_backbone._modules_list.6.m.0.cv2.dconv.bn.bias": "_backbone._modules_list.6.bottlenecks.0.cv2.dconv.bn.bias",
            "_backbone._modules_list.6.m.0.cv2.dconv.bn.num_batches_tracked": "_backbone._modules_list.6.bottlenecks.0.cv2.dconv.bn.num_batches_tracked",
            "_backbone._modules_list.6.m.0.cv2.dconv.bn.running_mean": "_backbone._modules_list.6.bottlenecks.0.cv2.dconv.bn.running_mean",
            "_backbone._modules_list.6.m.0.cv2.dconv.bn.running_var": "_backbone._modules_list.6.bottlenecks.0.cv2.dconv.bn.running_var",
            "_backbone._modules_list.6.m.0.cv2.dconv.bn.weight": "_backbone._modules_list.6.bottlenecks.0.cv2.dconv.bn.weight",
            "_backbone._modules_list.6.m.0.cv2.dconv.conv.weight": "_backbone._modules_list.6.bottlenecks.0.cv2.dconv.conv.weight",
            "_backbone._modules_list.6.m.1.cv1.bn.bias": "_backbone._modules_list.6.bottlenecks.1.cv1.bn.bias",
            "_backbone._modules_list.6.m.1.cv1.bn.num_batches_tracked": "_backbone._modules_list.6.bottlenecks.1.cv1.bn.num_batches_tracked",
            "_backbone._modules_list.6.m.1.cv1.bn.running_mean": "_backbone._modules_list.6.bottlenecks.1.cv1.bn.running_mean",
            "_backbone._modules_list.6.m.1.cv1.bn.running_var": "_backbone._modules_list.6.bottlenecks.1.cv1.bn.running_var",
            "_backbone._modules_list.6.m.1.cv1.bn.weight": "_backbone._modules_list.6.bottlenecks.1.cv1.bn.weight",
            "_backbone._modules_list.6.m.1.cv1.conv.weight": "_backbone._modules_list.6.bottlenecks.1.cv1.conv.weight",
            "_backbone._modules_list.6.m.1.cv2.bn.bias": "_backbone._modules_list.6.bottlenecks.1.cv2.bn.bias",
            "_backbone._modules_list.6.m.1.cv2.bn.num_batches_tracked": "_backbone._modules_list.6.bottlenecks.1.cv2.bn.num_batches_tracked",
            "_backbone._modules_list.6.m.1.cv2.bn.running_mean": "_backbone._modules_list.6.bottlenecks.1.cv2.bn.running_mean",
            "_backbone._modules_list.6.m.1.cv2.bn.running_var": "_backbone._modules_list.6.bottlenecks.1.cv2.bn.running_var",
            "_backbone._modules_list.6.m.1.cv2.bn.weight": "_backbone._modules_list.6.bottlenecks.1.cv2.bn.weight",
            "_backbone._modules_list.6.m.1.cv2.conv.bn.bias": "_backbone._modules_list.6.bottlenecks.1.cv2.conv.bn.bias",
            "_backbone._modules_list.6.m.1.cv2.conv.bn.num_batches_tracked": "_backbone._modules_list.6.bottlenecks.1.cv2.conv.bn.num_batches_tracked",
            "_backbone._modules_list.6.m.1.cv2.conv.bn.running_mean": "_backbone._modules_list.6.bottlenecks.1.cv2.conv.bn.running_mean",
            "_backbone._modules_list.6.m.1.cv2.conv.bn.running_var": "_backbone._modules_list.6.bottlenecks.1.cv2.conv.bn.running_var",
            "_backbone._modules_list.6.m.1.cv2.conv.bn.weight": "_backbone._modules_list.6.bottlenecks.1.cv2.conv.bn.weight",
            "_backbone._modules_list.6.m.1.cv2.conv.conv.weight": "_backbone._modules_list.6.bottlenecks.1.cv2.conv.conv.weight",
            "_backbone._modules_list.6.m.1.cv2.conv.weight": "_backbone._modules_list.6.bottlenecks.1.cv2.conv.weight",
            "_backbone._modules_list.6.m.1.cv2.dconv.bn.bias": "_backbone._modules_list.6.bottlenecks.1.cv2.dconv.bn.bias",
            "_backbone._modules_list.6.m.1.cv2.dconv.bn.num_batches_tracked": "_backbone._modules_list.6.bottlenecks.1.cv2.dconv.bn.num_batches_tracked",
            "_backbone._modules_list.6.m.1.cv2.dconv.bn.running_mean": "_backbone._modules_list.6.bottlenecks.1.cv2.dconv.bn.running_mean",
            "_backbone._modules_list.6.m.1.cv2.dconv.bn.running_var": "_backbone._modules_list.6.bottlenecks.1.cv2.dconv.bn.running_var",
            "_backbone._modules_list.6.m.1.cv2.dconv.bn.weight": "_backbone._modules_list.6.bottlenecks.1.cv2.dconv.bn.weight",
            "_backbone._modules_list.6.m.1.cv2.dconv.conv.weight": "_backbone._modules_list.6.bottlenecks.1.cv2.dconv.conv.weight",
            "_backbone._modules_list.6.m.2.cv1.bn.bias": "_backbone._modules_list.6.bottlenecks.2.cv1.bn.bias",
            "_backbone._modules_list.6.m.2.cv1.bn.num_batches_tracked": "_backbone._modules_list.6.bottlenecks.2.cv1.bn.num_batches_tracked",
            "_backbone._modules_list.6.m.2.cv1.bn.running_mean": "_backbone._modules_list.6.bottlenecks.2.cv1.bn.running_mean",
            "_backbone._modules_list.6.m.2.cv1.bn.running_var": "_backbone._modules_list.6.bottlenecks.2.cv1.bn.running_var",
            "_backbone._modules_list.6.m.2.cv1.bn.weight": "_backbone._modules_list.6.bottlenecks.2.cv1.bn.weight",
            "_backbone._modules_list.6.m.2.cv1.conv.weight": "_backbone._modules_list.6.bottlenecks.2.cv1.conv.weight",
            "_backbone._modules_list.6.m.2.cv2.bn.bias": "_backbone._modules_list.6.bottlenecks.2.cv2.bn.bias",
            "_backbone._modules_list.6.m.2.cv2.bn.num_batches_tracked": "_backbone._modules_list.6.bottlenecks.2.cv2.bn.num_batches_tracked",
            "_backbone._modules_list.6.m.2.cv2.bn.running_mean": "_backbone._modules_list.6.bottlenecks.2.cv2.bn.running_mean",
            "_backbone._modules_list.6.m.2.cv2.bn.running_var": "_backbone._modules_list.6.bottlenecks.2.cv2.bn.running_var",
            "_backbone._modules_list.6.m.2.cv2.bn.weight": "_backbone._modules_list.6.bottlenecks.2.cv2.bn.weight",
            "_backbone._modules_list.6.m.2.cv2.conv.bn.bias": "_backbone._modules_list.6.bottlenecks.2.cv2.conv.bn.bias",
            "_backbone._modules_list.6.m.2.cv2.conv.bn.num_batches_tracked": "_backbone._modules_list.6.bottlenecks.2.cv2.conv.bn.num_batches_tracked",
            "_backbone._modules_list.6.m.2.cv2.conv.bn.running_mean": "_backbone._modules_list.6.bottlenecks.2.cv2.conv.bn.running_mean",
            "_backbone._modules_list.6.m.2.cv2.conv.bn.running_var": "_backbone._modules_list.6.bottlenecks.2.cv2.conv.bn.running_var",
            "_backbone._modules_list.6.m.2.cv2.conv.bn.weight": "_backbone._modules_list.6.bottlenecks.2.cv2.conv.bn.weight",
            "_backbone._modules_list.6.m.2.cv2.conv.conv.weight": "_backbone._modules_list.6.bottlenecks.2.cv2.conv.conv.weight",
            "_backbone._modules_list.6.m.2.cv2.conv.weight": "_backbone._modules_list.6.bottlenecks.2.cv2.conv.weight",
            "_backbone._modules_list.6.m.2.cv2.dconv.bn.bias": "_backbone._modules_list.6.bottlenecks.2.cv2.dconv.bn.bias",
            "_backbone._modules_list.6.m.2.cv2.dconv.bn.num_batches_tracked": "_backbone._modules_list.6.bottlenecks.2.cv2.dconv.bn.num_batches_tracked",
            "_backbone._modules_list.6.m.2.cv2.dconv.bn.running_mean": "_backbone._modules_list.6.bottlenecks.2.cv2.dconv.bn.running_mean",
            "_backbone._modules_list.6.m.2.cv2.dconv.bn.running_var": "_backbone._modules_list.6.bottlenecks.2.cv2.dconv.bn.running_var",
            "_backbone._modules_list.6.m.2.cv2.dconv.bn.weight": "_backbone._modules_list.6.bottlenecks.2.cv2.dconv.bn.weight",
            "_backbone._modules_list.6.m.2.cv2.dconv.conv.weight": "_backbone._modules_list.6.bottlenecks.2.cv2.dconv.conv.weight",
            "_backbone._modules_list.6.m.3.cv1.bn.bias": "_backbone._modules_list.6.bottlenecks.3.cv1.bn.bias",
            "_backbone._modules_list.6.m.3.cv1.bn.num_batches_tracked": "_backbone._modules_list.6.bottlenecks.3.cv1.bn.num_batches_tracked",
            "_backbone._modules_list.6.m.3.cv1.bn.running_mean": "_backbone._modules_list.6.bottlenecks.3.cv1.bn.running_mean",
            "_backbone._modules_list.6.m.3.cv1.bn.running_var": "_backbone._modules_list.6.bottlenecks.3.cv1.bn.running_var",
            "_backbone._modules_list.6.m.3.cv1.bn.weight": "_backbone._modules_list.6.bottlenecks.3.cv1.bn.weight",
            "_backbone._modules_list.6.m.3.cv1.conv.weight": "_backbone._modules_list.6.bottlenecks.3.cv1.conv.weight",
            "_backbone._modules_list.6.m.3.cv2.bn.bias": "_backbone._modules_list.6.bottlenecks.3.cv2.bn.bias",
            "_backbone._modules_list.6.m.3.cv2.bn.num_batches_tracked": "_backbone._modules_list.6.bottlenecks.3.cv2.bn.num_batches_tracked",
            "_backbone._modules_list.6.m.3.cv2.bn.running_mean": "_backbone._modules_list.6.bottlenecks.3.cv2.bn.running_mean",
            "_backbone._modules_list.6.m.3.cv2.bn.running_var": "_backbone._modules_list.6.bottlenecks.3.cv2.bn.running_var",
            "_backbone._modules_list.6.m.3.cv2.bn.weight": "_backbone._modules_list.6.bottlenecks.3.cv2.bn.weight",
            "_backbone._modules_list.6.m.3.cv2.conv.weight": "_backbone._modules_list.6.bottlenecks.3.cv2.conv.weight",
            "_backbone._modules_list.6.m.4.cv1.bn.bias": "_backbone._modules_list.6.bottlenecks.4.cv1.bn.bias",
            "_backbone._modules_list.6.m.4.cv1.bn.num_batches_tracked": "_backbone._modules_list.6.bottlenecks.4.cv1.bn.num_batches_tracked",
            "_backbone._modules_list.6.m.4.cv1.bn.running_mean": "_backbone._modules_list.6.bottlenecks.4.cv1.bn.running_mean",
            "_backbone._modules_list.6.m.4.cv1.bn.running_var": "_backbone._modules_list.6.bottlenecks.4.cv1.bn.running_var",
            "_backbone._modules_list.6.m.4.cv1.bn.weight": "_backbone._modules_list.6.bottlenecks.4.cv1.bn.weight",
            "_backbone._modules_list.6.m.4.cv1.conv.weight": "_backbone._modules_list.6.bottlenecks.4.cv1.conv.weight",
            "_backbone._modules_list.6.m.4.cv2.bn.bias": "_backbone._modules_list.6.bottlenecks.4.cv2.bn.bias",
            "_backbone._modules_list.6.m.4.cv2.bn.num_batches_tracked": "_backbone._modules_list.6.bottlenecks.4.cv2.bn.num_batches_tracked",
            "_backbone._modules_list.6.m.4.cv2.bn.running_mean": "_backbone._modules_list.6.bottlenecks.4.cv2.bn.running_mean",
            "_backbone._modules_list.6.m.4.cv2.bn.running_var": "_backbone._modules_list.6.bottlenecks.4.cv2.bn.running_var",
            "_backbone._modules_list.6.m.4.cv2.bn.weight": "_backbone._modules_list.6.bottlenecks.4.cv2.bn.weight",
            "_backbone._modules_list.6.m.4.cv2.conv.weight": "_backbone._modules_list.6.bottlenecks.4.cv2.conv.weight",
            "_backbone._modules_list.6.m.5.cv1.bn.bias": "_backbone._modules_list.6.bottlenecks.5.cv1.bn.bias",
            "_backbone._modules_list.6.m.5.cv1.bn.num_batches_tracked": "_backbone._modules_list.6.bottlenecks.5.cv1.bn.num_batches_tracked",
            "_backbone._modules_list.6.m.5.cv1.bn.running_mean": "_backbone._modules_list.6.bottlenecks.5.cv1.bn.running_mean",
            "_backbone._modules_list.6.m.5.cv1.bn.running_var": "_backbone._modules_list.6.bottlenecks.5.cv1.bn.running_var",
            "_backbone._modules_list.6.m.5.cv1.bn.weight": "_backbone._modules_list.6.bottlenecks.5.cv1.bn.weight",
            "_backbone._modules_list.6.m.5.cv1.conv.weight": "_backbone._modules_list.6.bottlenecks.5.cv1.conv.weight",
            "_backbone._modules_list.6.m.5.cv2.bn.bias": "_backbone._modules_list.6.bottlenecks.5.cv2.bn.bias",
            "_backbone._modules_list.6.m.5.cv2.bn.num_batches_tracked": "_backbone._modules_list.6.bottlenecks.5.cv2.bn.num_batches_tracked",
            "_backbone._modules_list.6.m.5.cv2.bn.running_mean": "_backbone._modules_list.6.bottlenecks.5.cv2.bn.running_mean",
            "_backbone._modules_list.6.m.5.cv2.bn.running_var": "_backbone._modules_list.6.bottlenecks.5.cv2.bn.running_var",
            "_backbone._modules_list.6.m.5.cv2.bn.weight": "_backbone._modules_list.6.bottlenecks.5.cv2.bn.weight",
            "_backbone._modules_list.6.m.5.cv2.conv.weight": "_backbone._modules_list.6.bottlenecks.5.cv2.conv.weight",
            "_backbone._modules_list.6.m.6.cv1.bn.bias": "_backbone._modules_list.6.bottlenecks.6.cv1.bn.bias",
            "_backbone._modules_list.6.m.6.cv1.bn.num_batches_tracked": "_backbone._modules_list.6.bottlenecks.6.cv1.bn.num_batches_tracked",
            "_backbone._modules_list.6.m.6.cv1.bn.running_mean": "_backbone._modules_list.6.bottlenecks.6.cv1.bn.running_mean",
            "_backbone._modules_list.6.m.6.cv1.bn.running_var": "_backbone._modules_list.6.bottlenecks.6.cv1.bn.running_var",
            "_backbone._modules_list.6.m.6.cv1.bn.weight": "_backbone._modules_list.6.bottlenecks.6.cv1.bn.weight",
            "_backbone._modules_list.6.m.6.cv1.conv.weight": "_backbone._modules_list.6.bottlenecks.6.cv1.conv.weight",
            "_backbone._modules_list.6.m.6.cv2.bn.bias": "_backbone._modules_list.6.bottlenecks.6.cv2.bn.bias",
            "_backbone._modules_list.6.m.6.cv2.bn.num_batches_tracked": "_backbone._modules_list.6.bottlenecks.6.cv2.bn.num_batches_tracked",
            "_backbone._modules_list.6.m.6.cv2.bn.running_mean": "_backbone._modules_list.6.bottlenecks.6.cv2.bn.running_mean",
            "_backbone._modules_list.6.m.6.cv2.bn.running_var": "_backbone._modules_list.6.bottlenecks.6.cv2.bn.running_var",
            "_backbone._modules_list.6.m.6.cv2.bn.weight": "_backbone._modules_list.6.bottlenecks.6.cv2.bn.weight",
            "_backbone._modules_list.6.m.6.cv2.conv.weight": "_backbone._modules_list.6.bottlenecks.6.cv2.conv.weight",
            "_backbone._modules_list.6.m.7.cv1.bn.bias": "_backbone._modules_list.6.bottlenecks.7.cv1.bn.bias",
            "_backbone._modules_list.6.m.7.cv1.bn.num_batches_tracked": "_backbone._modules_list.6.bottlenecks.7.cv1.bn.num_batches_tracked",
            "_backbone._modules_list.6.m.7.cv1.bn.running_mean": "_backbone._modules_list.6.bottlenecks.7.cv1.bn.running_mean",
            "_backbone._modules_list.6.m.7.cv1.bn.running_var": "_backbone._modules_list.6.bottlenecks.7.cv1.bn.running_var",
            "_backbone._modules_list.6.m.7.cv1.bn.weight"