Skip to content

Utils

get_builtin_activation_type(activation, **kwargs)

Returns activation class by its name from torch.nn namespace. This function support all modules available from torch.nn and also their lower-case aliases. On top of that, it supports a few aliaes: leaky_relu (LeakyReLU), swish (silu).

act_cls = get_activation_type("LeakyReLU", inplace=True, slope=0.01) act = act_cls()

Parameters:

Name Type Description Default
activation Union[str, None]

Activation function name (E.g. ReLU). If None - return nn.Identity

required

Returns:

Type Description
Type[nn.Module]

Type of the activation function that is ready to be instantiated

Source code in src/super_gradients/training/utils/activations_utils.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
def get_builtin_activation_type(activation: Union[str, None], **kwargs) -> Type[nn.Module]:
    """
    Returns activation class by its name from torch.nn namespace. This function support all modules available from
    torch.nn and also their lower-case aliases.
    On top of that, it supports a few aliaes: leaky_relu (LeakyReLU), swish (silu).

    >>> act_cls = get_activation_type("LeakyReLU", inplace=True, slope=0.01)
    >>> act = act_cls()


    :param activation: Activation function name (E.g. ReLU). If None - return nn.Identity
    :param **kwargs  : Extra arguments to pass to constructor during instantiation (E.g. inplace=True)

    :returns         : Type of the activation function that is ready to be instantiated
    """

    if activation is None:
        activation_cls = nn.Identity
    else:
        lowercase_aliases: Dict[str, str] = dict((k.lower(), k) for k in torch.nn.__dict__.keys())

        # Register additional aliases
        lowercase_aliases["leaky_relu"] = "LeakyReLU"  # LeakyRelu in snake_case
        lowercase_aliases["swish"] = "SiLU"  # Swish shich is equivalent to SiLU
        lowercase_aliases["none"] = "Identity"

        if activation in lowercase_aliases:
            activation = lowercase_aliases[activation]

        if activation not in torch.nn.__dict__:
            raise KeyError(f"Requested activation function {activation} is not known")

        activation_cls = torch.nn.__dict__[activation]
        if len(kwargs):
            activation_cls = partial(activation_cls, **kwargs)

    return activation_cls

batch_distance2bbox(points, distance, max_shapes=None)

Decode distance prediction to bounding box for batch.

Parameters:

Name Type Description Default
points Tensor

[B, ..., 2], "xy" format

required
distance Tensor

[B, ..., 4], "ltrb" format

required
max_shapes Optional[Tensor]

[B, 2], "h,w" format, Shape of the image.

None

Returns:

Type Description
Tensor

Tensor: Decoded bboxes, "x1y1x2y2" format.

Source code in src/super_gradients/training/utils/bbox_utils.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
def batch_distance2bbox(points: Tensor, distance: Tensor, max_shapes: Optional[Tensor] = None) -> Tensor:
    """Decode distance prediction to bounding box for batch.

    :param points: [B, ..., 2], "xy" format
    :param distance: [B, ..., 4], "ltrb" format
    :param max_shapes: [B, 2], "h,w" format, Shape of the image.
    :return: Tensor: Decoded bboxes, "x1y1x2y2" format.
    """
    lt, rb = torch.split(distance, 2, dim=-1)
    # while tensor add parameters, parameters should be better placed on the second place
    x1y1 = points - lt
    x2y2 = rb + points
    out_bbox = torch.cat([x1y1, x2y2], dim=-1)
    if max_shapes is not None:
        max_shapes = max_shapes.flip(-1).tile([1, 2])
        delta_dim = out_bbox.ndim - max_shapes.ndim
        for _ in range(delta_dim):
            max_shapes.unsqueeze_(1)
        out_bbox = torch.where(out_bbox < max_shapes, out_bbox, max_shapes)
        out_bbox = torch.where(out_bbox > 0, out_bbox, torch.zeros_like(out_bbox))
    return out_bbox

Callback

Base callback class with all the callback methods. Derived classes may override one or many of the available events to receive callbacks when such events are triggered by the training loop.

The order of the events is as follows:

on_training_start(context) # called once before training starts, good for setting up the warmup LR

for epoch in range(epochs):
    on_train_loader_start(context)
        for batch in train_loader:
            on_train_batch_start(context)
            on_train_batch_loss_end(context)               # called after loss has been computed
            on_train_batch_backward_end(context)           # called after .backward() was called
            on_train_batch_gradient_step_start(context)    # called before the optimizer step about to happen (gradient clipping, logging of gradients)
            on_train_batch_gradient_step_end(context)      # called after gradient step was done, good place to update LR (for step-based schedulers)
            on_train_batch_end(context)
    on_train_loader_end(context)

    on_validation_loader_start(context)
        for batch in validation_loader:
            on_validation_batch_start(context)
            on_validation_batch_end(context)
    on_validation_loader_end(context)
    on_validation_end_best_epoch(context)

on_test_start(context)
    for batch in test_loader:
        on_test_batch_start(context)
        on_test_batch_end(context)
on_test_end(context)

on_average_best_models_validation_start
on_average_best_models_validation_end

on_training_end(context) # called once after training ends.

Correspondence mapping from the old callback API:

on_training_start(context) <-> Phase.PRE_TRAINING for epoch in range(epochs): on_train_loader_start(context) <-> Phase.TRAIN_EPOCH_START for batch in train_loader: on_train_batch_start(context) on_train_batch_loss_end(context) on_train_batch_backward_end(context) <-> Phase.TRAIN_BATCH_END on_train_batch_gradient_step_start(context) on_train_batch_gradient_step_end(context) <-> Phase.TRAIN_BATCH_STEP on_train_batch_end(context) on_train_loader_end(context) <-> Phase.TRAIN_EPOCH_END

on_validation_loader_start(context)
    for batch in validation_loader:
        on_validation_batch_start(context)
        on_validation_batch_end(context)               <-> Phase.VALIDATION_BATCH_END
on_validation_loader_end(context)                      <-> Phase.VALIDATION_EPOCH_END
on_validation_end_best_epoch(context)                  <-> Phase.VALIDATION_END_BEST_EPOCH

on_test_start(context) for batch in test_loader: on_test_batch_start(context) on_test_batch_end(context) <-> Phase.TEST_BATCH_END on_test_end(context) <-> Phase.TEST_END

on_training_end(context) <-> Phase.POST_TRAINING

Source code in src/super_gradients/training/utils/callbacks/base_callbacks.py
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
class Callback:
    """
    Base callback class with all the callback methods. Derived classes may override one or many of the available events
    to receive callbacks when such events are triggered by the training loop.

    The order of the events is as follows:

    on_training_start(context)                              # called once before training starts, good for setting up the warmup LR

        for epoch in range(epochs):
            on_train_loader_start(context)
                for batch in train_loader:
                    on_train_batch_start(context)
                    on_train_batch_loss_end(context)               # called after loss has been computed
                    on_train_batch_backward_end(context)           # called after .backward() was called
                    on_train_batch_gradient_step_start(context)    # called before the optimizer step about to happen (gradient clipping, logging of gradients)
                    on_train_batch_gradient_step_end(context)      # called after gradient step was done, good place to update LR (for step-based schedulers)
                    on_train_batch_end(context)
            on_train_loader_end(context)

            on_validation_loader_start(context)
                for batch in validation_loader:
                    on_validation_batch_start(context)
                    on_validation_batch_end(context)
            on_validation_loader_end(context)
            on_validation_end_best_epoch(context)

        on_test_start(context)
            for batch in test_loader:
                on_test_batch_start(context)
                on_test_batch_end(context)
        on_test_end(context)

        on_average_best_models_validation_start
        on_average_best_models_validation_end

    on_training_end(context)                    # called once after training ends.

    Correspondence mapping from the old callback API:

    on_training_start(context)                                 <-> Phase.PRE_TRAINING
    for epoch in range(epochs):
        on_train_loader_start(context)                         <-> Phase.TRAIN_EPOCH_START
            for batch in train_loader:
                on_train_batch_start(context)
                on_train_batch_loss_end(context)
                on_train_batch_backward_end(context)           <-> Phase.TRAIN_BATCH_END
                on_train_batch_gradient_step_start(context)
                on_train_batch_gradient_step_end(context)      <-> Phase.TRAIN_BATCH_STEP
                on_train_batch_end(context)
        on_train_loader_end(context)                           <-> Phase.TRAIN_EPOCH_END

        on_validation_loader_start(context)
            for batch in validation_loader:
                on_validation_batch_start(context)
                on_validation_batch_end(context)               <-> Phase.VALIDATION_BATCH_END
        on_validation_loader_end(context)                      <-> Phase.VALIDATION_EPOCH_END
        on_validation_end_best_epoch(context)                  <-> Phase.VALIDATION_END_BEST_EPOCH

    on_test_start(context)
        for batch in test_loader:
            on_test_batch_start(context)
            on_test_batch_end(context)                         <-> Phase.TEST_BATCH_END
    on_test_end(context)                                       <-> Phase.TEST_END

    on_training_end(context)                                   <-> Phase.POST_TRAINING
    """

    def on_training_start(self, context: PhaseContext) -> None:
        """
        Called once before start of the first epoch
        At this point, the context argument will have the following attributes:
            - optimizer
            - criterion
            - device
            - experiment_name
            - ckpt_dir
            - net
            - sg_logger
            - train_loader
            - valid_loader
            - training_params
            - checkpoint_params
            - arch_params
            - metric_to_watch
            - valid_metrics

        The corresponding Phase enum value for this event is Phase.PRE_TRAINING.
        :param context:
        """
        pass

    def on_train_loader_start(self, context: PhaseContext) -> None:
        """
        Called each epoch at the start of train data loader (before getting the first batch).
        At this point, the context argument will have the following attributes:
            - optimizer
            - criterion
            - device
            - experiment_name
            - ckpt_dir
            - net
            - sg_logger
            - train_loader
            - valid_loader
            - training_params
            - checkpoint_params
            - arch_params
            - metric_to_watch
            - valid_metrics
        The corresponding Phase enum value for this event is Phase.TRAIN_EPOCH_START.
        :param context:
        """
        pass

    def on_train_batch_start(self, context: PhaseContext) -> None:
        """
        Called at each batch after getting batch of data from data loader and moving it to target device.
        This event triggered AFTER Trainer.pre_prediction_callback call (If it was defined).

        At this point, the context argument will have the following attributes:
            - epoch
            - batch_idx
            - optimizer
            - inputs
            - target
            - metrics_compute_fn
            - loss_avg_meter
            - criterion
            - device
            - stop_training
            - experiment_name
            - ckpt_dir
            - net
            - lr_warmup_epochs
            - sg_logger
            - train_loader
            - valid_loader
            - training_params
            - ddp_silent_mode
            - checkpoint_params
            - arch_params
            - metric_to_watch
            - valid_metrics

        :param context:
        """
        pass

    def on_train_batch_loss_end(self, context: PhaseContext) -> None:
        """
        Called after model forward and loss computation has been done.
        At this point, the context argument will have the following attributes:
            - epoch
            - batch_idx
            - optimizer
            - inputs
            - preds
            - target
            - metrics_compute_fn
            - loss_avg_meter
            - loss_log_items
            - criterion
            - device
            - stop_training
            - experiment_name
            - ckpt_dir
            - net
            - lr_warmup_epochs
            - sg_logger
            - train_loader
            - valid_loader
            - training_params
            - ddp_silent_mode
            - checkpoint_params
            - arch_params
            - metric_to_watch
            - valid_metrics
            - loss_logging_items_names
        The corresponding Phase enum value for this event is Phase.TRAIN_BATCH_END.

        :param context:
        """
        pass

    def on_train_batch_backward_end(self, context: PhaseContext) -> None:
        """
        Called after loss.backward() method was called for a given batch
        At this point, the context argument will have the following attributes:
            - epoch
            - batch_idx
            - optimizer
            - inputs
            - preds
            - target
            - metrics_compute_fn
            - loss_avg_meter
            - loss_log_items
            - criterion
            - device
            - stop_training
            - experiment_name
            - ckpt_dir
            - net
            - lr_warmup_epochs
            - sg_logger
            - train_loader
            - valid_loader
            - training_params
            - ddp_silent_mode
            - checkpoint_params
            - arch_params
            - metric_to_watch
            - valid_metrics
            - loss_logging_items_names

        :param context:
        """
        pass

    def on_train_batch_gradient_step_start(self, context: PhaseContext) -> None:
        """
        Called before the graadient step is about to happen.
        Good place to clip gradients (with respect to scaler), log gradients to data ratio, etc.
        At this point, the context argument will have the following attributes:
            - epoch
            - batch_idx
            - optimizer
            - inputs
            - preds
            - target
            - metrics_compute_fn
            - loss_avg_meter
            - loss_log_items
            - criterion
            - device
            - stop_training
            - experiment_name
            - ckpt_dir
            - net
            - lr_warmup_epochs
            - sg_logger
            - train_loader
            - valid_loader
            - training_params
            - ddp_silent_mode
            - checkpoint_params
            - arch_params
            - metric_to_watch
            - valid_metrics
            - loss_logging_items_names

        :param context:
        """
        pass

    def on_train_batch_gradient_step_end(self, context: PhaseContext) -> None:
        """
        Called after gradient step has been performed. Good place to update LR (for step-based schedulers)
        At this point, the context argument will have the following attributes:
            - epoch
            - batch_idx
            - inputs
            - target
            - metrics_compute_fn
            - loss_avg_meter
            - criterion
            - device
            - stop_training
            - net
            - lr_warmup_epochs
            - sg_logger
            - train_loader
            - valid_loader
            - loss_logging_items_names

        The corresponding Phase enum value for this event is Phase.TRAIN_BATCH_STEP.
        :param context:
        """
        pass

    def on_train_batch_end(self, context: PhaseContext) -> None:
        """
        Called after all forward/backward/optimizer steps have been performed for a given batch and there is nothing left to do.
        At this point, the context argument will have the following attributes:
            - epoch
            - batch_idx
            - optimizer
            - inputs
            - preds
            - target
            - metrics_dict
            - metrics_compute_fn
            - loss_avg_meter
            - loss_log_items
            - criterion
            - device
            - stop_training
            - experiment_name
            - ckpt_dir
            - net
            - lr_warmup_epochs
            - sg_logger
            - train_loader
            - valid_loader
            - training_params
            - ddp_silent_mode
            - checkpoint_params
            - arch_params
            - metric_to_watch
            - valid_metrics
            - loss_logging_items_names

        :param context:
        """
        pass

    def on_train_loader_end(self, context: PhaseContext) -> None:
        """
        Called each epoch at the end of train data loader (after processing the last batch).
        At this point, the context argument will have the following attributes:
            - epoch
            - batch_idx
            - optimizer
            - inputs
            - preds
            - target
            - metrics_dict
            - metrics_compute_fn
            - loss_avg_meter
            - loss_log_items
            - criterion
            - device
            - stop_training
            - experiment_name
            - ckpt_dir
            - net
            - lr_warmup_epochs
            - sg_logger
            - train_loader
            - valid_loader
            - training_params
            - ddp_silent_mode
            - checkpoint_params
            - arch_params
            - metric_to_watch
            - valid_metrics
            - loss_logging_items_names

        The corresponding Phase enum value for this event is Phase.TRAIN_EPOCH_END.
        :param context:
        """
        pass

    def on_validation_loader_start(self, context: PhaseContext) -> None:
        """
        Called each epoch at the start of validation data loader (before getting the first batch).
        At this point, the context argument will have the following attributes:
            - epoch
            - batch_idx
            - optimizer
            - inputs
            - preds
            - target
            - metrics_dict
            - metrics_compute_fn
            - loss_avg_meter
            - loss_log_items
            - criterion
            - device
            - stop_training
            - experiment_name
            - ckpt_dir
            - net
            - lr_warmup_epochs
            - sg_logger
            - train_loader
            - valid_loader
            - training_params
            - ddp_silent_mode
            - checkpoint_params
            - arch_params
            - metric_to_watch
            - valid_metrics
            - loss_logging_items_names

        :param context:
        """
        pass

    def on_validation_batch_start(self, context: PhaseContext) -> None:
        """
        Called at each batch after getting batch of data from validation loader and moving it to target device.
        At this point, the context argument will have the following attributes:
            - epoch
            - batch_idx
            - inputs
            - target
            - metrics_compute_fn
            - loss_avg_meter
            - criterion
            - device
            - stop_training
            - net
            - lr_warmup_epochs
            - sg_logger
            - train_loader
            - valid_loader
            - loss_logging_items_names

        :param context:
        """
        pass

    def on_validation_batch_end(self, context: PhaseContext) -> None:
        """
        Called after all forward step / loss / metric computation have been performed for a given batch and there is nothing left to do.
        At this point, the context argument will have the following attributes:
            - epoch
            - batch_idx
            - inputs
            - preds
            - target
            - metrics_compute_fn
            - loss_avg_meter
            - loss_log_items
            - criterion
            - device
            - stop_training
            - net
            - lr_warmup_epochs
            - sg_logger
            - train_loader
            - valid_loader
            - loss_logging_items_names

        The corresponding Phase enum value for this event is Phase.VALIDATION_BATCH_END.
        :param context:
        """
        pass

    def on_validation_loader_end(self, context: PhaseContext) -> None:
        """
        Called each epoch at the end of validation data loader (after processing the last batch).
        At this point, the context argument will have the following attributes:
            - epoch
            - batch_idx
            - optimizer
            - inputs
            - preds
            - target
            - metrics_dict
            - metrics_compute_fn
            - loss_avg_meter
            - loss_log_items
            - criterion
            - device
            - stop_training
            - experiment_name
            - ckpt_dir
            - net
            - lr_warmup_epochs
            - sg_logger
            - train_loader
            - valid_loader
            - training_params
            - ddp_silent_mode
            - checkpoint_params
            - arch_params
            - metric_to_watch
            - valid_metrics
            - loss_logging_items_names

        The corresponding Phase enum value for this event is Phase.VALIDATION_EPOCH_END.
        :param context:
        """
        pass

    def on_validation_end_best_epoch(self, context: PhaseContext) -> None:
        """
        Called each epoch after validation has been performed and the best metric has been achieved.
        At this point, the context argument will have the following attributes:
            - epoch
            - batch_idx
            - optimizer
            - inputs
            - preds
            - target
            - metrics_dict
            - metrics_compute_fn
            - loss_avg_meter
            - loss_log_items
            - criterion
            - device
            - stop_training
            - experiment_name
            - ckpt_dir
            - net
            - lr_warmup_epochs
            - sg_logger
            - train_loader
            - valid_loader
            - training_params
            - ddp_silent_mode
            - checkpoint_params
            - arch_params
            - metric_to_watch
            - valid_metrics
            - loss_logging_items_names

        The corresponding Phase enum value for this event is Phase.VALIDATION_END_BEST_EPOCH.
        :param context:
        """
        pass

    def on_test_loader_start(self, context: PhaseContext) -> None:
        """
        Called once at the start of test data loader (before getting the first batch).
        At this point, the context argument will have the following attributes:
            - epoch
            - batch_idx
            - optimizer
            - inputs
            - preds
            - target
            - metrics_dict
            - metrics_compute_fn
            - loss_avg_meter
            - loss_log_items
            - criterion
            - device
            - stop_training
            - experiment_name
            - ckpt_dir
            - net
            - lr_warmup_epochs
            - sg_logger
            - train_loader
            - valid_loader
            - training_params
            - ddp_silent_mode
            - checkpoint_params
            - arch_params
            - metric_to_watch
            - valid_metrics
            - loss_logging_items_names

        :param context:
        """
        pass

    def on_test_batch_start(self, context: PhaseContext) -> None:
        """
        Called at each batch after getting batch of data from test loader and moving it to target device.
        At this point, the context argument will have the following attributes:
            - epoch
            - batch_idx
            - optimizer
            - inputs
            - preds
            - target
            - metrics_dict
            - metrics_compute_fn
            - loss_avg_meter
            - loss_log_items
            - criterion
            - device
            - stop_training
            - experiment_name
            - ckpt_dir
            - net
            - lr_warmup_epochs
            - sg_logger
            - train_loader
            - valid_loader
            - training_params
            - ddp_silent_mode
            - checkpoint_params
            - arch_params
            - metric_to_watch
            - valid_metrics
            - loss_logging_items_names

        :param context:
        """
        pass

    def on_test_batch_end(self, context: PhaseContext) -> None:
        """
        Called after all forward step have been performed for a given batch and there is nothing left to do.
        At this point, the context argument will have the following attributes:
            - epoch
            - batch_idx
            - optimizer
            - inputs
            - preds
            - target
            - metrics_dict
            - metrics_compute_fn
            - loss_avg_meter
            - loss_log_items
            - criterion
            - device
            - stop_training
            - experiment_name
            - ckpt_dir
            - net
            - lr_warmup_epochs
            - sg_logger
            - train_loader
            - valid_loader
            - training_params
            - ddp_silent_mode
            - checkpoint_params
            - arch_params
            - metric_to_watch
            - valid_metrics
            - loss_logging_items_names

        The corresponding Phase enum value for this event is Phase.TEST_BATCH_END.
        :param context:
        """
        pass

    def on_test_loader_end(self, context: PhaseContext) -> None:
        """
        Called once at the end of test data loader (after processing the last batch).
        At this point, the context argument will have the following attributes:
            - epoch
            - batch_idx
            - optimizer
            - inputs
            - preds
            - target
            - metrics_dict
            - metrics_compute_fn
            - loss_avg_meter
            - loss_log_items
            - criterion
            - device
            - stop_training
            - experiment_name
            - ckpt_dir
            - net
            - lr_warmup_epochs
            - sg_logger
            - train_loader
            - valid_loader
            - training_params
            - ddp_silent_mode
            - checkpoint_params
            - arch_params
            - metric_to_watch
            - valid_metrics
            - loss_logging_items_names

        The corresponding Phase enum value for this event is Phase.TEST_END.
        :param context:
        """
        pass

    def on_average_best_models_validation_start(self, context: PhaseContext) -> None:
        """
        Called once after the test was end before the training loop has finished.
        At this point, the context argument will have the following attributes:
            - epoch
            - batch_idx
            - optimizer
            - inputs
            - preds
            - target
            - metrics_dict
            - metrics_compute_fn
            - loss_avg_meter
            - loss_log_items
            - criterion
            - device
            - stop_training
            - experiment_name
            - ckpt_dir
            - net
            - lr_warmup_epochs
            - sg_logger
            - train_loader
            - valid_loader
            - training_params
            - ddp_silent_mode
            - checkpoint_params
            - arch_params
            - metric_to_watch
            - valid_metrics
            - loss_logging_items_names

        The corresponding Phase enum value for this event is Phase.AVERAGE_BEST_MODELS_VALIDATION_START.
        :param context:
        """
        pass

    def on_average_best_models_validation_end(self, context: PhaseContext) -> None:
        """
        Called once after the average model validation has finished.
        At this point, the context argument will have the following attributes:
            - epoch
            - batch_idx
            - optimizer
            - inputs
            - preds
            - target
            - metrics_dict
            - metrics_compute_fn
            - loss_avg_meter
            - loss_log_items
            - criterion
            - device
            - stop_training
            - experiment_name
            - ckpt_dir
            - net
            - lr_warmup_epochs
            - sg_logger
            - train_loader
            - valid_loader
            - training_params
            - ddp_silent_mode
            - checkpoint_params
            - arch_params
            - metric_to_watch
            - valid_metrics
            - loss_logging_items_names

        The corresponding Phase enum value for this event is Phase.AVERAGE_BEST_MODELS_VALIDATION_START.
        :param context:
        """
        pass

    def on_training_end(self, context: PhaseContext) -> None:
        """
        Called once after the training loop has finished (Due to reaching optimization criterion or because of an error.)
        At this point, the context argument will have the following attributes:
            - epoch
            - batch_idx
            - optimizer
            - inputs
            - preds
            - target
            - metrics_compute_fn
            - loss_avg_meter
            - loss_log_items
            - criterion
            - device
            - stop_training
            - experiment_name
            - ckpt_dir
            - net
            - lr_warmup_epochs
            - sg_logger
            - train_loader
            - valid_loader
            - training_params
            - ddp_silent_mode
            - checkpoint_params
            - arch_params
            - metric_to_watch
            - valid_metrics
            - loss_logging_items_names

        The corresponding Phase enum value for this event is Phase.POST_TRAINING.
        :param context:
        """
        pass

on_average_best_models_validation_end(context)

Called once after the average model validation has finished. At this point, the context argument will have the following attributes: - epoch - batch_idx - optimizer - inputs - preds - target - metrics_dict - metrics_compute_fn - loss_avg_meter - loss_log_items - criterion - device - stop_training - experiment_name - ckpt_dir - net - lr_warmup_epochs - sg_logger - train_loader - valid_loader - training_params - ddp_silent_mode - checkpoint_params - arch_params - metric_to_watch - valid_metrics - loss_logging_items_names

The corresponding Phase enum value for this event is Phase.AVERAGE_BEST_MODELS_VALIDATION_START.

Parameters:

Name Type Description Default
context PhaseContext required
Source code in src/super_gradients/training/utils/callbacks/base_callbacks.py
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
def on_average_best_models_validation_end(self, context: PhaseContext) -> None:
    """
    Called once after the average model validation has finished.
    At this point, the context argument will have the following attributes:
        - epoch
        - batch_idx
        - optimizer
        - inputs
        - preds
        - target
        - metrics_dict
        - metrics_compute_fn
        - loss_avg_meter
        - loss_log_items
        - criterion
        - device
        - stop_training
        - experiment_name
        - ckpt_dir
        - net
        - lr_warmup_epochs
        - sg_logger
        - train_loader
        - valid_loader
        - training_params
        - ddp_silent_mode
        - checkpoint_params
        - arch_params
        - metric_to_watch
        - valid_metrics
        - loss_logging_items_names

    The corresponding Phase enum value for this event is Phase.AVERAGE_BEST_MODELS_VALIDATION_START.
    :param context:
    """
    pass

on_average_best_models_validation_start(context)

Called once after the test was end before the training loop has finished. At this point, the context argument will have the following attributes: - epoch - batch_idx - optimizer - inputs - preds - target - metrics_dict - metrics_compute_fn - loss_avg_meter - loss_log_items - criterion - device - stop_training - experiment_name - ckpt_dir - net - lr_warmup_epochs - sg_logger - train_loader - valid_loader - training_params - ddp_silent_mode - checkpoint_params - arch_params - metric_to_watch - valid_metrics - loss_logging_items_names

The corresponding Phase enum value for this event is Phase.AVERAGE_BEST_MODELS_VALIDATION_START.

Parameters:

Name Type Description Default
context PhaseContext required
Source code in src/super_gradients/training/utils/callbacks/base_callbacks.py
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
def on_average_best_models_validation_start(self, context: PhaseContext) -> None:
    """
    Called once after the test was end before the training loop has finished.
    At this point, the context argument will have the following attributes:
        - epoch
        - batch_idx
        - optimizer
        - inputs
        - preds
        - target
        - metrics_dict
        - metrics_compute_fn
        - loss_avg_meter
        - loss_log_items
        - criterion
        - device
        - stop_training
        - experiment_name
        - ckpt_dir
        - net
        - lr_warmup_epochs
        - sg_logger
        - train_loader
        - valid_loader
        - training_params
        - ddp_silent_mode
        - checkpoint_params
        - arch_params
        - metric_to_watch
        - valid_metrics
        - loss_logging_items_names

    The corresponding Phase enum value for this event is Phase.AVERAGE_BEST_MODELS_VALIDATION_START.
    :param context:
    """
    pass

on_test_batch_end(context)

Called after all forward step have been performed for a given batch and there is nothing left to do. At this point, the context argument will have the following attributes: - epoch - batch_idx - optimizer - inputs - preds - target - metrics_dict - metrics_compute_fn - loss_avg_meter - loss_log_items - criterion - device - stop_training - experiment_name - ckpt_dir - net - lr_warmup_epochs - sg_logger - train_loader - valid_loader - training_params - ddp_silent_mode - checkpoint_params - arch_params - metric_to_watch - valid_metrics - loss_logging_items_names

The corresponding Phase enum value for this event is Phase.TEST_BATCH_END.

Parameters:

Name Type Description Default
context PhaseContext required
Source code in src/super_gradients/training/utils/callbacks/base_callbacks.py
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
def on_test_batch_end(self, context: PhaseContext) -> None:
    """
    Called after all forward step have been performed for a given batch and there is nothing left to do.
    At this point, the context argument will have the following attributes:
        - epoch
        - batch_idx
        - optimizer
        - inputs
        - preds
        - target
        - metrics_dict
        - metrics_compute_fn
        - loss_avg_meter
        - loss_log_items
        - criterion
        - device
        - stop_training
        - experiment_name
        - ckpt_dir
        - net
        - lr_warmup_epochs
        - sg_logger
        - train_loader
        - valid_loader
        - training_params
        - ddp_silent_mode
        - checkpoint_params
        - arch_params
        - metric_to_watch
        - valid_metrics
        - loss_logging_items_names

    The corresponding Phase enum value for this event is Phase.TEST_BATCH_END.
    :param context:
    """
    pass

on_test_batch_start(context)

Called at each batch after getting batch of data from test loader and moving it to target device. At this point, the context argument will have the following attributes: - epoch - batch_idx - optimizer - inputs - preds - target - metrics_dict - metrics_compute_fn - loss_avg_meter - loss_log_items - criterion - device - stop_training - experiment_name - ckpt_dir - net - lr_warmup_epochs - sg_logger - train_loader - valid_loader - training_params - ddp_silent_mode - checkpoint_params - arch_params - metric_to_watch - valid_metrics - loss_logging_items_names

Parameters:

Name Type Description Default
context PhaseContext required
Source code in src/super_gradients/training/utils/callbacks/base_callbacks.py
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
def on_test_batch_start(self, context: PhaseContext) -> None:
    """
    Called at each batch after getting batch of data from test loader and moving it to target device.
    At this point, the context argument will have the following attributes:
        - epoch
        - batch_idx
        - optimizer
        - inputs
        - preds
        - target
        - metrics_dict
        - metrics_compute_fn
        - loss_avg_meter
        - loss_log_items
        - criterion
        - device
        - stop_training
        - experiment_name
        - ckpt_dir
        - net
        - lr_warmup_epochs
        - sg_logger
        - train_loader
        - valid_loader
        - training_params
        - ddp_silent_mode
        - checkpoint_params
        - arch_params
        - metric_to_watch
        - valid_metrics
        - loss_logging_items_names

    :param context:
    """
    pass

on_test_loader_end(context)

Called once at the end of test data loader (after processing the last batch). At this point, the context argument will have the following attributes: - epoch - batch_idx - optimizer - inputs - preds - target - metrics_dict - metrics_compute_fn - loss_avg_meter - loss_log_items - criterion - device - stop_training - experiment_name - ckpt_dir - net - lr_warmup_epochs - sg_logger - train_loader - valid_loader - training_params - ddp_silent_mode - checkpoint_params - arch_params - metric_to_watch - valid_metrics - loss_logging_items_names

The corresponding Phase enum value for this event is Phase.TEST_END.

Parameters:

Name Type Description Default
context PhaseContext required
Source code in src/super_gradients/training/utils/callbacks/base_callbacks.py
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
def on_test_loader_end(self, context: PhaseContext) -> None:
    """
    Called once at the end of test data loader (after processing the last batch).
    At this point, the context argument will have the following attributes:
        - epoch
        - batch_idx
        - optimizer
        - inputs
        - preds
        - target
        - metrics_dict
        - metrics_compute_fn
        - loss_avg_meter
        - loss_log_items
        - criterion
        - device
        - stop_training
        - experiment_name
        - ckpt_dir
        - net
        - lr_warmup_epochs
        - sg_logger
        - train_loader
        - valid_loader
        - training_params
        - ddp_silent_mode
        - checkpoint_params
        - arch_params
        - metric_to_watch
        - valid_metrics
        - loss_logging_items_names

    The corresponding Phase enum value for this event is Phase.TEST_END.
    :param context:
    """
    pass

on_test_loader_start(context)

Called once at the start of test data loader (before getting the first batch). At this point, the context argument will have the following attributes: - epoch - batch_idx - optimizer - inputs - preds - target - metrics_dict - metrics_compute_fn - loss_avg_meter - loss_log_items - criterion - device - stop_training - experiment_name - ckpt_dir - net - lr_warmup_epochs - sg_logger - train_loader - valid_loader - training_params - ddp_silent_mode - checkpoint_params - arch_params - metric_to_watch - valid_metrics - loss_logging_items_names

Parameters:

Name Type Description Default
context PhaseContext required
Source code in src/super_gradients/training/utils/callbacks/base_callbacks.py
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
def on_test_loader_start(self, context: PhaseContext) -> None:
    """
    Called once at the start of test data loader (before getting the first batch).
    At this point, the context argument will have the following attributes:
        - epoch
        - batch_idx
        - optimizer
        - inputs
        - preds
        - target
        - metrics_dict
        - metrics_compute_fn
        - loss_avg_meter
        - loss_log_items
        - criterion
        - device
        - stop_training
        - experiment_name
        - ckpt_dir
        - net
        - lr_warmup_epochs
        - sg_logger
        - train_loader
        - valid_loader
        - training_params
        - ddp_silent_mode
        - checkpoint_params
        - arch_params
        - metric_to_watch
        - valid_metrics
        - loss_logging_items_names

    :param context:
    """
    pass

on_train_batch_backward_end(context)

Called after loss.backward() method was called for a given batch At this point, the context argument will have the following attributes: - epoch - batch_idx - optimizer - inputs - preds - target - metrics_compute_fn - loss_avg_meter - loss_log_items - criterion - device - stop_training - experiment_name - ckpt_dir - net - lr_warmup_epochs - sg_logger - train_loader - valid_loader - training_params - ddp_silent_mode - checkpoint_params - arch_params - metric_to_watch - valid_metrics - loss_logging_items_names

Parameters:

Name Type Description Default
context PhaseContext required
Source code in src/super_gradients/training/utils/callbacks/base_callbacks.py
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
def on_train_batch_backward_end(self, context: PhaseContext) -> None:
    """
    Called after loss.backward() method was called for a given batch
    At this point, the context argument will have the following attributes:
        - epoch
        - batch_idx
        - optimizer
        - inputs
        - preds
        - target
        - metrics_compute_fn
        - loss_avg_meter
        - loss_log_items
        - criterion
        - device
        - stop_training
        - experiment_name
        - ckpt_dir
        - net
        - lr_warmup_epochs
        - sg_logger
        - train_loader
        - valid_loader
        - training_params
        - ddp_silent_mode
        - checkpoint_params
        - arch_params
        - metric_to_watch
        - valid_metrics
        - loss_logging_items_names

    :param context:
    """
    pass

on_train_batch_end(context)

Called after all forward/backward/optimizer steps have been performed for a given batch and there is nothing left to do. At this point, the context argument will have the following attributes: - epoch - batch_idx - optimizer - inputs - preds - target - metrics_dict - metrics_compute_fn - loss_avg_meter - loss_log_items - criterion - device - stop_training - experiment_name - ckpt_dir - net - lr_warmup_epochs - sg_logger - train_loader - valid_loader - training_params - ddp_silent_mode - checkpoint_params - arch_params - metric_to_watch - valid_metrics - loss_logging_items_names

Parameters:

Name Type Description Default
context PhaseContext required
Source code in src/super_gradients/training/utils/callbacks/base_callbacks.py
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
def on_train_batch_end(self, context: PhaseContext) -> None:
    """
    Called after all forward/backward/optimizer steps have been performed for a given batch and there is nothing left to do.
    At this point, the context argument will have the following attributes:
        - epoch
        - batch_idx
        - optimizer
        - inputs
        - preds
        - target
        - metrics_dict
        - metrics_compute_fn
        - loss_avg_meter
        - loss_log_items
        - criterion
        - device
        - stop_training
        - experiment_name
        - ckpt_dir
        - net
        - lr_warmup_epochs
        - sg_logger
        - train_loader
        - valid_loader
        - training_params
        - ddp_silent_mode
        - checkpoint_params
        - arch_params
        - metric_to_watch
        - valid_metrics
        - loss_logging_items_names

    :param context:
    """
    pass

on_train_batch_gradient_step_end(context)

Called after gradient step has been performed. Good place to update LR (for step-based schedulers) At this point, the context argument will have the following attributes: - epoch - batch_idx - inputs - target - metrics_compute_fn - loss_avg_meter - criterion - device - stop_training - net - lr_warmup_epochs - sg_logger - train_loader - valid_loader - loss_logging_items_names

The corresponding Phase enum value for this event is Phase.TRAIN_BATCH_STEP.

Parameters:

Name Type Description Default
context PhaseContext required
Source code in src/super_gradients/training/utils/callbacks/base_callbacks.py
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
def on_train_batch_gradient_step_end(self, context: PhaseContext) -> None:
    """
    Called after gradient step has been performed. Good place to update LR (for step-based schedulers)
    At this point, the context argument will have the following attributes:
        - epoch
        - batch_idx
        - inputs
        - target
        - metrics_compute_fn
        - loss_avg_meter
        - criterion
        - device
        - stop_training
        - net
        - lr_warmup_epochs
        - sg_logger
        - train_loader
        - valid_loader
        - loss_logging_items_names

    The corresponding Phase enum value for this event is Phase.TRAIN_BATCH_STEP.
    :param context:
    """
    pass

on_train_batch_gradient_step_start(context)

Called before the graadient step is about to happen. Good place to clip gradients (with respect to scaler), log gradients to data ratio, etc. At this point, the context argument will have the following attributes: - epoch - batch_idx - optimizer - inputs - preds - target - metrics_compute_fn - loss_avg_meter - loss_log_items - criterion - device - stop_training - experiment_name - ckpt_dir - net - lr_warmup_epochs - sg_logger - train_loader - valid_loader - training_params - ddp_silent_mode - checkpoint_params - arch_params - metric_to_watch - valid_metrics - loss_logging_items_names

Parameters:

Name Type Description Default
context PhaseContext required
Source code in src/super_gradients/training/utils/callbacks/base_callbacks.py
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
def on_train_batch_gradient_step_start(self, context: PhaseContext) -> None:
    """
    Called before the graadient step is about to happen.
    Good place to clip gradients (with respect to scaler), log gradients to data ratio, etc.
    At this point, the context argument will have the following attributes:
        - epoch
        - batch_idx
        - optimizer
        - inputs
        - preds
        - target
        - metrics_compute_fn
        - loss_avg_meter
        - loss_log_items
        - criterion
        - device
        - stop_training
        - experiment_name
        - ckpt_dir
        - net
        - lr_warmup_epochs
        - sg_logger
        - train_loader
        - valid_loader
        - training_params
        - ddp_silent_mode
        - checkpoint_params
        - arch_params
        - metric_to_watch
        - valid_metrics
        - loss_logging_items_names

    :param context:
    """
    pass

on_train_batch_loss_end(context)

Called after model forward and loss computation has been done. At this point, the context argument will have the following attributes: - epoch - batch_idx - optimizer - inputs - preds - target - metrics_compute_fn - loss_avg_meter - loss_log_items - criterion - device - stop_training - experiment_name - ckpt_dir - net - lr_warmup_epochs - sg_logger - train_loader - valid_loader - training_params - ddp_silent_mode - checkpoint_params - arch_params - metric_to_watch - valid_metrics - loss_logging_items_names The corresponding Phase enum value for this event is Phase.TRAIN_BATCH_END.

Parameters:

Name Type Description Default
context PhaseContext required
Source code in src/super_gradients/training/utils/callbacks/base_callbacks.py
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
def on_train_batch_loss_end(self, context: PhaseContext) -> None:
    """
    Called after model forward and loss computation has been done.
    At this point, the context argument will have the following attributes:
        - epoch
        - batch_idx
        - optimizer
        - inputs
        - preds
        - target
        - metrics_compute_fn
        - loss_avg_meter
        - loss_log_items
        - criterion
        - device
        - stop_training
        - experiment_name
        - ckpt_dir
        - net
        - lr_warmup_epochs
        - sg_logger
        - train_loader
        - valid_loader
        - training_params
        - ddp_silent_mode
        - checkpoint_params
        - arch_params
        - metric_to_watch
        - valid_metrics
        - loss_logging_items_names
    The corresponding Phase enum value for this event is Phase.TRAIN_BATCH_END.

    :param context:
    """
    pass

on_train_batch_start(context)

Called at each batch after getting batch of data from data loader and moving it to target device. This event triggered AFTER Trainer.pre_prediction_callback call (If it was defined).

At this point, the context argument will have the following attributes: - epoch - batch_idx - optimizer - inputs - target - metrics_compute_fn - loss_avg_meter - criterion - device - stop_training - experiment_name - ckpt_dir - net - lr_warmup_epochs - sg_logger - train_loader - valid_loader - training_params - ddp_silent_mode - checkpoint_params - arch_params - metric_to_watch - valid_metrics

Parameters:

Name Type Description Default
context PhaseContext required
Source code in src/super_gradients/training/utils/callbacks/base_callbacks.py
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
def on_train_batch_start(self, context: PhaseContext) -> None:
    """
    Called at each batch after getting batch of data from data loader and moving it to target device.
    This event triggered AFTER Trainer.pre_prediction_callback call (If it was defined).

    At this point, the context argument will have the following attributes:
        - epoch
        - batch_idx
        - optimizer
        - inputs
        - target
        - metrics_compute_fn
        - loss_avg_meter
        - criterion
        - device
        - stop_training
        - experiment_name
        - ckpt_dir
        - net
        - lr_warmup_epochs
        - sg_logger
        - train_loader
        - valid_loader
        - training_params
        - ddp_silent_mode
        - checkpoint_params
        - arch_params
        - metric_to_watch
        - valid_metrics

    :param context:
    """
    pass

on_train_loader_end(context)

Called each epoch at the end of train data loader (after processing the last batch). At this point, the context argument will have the following attributes: - epoch - batch_idx - optimizer - inputs - preds - target - metrics_dict - metrics_compute_fn - loss_avg_meter - loss_log_items - criterion - device - stop_training - experiment_name - ckpt_dir - net - lr_warmup_epochs - sg_logger - train_loader - valid_loader - training_params - ddp_silent_mode - checkpoint_params - arch_params - metric_to_watch - valid_metrics - loss_logging_items_names

The corresponding Phase enum value for this event is Phase.TRAIN_EPOCH_END.

Parameters:

Name Type Description Default
context PhaseContext required
Source code in src/super_gradients/training/utils/callbacks/base_callbacks.py
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
def on_train_loader_end(self, context: PhaseContext) -> None:
    """
    Called each epoch at the end of train data loader (after processing the last batch).
    At this point, the context argument will have the following attributes:
        - epoch
        - batch_idx
        - optimizer
        - inputs
        - preds
        - target
        - metrics_dict
        - metrics_compute_fn
        - loss_avg_meter
        - loss_log_items
        - criterion
        - device
        - stop_training
        - experiment_name
        - ckpt_dir
        - net
        - lr_warmup_epochs
        - sg_logger
        - train_loader
        - valid_loader
        - training_params
        - ddp_silent_mode
        - checkpoint_params
        - arch_params
        - metric_to_watch
        - valid_metrics
        - loss_logging_items_names

    The corresponding Phase enum value for this event is Phase.TRAIN_EPOCH_END.
    :param context:
    """
    pass

on_train_loader_start(context)

Called each epoch at the start of train data loader (before getting the first batch). At this point, the context argument will have the following attributes: - optimizer - criterion - device - experiment_name - ckpt_dir - net - sg_logger - train_loader - valid_loader - training_params - checkpoint_params - arch_params - metric_to_watch - valid_metrics The corresponding Phase enum value for this event is Phase.TRAIN_EPOCH_START.

Parameters:

Name Type Description Default
context PhaseContext required
Source code in src/super_gradients/training/utils/callbacks/base_callbacks.py
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
def on_train_loader_start(self, context: PhaseContext) -> None:
    """
    Called each epoch at the start of train data loader (before getting the first batch).
    At this point, the context argument will have the following attributes:
        - optimizer
        - criterion
        - device
        - experiment_name
        - ckpt_dir
        - net
        - sg_logger
        - train_loader
        - valid_loader
        - training_params
        - checkpoint_params
        - arch_params
        - metric_to_watch
        - valid_metrics
    The corresponding Phase enum value for this event is Phase.TRAIN_EPOCH_START.
    :param context:
    """
    pass

on_training_end(context)

Called once after the training loop has finished (Due to reaching optimization criterion or because of an error.) At this point, the context argument will have the following attributes: - epoch - batch_idx - optimizer - inputs - preds - target - metrics_compute_fn - loss_avg_meter - loss_log_items - criterion - device - stop_training - experiment_name - ckpt_dir - net - lr_warmup_epochs - sg_logger - train_loader - valid_loader - training_params - ddp_silent_mode - checkpoint_params - arch_params - metric_to_watch - valid_metrics - loss_logging_items_names

The corresponding Phase enum value for this event is Phase.POST_TRAINING.

Parameters:

Name Type Description Default
context PhaseContext required
Source code in src/super_gradients/training/utils/callbacks/base_callbacks.py
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
def on_training_end(self, context: PhaseContext) -> None:
    """
    Called once after the training loop has finished (Due to reaching optimization criterion or because of an error.)
    At this point, the context argument will have the following attributes:
        - epoch
        - batch_idx
        - optimizer
        - inputs
        - preds
        - target
        - metrics_compute_fn
        - loss_avg_meter
        - loss_log_items
        - criterion
        - device
        - stop_training
        - experiment_name
        - ckpt_dir
        - net
        - lr_warmup_epochs
        - sg_logger
        - train_loader
        - valid_loader
        - training_params
        - ddp_silent_mode
        - checkpoint_params
        - arch_params
        - metric_to_watch
        - valid_metrics
        - loss_logging_items_names

    The corresponding Phase enum value for this event is Phase.POST_TRAINING.
    :param context:
    """
    pass

on_training_start(context)

Called once before start of the first epoch At this point, the context argument will have the following attributes: - optimizer - criterion - device - experiment_name - ckpt_dir - net - sg_logger - train_loader - valid_loader - training_params - checkpoint_params - arch_params - metric_to_watch - valid_metrics

The corresponding Phase enum value for this event is Phase.PRE_TRAINING.

Parameters:

Name Type Description Default
context PhaseContext required
Source code in src/super_gradients/training/utils/callbacks/base_callbacks.py
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
def on_training_start(self, context: PhaseContext) -> None:
    """
    Called once before start of the first epoch
    At this point, the context argument will have the following attributes:
        - optimizer
        - criterion
        - device
        - experiment_name
        - ckpt_dir
        - net
        - sg_logger
        - train_loader
        - valid_loader
        - training_params
        - checkpoint_params
        - arch_params
        - metric_to_watch
        - valid_metrics

    The corresponding Phase enum value for this event is Phase.PRE_TRAINING.
    :param context:
    """
    pass

on_validation_batch_end(context)

Called after all forward step / loss / metric computation have been performed for a given batch and there is nothing left to do. At this point, the context argument will have the following attributes: - epoch - batch_idx - inputs - preds - target - metrics_compute_fn - loss_avg_meter - loss_log_items - criterion - device - stop_training - net - lr_warmup_epochs - sg_logger - train_loader - valid_loader - loss_logging_items_names

The corresponding Phase enum value for this event is Phase.VALIDATION_BATCH_END.

Parameters:

Name Type Description Default
context PhaseContext required
Source code in src/super_gradients/training/utils/callbacks/base_callbacks.py
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
def on_validation_batch_end(self, context: PhaseContext) -> None:
    """
    Called after all forward step / loss / metric computation have been performed for a given batch and there is nothing left to do.
    At this point, the context argument will have the following attributes:
        - epoch
        - batch_idx
        - inputs
        - preds
        - target
        - metrics_compute_fn
        - loss_avg_meter
        - loss_log_items
        - criterion
        - device
        - stop_training
        - net
        - lr_warmup_epochs
        - sg_logger
        - train_loader
        - valid_loader
        - loss_logging_items_names

    The corresponding Phase enum value for this event is Phase.VALIDATION_BATCH_END.
    :param context:
    """
    pass

on_validation_batch_start(context)

Called at each batch after getting batch of data from validation loader and moving it to target device. At this point, the context argument will have the following attributes: - epoch - batch_idx - inputs - target - metrics_compute_fn - loss_avg_meter - criterion - device - stop_training - net - lr_warmup_epochs - sg_logger - train_loader - valid_loader - loss_logging_items_names

Parameters:

Name Type Description Default
context PhaseContext required
Source code in src/super_gradients/training/utils/callbacks/base_callbacks.py
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
def on_validation_batch_start(self, context: PhaseContext) -> None:
    """
    Called at each batch after getting batch of data from validation loader and moving it to target device.
    At this point, the context argument will have the following attributes:
        - epoch
        - batch_idx
        - inputs
        - target
        - metrics_compute_fn
        - loss_avg_meter
        - criterion
        - device
        - stop_training
        - net
        - lr_warmup_epochs
        - sg_logger
        - train_loader
        - valid_loader
        - loss_logging_items_names

    :param context:
    """
    pass

on_validation_end_best_epoch(context)

Called each epoch after validation has been performed and the best metric has been achieved. At this point, the context argument will have the following attributes: - epoch - batch_idx - optimizer - inputs - preds - target - metrics_dict - metrics_compute_fn - loss_avg_meter - loss_log_items - criterion - device - stop_training - experiment_name - ckpt_dir - net - lr_warmup_epochs - sg_logger - train_loader - valid_loader - training_params - ddp_silent_mode - checkpoint_params - arch_params - metric_to_watch - valid_metrics - loss_logging_items_names

The corresponding Phase enum value for this event is Phase.VALIDATION_END_BEST_EPOCH.

Parameters:

Name Type Description Default
context PhaseContext required
Source code in src/super_gradients/training/utils/callbacks/base_callbacks.py
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
def on_validation_end_best_epoch(self, context: PhaseContext) -> None:
    """
    Called each epoch after validation has been performed and the best metric has been achieved.
    At this point, the context argument will have the following attributes:
        - epoch
        - batch_idx
        - optimizer
        - inputs
        - preds
        - target
        - metrics_dict
        - metrics_compute_fn
        - loss_avg_meter
        - loss_log_items
        - criterion
        - device
        - stop_training
        - experiment_name
        - ckpt_dir
        - net
        - lr_warmup_epochs
        - sg_logger
        - train_loader
        - valid_loader
        - training_params
        - ddp_silent_mode
        - checkpoint_params
        - arch_params
        - metric_to_watch
        - valid_metrics
        - loss_logging_items_names

    The corresponding Phase enum value for this event is Phase.VALIDATION_END_BEST_EPOCH.
    :param context:
    """
    pass

on_validation_loader_end(context)

Called each epoch at the end of validation data loader (after processing the last batch). At this point, the context argument will have the following attributes: - epoch - batch_idx - optimizer - inputs - preds - target - metrics_dict - metrics_compute_fn - loss_avg_meter - loss_log_items - criterion - device - stop_training - experiment_name - ckpt_dir - net - lr_warmup_epochs - sg_logger - train_loader - valid_loader - training_params - ddp_silent_mode - checkpoint_params - arch_params - metric_to_watch - valid_metrics - loss_logging_items_names

The corresponding Phase enum value for this event is Phase.VALIDATION_EPOCH_END.

Parameters:

Name Type Description Default
context PhaseContext required
Source code in src/super_gradients/training/utils/callbacks/base_callbacks.py
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
def on_validation_loader_end(self, context: PhaseContext) -> None:
    """
    Called each epoch at the end of validation data loader (after processing the last batch).
    At this point, the context argument will have the following attributes:
        - epoch
        - batch_idx
        - optimizer
        - inputs
        - preds
        - target
        - metrics_dict
        - metrics_compute_fn
        - loss_avg_meter
        - loss_log_items
        - criterion
        - device
        - stop_training
        - experiment_name
        - ckpt_dir
        - net
        - lr_warmup_epochs
        - sg_logger
        - train_loader
        - valid_loader
        - training_params
        - ddp_silent_mode
        - checkpoint_params
        - arch_params
        - metric_to_watch
        - valid_metrics
        - loss_logging_items_names

    The corresponding Phase enum value for this event is Phase.VALIDATION_EPOCH_END.
    :param context:
    """
    pass

on_validation_loader_start(context)

Called each epoch at the start of validation data loader (before getting the first batch). At this point, the context argument will have the following attributes: - epoch - batch_idx - optimizer - inputs - preds - target - metrics_dict - metrics_compute_fn - loss_avg_meter - loss_log_items - criterion - device - stop_training - experiment_name - ckpt_dir - net - lr_warmup_epochs - sg_logger - train_loader - valid_loader - training_params - ddp_silent_mode - checkpoint_params - arch_params - metric_to_watch - valid_metrics - loss_logging_items_names

Parameters:

Name Type Description Default
context PhaseContext required
Source code in src/super_gradients/training/utils/callbacks/base_callbacks.py
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
def on_validation_loader_start(self, context: PhaseContext) -> None:
    """
    Called each epoch at the start of validation data loader (before getting the first batch).
    At this point, the context argument will have the following attributes:
        - epoch
        - batch_idx
        - optimizer
        - inputs
        - preds
        - target
        - metrics_dict
        - metrics_compute_fn
        - loss_avg_meter
        - loss_log_items
        - criterion
        - device
        - stop_training
        - experiment_name
        - ckpt_dir
        - net
        - lr_warmup_epochs
        - sg_logger
        - train_loader
        - valid_loader
        - training_params
        - ddp_silent_mode
        - checkpoint_params
        - arch_params
        - metric_to_watch
        - valid_metrics
        - loss_logging_items_names

    :param context:
    """
    pass

CallbackHandler

Bases: Callback

Runs all callbacks

Parameters:

Name Type Description Default
callbacks List[Callback]

Callbacks to be run.

required
Source code in src/super_gradients/training/utils/callbacks/base_callbacks.py
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
class CallbackHandler(Callback):
    """
    Runs all callbacks

    :param callbacks: Callbacks to be run.
    """

    def __init__(self, callbacks: List[Callback]):
        # TODO: Add reordering of callbacks to make sure that they are called in the right order
        # For instance, two callbacks may be dependent on each other, so the first one should be called first
        # Example: Gradient Clipping & Gradient Logging callback. We first need to clip the gradients, and then log them
        # So if user added them in wrong order we can guarantee their order would be correct.
        # We can achieve this by adding a property to the callback to the callback indicating it's priority:
        # Forward   = 0
        # Loss      = 100
        # Backward  = 200
        # Metrics   = 300
        # Scheduler = 400
        # Logging   = 500
        # So ordering callbacks by their order would ensure than we first run all Forward-related callbacks (for a given event),
        # Than backward, and only then - logging.
        self.callbacks = callbacks

    def on_training_start(self, context: PhaseContext) -> None:
        for callback in self.callbacks:
            callback.on_training_start(context)

    def on_train_loader_start(self, context: PhaseContext) -> None:
        for callback in self.callbacks:
            callback.on_train_loader_start(context)

    def on_train_batch_start(self, context: PhaseContext) -> None:
        for callback in self.callbacks:
            callback.on_train_batch_start(context)

    def on_train_batch_loss_end(self, context: PhaseContext) -> None:
        for callback in self.callbacks:
            callback.on_train_batch_loss_end(context)

    def on_train_batch_backward_end(self, context: PhaseContext) -> None:
        for callback in self.callbacks:
            callback.on_train_batch_backward_end(context)

    def on_train_batch_gradient_step_start(self, context: PhaseContext) -> None:
        for callback in self.callbacks:
            callback.on_train_batch_gradient_step_start(context)

    def on_train_batch_gradient_step_end(self, context: PhaseContext) -> None:
        for callback in self.callbacks:
            callback.on_train_batch_gradient_step_end(context)

    def on_train_batch_end(self, context: PhaseContext) -> None:
        for callback in self.callbacks:
            callback.on_train_batch_end(context)

    def on_validation_loader_start(self, context: PhaseContext) -> None:
        for callback in self.callbacks:
            callback.on_validation_loader_start(context)

    def on_validation_batch_start(self, context: PhaseContext) -> None:
        for callback in self.callbacks:
            callback.on_validation_batch_start(context)

    def on_validation_batch_end(self, context: PhaseContext) -> None:
        for callback in self.callbacks:
            callback.on_validation_batch_end(context)

    def on_validation_loader_end(self, context: PhaseContext) -> None:
        for callback in self.callbacks:
            callback.on_validation_loader_end(context)

    def on_train_loader_end(self, context: PhaseContext) -> None:
        for callback in self.callbacks:
            callback.on_train_loader_end(context)

    def on_training_end(self, context: PhaseContext) -> None:
        for callback in self.callbacks:
            callback.on_training_end(context)

    def on_validation_end_best_epoch(self, context: PhaseContext) -> None:
        for callback in self.callbacks:
            callback.on_validation_end_best_epoch(context)

    def on_test_loader_start(self, context: PhaseContext) -> None:
        for callback in self.callbacks:
            callback.on_test_loader_start(context)

    def on_test_batch_start(self, context: PhaseContext) -> None:
        for callback in self.callbacks:
            callback.on_test_batch_start(context)

    def on_test_batch_end(self, context: PhaseContext) -> None:
        for callback in self.callbacks:
            callback.on_test_batch_end(context)

    def on_test_loader_end(self, context: PhaseContext) -> None:
        for callback in self.callbacks:
            callback.on_test_loader_end(context)

    def on_average_best_models_validation_start(self, context: PhaseContext) -> None:
        for callback in self.callbacks:
            callback.on_average_best_models_validation_start(context)

    def on_average_best_models_validation_end(self, context: PhaseContext) -> None:
        for callback in self.callbacks:
            callback.on_average_best_models_validation_end(context)

PhaseCallback

Bases: Callback

Kept here to keep backward compatibility with old code. New callbacks should use Callback class instead. This callback supports receiving only a subset of events defined in Phase enum:

PRE_TRAINING = "PRE_TRAINING" TRAIN_EPOCH_START = "TRAIN_EPOCH_START" TRAIN_BATCH_END = "TRAIN_BATCH_END" TRAIN_BATCH_STEP = "TRAIN_BATCH_STEP" TRAIN_EPOCH_END = "TRAIN_EPOCH_END"

VALIDATION_BATCH_END = "VALIDATION_BATCH_END" VALIDATION_EPOCH_END = "VALIDATION_EPOCH_END" VALIDATION_END_BEST_EPOCH = "VALIDATION_END_BEST_EPOCH"

TEST_BATCH_END = "TEST_BATCH_END" TEST_END = "TEST_END" AVERAGE_BEST_MODELS_VALIDATION_START = "AVERAGE_BEST_MODELS_VALIDATION_START" AVERAGE_BEST_MODELS_VALIDATION_END = "AVERAGE_BEST_MODELS_VALIDATION_END" POST_TRAINING = "POST_TRAINING"

Source code in src/super_gradients/training/utils/callbacks/base_callbacks.py
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
class PhaseCallback(Callback):
    """
    Kept here to keep backward compatibility with old code. New callbacks should use Callback class instead.
    This callback supports receiving only a subset of events defined in Phase enum:

    PRE_TRAINING = "PRE_TRAINING"
    TRAIN_EPOCH_START = "TRAIN_EPOCH_START"
    TRAIN_BATCH_END = "TRAIN_BATCH_END"
    TRAIN_BATCH_STEP = "TRAIN_BATCH_STEP"
    TRAIN_EPOCH_END = "TRAIN_EPOCH_END"

    VALIDATION_BATCH_END = "VALIDATION_BATCH_END"
    VALIDATION_EPOCH_END = "VALIDATION_EPOCH_END"
    VALIDATION_END_BEST_EPOCH = "VALIDATION_END_BEST_EPOCH"

    TEST_BATCH_END = "TEST_BATCH_END"
    TEST_END = "TEST_END"
    AVERAGE_BEST_MODELS_VALIDATION_START = "AVERAGE_BEST_MODELS_VALIDATION_START"
    AVERAGE_BEST_MODELS_VALIDATION_END = "AVERAGE_BEST_MODELS_VALIDATION_END"
    POST_TRAINING = "POST_TRAINING"
    """

    def __init__(self, phase: Union[Phase, str]):
        if isinstance(phase, str):
            phase = Phase.from_string(phase)
        elif not isinstance(phase, Phase):
            raise TypeError("phase must be a string or a Phase enum member")

        self.phase = phase

    def __call__(self, *args, **kwargs):
        raise NotImplementedError

    def __repr__(self) -> str:
        return self.__class__.__name__

    def on_training_start(self, context: PhaseContext) -> None:
        if self.phase == Phase.PRE_TRAINING:
            self(context)

    def on_train_loader_start(self, context: PhaseContext) -> None:
        if self.phase == Phase.TRAIN_EPOCH_START:
            self(context)

    def on_train_batch_loss_end(self, context: PhaseContext) -> None:
        if self.phase == Phase.TRAIN_BATCH_END:
            self(context)

    def on_train_batch_gradient_step_end(self, context: PhaseContext) -> None:
        if self.phase == Phase.TRAIN_BATCH_STEP:
            self(context)

    def on_train_loader_end(self, context: PhaseContext) -> None:
        if self.phase == Phase.TRAIN_EPOCH_END:
            self(context)

    def on_validation_batch_end(self, context: PhaseContext) -> None:
        if self.phase == Phase.VALIDATION_BATCH_END:
            self(context)

    def on_validation_loader_end(self, context: PhaseContext) -> None:
        if self.phase == Phase.VALIDATION_EPOCH_END:
            self(context)

    def on_validation_end_best_epoch(self, context: PhaseContext) -> None:
        if self.phase == Phase.VALIDATION_END_BEST_EPOCH:
            self(context)

    def on_test_batch_end(self, context: PhaseContext) -> None:
        if self.phase == Phase.TEST_BATCH_END:
            self(context)

    def on_test_loader_end(self, context: PhaseContext) -> None:
        if self.phase == Phase.TEST_END:
            self(context)

    def on_average_best_models_validation_start(self, context: PhaseContext) -> None:
        if self.phase == Phase.AVERAGE_BEST_MODELS_VALIDATION_START:
            self(context)

    def on_average_best_models_validation_end(self, context: PhaseContext) -> None:
        if self.phase == Phase.AVERAGE_BEST_MODELS_VALIDATION_END:
            self(context)

    def on_training_end(self, context: PhaseContext) -> None:
        if self.phase == Phase.POST_TRAINING:
            self(context)

PhaseContext

Represents the input for phase callbacks, and is constantly updated after callback calls.

Source code in src/super_gradients/training/utils/callbacks/base_callbacks.py
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
class PhaseContext:
    """
    Represents the input for phase callbacks, and is constantly updated after callback calls.

    """

    def __init__(
        self,
        epoch: Optional[int] = None,
        batch_idx: Optional[int] = None,
        optimizer: Optional[torch.optim.Optimizer] = None,
        metrics_dict=None,
        inputs: Optional[torch.Tensor] = None,
        preds: Optional[torch.Tensor] = None,
        target: Optional[torch.Tensor] = None,
        metrics_compute_fn: Optional[MetricCollection] = None,
        loss_avg_meter: Optional["AverageMeter"] = None,  # noqa: ignore
        loss_log_items: Optional[torch.Tensor] = None,
        criterion: Optional[_Loss] = None,
        device: Optional[str] = None,
        experiment_name: Optional[str] = None,
        ckpt_dir: Optional[str] = None,
        net: Optional["SgModule"] = None,  # noqa: ignore
        lr_warmup_epochs: Optional[int] = None,
        sg_logger: Optional["BaseSGLogger"] = None,  # noqa: ignore
        train_loader: Optional[DataLoader] = None,
        valid_loader: Optional[DataLoader] = None,
        test_loader: Optional[DataLoader] = None,
        training_params: Optional["TrainingParams"] = None,  # noqa: ignore
        ddp_silent_mode: Optional[bool] = None,
        checkpoint_params: Optional["HpmStruct"] = None,  # noqa: ignore
        architecture: Optional = None,
        arch_params: Optional["HpmStruct"] = None,  # noqa: ignore
        metric_to_watch: Optional[str] = None,
        valid_metrics: Optional[MetricCollection] = None,  # noqa: ignore
        ema_model: Optional["SgModule"] = None,  # noqa: ignore
        loss_logging_items_names: Optional[List[str]] = None,
        additional_batch_items: Optional[Any] = None,
    ):
        self.epoch = epoch
        self.batch_idx = batch_idx
        self.optimizer = optimizer
        self.inputs = inputs
        self.preds = preds
        self.target = target
        self.metrics_dict = metrics_dict
        self.metrics_compute_fn = metrics_compute_fn
        self.loss_avg_meter = loss_avg_meter
        self.loss_log_items = loss_log_items
        self.criterion = criterion
        self.device = device
        self.stop_training = False
        self.experiment_name = experiment_name
        self.ckpt_dir = ckpt_dir
        self.net = net
        self.lr_warmup_epochs = lr_warmup_epochs
        self.sg_logger = sg_logger
        self.train_loader = train_loader
        self.valid_loader = valid_loader
        self.test_loader = test_loader
        self.training_params = training_params
        self.ddp_silent_mode = ddp_silent_mode
        self.checkpoint_params = checkpoint_params
        self.architecture = architecture
        self.arch_params = arch_params
        self.metric_to_watch = metric_to_watch
        self.valid_metrics = valid_metrics
        self.ema_model = ema_model
        self.loss_logging_items_names = loss_logging_items_names
        self.additional_batch_items = additional_batch_items

    def update_context(self, **kwargs):
        for attr, attr_val in kwargs.items():
            setattr(self, attr, attr_val)

BinarySegmentationVisualizationCallback

Bases: PhaseCallback

A callback that adds a visualization of a batch of segmentation predictions to context.sg_logger

Parameters:

Name Type Description Default
phase Union[Phase, str]

When to trigger the callback.

required
freq int

Frequency (in epochs) to perform this callback.

required
batch_idx int

Batch index to perform visualization for.

0
last_img_idx_in_batch int

Last image index to add to log. (default=-1, will take entire batch).

-1
Source code in src/super_gradients/training/utils/callbacks/callbacks.py
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
class BinarySegmentationVisualizationCallback(PhaseCallback):
    """
    A callback that adds a visualization of a batch of segmentation predictions to context.sg_logger

    :param phase:                   When to trigger the callback.
    :param freq:                    Frequency (in epochs) to perform this callback.
    :param batch_idx:               Batch index to perform visualization for.
    :param last_img_idx_in_batch:   Last image index to add to log. (default=-1, will take entire batch).
    """

    def __init__(self, phase: Union[Phase, str], freq: int, batch_idx: int = 0, last_img_idx_in_batch: int = -1):
        super(BinarySegmentationVisualizationCallback, self).__init__(phase)
        self.freq = freq
        self.batch_idx = batch_idx
        self.last_img_idx_in_batch = last_img_idx_in_batch

    def __call__(self, context: PhaseContext):
        if context.epoch % self.freq == 0 and context.batch_idx == self.batch_idx:
            if isinstance(context.preds, tuple):
                preds = context.preds[0].clone()
            else:
                preds = context.preds.clone()
            batch_imgs = BinarySegmentationVisualization.visualize_batch(context.inputs, preds, context.target, self.batch_idx)
            batch_imgs = [cv2.cvtColor(image, cv2.COLOR_BGR2RGB) for image in batch_imgs]
            batch_imgs = np.stack(batch_imgs)
            tag = "batch_" + str(self.batch_idx) + "_images"
            context.sg_logger.add_images(tag=tag, images=batch_imgs[: self.last_img_idx_in_batch], global_step=context.epoch, data_format="NHWC")

CosineLRScheduler

Bases: LRCallbackBase

Hard coded step Cosine anealing learning rate scheduling.

Source code in src/super_gradients/training/utils/callbacks/callbacks.py
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
@register_lr_scheduler(LRSchedulers.COSINE, deprecated_name="cosine")
class CosineLRScheduler(LRCallbackBase):
    """
    Hard coded step Cosine anealing learning rate scheduling.
    """

    def __init__(self, max_epochs, cosine_final_lr_ratio, **kwargs):
        super().__init__(Phase.TRAIN_BATCH_STEP, **kwargs)
        self.max_epochs = max_epochs
        self.cosine_final_lr_ratio = cosine_final_lr_ratio

    def perform_scheduling(self, context):
        effective_epoch = context.epoch - self.training_params.lr_warmup_epochs
        effective_max_epochs = self.max_epochs - self.training_params.lr_warmup_epochs - self.training_params.lr_cooldown_epochs
        current_iter = max(0, self.train_loader_len * effective_epoch + context.batch_idx - self.training_params.lr_warmup_steps)
        max_iter = self.train_loader_len * effective_max_epochs - self.training_params.lr_warmup_steps
        for group_name in self.lr.keys():
            self.lr[group_name] = float(self.compute_learning_rate(current_iter, max_iter, self.initial_lr[group_name], self.cosine_final_lr_ratio))

        self.update_lr(context.optimizer, context.epoch, context.batch_idx)

    def is_lr_scheduling_enabled(self, context):
        # Account of per-step warmup
        if self.training_params.lr_warmup_steps > 0:
            current_step = self.train_loader_len * context.epoch + context.batch_idx
            return current_step >= self.training_params.lr_warmup_steps

        post_warmup_epochs = self.training_params.max_epochs - self.training_params.lr_cooldown_epochs
        return self.training_params.lr_warmup_epochs <= context.epoch < post_warmup_epochs

    @classmethod
    def compute_learning_rate(cls, step: Union[float, np.ndarray], total_steps: float, initial_lr: float, final_lr_ratio: float):
        # the cosine starts from initial_lr and reaches initial_lr * cosine_final_lr_ratio in last epoch

        lr = 0.5 * initial_lr * (1.0 + np.cos(step / (total_steps + 1) * math.pi))
        return lr * (1 - final_lr_ratio) + (initial_lr * final_lr_ratio)

DeciLabUploadCallback

Bases: PhaseCallback

Post-training callback for uploading and optimizing a model.

Parameters:

Name Type Description Default
model_meta_data

Model's meta-data object. Type: ModelMetadata

required
optimization_request_form

Optimization request form object. Type: OptimizationRequestForm

required
ckpt_name str

Checkpoint filename, inside the checkpoint directory.

'ckpt_best.pth'
Source code in src/super_gradients/training/utils/callbacks/callbacks.py
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
@register_callback(Callbacks.DECI_LAB_UPLOAD)
class DeciLabUploadCallback(PhaseCallback):
    """
    Post-training callback for uploading and optimizing a model.

    :param model_meta_data:             Model's meta-data object. Type: ModelMetadata
    :param optimization_request_form:   Optimization request form object. Type: OptimizationRequestForm
    :param ckpt_name:                   Checkpoint filename, inside the checkpoint directory.
    """

    def __init__(
        self,
        model_name: str,
        input_dimensions: Sequence[int],
        target_hardware_types: "Optional[List[str]]" = None,
        target_batch_size: "Optional[int]" = None,
        target_quantization_level: "Optional[str]" = None,
        ckpt_name: str = "ckpt_best.pth",
        **kwargs,
    ):
        super().__init__(phase=Phase.POST_TRAINING)
        self.input_dimensions = input_dimensions
        self.model_name = model_name
        self.target_hardware_types = target_hardware_types
        self.target_batch_size = target_batch_size
        self.target_quantization_level = target_quantization_level
        self.ckpt_name = ckpt_name
        self.platform_client = DeciClient()

    @staticmethod
    def log_optimization_failed():
        logger.info("We couldn't finish your model optimization. Visit https://console.deci.ai for details")

    def upload_model(self, model):
        """
        This function will upload the trained model to the Deci Lab

        :param model: The resulting model from the training process
        """
        self.platform_client.upload_model(
            model=model,
            name=self.model_name,
            input_dimensions=self.input_dimensions,
            target_hardware_types=self.target_hardware_types,
            target_batch_size=self.target_batch_size,
            target_quantization_level=self.target_quantization_level,
        )

    def get_optimization_status(self, optimized_model_name: str):
        """
        This function will do fetch the optimized version of the trained model and check on its benchmark status.
        The status will be checked against the server every 30 seconds and the process will timeout after 30 minutes
        or log about the successful optimization - whichever happens first.

        :param optimized_model_name: Optimized model name

        :return: Whether or not the optimized model has been benchmarked
        """

        def handler(_signum, _frame):
            logger.error("Process timed out. Visit https://console.deci.ai for details")
            return False

        signal.signal(signal.SIGALRM, handler)
        signal.alarm(1800)

        finished = False
        while not finished:
            if self.platform_client.is_model_benchmarking(name=optimized_model_name):
                time.sleep(30)
            else:
                finished = True

        signal.alarm(0)
        return True

    def __call__(self, context: PhaseContext) -> None:
        """
        This function will attempt to upload the trained model and schedule an optimization for it.

        :param context: Training phase context
        """
        try:
            model = copy.deepcopy(unwrap_model(context.net))
            model_state_dict_path = os.path.join(context.ckpt_dir, self.ckpt_name)
            model_state_dict = torch.load(model_state_dict_path)["net"]
            model.load_state_dict(state_dict=model_state_dict)

            model = model.cpu()
            if hasattr(model, "prep_model_for_conversion"):
                model.prep_model_for_conversion(input_size=self.input_dimensions)

            self.upload_model(model=model)
            model_name = self.model_name
            logger.info(f"Successfully added {model_name} to the model repository")

            optimized_model_name = f"{model_name}_1_1"
            logger.info("We'll wait for the scheduled optimization to finish. Please don't close this window")
            success = self.get_optimization_status(optimized_model_name=optimized_model_name)
            if success:
                logger.info("Successfully finished your model optimization. Visit https://console.deci.ai for details")
            else:
                DeciLabUploadCallback.log_optimization_failed()
        except Exception as ex:
            DeciLabUploadCallback.log_optimization_failed()
            logger.error(ex)

__call__(context)

This function will attempt to upload the trained model and schedule an optimization for it.

Parameters:

Name Type Description Default
context PhaseContext

Training phase context

required
Source code in src/super_gradients/training/utils/callbacks/callbacks.py
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
def __call__(self, context: PhaseContext) -> None:
    """
    This function will attempt to upload the trained model and schedule an optimization for it.

    :param context: Training phase context
    """
    try:
        model = copy.deepcopy(unwrap_model(context.net))
        model_state_dict_path = os.path.join(context.ckpt_dir, self.ckpt_name)
        model_state_dict = torch.load(model_state_dict_path)["net"]
        model.load_state_dict(state_dict=model_state_dict)

        model = model.cpu()
        if hasattr(model, "prep_model_for_conversion"):
            model.prep_model_for_conversion(input_size=self.input_dimensions)

        self.upload_model(model=model)
        model_name = self.model_name
        logger.info(f"Successfully added {model_name} to the model repository")

        optimized_model_name = f"{model_name}_1_1"
        logger.info("We'll wait for the scheduled optimization to finish. Please don't close this window")
        success = self.get_optimization_status(optimized_model_name=optimized_model_name)
        if success:
            logger.info("Successfully finished your model optimization. Visit https://console.deci.ai for details")
        else:
            DeciLabUploadCallback.log_optimization_failed()
    except Exception as ex:
        DeciLabUploadCallback.log_optimization_failed()
        logger.error(ex)

get_optimization_status(optimized_model_name)

This function will do fetch the optimized version of the trained model and check on its benchmark status. The status will be checked against the server every 30 seconds and the process will timeout after 30 minutes or log about the successful optimization - whichever happens first.

Parameters:

Name Type Description Default
optimized_model_name str

Optimized model name

required

Returns:

Type Description

Whether or not the optimized model has been benchmarked

Source code in src/super_gradients/training/utils/callbacks/callbacks.py
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
def get_optimization_status(self, optimized_model_name: str):
    """
    This function will do fetch the optimized version of the trained model and check on its benchmark status.
    The status will be checked against the server every 30 seconds and the process will timeout after 30 minutes
    or log about the successful optimization - whichever happens first.

    :param optimized_model_name: Optimized model name

    :return: Whether or not the optimized model has been benchmarked
    """

    def handler(_signum, _frame):
        logger.error("Process timed out. Visit https://console.deci.ai for details")
        return False

    signal.signal(signal.SIGALRM, handler)
    signal.alarm(1800)

    finished = False
    while not finished:
        if self.platform_client.is_model_benchmarking(name=optimized_model_name):
            time.sleep(30)
        else:
            finished = True

    signal.alarm(0)
    return True

upload_model(model)

This function will upload the trained model to the Deci Lab

Parameters:

Name Type Description Default
model

The resulting model from the training process

required
Source code in src/super_gradients/training/utils/callbacks/callbacks.py
156
157
158
159
160
161
162
163
164
165
166
167
168
169
def upload_model(self, model):
    """
    This function will upload the trained model to the Deci Lab

    :param model: The resulting model from the training process
    """
    self.platform_client.upload_model(
        model=model,
        name=self.model_name,
        input_dimensions=self.input_dimensions,
        target_hardware_types=self.target_hardware_types,
        target_batch_size=self.target_batch_size,
        target_quantization_level=self.target_quantization_level,
    )

DetectionVisualizationCallback

Bases: PhaseCallback

A callback that adds a visualization of a batch of detection predictions to context.sg_logger

Parameters:

Name Type Description Default
phase Union[Phase, str]

When to trigger the callback.

required
freq int

Frequency (in epochs) to perform this callback.

required
batch_idx int

Batch index to perform visualization for.

0
classes list

Class list of the dataset.

required
last_img_idx_in_batch int

Last image index to add to log. (default=-1, will take entire batch).

-1
Source code in src/super_gradients/training/utils/callbacks/callbacks.py
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
@register_callback(Callbacks.DETECTION_VISUALIZATION_CALLBACK)
class DetectionVisualizationCallback(PhaseCallback):
    """
    A callback that adds a visualization of a batch of detection predictions to context.sg_logger

    :param phase:                   When to trigger the callback.
    :param freq:                    Frequency (in epochs) to perform this callback.
    :param batch_idx:               Batch index to perform visualization for.
    :param classes:                 Class list of the dataset.
    :param last_img_idx_in_batch:   Last image index to add to log. (default=-1, will take entire batch).
    """

    def __init__(
        self,
        phase: Union[Phase, str],
        freq: int,
        post_prediction_callback: DetectionPostPredictionCallback,
        classes: list,
        batch_idx: int = 0,
        last_img_idx_in_batch: int = -1,
    ):
        super(DetectionVisualizationCallback, self).__init__(phase)
        self.freq = freq
        self.post_prediction_callback = post_prediction_callback
        self.batch_idx = batch_idx
        self.classes = classes
        self.last_img_idx_in_batch = last_img_idx_in_batch

    def __call__(self, context: PhaseContext):
        if context.epoch % self.freq == 0 and context.batch_idx == self.batch_idx and not context.ddp_silent_mode:
            # SOME CALCULATIONS ARE IN-PLACE IN NMS, SO CLONE THE PREDICTIONS
            preds = (context.preds[0].clone(), None)
            preds = self.post_prediction_callback(preds)
            batch_imgs = DetectionVisualization.visualize_batch(context.inputs, preds, context.target, self.batch_idx, self.classes)
            batch_imgs = [cv2.cvtColor(image, cv2.COLOR_BGR2RGB) for image in batch_imgs]
            batch_imgs = np.stack(batch_imgs)
            tag = "batch_" + str(self.batch_idx) + "_images"
            context.sg_logger.add_images(tag=tag, images=batch_imgs[: self.last_img_idx_in_batch], global_step=context.epoch, data_format="NHWC")

ExponentialLRScheduler

Bases: LRCallbackBase

Exponential decay learning rate scheduling. Decays the learning rate by lr_decay_factor every epoch.

Source code in src/super_gradients/training/utils/callbacks/callbacks.py
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
@register_lr_scheduler(LRSchedulers.EXP, deprecated_name="exp")
class ExponentialLRScheduler(LRCallbackBase):
    """
    Exponential decay learning rate scheduling. Decays the learning rate by `lr_decay_factor` every epoch.
    """

    def __init__(self, lr_decay_factor: float, **kwargs):
        super().__init__(phase=Phase.TRAIN_BATCH_STEP, **kwargs)
        self.lr_decay_factor = lr_decay_factor

    def perform_scheduling(self, context):
        effective_epoch = context.epoch - self.training_params.lr_warmup_epochs
        current_iter = self.train_loader_len * effective_epoch + context.batch_idx
        for group_name in self.lr.keys():
            self.lr[group_name] = self.initial_lr[group_name] * self.lr_decay_factor ** (current_iter / self.train_loader_len)
        self.update_lr(context.optimizer, context.epoch, context.batch_idx)

    def is_lr_scheduling_enabled(self, context):
        post_warmup_epochs = self.training_params.max_epochs - self.training_params.lr_cooldown_epochs
        return self.training_params.lr_warmup_epochs <= context.epoch < post_warmup_epochs

ExtremeBatchCaseVisualizationCallback

Bases: Callback, ABC

ExtremeBatchCaseVisualizationCallback

A base class for visualizing worst/best validation batches in an epoch according to some metric or loss value, with Full DDP support.

Images are saved with training_hyperparams.sg_logger.

Parameters:

Name Type Description Default
metric Optional[Metric]

Metric, will be the metric which is monitored.

None
metric_component_name Optional[str]

In case metric returns multiple values (as Mapping), the value at metric.compute()[metric_component_name] will be the one monitored.

None
loss_to_monitor Optional[str]

str, loss_to_monitor corresponfing to the 'criterion' passed through training_params in Trainer.train(...). Monitoring loss follows the same logic as metric_to_watch in Trainer.train(..), when watching the loss and should be: if hasattr(criterion, "component_names") and criterion.forward(..) returns a tuple: "/". If a single item is returned rather then a tuple: . When there is no such attributesand criterion.forward(..) returns a tuple: "/"Loss_"

None
max bool

bool, Whether to take the batch corresponding to the max value of the metric/loss or the minimum (default=False).

False
freq int

int, epoch frequency to perform all of the above (default=1). Inheritors should implement process_extreme_batch which returns an image, as np.ndarray (uint8) with shape BHWC.

1
Source code in src/super_gradients/training/utils/callbacks/callbacks.py
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
class ExtremeBatchCaseVisualizationCallback(Callback, ABC):
    """
    ExtremeBatchCaseVisualizationCallback

    A base class for visualizing worst/best validation batches in an epoch
     according to some metric or loss value, with Full DDP support.

    Images are saved with training_hyperparams.sg_logger.

    :param metric: Metric, will be the metric which is monitored.

    :param metric_component_name: In case metric returns multiple values (as Mapping),
     the value at metric.compute()[metric_component_name] will be the one monitored.

    :param loss_to_monitor: str, loss_to_monitor corresponfing to the 'criterion' passed through training_params in Trainer.train(...).
     Monitoring loss follows the same logic as metric_to_watch in Trainer.train(..), when watching the loss and should be:

        if hasattr(criterion, "component_names") and criterion.forward(..) returns a tuple:
            <LOSS_CLASS.__name__>"/"<COMPONENT_NAME>.

        If a single item is returned rather then a tuple:
            <LOSS_CLASS.__name__>.

        When there is no such attributesand criterion.forward(..) returns a tuple:
            <LOSS_CLASS.__name__>"/"Loss_"<IDX>

    :param max: bool, Whether to take the batch corresponding to the max value of the metric/loss or
     the minimum (default=False).

    :param freq: int, epoch frequency to perform all of the above (default=1).

     Inheritors should implement process_extreme_batch which returns an image, as np.ndarray (uint8) with shape BHWC.
    """

    @resolve_param("metric", MetricsFactory())
    def __init__(
        self,
        metric: Optional[Metric] = None,
        metric_component_name: Optional[str] = None,
        loss_to_monitor: Optional[str] = None,
        max: bool = False,
        freq: int = 1,
        enable_on_train_loader: bool = False,
        enable_on_valid_loader: bool = True,
        max_images: int = -1,
    ):
        """
        :param metric: Metric, will be the metric which is monitored.

        :param metric_component_name: In case metric returns multiple values (as Mapping),
         the value at metric.compute()[metric_component_name] will be the one monitored.

        :param loss_to_monitor: str, loss_to_monitor corresponding to the 'criterion' passed through training_params in Trainer.train(...).
         Monitoring loss follows the same logic as metric_to_watch in Trainer.train(..), when watching the loss and should be:

        if hasattr(criterion, "component_names") and criterion.forward(..) returns a tuple:
            <LOSS_CLASS.__name__>"/"<COMPONENT_NAME>.

        If a single item is returned rather then a tuple:
            <LOSS_CLASS.__name__>.

        When there is no such attributes and criterion.forward(..) returns a tuple:
            <LOSS_CLASS.__name__>"/"Loss_"<IDX>

        :param max:                    bool, Whether to take the batch corresponding to the max value of the metric/loss or
        the minimum (default=False).

        :param freq:                   int, epoch frequency to perform all of the above (default=1).
        :param enable_on_train_loader: Controls whether to enable this callback on the train loader. Default is False.
        :param enable_on_valid_loader: Controls whether to enable this callback on the valid loader. Default is True.
        :param max_images:             Maximum images to save. If -1, save all images.
        """
        super(ExtremeBatchCaseVisualizationCallback, self).__init__()

        if (metric and loss_to_monitor) or (metric is None and loss_to_monitor is None):
            raise RuntimeError("Must pass exactly one of: loss, metric != None")

        self._set_tag_attr(loss_to_monitor, max, metric, metric_component_name)
        self.metric = metric
        if self.metric:
            self.metric = MetricCollection(self.metric)
            self.metric.to(device_config.device)

        self.metric_component_name = metric_component_name

        self.loss_to_monitor = loss_to_monitor
        self.max = max
        self.freq = freq

        self.extreme_score = None
        self.extreme_batch = None
        self.extreme_preds = None
        self.extreme_targets = None
        self.extreme_additional_batch_items = None

        self._first_call = True
        self._idx_loss_tuple = None

        self.enable_on_train_loader = enable_on_train_loader
        self.enable_on_valid_loader = enable_on_valid_loader
        self.max_images = max_images

    def _set_tag_attr(self, loss_to_monitor, max, metric, metric_component_name):
        if metric_component_name:
            monitored_val_name = metric_component_name
        elif metric:
            monitored_val_name = metric.__class__.__name__
        else:
            monitored_val_name = loss_to_monitor
        self._tag = f"max_{monitored_val_name}_batch" if max else f"min_{monitored_val_name}_batch"

    @abstractmethod
    def process_extreme_batch(self) -> np.ndarray:
        """
        This method is called right before adding the images to the in  SGLoggger (inside the on_validation_loader_end call).
         It should process self.extreme_batch, self.extreme_preds and self.extreme_targets and output the images, as np.ndarrray.
         Output should be of shape N,H,W,3 and uint8.
        :return: images to save, np.ndarray
        """
        raise NotImplementedError

    def on_train_loader_start(self, context: PhaseContext) -> None:
        self._reset()

    def on_train_batch_end(self, context: PhaseContext) -> None:
        if self.enable_on_train_loader and (context.epoch + 1) % self.freq == 0:
            self._on_batch_end(context)

    def on_train_loader_end(self, context: PhaseContext) -> None:
        if self.enable_on_train_loader and (context.epoch + 1) % self.freq == 0:
            self._gather_extreme_batch_images_and_log(context, "train")
            self._reset()

    def on_validation_loader_start(self, context: PhaseContext) -> None:
        self._reset()

    def on_validation_batch_end(self, context: PhaseContext) -> None:
        if self.enable_on_valid_loader and (context.epoch + 1) % self.freq == 0:
            self._on_batch_end(context)

    def on_validation_loader_end(self, context: PhaseContext) -> None:
        if self.enable_on_valid_loader and (context.epoch + 1) % self.freq == 0:
            self._gather_extreme_batch_images_and_log(context, "valid")
            self._reset()

    def _gather_extreme_batch_images_and_log(self, context, loader_name: str):
        input_images_to_save = self.process_extreme_batch()
        images_to_save = maybe_all_gather_as_list(input_images_to_save)
        images_to_save: List[np.ndarray] = list(itertools.chain(*images_to_save))

        if not context.ddp_silent_mode:
            if self.max_images > 0:
                images_to_save = images_to_save[: self.max_images]

            # Before saving images to logger we need to pad them to the same size
            max_height = max([image.shape[0] for image in images_to_save])
            max_width = max([image.shape[1] for image in images_to_save])
            images_to_save = [
                cv2.copyMakeBorder(image, 0, max_height - image.shape[0], 0, max_width - image.shape[1], cv2.BORDER_CONSTANT, value=0)
                for image in images_to_save
            ]
            images_to_save = np.stack(images_to_save, axis=0)

            context.sg_logger.add_images(tag=f"{loader_name}/{self._tag}", images=images_to_save, global_step=context.epoch, data_format="NHWC")

    def _on_batch_end(self, context: PhaseContext) -> None:
        if self.metric is not None:
            self.metric.update(**context.__dict__)
            score = self.metric.compute()
            if self.metric_component_name is not None:
                if not isinstance(score, Mapping) or (isinstance(score, Mapping) and self.metric_component_name not in score.keys()):
                    raise RuntimeError(
                        f"metric_component_name: {self.metric_component_name} is not a component of the monitored metric: {self.metric.__class__.__name__}"
                    )
                score = score[self.metric_component_name]
            elif len(score) > 1:
                raise RuntimeError(f"returned multiple values from {self.metric} but no metric_component_name has been passed to __init__.")
            else:
                score = score.pop(list(score.keys())[0])
            self.metric.reset()

        else:
            # FOR LOSS VALUES, GET THE RIGHT COMPONENT, DERIVE IT ON THE FIRST PASS
            loss_tuple = context.loss_log_items
            if self._first_call:
                self._init_loss_attributes(context)
            score = loss_tuple[self._idx_loss_tuple].detach().cpu().item()

            # IN CONTRARY TO METRICS - LOSS VALUES NEED TO BE REDUCES IN DDP
            device = infer_model_device(context.net)
            score = torch.tensor(score, device=device)
            score = maybe_all_reduce_tensor_average(score)

        if self._is_more_extreme(score):
            self.extreme_score = tensor_container_to_device(score, device="cpu", detach=True, non_blocking=False)
            self.extreme_batch = tensor_container_to_device(context.inputs, device="cpu", detach=True, non_blocking=False)
            self.extreme_preds = tensor_container_to_device(context.preds, device="cpu", detach=True, non_blocking=False)
            self.extreme_targets = tensor_container_to_device(context.target, device="cpu", detach=True, non_blocking=False)
            self.extreme_additional_batch_items = tensor_container_to_device(context.additional_batch_items, device="cpu", detach=True, non_blocking=False)

    def _init_loss_attributes(self, context: PhaseContext):
        if self.loss_to_monitor not in context.loss_logging_items_names:
            raise ValueError(f"{self.loss_to_monitor} not a loss or loss component.")
        self._idx_loss_tuple = context.loss_logging_items_names.index(self.loss_to_monitor)
        self._first_call = False

    def _reset(self):
        self.extreme_score = None
        self.extreme_batch = None
        self.extreme_preds = None
        self.extreme_targets = None
        self.extreme_additional_batch_items = None
        if self.metric is not None:
            self.metric.reset()

    def _is_more_extreme(self, score: Union[float, torch.Tensor]) -> bool:
        """
        Checks whether computed score is the more extreme than the current extreme score.
        If the current score is None (first call), returns True.
        :param score: A newly computed score.
        :return:      True if score is more extreme than the current extreme score, False otherwise.
        """
        # A score can be Nan/Inf (rare but possible event when training diverges).
        # In such case the both < and > operators would return False according to IEEE 754.
        # As a consequence, self.extreme_inputs / self.extreme_outputs would not be updated
        # and that would crash at the attempt to visualize batch.
        if self.extreme_score is None:
            return True

        if self.max:
            return self.extreme_score < score
        else:
            return self.extreme_score > score

__init__(metric=None, metric_component_name=None, loss_to_monitor=None, max=False, freq=1, enable_on_train_loader=False, enable_on_valid_loader=True, max_images=-1)

Parameters:

Name Type Description Default
metric Optional[Metric]

Metric, will be the metric which is monitored.

None
metric_component_name Optional[str]

In case metric returns multiple values (as Mapping), the value at metric.compute()[metric_component_name] will be the one monitored.

None
loss_to_monitor Optional[str]

str, loss_to_monitor corresponding to the 'criterion' passed through training_params in Trainer.train(...). Monitoring loss follows the same logic as metric_to_watch in Trainer.train(..), when watching the loss and should be: if hasattr(criterion, "component_names") and criterion.forward(..) returns a tuple: "/". If a single item is returned rather then a tuple: . When there is no such attributes and criterion.forward(..) returns a tuple: "/"Loss_"

None
max bool

bool, Whether to take the batch corresponding to the max value of the metric/loss or the minimum (default=False).

False
freq int

int, epoch frequency to perform all of the above (default=1).

1
enable_on_train_loader bool

Controls whether to enable this callback on the train loader. Default is False.

False
enable_on_valid_loader bool

Controls whether to enable this callback on the valid loader. Default is True.

True
max_images int

Maximum images to save. If -1, save all images.

-1
Source code in src/super_gradients/training/utils/callbacks/callbacks.py
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
@resolve_param("metric", MetricsFactory())
def __init__(
    self,
    metric: Optional[Metric] = None,
    metric_component_name: Optional[str] = None,
    loss_to_monitor: Optional[str] = None,
    max: bool = False,
    freq: int = 1,
    enable_on_train_loader: bool = False,
    enable_on_valid_loader: bool = True,
    max_images: int = -1,
):
    """
    :param metric: Metric, will be the metric which is monitored.

    :param metric_component_name: In case metric returns multiple values (as Mapping),
     the value at metric.compute()[metric_component_name] will be the one monitored.

    :param loss_to_monitor: str, loss_to_monitor corresponding to the 'criterion' passed through training_params in Trainer.train(...).
     Monitoring loss follows the same logic as metric_to_watch in Trainer.train(..), when watching the loss and should be:

    if hasattr(criterion, "component_names") and criterion.forward(..) returns a tuple:
        <LOSS_CLASS.__name__>"/"<COMPONENT_NAME>.

    If a single item is returned rather then a tuple:
        <LOSS_CLASS.__name__>.

    When there is no such attributes and criterion.forward(..) returns a tuple:
        <LOSS_CLASS.__name__>"/"Loss_"<IDX>

    :param max:                    bool, Whether to take the batch corresponding to the max value of the metric/loss or
    the minimum (default=False).

    :param freq:                   int, epoch frequency to perform all of the above (default=1).
    :param enable_on_train_loader: Controls whether to enable this callback on the train loader. Default is False.
    :param enable_on_valid_loader: Controls whether to enable this callback on the valid loader. Default is True.
    :param max_images:             Maximum images to save. If -1, save all images.
    """
    super(ExtremeBatchCaseVisualizationCallback, self).__init__()

    if (metric and loss_to_monitor) or (metric is None and loss_to_monitor is None):
        raise RuntimeError("Must pass exactly one of: loss, metric != None")

    self._set_tag_attr(loss_to_monitor, max, metric, metric_component_name)
    self.metric = metric
    if self.metric:
        self.metric = MetricCollection(self.metric)
        self.metric.to(device_config.device)

    self.metric_component_name = metric_component_name

    self.loss_to_monitor = loss_to_monitor
    self.max = max
    self.freq = freq

    self.extreme_score = None
    self.extreme_batch = None
    self.extreme_preds = None
    self.extreme_targets = None
    self.extreme_additional_batch_items = None

    self._first_call = True
    self._idx_loss_tuple = None

    self.enable_on_train_loader = enable_on_train_loader
    self.enable_on_valid_loader = enable_on_valid_loader
    self.max_images = max_images

process_extreme_batch() abstractmethod

This method is called right before adding the images to the in SGLoggger (inside the on_validation_loader_end call). It should process self.extreme_batch, self.extreme_preds and self.extreme_targets and output the images, as np.ndarrray. Output should be of shape N,H,W,3 and uint8.

Returns:

Type Description
np.ndarray

images to save, np.ndarray

Source code in src/super_gradients/training/utils/callbacks/callbacks.py
1135
1136
1137
1138
1139
1140
1141
1142
1143
@abstractmethod
def process_extreme_batch(self) -> np.ndarray:
    """
    This method is called right before adding the images to the in  SGLoggger (inside the on_validation_loader_end call).
     It should process self.extreme_batch, self.extreme_preds and self.extreme_targets and output the images, as np.ndarrray.
     Output should be of shape N,H,W,3 and uint8.
    :return: images to save, np.ndarray
    """
    raise NotImplementedError

ExtremeBatchDetectionVisualizationCallback

Bases: ExtremeBatchCaseVisualizationCallback

ExtremeBatchSegVisualizationCallback

Visualizes worst/best batch in an epoch for Object detection. For clarity, the batch is saved twice in the SG Logger, once with the model's predictions and once with ground truth targets.

Assumptions on bbox dormats: - After applying post_prediction_callback on context.preds, the predictions are a list/Tensor s.t: predictions[i] is a tensor of shape nx6 - (x1, y1, x2, y2, confidence, class) where x and y are in pixel units.

  • context.targets is a tensor of shape (total_num_targets, 6), in LABEL_CXCYWH format: (index, label, cx, cy, w, h).

Example usage in Yaml config:

training_hyperparams:
  phase_callbacks:
    - ExtremeBatchDetectionVisualizationCallback:
        metric:
          DetectionMetrics_050:
            score_thres: 0.1
            top_k_predictions: 300
            num_cls: ${num_classes}
            normalize_targets: True
            post_prediction_callback:
              _target_: super_gradients.training.models.detection_models.pp_yolo_e.PPYoloEPostPredictionCallback
              score_threshold: 0.01
              nms_top_k: 1000
              max_predictions: 300
              nms_threshold: 0.7
        metric_component_name: 'mAP@0.50'
        post_prediction_callback:
          _target_: super_gradients.training.models.detection_models.pp_yolo_e.PPYoloEPostPredictionCallback
          score_threshold: 0.25
          nms_top_k: 1000
          max_predictions: 300
          nms_threshold: 0.7
        normalize_targets: True

Parameters:

Name Type Description Default
metric Optional[Metric]

Metric, will be the metric which is monitored.

None
metric_component_name Optional[str]

In case metric returns multiple values (as Mapping), the value at metric.compute()[metric_component_name] will be the one monitored.

None
loss_to_monitor Optional[str]

str, loss_to_monitor corresponding to the 'criterion' passed through training_params in Trainer.train(...). Monitoring loss follows the same logic as metric_to_watch in Trainer.train(..), when watching the loss and should be: if hasattr(criterion, "component_names") and criterion.forward(..) returns a tuple: "/". If a single item is returned rather then a tuple: . When there is no such attributes and criterion.forward(..) returns a tuple: "/"Loss_"

None
max bool

bool, Whether to take the batch corresponding to the max value of the metric/loss or the minimum (default=False).

False
freq int

int, epoch frequency to perform all of the above (default=1).

1
classes Optional[List[str]]

List[str], a list of class names corresponding to the class indices for display. When None, will try to fetch this through a "classes" attribute of the valdiation dataset. If such attribute does not exist an error will be raised (default=None).

None
normalize_targets bool

bool, whether to scale the target bboxes. If the bboxes returned by the validation data loader are in pixel values range, this needs to be set to True (default=False)

False
enable_on_train_loader bool

Controls whether to enable this callback on the train loader. Default is False.

False
enable_on_valid_loader bool

Controls whether to enable this callback on the valid loader. Default is True.

True
max_images int

Maximum images to save. If -1, save all images.

-1
Source code in src/super_gradients/training/utils/callbacks/callbacks.py
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
@register_callback("ExtremeBatchDetectionVisualizationCallback")
class ExtremeBatchDetectionVisualizationCallback(ExtremeBatchCaseVisualizationCallback):
    """
    ExtremeBatchSegVisualizationCallback

    Visualizes worst/best batch in an epoch for Object detection.
    For clarity, the batch is saved twice in the SG Logger, once with the model's predictions and once with
     ground truth targets.

    Assumptions on bbox dormats:
     - After applying post_prediction_callback on context.preds, the predictions are a list/Tensor s.t:
        predictions[i] is a tensor of shape nx6 - (x1, y1, x2, y2, confidence, class) where x and y are in pixel units.

     - context.targets is a tensor of shape (total_num_targets, 6), in LABEL_CXCYWH format:  (index, label, cx, cy, w, h).



    Example usage in Yaml config:

        training_hyperparams:
          phase_callbacks:
            - ExtremeBatchDetectionVisualizationCallback:
                metric:
                  DetectionMetrics_050:
                    score_thres: 0.1
                    top_k_predictions: 300
                    num_cls: ${num_classes}
                    normalize_targets: True
                    post_prediction_callback:
                      _target_: super_gradients.training.models.detection_models.pp_yolo_e.PPYoloEPostPredictionCallback
                      score_threshold: 0.01
                      nms_top_k: 1000
                      max_predictions: 300
                      nms_threshold: 0.7
                metric_component_name: 'mAP@0.50'
                post_prediction_callback:
                  _target_: super_gradients.training.models.detection_models.pp_yolo_e.PPYoloEPostPredictionCallback
                  score_threshold: 0.25
                  nms_top_k: 1000
                  max_predictions: 300
                  nms_threshold: 0.7
                normalize_targets: True

    :param metric: Metric, will be the metric which is monitored.

    :param metric_component_name: In case metric returns multiple values (as Mapping),
     the value at metric.compute()[metric_component_name] will be the one monitored.

    :param loss_to_monitor: str, loss_to_monitor corresponding to the 'criterion' passed through training_params in Trainer.train(...).
     Monitoring loss follows the same logic as metric_to_watch in Trainer.train(..), when watching the loss and should be:

        if hasattr(criterion, "component_names") and criterion.forward(..) returns a tuple:
            <LOSS_CLASS.__name__>"/"<COMPONENT_NAME>.

        If a single item is returned rather then a tuple:
            <LOSS_CLASS.__name__>.

        When there is no such attributes and criterion.forward(..) returns a tuple:
            <LOSS_CLASS.__name__>"/"Loss_"<IDX>

    :param max:                    bool, Whether to take the batch corresponding to the max value of the metric/loss or
    the minimum (default=False).

    :param freq:                   int, epoch frequency to perform all of the above (default=1).

    :param classes:                List[str], a list of class names corresponding to the class indices for display.
    When None, will try to fetch this through a "classes" attribute of the valdiation dataset. If such attribute does
    not exist an error will be raised (default=None).

    :param normalize_targets:      bool, whether to scale the target bboxes. If the bboxes returned by the validation data loader
     are in pixel values range, this needs to be set to True (default=False)

    :param enable_on_train_loader: Controls whether to enable this callback on the train loader. Default is False.
    :param enable_on_valid_loader: Controls whether to enable this callback on the valid loader. Default is True.
    :param max_images:             Maximum images to save. If -1, save all images.
    """

    def __init__(
        self,
        post_prediction_callback: DetectionPostPredictionCallback,
        metric: Optional[Metric] = None,
        metric_component_name: Optional[str] = None,
        loss_to_monitor: Optional[str] = None,
        max: bool = False,
        freq: int = 1,
        classes: Optional[List[str]] = None,
        normalize_targets: bool = False,
        enable_on_train_loader: bool = False,
        enable_on_valid_loader: bool = True,
        max_images: int = -1,
    ):
        super(ExtremeBatchDetectionVisualizationCallback, self).__init__(
            metric=metric,
            metric_component_name=metric_component_name,
            loss_to_monitor=loss_to_monitor,
            max=max,
            freq=freq,
            enable_on_valid_loader=enable_on_valid_loader,
            enable_on_train_loader=enable_on_train_loader,
            max_images=max_images,
        )
        self.post_prediction_callback = post_prediction_callback
        if classes is None:
            logger.info(
                "No classes have been passed to ExtremeBatchDetectionVisualizationCallback. "
                "Will try to fetch them through context.valid_loader.dataset classes attribute if it exists."
            )
        self.classes = list(classes) if classes is not None else None
        self.normalize_targets = normalize_targets

    @staticmethod
    def universal_undo_preprocessing_fn(inputs: torch.Tensor) -> np.ndarray:
        """
        A universal reversing of preprocessing to be passed to DetectionVisualization.visualize_batch's undo_preprocessing_func kwarg.
        This function scales input tensor to 0..255 range, and cast it to uint8 dtype.

        :param inputs: Input 4D tensor of images in BCHW format with unknown normalization.
        :return:       Numpy 4D tensor of images in BHWC format, normalized to 0..255 range (uint8).
        """
        inputs -= inputs.min()
        inputs /= inputs.max() + 1e-8
        inputs *= 255
        inputs = inputs.to(torch.uint8)
        inputs = inputs.cpu().numpy()
        inputs = inputs[:, ::-1, :, :].transpose(0, 2, 3, 1)
        inputs = np.ascontiguousarray(inputs, dtype=np.uint8)
        return inputs

    def process_extreme_batch(self) -> np.ndarray:
        """
        Processes the extreme batch, and returns list of images for visualization.
        Default implementations stacks GT and prediction overlays horisontally.

        :return: np.ndarray A 4D tensor of BHWC shape with visualizations of the extreme batch.
        """
        inputs = self.extreme_batch
        preds = self.post_prediction_callback(self.extreme_preds, self.extreme_batch.device)
        targets = self.extreme_targets.clone()
        if self.normalize_targets:
            target_bboxes = targets[:, 2:]
            target_bboxes = cxcywh2xyxy(target_bboxes)
            _, _, height, width = inputs.shape
            target_bboxes[:, [0, 2]] /= width
            target_bboxes[:, [1, 3]] /= height
            target_bboxes = xyxy2cxcywh(target_bboxes)
            targets[:, 2:] = target_bboxes

        images_to_save_preds = DetectionVisualization.visualize_batch(
            inputs, preds, targets, "extreme_batch_preds", self.classes, gt_alpha=0.0, undo_preprocessing_func=self.universal_undo_preprocessing_fn
        )
        images_to_save_preds = np.stack(images_to_save_preds)

        images_to_save_gt = DetectionVisualization.visualize_batch(
            inputs, None, targets, "extreme_batch_gt", self.classes, gt_alpha=1.0, undo_preprocessing_func=self.universal_undo_preprocessing_fn
        )
        images_to_save_gt = np.stack(images_to_save_gt)

        # Stack the predictions and GT images together
        return np.concatenate([images_to_save_gt, images_to_save_preds], axis=2)

    def on_validation_loader_start(self, context: PhaseContext) -> None:
        if self.classes is None:
            if hasattr(context.valid_loader.dataset, "classes"):
                self.classes = context.valid_loader.dataset.classes
            else:
                raise RuntimeError("Couldn't fetch classes from valid_loader, please pass classes explicitly")
        super().on_validation_loader_start(context)

process_extreme_batch()

Processes the extreme batch, and returns list of images for visualization. Default implementations stacks GT and prediction overlays horisontally.

Returns:

Type Description
np.ndarray

np.ndarray A 4D tensor of BHWC shape with visualizations of the extreme batch.

Source code in src/super_gradients/training/utils/callbacks/callbacks.py
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
def process_extreme_batch(self) -> np.ndarray:
    """
    Processes the extreme batch, and returns list of images for visualization.
    Default implementations stacks GT and prediction overlays horisontally.

    :return: np.ndarray A 4D tensor of BHWC shape with visualizations of the extreme batch.
    """
    inputs = self.extreme_batch
    preds = self.post_prediction_callback(self.extreme_preds, self.extreme_batch.device)
    targets = self.extreme_targets.clone()
    if self.normalize_targets:
        target_bboxes = targets[:, 2:]
        target_bboxes = cxcywh2xyxy(target_bboxes)
        _, _, height, width = inputs.shape
        target_bboxes[:, [0, 2]] /= width
        target_bboxes[:, [1, 3]] /= height
        target_bboxes = xyxy2cxcywh(target_bboxes)
        targets[:, 2:] = target_bboxes

    images_to_save_preds = DetectionVisualization.visualize_batch(
        inputs, preds, targets, "extreme_batch_preds", self.classes, gt_alpha=0.0, undo_preprocessing_func=self.universal_undo_preprocessing_fn
    )
    images_to_save_preds = np.stack(images_to_save_preds)

    images_to_save_gt = DetectionVisualization.visualize_batch(
        inputs, None, targets, "extreme_batch_gt", self.classes, gt_alpha=1.0, undo_preprocessing_func=self.universal_undo_preprocessing_fn
    )
    images_to_save_gt = np.stack(images_to_save_gt)

    # Stack the predictions and GT images together
    return np.concatenate([images_to_save_gt, images_to_save_preds], axis=2)

universal_undo_preprocessing_fn(inputs) staticmethod

A universal reversing of preprocessing to be passed to DetectionVisualization.visualize_batch's undo_preprocessing_func kwarg. This function scales input tensor to 0..255 range, and cast it to uint8 dtype.

Parameters:

Name Type Description Default
inputs torch.Tensor

Input 4D tensor of images in BCHW format with unknown normalization.

required

Returns:

Type Description
np.ndarray

Numpy 4D tensor of images in BHWC format, normalized to 0..255 range (uint8).

Source code in src/super_gradients/training/utils/callbacks/callbacks.py
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
@staticmethod
def universal_undo_preprocessing_fn(inputs: torch.Tensor) -> np.ndarray:
    """
    A universal reversing of preprocessing to be passed to DetectionVisualization.visualize_batch's undo_preprocessing_func kwarg.
    This function scales input tensor to 0..255 range, and cast it to uint8 dtype.

    :param inputs: Input 4D tensor of images in BCHW format with unknown normalization.
    :return:       Numpy 4D tensor of images in BHWC format, normalized to 0..255 range (uint8).
    """
    inputs -= inputs.min()
    inputs /= inputs.max() + 1e-8
    inputs *= 255
    inputs = inputs.to(torch.uint8)
    inputs = inputs.cpu().numpy()
    inputs = inputs[:, ::-1, :, :].transpose(0, 2, 3, 1)
    inputs = np.ascontiguousarray(inputs, dtype=np.uint8)
    return inputs

ExtremeBatchSegVisualizationCallback

Bases: ExtremeBatchCaseVisualizationCallback

ExtremeBatchSegVisualizationCallback

Visualizes worst/best batch in an epoch, for segmentation. Assumes context.preds in validation is a score tensor of shape BCHW, or a tuple whose first item is one.

True predictions will be marked with green, false ones with red.

Example usage in training_params definition:

training_hyperparams ={
  ...
  "phase_callbacks":
    [ExtremeBatchSegVisualizationCallback(
        metric=IoU(20, ignore_idx=19)
        max=False
        ignore_idx=19),
    ExtremeBatchSegVisualizationCallback(
        loss_to_monitor="CrossEntropyLoss"
        max=True
        ignore_idx=19)]
        ...}

Example usage in Yaml config:

training_hyperparams:
  phase_callbacks:
    - ExtremeBatchSegVisualizationCallback:
        loss_to_monitor: DiceCEEdgeLoss/aux_loss0
        ignore_idx: 19

Parameters:

Name Type Description Default
metric Optional[Metric]

Metric, will be the metric which is monitored.

None
metric_component_name Optional[str]

In case metric returns multiple values (as Mapping), the value at metric.compute()[metric_component_name] will be the one monitored.

None
loss_to_monitor Optional[str]

str, loss_to_monitor corresponding to the 'criterion' passed through training_params in Trainer.train(...). Monitoring loss follows the same logic as metric_to_watch in Trainer.train(..), when watching the loss and should be: if hasattr(criterion, "component_names") and criterion.forward(..) returns a tuple: "/". If a single item is returned rather then a tuple: . When there is no such attributes and criterion.forward(..) returns a tuple: "/"Loss_"

None
max bool

bool, Whether to take the batch corresponding to the max value of the metric/loss or the minimum (default=False).

False
freq int

int, epoch frequency to perform all of the above (default=1).

1
enable_on_train_loader bool

Controls whether to enable this callback on the train loader. Default is False.

False
enable_on_valid_loader bool

Controls whether to enable this callback on the valid loader. Default is True.

True
max_images int

Maximum images to save. If -1, save all images.

-1
Source code in src/super_gradients/training/utils/callbacks/callbacks.py
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
@register_callback("ExtremeBatchSegVisualizationCallback")
class ExtremeBatchSegVisualizationCallback(ExtremeBatchCaseVisualizationCallback):
    """
    ExtremeBatchSegVisualizationCallback

    Visualizes worst/best batch in an epoch, for segmentation.
    Assumes context.preds in validation is a score tensor of shape BCHW, or a tuple whose first item is one.

    True predictions will be marked with green, false ones with red.

    Example usage in training_params definition:

        training_hyperparams ={
          ...
          "phase_callbacks":
            [ExtremeBatchSegVisualizationCallback(
                metric=IoU(20, ignore_idx=19)
                max=False
                ignore_idx=19),
            ExtremeBatchSegVisualizationCallback(
                loss_to_monitor="CrossEntropyLoss"
                max=True
                ignore_idx=19)]
                ...}

    Example usage in Yaml config:

        training_hyperparams:
          phase_callbacks:
            - ExtremeBatchSegVisualizationCallback:
                loss_to_monitor: DiceCEEdgeLoss/aux_loss0
                ignore_idx: 19

    :param metric: Metric, will be the metric which is monitored.

    :param metric_component_name: In case metric returns multiple values (as Mapping),
     the value at metric.compute()[metric_component_name] will be the one monitored.

    :param loss_to_monitor: str, loss_to_monitor corresponding to the 'criterion' passed through training_params in Trainer.train(...).
     Monitoring loss follows the same logic as metric_to_watch in Trainer.train(..), when watching the loss and should be:

        if hasattr(criterion, "component_names") and criterion.forward(..) returns a tuple:
            <LOSS_CLASS.__name__>"/"<COMPONENT_NAME>.

        If a single item is returned rather then a tuple:
            <LOSS_CLASS.__name__>.

        When there is no such attributes and criterion.forward(..) returns a tuple:
            <LOSS_CLASS.__name__>"/"Loss_"<IDX>

    :param max:                    bool, Whether to take the batch corresponding to the max value of the metric/loss or
    the minimum (default=False).

    :param freq:                   int, epoch frequency to perform all of the above (default=1).

    :param enable_on_train_loader: Controls whether to enable this callback on the train loader. Default is False.
    :param enable_on_valid_loader: Controls whether to enable this callback on the valid loader. Default is True.
    :param max_images:             Maximum images to save. If -1, save all images.
    """

    def __init__(
        self,
        metric: Optional[Metric] = None,
        metric_component_name: Optional[str] = None,
        loss_to_monitor: Optional[str] = None,
        max: bool = False,
        freq: int = 1,
        ignore_idx: int = -1,
        enable_on_train_loader: bool = False,
        enable_on_valid_loader: bool = True,
        max_images: int = -1,
    ):
        super(ExtremeBatchSegVisualizationCallback, self).__init__(
            metric=metric,
            metric_component_name=metric_component_name,
            loss_to_monitor=loss_to_monitor,
            max=max,
            freq=freq,
            enable_on_valid_loader=enable_on_valid_loader,
            enable_on_train_loader=enable_on_train_loader,
            max_images=max_images,
        )
        self.ignore_idx = ignore_idx

    @torch.no_grad()
    def process_extreme_batch(self) -> np.ndarray:
        inputs = self.extreme_batch
        inputs -= inputs.min()
        inputs /= inputs.max()
        inputs *= 255
        inputs = inputs.to(torch.uint8)
        preds = self.extreme_preds
        if isinstance(preds, tuple):
            preds = preds[0]
        preds = preds.argmax(1)
        p_mask = preds == self.extreme_targets
        n_mask = preds != self.extreme_targets
        p_mask[self.extreme_targets == self.ignore_idx] = False
        n_mask[self.extreme_targets == self.ignore_idx] = False
        overlay = torch.cat([p_mask.unsqueeze(1), n_mask.unsqueeze(1)], 1)
        colors = ["green", "red"]
        images_to_save = []
        for i in range(len(inputs)):
            image = draw_segmentation_masks(inputs[i].cpu(), overlay[i].cpu(), colors=colors, alpha=0.4).numpy()
            image = np.transpose(image, (1, 2, 0))
            images_to_save.append(image)
        images_to_save = np.stack(images_to_save)
        return images_to_save

FunctionLRScheduler

Bases: LRCallbackBase

Hard coded rate scheduling for user defined lr scheduling function.

Source code in src/super_gradients/training/utils/callbacks/callbacks.py
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
@register_lr_scheduler(LRSchedulers.FUNCTION, deprecated_name="function")
class FunctionLRScheduler(LRCallbackBase):
    """
    Hard coded rate scheduling for user defined lr scheduling function.
    """

    def __init__(self, max_epochs, lr_schedule_function, **kwargs):
        super().__init__(Phase.TRAIN_BATCH_STEP, **kwargs)
        assert callable(lr_schedule_function), "self.lr_function must be callable"
        self.lr_schedule_function = lr_schedule_function
        self.max_epochs = max_epochs

    def is_lr_scheduling_enabled(self, context):
        post_warmup_epochs = self.training_params.max_epochs - self.training_params.lr_cooldown_epochs
        return self.training_params.lr_warmup_epochs <= context.epoch < post_warmup_epochs

    def perform_scheduling(self, context):
        effective_epoch = context.epoch - self.training_params.lr_warmup_epochs
        effective_max_epochs = self.max_epochs - self.training_params.lr_warmup_epochs - self.training_params.lr_cooldown_epochs
        for group_name in self.lr.keys():
            self.lr[group_name] = self.lr_schedule_function(
                initial_lr=self.initial_lr[group_name],
                epoch=effective_epoch,
                iter=context.batch_idx,
                max_epoch=effective_max_epochs,
                iters_per_epoch=self.train_loader_len,
            )
        self.update_lr(context.optimizer, context.epoch, context.batch_idx)

IllegalLRSchedulerMetric

Bases: Exception

Exception raised illegal combination of training parameters.

Parameters:

Name Type Description Default
metric_name str

Name of the metric that is not supported.

required
metrics_dict dict

Dictionary of metrics that are supported.

required
Source code in src/super_gradients/training/utils/callbacks/callbacks.py
546
547
548
549
550
551
552
553
554
555
class IllegalLRSchedulerMetric(Exception):
    """Exception raised illegal combination of training parameters.

    :param metric_name: Name of the metric that is not supported.
    :param metrics_dict: Dictionary of metrics that are supported.
    """

    def __init__(self, metric_name: str, metrics_dict: dict):
        self.message = "Illegal metric name: " + metric_name + ". Expected one of metics_dics keys: " + str(metrics_dict.keys())
        super().__init__(self.message)

LRCallbackBase

Bases: PhaseCallback

Base class for hard coded learning rate scheduling regimes, implemented as callbacks.

Source code in src/super_gradients/training/utils/callbacks/callbacks.py
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
@register_callback(Callbacks.LR_CALLBACK_BASE)
class LRCallbackBase(PhaseCallback):
    """
    Base class for hard coded learning rate scheduling regimes, implemented as callbacks.
    """

    def __init__(self, phase, initial_lr, update_param_groups, train_loader_len, net, training_params, **kwargs):
        super(LRCallbackBase, self).__init__(phase)
        if not isinstance(initial_lr, dict):
            initial_lr = {"default": float(initial_lr)}
        self.initial_lr = initial_lr
        self.lr = initial_lr.copy()
        self.update_param_groups = update_param_groups
        self.train_loader_len = train_loader_len
        self.net = net
        self.training_params = training_params

    def __call__(self, context: PhaseContext, **kwargs):
        if self.is_lr_scheduling_enabled(context):
            self.perform_scheduling(context)

    def is_lr_scheduling_enabled(self, context: PhaseContext):
        """
        Predicate that controls whether to perform lr scheduling based on values in context.

        :param context: PhaseContext: current phase's context.
        :return: bool, whether to apply lr scheduling or not.
        """
        raise NotImplementedError

    def perform_scheduling(self, context: PhaseContext):
        """
        Performs lr scheduling based on values in context.

        :param context: PhaseContext: current phase's context.
        """
        raise NotImplementedError

    def update_lr(self, optimizer, epoch, batch_idx=None):
        for param_group in optimizer.param_groups:
            param_group["lr"] = self.lr[param_group["name"]]

is_lr_scheduling_enabled(context)

Predicate that controls whether to perform lr scheduling based on values in context.

Parameters:

Name Type Description Default
context PhaseContext

PhaseContext: current phase's context.

required

Returns:

Type Description

bool, whether to apply lr scheduling or not.

Source code in src/super_gradients/training/utils/callbacks/callbacks.py
252
253
254
255
256
257
258
259
def is_lr_scheduling_enabled(self, context: PhaseContext):
    """
    Predicate that controls whether to perform lr scheduling based on values in context.

    :param context: PhaseContext: current phase's context.
    :return: bool, whether to apply lr scheduling or not.
    """
    raise NotImplementedError

perform_scheduling(context)

Performs lr scheduling based on values in context.

Parameters:

Name Type Description Default
context PhaseContext

PhaseContext: current phase's context.

required
Source code in src/super_gradients/training/utils/callbacks/callbacks.py
261
262
263
264
265
266
267
def perform_scheduling(self, context: PhaseContext):
    """
    Performs lr scheduling based on values in context.

    :param context: PhaseContext: current phase's context.
    """
    raise NotImplementedError

LRSchedulerCallback

Bases: PhaseCallback

Learning rate scheduler callback.

When passing call a metrics_dict, with a key=self.metric_name, the value of that metric will monitored for ReduceLROnPlateau (i.e step(metrics_dict[self.metric_name]).

Parameters:

Name Type Description Default
scheduler torch.optim.lr_scheduler._LRScheduler

Learning rate scheduler to be called step() with.

required
metric_name str

Metric name for ReduceLROnPlateau learning rate scheduler.

None
phase Union[Phase, str]

Phase of when to trigger it.

required
Source code in src/super_gradients/training/utils/callbacks/callbacks.py
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
@register_callback(Callbacks.LR_SCHEDULER)
class LRSchedulerCallback(PhaseCallback):
    """
    Learning rate scheduler callback.

    When passing __call__ a metrics_dict, with a key=self.metric_name, the value of that metric will monitored
         for ReduceLROnPlateau (i.e step(metrics_dict[self.metric_name]).

    :param scheduler:       Learning rate scheduler to be called step() with.
    :param metric_name:     Metric name for ReduceLROnPlateau learning rate scheduler.
    :param phase:           Phase of when to trigger it.
    """

    def __init__(self, scheduler: torch.optim.lr_scheduler._LRScheduler, phase: Union[Phase, str], metric_name: str = None):
        super(LRSchedulerCallback, self).__init__(phase)
        self.scheduler = scheduler
        self.metric_name = metric_name

    def __call__(self, context: PhaseContext):
        if context.lr_warmup_epochs <= context.epoch:
            if self.metric_name and self.metric_name in context.metrics_dict.keys():
                self.scheduler.step(context.metrics_dict[self.metric_name])
            elif self.metric_name is None:
                self.scheduler.step()
            else:
                raise IllegalLRSchedulerMetric(self.metric_name, context.metrics_dict)

    def __repr__(self):
        return "LRSchedulerCallback: " + repr(self.scheduler)

LinearBatchLRWarmup

Bases: Callback

LR scheduling callback for linear step warmup on each batch step. LR climbs from warmup_initial_lr with to initial lr.

Source code in src/super_gradients/training/utils/callbacks/callbacks.py
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
@register_lr_warmup(LRWarmups.LINEAR_BATCH_STEP, deprecated_name="linear_batch_step")
class LinearBatchLRWarmup(Callback):
    """
    LR scheduling callback for linear step warmup on each batch step.
    LR climbs from warmup_initial_lr with to initial lr.
    """

    def __init__(
        self,
        warmup_initial_lr: float,
        initial_lr: float,
        train_loader_len: int,
        lr_warmup_steps: int,
        training_params,
        net,
        **kwargs,
    ):
        """

        :param warmup_initial_lr: Starting learning rate
        :param initial_lr: Target learning rate after warmup
        :param train_loader_len: Length of train data loader
        :param lr_warmup_steps: Optional. If passed, will use fixed number of warmup steps to warmup LR. Default is None.
        :param kwargs:
        """

        super().__init__()

        if lr_warmup_steps > train_loader_len:
            logger.warning(
                f"Number of warmup steps ({lr_warmup_steps}) is greater than number of steps in epoch ({train_loader_len}). "
                f"Warmup steps will be capped to number of steps in epoch to avoid interfering with any pre-epoch LR schedulers."
            )

        if isinstance(initial_lr, numbers.Number):
            initial_lr = {"default": initial_lr}
        self.initial_lr = initial_lr
        self.lr = initial_lr.copy()

        if isinstance(warmup_initial_lr, numbers.Number):
            warmup_initial_lr = {group_name: warmup_initial_lr for group_name in self.lr.keys()}
        elif isinstance(warmup_initial_lr, Mapping):
            warmup_initial_lr = warmup_initial_lr
        else:
            raise TypeError("Warmup initial lr expected to be of type float or Mapping.")

        lr_warmup_steps = min(lr_warmup_steps, train_loader_len)
        learning_rates = {
            group_name: np.linspace(start=warmup_initial_lr[group_name], stop=initial_lr[group_name], num=lr_warmup_steps, endpoint=True)
            for group_name in self.initial_lr.keys()
        }
        self.training_params = training_params
        self.net = net
        self.learning_rates = learning_rates
        self.train_loader_len = train_loader_len
        self.lr_warmup_steps = lr_warmup_steps

    def on_train_batch_start(self, context: PhaseContext) -> None:
        global_training_step = context.batch_idx + context.epoch * self.train_loader_len
        if global_training_step < self.lr_warmup_steps:
            for group_name in self.initial_lr.keys():
                self.lr[group_name] = float(self.learning_rates[group_name][global_training_step])
            self.update_lr(context.optimizer, context.epoch, context.batch_idx)

    def update_lr(self, optimizer, epoch, batch_idx=None):
        """
        Same as in LRCallbackBase
        :param optimizer:
        :param epoch:
        :param batch_idx:
        :return:
        """
        # UPDATE THE OPTIMIZERS PARAMETER
        for param_group in optimizer.param_groups:
            param_group["lr"] = self.lr[param_group["name"]]

__init__(warmup_initial_lr, initial_lr, train_loader_len, lr_warmup_steps, training_params, net, **kwargs)

Parameters:

Name Type Description Default
warmup_initial_lr float

Starting learning rate

required
initial_lr float

Target learning rate after warmup

required
train_loader_len int

Length of train data loader

required
lr_warmup_steps int

Optional. If passed, will use fixed number of warmup steps to warmup LR. Default is None.

required
kwargs {}
Source code in src/super_gradients/training/utils/callbacks/callbacks.py
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
def __init__(
    self,
    warmup_initial_lr: float,
    initial_lr: float,
    train_loader_len: int,
    lr_warmup_steps: int,
    training_params,
    net,
    **kwargs,
):
    """

    :param warmup_initial_lr: Starting learning rate
    :param initial_lr: Target learning rate after warmup
    :param train_loader_len: Length of train data loader
    :param lr_warmup_steps: Optional. If passed, will use fixed number of warmup steps to warmup LR. Default is None.
    :param kwargs:
    """

    super().__init__()

    if lr_warmup_steps > train_loader_len:
        logger.warning(
            f"Number of warmup steps ({lr_warmup_steps}) is greater than number of steps in epoch ({train_loader_len}). "
            f"Warmup steps will be capped to number of steps in epoch to avoid interfering with any pre-epoch LR schedulers."
        )

    if isinstance(initial_lr, numbers.Number):
        initial_lr = {"default": initial_lr}
    self.initial_lr = initial_lr
    self.lr = initial_lr.copy()

    if isinstance(warmup_initial_lr, numbers.Number):
        warmup_initial_lr = {group_name: warmup_initial_lr for group_name in self.lr.keys()}
    elif isinstance(warmup_initial_lr, Mapping):
        warmup_initial_lr = warmup_initial_lr
    else:
        raise TypeError("Warmup initial lr expected to be of type float or Mapping.")

    lr_warmup_steps = min(lr_warmup_steps, train_loader_len)
    learning_rates = {
        group_name: np.linspace(start=warmup_initial_lr[group_name], stop=initial_lr[group_name], num=lr_warmup_steps, endpoint=True)
        for group_name in self.initial_lr.keys()
    }
    self.training_params = training_params
    self.net = net
    self.learning_rates = learning_rates
    self.train_loader_len = train_loader_len
    self.lr_warmup_steps = lr_warmup_steps

update_lr(optimizer, epoch, batch_idx=None)

Same as in LRCallbackBase

Parameters:

Name Type Description Default
optimizer required
epoch required
batch_idx None

Returns:

Type Description
Source code in src/super_gradients/training/utils/callbacks/callbacks.py
381
382
383
384
385
386
387
388
389
390
391
def update_lr(self, optimizer, epoch, batch_idx=None):
    """
    Same as in LRCallbackBase
    :param optimizer:
    :param epoch:
    :param batch_idx:
    :return:
    """
    # UPDATE THE OPTIMIZERS PARAMETER
    for param_group in optimizer.param_groups:
        param_group["lr"] = self.lr[param_group["name"]]

LinearEpochLRWarmup

Bases: LRCallbackBase

LR scheduling callback for linear step warmup. This scheduler uses a whole epoch as single step. LR climbs from warmup_initial_lr with even steps to initial lr. When warmup_initial_lr is None - LR climb starts from initial_lr/(1+warmup_epochs).

Source code in src/super_gradients/training/utils/callbacks/callbacks.py
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
@register_lr_warmup(LRWarmups.LINEAR_EPOCH_STEP, deprecated_name="linear_epoch_step")
class LinearEpochLRWarmup(LRCallbackBase):
    """
    LR scheduling callback for linear step warmup. This scheduler uses a whole epoch as single step.
    LR climbs from warmup_initial_lr with even steps to initial lr. When warmup_initial_lr is None - LR climb starts from
     initial_lr/(1+warmup_epochs).

    """

    def __init__(self, **kwargs):
        super().__init__(Phase.TRAIN_EPOCH_START, **kwargs)
        warmup_initial_lr = {}
        if self.training_params.warmup_initial_lr is not None:
            if isinstance(self.training_params.warmup_initial_lr, float):
                for group_name in self.initial_lr.keys():
                    warmup_initial_lr[group_name] = self.training_params.warmup_initial_lr
            elif isinstance(self.training_params.warmup_initial_lr, Mapping):
                warmup_initial_lr = self.training_params.warmup_initial_lr
            else:
                raise TypeError("Warmup initial lr expected to be of type float or Mapping.")
        else:
            for group_name in self.initial_lr.keys():
                warmup_initial_lr[group_name] = self.initial_lr[group_name] / (self.training_params.lr_warmup_epochs + 1)
        self.warmup_initial_lr = warmup_initial_lr

        warmup_step_size = {}
        for group_name in self.initial_lr.keys():
            warmup_step_size[group_name] = (
                (self.initial_lr[group_name] - self.warmup_initial_lr[group_name]) / self.training_params.lr_warmup_epochs
                if self.training_params.lr_warmup_epochs > 0
                else 0
            )
        self.warmup_step_size = warmup_step_size

    def perform_scheduling(self, context):
        for group_name in self.initial_lr.keys():
            self.lr[group_name] = self.warmup_initial_lr[group_name] + context.epoch * self.warmup_step_size[group_name]
        self.update_lr(context.optimizer, context.epoch, None)

    def is_lr_scheduling_enabled(self, context):
        return self.training_params.lr_warmup_epochs > 0 and self.training_params.lr_warmup_epochs >= context.epoch

ModelConversionCheckCallback

Bases: PhaseCallback

Pre-training callback that verifies model conversion to onnx given specified conversion parameters.

The model is converted, then inference is applied with onnx runtime.

Use this callback with the same args as DeciPlatformCallback to prevent conversion fails at the end of training.

Parameters:

Name Type Description Default
model_name str

Model's name

required
input_dimensions Sequence[int]

Model's input dimensions

required
primary_batch_size int

Model's primary batch size

required
opset_version

(default=11)

required
do_constant_folding

(default=True)

required
dynamic_axes

(default={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}})

required
input_names

(default=["input"])

required
output_names

(default=["output"])

required
rtol

(default=1e-03)

required
atol

(default=1e-05)

required
Source code in src/super_gradients/training/utils/callbacks/callbacks.py
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
@register_callback(Callbacks.MODEL_CONVERSION_CHECK)
class ModelConversionCheckCallback(PhaseCallback):
    """
    Pre-training callback that verifies model conversion to onnx given specified conversion parameters.

    The model is converted, then inference is applied with onnx runtime.

    Use this callback with the same args as DeciPlatformCallback to prevent conversion fails at the end of training.

    :param model_name:              Model's name
    :param input_dimensions:        Model's input dimensions
    :param primary_batch_size:      Model's primary batch size
    :param opset_version:           (default=11)
    :param do_constant_folding:     (default=True)
    :param dynamic_axes:            (default={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}})
    :param input_names:             (default=["input"])
    :param output_names:            (default=["output"])
    :param rtol:                    (default=1e-03)
    :param atol:                    (default=1e-05)
    """

    def __init__(self, model_name: str, input_dimensions: Sequence[int], primary_batch_size: int, **kwargs):
        super(ModelConversionCheckCallback, self).__init__(phase=Phase.PRE_TRAINING)
        self.model_name = model_name
        self.input_dimensions = input_dimensions
        self.primary_batch_size = primary_batch_size

        self.opset_version = kwargs.get("opset_version", 10)
        self.do_constant_folding = kwargs.get("do_constant_folding", None) if kwargs.get("do_constant_folding", None) else True
        self.input_names = kwargs.get("input_names") or ["input"]
        self.output_names = kwargs.get("output_names") or ["output"]
        self.dynamic_axes = kwargs.get("dynamic_axes") or {"input": {0: "batch_size"}, "output": {0: "batch_size"}}

        self.rtol = kwargs.get("rtol", 1e-03)
        self.atol = kwargs.get("atol", 1e-05)

    def __call__(self, context: PhaseContext):
        model = copy.deepcopy(unwrap_model(context.net))
        model = model.cpu()
        model.eval()  # Put model into eval mode

        if hasattr(model, "prep_model_for_conversion"):
            model.prep_model_for_conversion(input_size=self.input_dimensions)

        x = torch.randn(self.primary_batch_size, *self.input_dimensions, requires_grad=False)

        tmp_model_path = os.path.join(context.ckpt_dir, self.model_name + "_tmp.onnx")

        with torch.no_grad():
            torch_out = model(x)

        torch.onnx.export(
            model,  # Model being run
            x,  # Model input (or a tuple for multiple inputs)
            tmp_model_path,  # Where to save the model (can be a file or file-like object)
            export_params=True,  # Store the trained parameter weights inside the model file
            opset_version=self.opset_version,
            do_constant_folding=self.do_constant_folding,
            input_names=self.input_names,
            output_names=self.output_names,
            dynamic_axes=self.dynamic_axes,
        )

        onnx_model = onnx.load(tmp_model_path)
        onnx.checker.check_model(onnx_model)

        ort_session = onnxruntime.InferenceSession(tmp_model_path, providers=["CUDAExecutionProvider", "CPUExecutionProvider"])

        # compute ONNX Runtime output prediction
        ort_inputs = {ort_session.get_inputs()[0].name: x.cpu().numpy()}
        ort_outs = ort_session.run(None, ort_inputs)

        # TODO: Ideally we don't want to check this but have the certainty of just calling torch_out.cpu()
        if isinstance(torch_out, List) or isinstance(torch_out, tuple):
            torch_out = torch_out[0]
        # compare ONNX Runtime and PyTorch results
        np.testing.assert_allclose(torch_out.cpu().numpy(), ort_outs[0], rtol=self.rtol, atol=self.atol)

        os.remove(tmp_model_path)

        logger.info("Exported model has been tested with ONNXRuntime, and the result looks good!")

PhaseContextTestCallback

Bases: PhaseCallback

A callback that saves the phase context the for testing.

Source code in src/super_gradients/training/utils/callbacks/callbacks.py
611
612
613
614
615
616
617
618
619
620
621
class PhaseContextTestCallback(PhaseCallback):
    """
    A callback that saves the phase context the for testing.
    """

    def __init__(self, phase: Union[Phase, str]):
        super(PhaseContextTestCallback, self).__init__(phase)
        self.context = None

    def __call__(self, context: PhaseContext):
        self.context = context

PolyLRScheduler

Bases: LRCallbackBase

Hard coded polynomial decay learning rate scheduling (i.e at specific milestones).

Source code in src/super_gradients/training/utils/callbacks/callbacks.py
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
@register_lr_scheduler(LRSchedulers.POLY, deprecated_name="poly")
class PolyLRScheduler(LRCallbackBase):
    """
    Hard coded polynomial decay learning rate scheduling (i.e at specific milestones).
    """

    def __init__(self, max_epochs, **kwargs):
        super().__init__(Phase.TRAIN_BATCH_STEP, **kwargs)
        self.max_epochs = max_epochs

    def perform_scheduling(self, context):
        effective_epoch = context.epoch - self.training_params.lr_warmup_epochs
        effective_max_epochs = self.max_epochs - self.training_params.lr_warmup_epochs - self.training_params.lr_cooldown_epochs
        current_iter = (self.train_loader_len * effective_epoch + context.batch_idx) / self.training_params.batch_accumulate
        max_iter = self.train_loader_len * effective_max_epochs / self.training_params.batch_accumulate
        for group_name in self.lr.keys():
            self.lr[group_name] = self.initial_lr[group_name] * pow((1.0 - (current_iter / max_iter)), 0.9)
        self.update_lr(context.optimizer, context.epoch, context.batch_idx)

    def is_lr_scheduling_enabled(self, context):
        post_warmup_epochs = self.training_params.max_epochs - self.training_params.lr_cooldown_epochs
        return self.training_params.lr_warmup_epochs <= context.epoch < post_warmup_epochs

RoboflowResultCallback

Bases: Callback

Append the training results to a csv file. Be aware that this does not fully overwrite the existing file, just appends.

Source code in src/super_gradients/training/utils/callbacks/callbacks.py
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
@register_callback(Callbacks.ROBOFLOW_RESULT_CALLBACK)
class RoboflowResultCallback(Callback):
    """Append the training results to a csv file. Be aware that this does not fully overwrite the existing file, just appends."""

    def __init__(self, dataset_name: str, output_path: Optional[str] = None):
        """
        :param dataset_name:    Name of the dataset that was used to train the model.
        :param output_path:     Full path to the output csv file. By default, save at 'checkpoint_dir/results.csv'
        """
        self.dataset_name = dataset_name
        self.output_path = output_path or os.path.join(get_project_checkpoints_dir_path(), "results.csv")

        if self.output_path is None:
            raise ValueError("Output path must be specified")

        super(RoboflowResultCallback, self).__init__()

    @multi_process_safe
    def on_training_end(self, context: PhaseContext):
        with open(self.output_path, mode="a", newline="") as csv_file:
            writer = csv.writer(csv_file)

            mAP = context.metrics_dict["mAP@0.50:0.95"].item()
            writer.writerow([self.dataset_name, mAP])

__init__(dataset_name, output_path=None)

Parameters:

Name Type Description Default
dataset_name str

Name of the dataset that was used to train the model.

required
output_path Optional[str]

Full path to the output csv file. By default, save at 'checkpoint_dir/results.csv'

None
Source code in src/super_gradients/training/utils/callbacks/callbacks.py
746
747
748
749
750
751
752
753
754
755
756
757
def __init__(self, dataset_name: str, output_path: Optional[str] = None):
    """
    :param dataset_name:    Name of the dataset that was used to train the model.
    :param output_path:     Full path to the output csv file. By default, save at 'checkpoint_dir/results.csv'
    """
    self.dataset_name = dataset_name
    self.output_path = output_path or os.path.join(get_project_checkpoints_dir_path(), "results.csv")

    if self.output_path is None:
        raise ValueError("Output path must be specified")

    super(RoboflowResultCallback, self).__init__()

SlidingWindowValidationCallback

Bases: Callback

Performing single-scale sliding window during inference at the last epoch on the validation set and on the average model.

Source code in src/super_gradients/training/utils/callbacks/callbacks.py
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
@register_callback(Callbacks.SLIDING_WINDOW_VALIDATION)
class SlidingWindowValidationCallback(Callback):
    """
    Performing single-scale sliding window during inference at the last epoch on the validation set and on the average model.
    """

    def __init__(self, transforms_for_sliding_window) -> None:
        self.transforms_for_sliding_window = transforms_for_sliding_window
        self.valid_loader_transforms = []
        self.test_loader_transforms = []

    def on_validation_loader_start(self, context: PhaseContext) -> None:
        if context.training_params.max_epochs - 1 == context.epoch:
            unwrap_model(context.net).enable_sliding_window_validation()
            self.valid_loader_transforms = context.valid_loader.dataset.transforms.transforms
            context.valid_loader.dataset.transforms.transforms = self.transforms_for_sliding_window
            iter(context.valid_loader)

    def on_validation_loader_end(self, context: PhaseContext) -> None:
        if context.training_params.max_epochs - 1 == context.epoch:
            unwrap_model(context.net).disable_sliding_window_validation()

    def on_average_best_models_validation_start(self, context: PhaseContext) -> None:
        if context.training_params.max_epochs - 1 == context.epoch and context.training_params.average_best_models:
            unwrap_model(context.net).enable_sliding_window_validation()
            context.valid_loader.dataset.transforms.transforms = self.transforms_for_sliding_window
            iter(context.valid_loader)

    def on_average_best_models_validation_end(self, context: PhaseContext) -> None:
        if context.training_params.max_epochs == context.epoch and context.training_params.average_best_models:
            unwrap_model(context.net).disable_sliding_window_validation()
            context.valid_loader.dataset.transforms.transforms = self.valid_loader_transforms
            iter(context.valid_loader)

    def on_test_loader_start(self, context: PhaseContext) -> None:
        unwrap_model(context.net).enable_sliding_window_validation()
        self.test_loader_transforms = context.test_loader.dataset.transforms.transforms
        context.test_loader.dataset.transforms.transforms = self.transforms_for_sliding_window
        iter(context.test_loader)

    def on_test_loader_end(self, context: PhaseContext) -> None:
        unwrap_model(context.net).disable_sliding_window_validation()
        context.test_loader.dataset.transforms.transforms = self.test_loader_transforms
        iter(context.test_loader)

StepLRScheduler

Bases: LRCallbackBase

Hard coded step learning rate scheduling (i.e at specific milestones).

Source code in src/super_gradients/training/utils/callbacks/callbacks.py
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
@register_lr_scheduler(LRSchedulers.STEP, deprecated_name="step")
class StepLRScheduler(LRCallbackBase):
    """
    Hard coded step learning rate scheduling (i.e at specific milestones).
    """

    def __init__(self, lr_updates, lr_decay_factor, step_lr_update_freq=None, **kwargs):
        super().__init__(Phase.TRAIN_EPOCH_END, **kwargs)
        if step_lr_update_freq and len(lr_updates):
            raise ValueError(
                "Parameters lr_updates and step_lr_update_freq are mutually exclusive"
                f" and cannot be passed to {StepLRScheduler.__name__} constructor simultaneously"
            )

        if step_lr_update_freq is None and len(lr_updates) == 0:
            raise ValueError(f"At least one of [lr_updates, step_lr_update_freq] parameters should be passed to {StepLRScheduler.__name__} constructor")

        if step_lr_update_freq:
            max_epochs = self.training_params.max_epochs - self.training_params.lr_cooldown_epochs
            warmup_epochs = self.training_params.lr_warmup_epochs
            lr_updates = [
                int(np.ceil(step_lr_update_freq * x)) for x in range(1, max_epochs) if warmup_epochs <= int(np.ceil(step_lr_update_freq * x)) < max_epochs
            ]
        elif self.training_params.lr_cooldown_epochs > 0:
            logger.warning("Specific lr_updates were passed along with cooldown_epochs > 0," " cooldown will have no effect.")
        self.lr_updates = lr_updates
        self.lr_decay_factor = lr_decay_factor

    def perform_scheduling(self, context):
        num_updates_passed = [x for x in self.lr_updates if x <= context.epoch]
        for group_name in self.lr.keys():
            self.lr[group_name] = self.initial_lr[group_name] * self.lr_decay_factor ** len(num_updates_passed)
        self.update_lr(context.optimizer, context.epoch, None)

    def is_lr_scheduling_enabled(self, context):
        return self.training_params.lr_warmup_epochs <= context.epoch

TestLRCallback

Bases: PhaseCallback

Phase callback that collects the learning rates in lr_placeholder at the end of each epoch (used for testing). In the case of multiple parameter groups (i.e multiple learning rates) the learning rate is collected from the first one. The phase is VALIDATION_EPOCH_END to ensure all lr updates have been performed before calling this callback.

Source code in src/super_gradients/training/utils/callbacks/callbacks.py
768
769
770
771
772
773
774
775
776
777
778
779
780
class TestLRCallback(PhaseCallback):
    """
    Phase callback that collects the learning rates in lr_placeholder at the end of each epoch (used for testing). In
     the case of multiple parameter groups (i.e multiple learning rates) the learning rate is collected from the first
     one. The phase is VALIDATION_EPOCH_END to ensure all lr updates have been performed before calling this callback.
    """

    def __init__(self, lr_placeholder):
        super(TestLRCallback, self).__init__(Phase.VALIDATION_EPOCH_END)
        self.lr_placeholder = lr_placeholder

    def __call__(self, context: PhaseContext):
        self.lr_placeholder.append(context.optimizer.param_groups[0]["lr"])

TrainingStageSwitchCallbackBase

Bases: PhaseCallback

TrainingStageSwitchCallback

A phase callback that is called at a specific epoch (epoch start) to support multi-stage training. It does so by manipulating the objects inside the context.

Parameters:

Name Type Description Default
next_stage_start_epoch int

Epoch idx to apply the stage change.

required
Source code in src/super_gradients/training/utils/callbacks/callbacks.py
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
class TrainingStageSwitchCallbackBase(PhaseCallback):
    """
    TrainingStageSwitchCallback

    A phase callback that is called at a specific epoch (epoch start) to support multi-stage training.
    It does so by manipulating the objects inside the context.

    :param next_stage_start_epoch: Epoch idx to apply the stage change.
    """

    def __init__(self, next_stage_start_epoch: int):
        super(TrainingStageSwitchCallbackBase, self).__init__(phase=Phase.TRAIN_EPOCH_START)
        self.next_stage_start_epoch = next_stage_start_epoch

    def __call__(self, context: PhaseContext):
        if context.epoch == self.next_stage_start_epoch:
            self.apply_stage_change(context)

    def apply_stage_change(self, context: PhaseContext):
        """
        This method is called when the callback is fired on the next_stage_start_epoch,
         and holds the stage change logic that should be applied to the context's objects.

        :param context: PhaseContext, context of current phase
        """
        raise NotImplementedError

apply_stage_change(context)

This method is called when the callback is fired on the next_stage_start_epoch, and holds the stage change logic that should be applied to the context's objects.

Parameters:

Name Type Description Default
context PhaseContext

PhaseContext, context of current phase

required
Source code in src/super_gradients/training/utils/callbacks/callbacks.py
711
712
713
714
715
716
717
718
def apply_stage_change(self, context: PhaseContext):
    """
    This method is called when the callback is fired on the next_stage_start_epoch,
     and holds the stage change logic that should be applied to the context's objects.

    :param context: PhaseContext, context of current phase
    """
    raise NotImplementedError

YoloXTrainingStageSwitchCallback

Bases: TrainingStageSwitchCallbackBase

YoloXTrainingStageSwitchCallback

Training stage switch for YoloX training. Disables mosaic, and manipulates YoloX loss to use L1.

Source code in src/super_gradients/training/utils/callbacks/callbacks.py
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
@register_callback(Callbacks.YOLOX_TRAINING_STAGE_SWITCH)
class YoloXTrainingStageSwitchCallback(TrainingStageSwitchCallbackBase):
    """
    YoloXTrainingStageSwitchCallback

    Training stage switch for YoloX training.
    Disables mosaic, and manipulates YoloX loss to use L1.

    """

    def __init__(self, next_stage_start_epoch: int = 285):
        super(YoloXTrainingStageSwitchCallback, self).__init__(next_stage_start_epoch=next_stage_start_epoch)

    def apply_stage_change(self, context: PhaseContext):
        for transform in context.train_loader.dataset.transforms:
            if hasattr(transform, "close"):
                transform.close()
        iter(context.train_loader)
        context.criterion.use_l1 = True

create_lr_scheduler_callback(lr_mode, train_loader, net, training_params, update_param_groups, optimizer)

Creates the phase callback in charge of LR scheduling, to be used by Trainer.

Parameters:

Name Type Description Default
lr_mode Union[str, Mapping]

Union[str, Mapping], When str: Learning rate scheduling policy, one of ['StepLRScheduler','PolyLRScheduler','CosineLRScheduler','FunctionLRScheduler']. 'StepLRScheduler' refers to constant updates at epoch numbers passed through lr_updates. Each update decays the learning rate by lr_decay_factor. 'CosineLRScheduler' refers to the Cosine Anealing policy as mentioned in https://arxiv.org/abs/1608.03983. The final learning rate ratio is controlled by cosine_final_lr_ratio training parameter. 'PolyLRScheduler' refers to the polynomial decrease: in each epoch iteration self.lr = self.initial_lr * pow((1.0 - (current_iter / max_iter)), 0.9) 'FunctionLRScheduler' refers to a user-defined learning rate scheduling function, that is passed through lr_schedule_function. When Mapping, refers to a torch.optim.lr_scheduler.LRScheduler, following the below API: lr_mode = {LR_SCHEDULER_CLASS_NAME: {*LR_SCHEDULER_KWARGS, "phase": XXX, "metric_name": XXX) Where "phase" (of Phase type) controls when to call torch.optim.lr_scheduler._LRScheduler.step(). For instance, in order to: - Update LR on each batch: Use phase: Phase.TRAIN_BATCH_END - Update LR after each epoch: Use phase: Phase.TRAIN_EPOCH_END The "metric_name" refers to the metric to watch (See docs for "metric_to_watch" in train(...) https://docs.deci.ai/super-gradients/docstring/training/sg_trainer.html) when using ReduceLROnPlateau. In any other case this kwarg is ignored. *LR_SCHEDULER_KWARGS are simply passed to the torch scheduler's __init_.

required
train_loader DataLoader

DataLoader, the Trainer.train_loader used for training.

required
net torch.nn.Module

torch.nn.Module, the Trainer.net used for training.

required
training_params Mapping

Mapping, Trainer.training_params.

required
update_param_groups bool

bool, Whether the Trainer.net has a specific way of updaitng its parameter group.

required
optimizer torch.optim.Optimizer

The optimizer used for training. Will be passed to the LR callback's init (or the torch scheduler's init, depending on the lr_mode value as described above).

required

Returns:

Type Description
PhaseCallback

a PhaseCallback instance to be used by Trainer for LR scheduling.

Source code in src/super_gradients/training/utils/callbacks/callbacks.py
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
def create_lr_scheduler_callback(
    lr_mode: Union[str, Mapping],
    train_loader: DataLoader,
    net: torch.nn.Module,
    training_params: Mapping,
    update_param_groups: bool,
    optimizer: torch.optim.Optimizer,
) -> PhaseCallback:
    """
    Creates the phase callback in charge of LR scheduling, to be used by Trainer.

    :param lr_mode: Union[str, Mapping],

                    When str:

                    Learning rate scheduling policy, one of ['StepLRScheduler','PolyLRScheduler','CosineLRScheduler','FunctionLRScheduler'].

                    'StepLRScheduler' refers to constant updates at epoch numbers passed through `lr_updates`.
                        Each update decays the learning rate by `lr_decay_factor`.

                    'CosineLRScheduler' refers to the Cosine Anealing policy as mentioned in https://arxiv.org/abs/1608.03983.
                      The final learning rate ratio is controlled by `cosine_final_lr_ratio` training parameter.

                    'PolyLRScheduler' refers to the polynomial decrease:
                        in each epoch iteration `self.lr = self.initial_lr * pow((1.0 - (current_iter / max_iter)), 0.9)`

                    'FunctionLRScheduler' refers to a user-defined learning rate scheduling function, that is passed through `lr_schedule_function`.



                    When Mapping, refers to a torch.optim.lr_scheduler._LRScheduler, following the below API:

                        lr_mode = {LR_SCHEDULER_CLASS_NAME: {**LR_SCHEDULER_KWARGS, "phase": XXX, "metric_name": XXX)

                        Where "phase" (of Phase type) controls when to call torch.optim.lr_scheduler._LRScheduler.step().

                        For instance, in order to:
                        - Update LR on each batch: Use phase: Phase.TRAIN_BATCH_END
                        - Update LR after each epoch: Use phase: Phase.TRAIN_EPOCH_END

                        The "metric_name" refers to the metric to watch (See docs for "metric_to_watch" in train(...)
                         https://docs.deci.ai/super-gradients/docstring/training/sg_trainer.html) when using
                          ReduceLROnPlateau. In any other case this kwarg is ignored.

                        **LR_SCHEDULER_KWARGS are simply passed to the torch scheduler's __init__.




    :param train_loader: DataLoader, the Trainer.train_loader used for training.

    :param net: torch.nn.Module, the Trainer.net used for training.

    :param training_params: Mapping, Trainer.training_params.

    :param update_param_groups:bool,  Whether the Trainer.net has a specific way of updaitng its parameter group.

    :param optimizer: The optimizer used for training. Will be passed to the LR callback's __init__
     (or the torch scheduler's init, depending on the lr_mode value as described above).

    :return: a PhaseCallback instance to be used by Trainer for LR scheduling.
    """

    if isinstance(lr_mode, str) and lr_mode in LR_SCHEDULERS_CLS_DICT:
        sg_lr_callback_cls = LR_SCHEDULERS_CLS_DICT[lr_mode]
        sg_lr_callback = sg_lr_callback_cls(
            train_loader_len=len(train_loader),
            net=net,
            training_params=training_params,
            update_param_groups=update_param_groups,
            **training_params.to_dict(),
        )
    elif isinstance(lr_mode, Mapping) and list(lr_mode.keys())[0] in TORCH_LR_SCHEDULERS:
        if update_param_groups:
            logger.warning(
                "The network's way of updataing (i.e update_param_groups) is not supported with native " "torch lr schedulers and will have no effect."
            )
        lr_scheduler_name = list(lr_mode.keys())[0]
        torch_scheduler_params = {k: v for k, v in lr_mode[lr_scheduler_name].items() if k != "phase" and k != "metric_name"}
        torch_scheduler_params["optimizer"] = optimizer
        torch_scheduler = TORCH_LR_SCHEDULERS[lr_scheduler_name](**torch_scheduler_params)
        if get_param(lr_mode[lr_scheduler_name], "phase") is None:
            raise ValueError("Phase is required argument when working with torch schedulers.")

        if lr_scheduler_name == "ReduceLROnPlateau" and get_param(lr_mode[lr_scheduler_name], "metric_name") is None:
            raise ValueError("metric_name is required argument when working with ReduceLROnPlateau schedulers.")

        sg_lr_callback = LRSchedulerCallback(
            scheduler=torch_scheduler, phase=lr_mode[lr_scheduler_name]["phase"], metric_name=get_param(lr_mode[lr_scheduler_name], "metric_name")
        )
    else:
        raise ValueError(f"Unknown lr_mode: {lr_mode}")

    return sg_lr_callback

ExtremeBatchPoseEstimationVisualizationCallback

Bases: ExtremeBatchCaseVisualizationCallback

ExtremeBatchPoseEstimationVisualizationCallback

Visualizes worst/best batch in an epoch for pose estimation task. This class visualize horizontally-stacked GT and predicted poses. It requires a key 'gt_samples' (List[PoseEstimationSample]) to be present in additional_batch_items dictionary.

Supported models: YoloNASPose Supported datasets: COCOPoseEstimationDataset

Example usage in Yaml config:

training_hyperparams:
  phase_callbacks:
      - ExtremeBatchPoseEstimationVisualizationCallback:
          keypoint_colors: ${dataset_params.keypoint_colors}
          edge_colors: ${dataset_params.edge_colors}
          edge_links: ${dataset_params.edge_links}
          loss_to_monitor: YoloNASPoseLoss/loss
          max: True
          freq: 1
          max_images: 16
          enable_on_train_loader: True
          enable_on_valid_loader: True
          post_prediction_callback:
            _target_: super_gradients.training.models.pose_estimation_models.yolo_nas_pose.YoloNASPosePostPredictionCallback
            pose_confidence_threshold: 0.01
            nms_iou_threshold: 0.7
            pre_nms_max_predictions: 300
            post_nms_max_predictions: 30

Parameters:

Name Type Description Default
metric Optional[Metric]

Metric, will be the metric which is monitored.

None
metric_component_name Optional[str]

In case metric returns multiple values (as Mapping), the value at metric.compute()[metric_component_name] will be the one monitored.

None
loss_to_monitor Optional[str]

str, loss_to_monitor corresponding to the 'criterion' passed through training_params in Trainer.train(...). Monitoring loss follows the same logic as metric_to_watch in Trainer.train(..), when watching the loss and should be: if hasattr(criterion, "component_names") and criterion.forward(..) returns a tuple: "/". If a single item is returned rather then a tuple: . When there is no such attributes and criterion.forward(..) returns a tuple: "/"Loss_"

None
max bool

bool, Whether to take the batch corresponding to the max value of the metric/loss or the minimum (default=False).

False
freq int

int, epoch frequency to perform all of the above (default=1).

1
classes

List[str], a list of class names corresponding to the class indices for display. When None, will try to fetch this through a "classes" attribute of the valdiation dataset. If such attribute does not exist an error will be raised (default=None).

required
normalize_targets

bool, whether to scale the target bboxes. If the bboxes returned by the validation data loader are in pixel values range, this needs to be set to True (default=False)

required
Source code in src/super_gradients/training/utils/callbacks/extreme_batch_pose_visualization_callback.py
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
@register_callback("ExtremeBatchPoseEstimationVisualizationCallback")
class ExtremeBatchPoseEstimationVisualizationCallback(ExtremeBatchCaseVisualizationCallback):
    """
    ExtremeBatchPoseEstimationVisualizationCallback

    Visualizes worst/best batch in an epoch for pose estimation task.
    This class visualize horizontally-stacked GT and predicted poses.
    It requires a key 'gt_samples' (List[PoseEstimationSample]) to be present in additional_batch_items dictionary.

    Supported models: YoloNASPose
    Supported datasets: COCOPoseEstimationDataset

    Example usage in Yaml config:

        training_hyperparams:
          phase_callbacks:
              - ExtremeBatchPoseEstimationVisualizationCallback:
                  keypoint_colors: ${dataset_params.keypoint_colors}
                  edge_colors: ${dataset_params.edge_colors}
                  edge_links: ${dataset_params.edge_links}
                  loss_to_monitor: YoloNASPoseLoss/loss
                  max: True
                  freq: 1
                  max_images: 16
                  enable_on_train_loader: True
                  enable_on_valid_loader: True
                  post_prediction_callback:
                    _target_: super_gradients.training.models.pose_estimation_models.yolo_nas_pose.YoloNASPosePostPredictionCallback
                    pose_confidence_threshold: 0.01
                    nms_iou_threshold: 0.7
                    pre_nms_max_predictions: 300
                    post_nms_max_predictions: 30

    :param metric: Metric, will be the metric which is monitored.

    :param metric_component_name: In case metric returns multiple values (as Mapping),
     the value at metric.compute()[metric_component_name] will be the one monitored.

    :param loss_to_monitor: str, loss_to_monitor corresponding to the 'criterion' passed through training_params in Trainer.train(...).
     Monitoring loss follows the same logic as metric_to_watch in Trainer.train(..), when watching the loss and should be:

        if hasattr(criterion, "component_names") and criterion.forward(..) returns a tuple:
            <LOSS_CLASS.__name__>"/"<COMPONENT_NAME>.

        If a single item is returned rather then a tuple:
            <LOSS_CLASS.__name__>.

        When there is no such attributes and criterion.forward(..) returns a tuple:
            <LOSS_CLASS.__name__>"/"Loss_"<IDX>

    :param max: bool, Whether to take the batch corresponding to the max value of the metric/loss or
     the minimum (default=False).

    :param freq: int, epoch frequency to perform all of the above (default=1).

    :param classes: List[str], a list of class names corresponding to the class indices for display.
     When None, will try to fetch this through a "classes" attribute of the valdiation dataset. If such attribute does
      not exist an error will be raised (default=None).

    :param normalize_targets: bool, whether to scale the target bboxes. If the bboxes returned by the validation data loader
     are in pixel values range, this needs to be set to True (default=False)

    """

    def __init__(
        self,
        post_prediction_callback: Callable,
        keypoint_colors: Union[ListConfig, List[Tuple[int, int, int]]],
        edge_colors: Union[ListConfig, List[Tuple[int, int, int]]],
        edge_links: Union[ListConfig, List[Tuple[int, int]]],
        metric: Optional[Metric] = None,
        metric_component_name: Optional[str] = None,
        loss_to_monitor: Optional[str] = None,
        max: bool = False,
        freq: int = 1,
        max_images: Optional[int] = None,
        enable_on_train_loader: bool = False,
        enable_on_valid_loader: bool = True,
    ):
        super().__init__(
            metric=metric,
            metric_component_name=metric_component_name,
            loss_to_monitor=loss_to_monitor,
            max=max,
            freq=freq,
            enable_on_train_loader=enable_on_train_loader,
            enable_on_valid_loader=enable_on_valid_loader,
        )
        self.post_prediction_callback = post_prediction_callback
        self.keypoint_colors = list(tuple(map(int, color)) for color in keypoint_colors)
        self.edge_colors = list(tuple(map(int, color)) for color in edge_colors)
        self.edge_links = list(tuple(map(int, link)) for link in edge_links)
        self.max_images = max_images

    @classmethod
    def universal_undo_preprocessing_fn(cls, inputs: torch.Tensor) -> np.ndarray:
        """
        A universal reversing of preprocessing to be passed to DetectionVisualization.visualize_batch's undo_preprocessing_func kwarg.
        :param inputs:
        :return:
        """
        inputs = inputs - inputs.min()
        inputs /= inputs.max()
        inputs *= 255
        inputs = inputs.to(torch.uint8)
        inputs = inputs.cpu().numpy()
        inputs = inputs[:, ::-1, :, :].transpose(0, 2, 3, 1)
        inputs = np.ascontiguousarray(inputs, dtype=np.uint8)
        return inputs

    @classmethod
    def _visualize_batch(
        cls,
        image_tensor: np.ndarray,
        keypoints: List[Union[np.ndarray, Tensor]],
        bboxes: List[Union[None, np.ndarray, Tensor]],
        scores: Optional[List[Union[None, np.ndarray, Tensor]]],
        is_crowd: Optional[List[Union[None, np.ndarray, Tensor]]],
        keypoint_colors: List[Tuple[int, int, int]],
        edge_colors: List[Tuple[int, int, int]],
        edge_links: List[Tuple[int, int]],
        show_keypoint_confidence: bool,
    ) -> List[np.ndarray]:
        """
        Generate list of samples visualization of a batch of images with keypoints and bounding boxes.

        :param image_tensor:             Images batch of [Batch Size, 3, H, W] shape with values in [0, 255] range.
                                         The images should be scaled to [0, 255] range and converted to uint8 type beforehead.
        :param keypoints:                Keypoints in XY format. Shape [Num Instances, Num Joints, 2]. Can be None.
        :param bboxes:                   Bounding boxes in XYXY format. Shape [Num Instances, 4]. Can be None.
        :param scores:                   Keypoint scores. Shape [Num Instances, Num Joints]. Can be None.
        :param is_crowd:                 Whether each sample is crowd or not. Shape [Num Instances]. Can be None.
        :param keypoint_colors:          Keypoint colors. Shape [Num Joints, 3]
        :param edge_colors:              Edge colors between joints. Shape [Num Links, 3]
        :param edge_links:               Edge links between joints. Shape [Num Links, 2]
        :param show_keypoint_confidence: Whether to show confidence for each keypoint. Requires `scores` to be not None.
        :return:                         List of visualization images.
        """

        out_images = []
        for i in range(image_tensor.shape[0]):
            keypoints_i = keypoints[i]
            bboxes_i = bboxes[i]
            scores_i = scores[i] if scores is not None else None
            is_crowd_i = is_crowd[i] if is_crowd is not None else None

            if torch.is_tensor(keypoints_i):
                keypoints_i = keypoints_i.detach().cpu().numpy()
            if torch.is_tensor(bboxes_i):
                bboxes_i = bboxes_i.detach().cpu().numpy()
            if torch.is_tensor(scores_i):
                scores_i = scores_i.detach().cpu().numpy()
            if torch.is_tensor(is_crowd_i):
                is_crowd_i = is_crowd_i.detach().cpu().numpy()

            res_image = image_tensor[i]
            res_image = PoseVisualization.draw_poses(
                image=res_image,
                poses=keypoints_i,
                boxes=bboxes_i,
                scores=scores_i,
                is_crowd=is_crowd_i,
                show_keypoint_confidence=show_keypoint_confidence,
                edge_links=edge_links,
                edge_colors=edge_colors,
                keypoint_colors=keypoint_colors,
                keypoint_confidence_threshold=0.01,
            )

            out_images.append(res_image)

        return out_images

    @torch.no_grad()
    def process_extreme_batch(self) -> np.ndarray:
        """
        Processes the extreme batch, and returns batche of images for visualization - predictions and GT poses stacked horizontally.

        :return: np.ndarray - the visualization of predictions and GT
        """
        if "gt_samples" not in self.extreme_additional_batch_items:
            raise RuntimeError(
                "ExtremeBatchPoseEstimationVisualizationCallback requires 'gt_samples' to be present in additional_batch_items."
                "Currently only YoloNASPose model is supported. Old DEKR recipe is not supported at the moment."
            )

        inputs = self.universal_undo_preprocessing_fn(self.extreme_batch)
        gt_samples: List[PoseEstimationSample] = self.extreme_additional_batch_items["gt_samples"]
        predictions: List[PoseEstimationPredictions] = self.post_prediction_callback(self.extreme_preds)

        images_to_save_preds = self._visualize_batch(
            image_tensor=inputs,
            keypoints=[p.poses for p in predictions],
            bboxes=[p.bboxes_xyxy for p in predictions],
            scores=[p.scores for p in predictions],
            is_crowd=None,
            edge_links=self.edge_links,
            edge_colors=self.edge_colors,
            keypoint_colors=self.keypoint_colors,
            show_keypoint_confidence=True,
        )
        images_to_save_preds = np.stack(images_to_save_preds)

        images_to_save_gt = self._visualize_batch(
            image_tensor=inputs,
            keypoints=[gt.joints for gt in gt_samples],
            bboxes=[xywh_to_xyxy(gt.bboxes_xywh, image_shape=None) if gt.bboxes_xywh is not None else None for gt in gt_samples],
            scores=None,
            is_crowd=[gt.is_crowd for gt in gt_samples],
            edge_links=self.edge_links,
            edge_colors=self.edge_colors,
            keypoint_colors=self.keypoint_colors,
            show_keypoint_confidence=False,
        )
        images_to_save_gt = np.stack(images_to_save_gt)

        # Stack the predictions and GT images together
        return np.concatenate([images_to_save_gt, images_to_save_preds], axis=2)

process_extreme_batch()

Processes the extreme batch, and returns batche of images for visualization - predictions and GT poses stacked horizontally.

Returns:

Type Description
np.ndarray

np.ndarray - the visualization of predictions and GT

Source code in src/super_gradients/training/utils/callbacks/extreme_batch_pose_visualization_callback.py
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
@torch.no_grad()
def process_extreme_batch(self) -> np.ndarray:
    """
    Processes the extreme batch, and returns batche of images for visualization - predictions and GT poses stacked horizontally.

    :return: np.ndarray - the visualization of predictions and GT
    """
    if "gt_samples" not in self.extreme_additional_batch_items:
        raise RuntimeError(
            "ExtremeBatchPoseEstimationVisualizationCallback requires 'gt_samples' to be present in additional_batch_items."
            "Currently only YoloNASPose model is supported. Old DEKR recipe is not supported at the moment."
        )

    inputs = self.universal_undo_preprocessing_fn(self.extreme_batch)
    gt_samples: List[PoseEstimationSample] = self.extreme_additional_batch_items["gt_samples"]
    predictions: List[PoseEstimationPredictions] = self.post_prediction_callback(self.extreme_preds)

    images_to_save_preds = self._visualize_batch(
        image_tensor=inputs,
        keypoints=[p.poses for p in predictions],
        bboxes=[p.bboxes_xyxy for p in predictions],
        scores=[p.scores for p in predictions],
        is_crowd=None,
        edge_links=self.edge_links,
        edge_colors=self.edge_colors,
        keypoint_colors=self.keypoint_colors,
        show_keypoint_confidence=True,
    )
    images_to_save_preds = np.stack(images_to_save_preds)

    images_to_save_gt = self._visualize_batch(
        image_tensor=inputs,
        keypoints=[gt.joints for gt in gt_samples],
        bboxes=[xywh_to_xyxy(gt.bboxes_xywh, image_shape=None) if gt.bboxes_xywh is not None else None for gt in gt_samples],
        scores=None,
        is_crowd=[gt.is_crowd for gt in gt_samples],
        edge_links=self.edge_links,
        edge_colors=self.edge_colors,
        keypoint_colors=self.keypoint_colors,
        show_keypoint_confidence=False,
    )
    images_to_save_gt = np.stack(images_to_save_gt)

    # Stack the predictions and GT images together
    return np.concatenate([images_to_save_gt, images_to_save_preds], axis=2)

universal_undo_preprocessing_fn(inputs) classmethod

A universal reversing of preprocessing to be passed to DetectionVisualization.visualize_batch's undo_preprocessing_func kwarg.

Parameters:

Name Type Description Default
inputs torch.Tensor required

Returns:

Type Description
np.ndarray
Source code in src/super_gradients/training/utils/callbacks/extreme_batch_pose_visualization_callback.py
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
@classmethod
def universal_undo_preprocessing_fn(cls, inputs: torch.Tensor) -> np.ndarray:
    """
    A universal reversing of preprocessing to be passed to DetectionVisualization.visualize_batch's undo_preprocessing_func kwarg.
    :param inputs:
    :return:
    """
    inputs = inputs - inputs.min()
    inputs /= inputs.max()
    inputs *= 255
    inputs = inputs.to(torch.uint8)
    inputs = inputs.cpu().numpy()
    inputs = inputs[:, ::-1, :, :].transpose(0, 2, 3, 1)
    inputs = np.ascontiguousarray(inputs, dtype=np.uint8)
    return inputs

PPYoloETrainingStageSwitchCallback

Bases: TrainingStageSwitchCallbackBase

PPYoloETrainingStageSwitchCallback

Training stage switch for PPYolo training. It changes static bbox assigner to a task aligned assigned after certain number of epochs passed

Source code in src/super_gradients/training/utils/callbacks/ppyoloe_switch_callback.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
@register_callback(Callbacks.PPYOLOE_TRAINING_STAGE_SWITCH)
class PPYoloETrainingStageSwitchCallback(TrainingStageSwitchCallbackBase):
    """
    PPYoloETrainingStageSwitchCallback

    Training stage switch for PPYolo training.
    It changes static bbox assigner to a task aligned assigned after certain number of epochs passed

    """

    def __init__(
        self,
        static_assigner_end_epoch: int = 30,
    ):
        super().__init__(next_stage_start_epoch=static_assigner_end_epoch)

    def apply_stage_change(self, context: PhaseContext):
        from super_gradients.training.losses import PPYoloELoss

        if not isinstance(context.criterion, PPYoloELoss):
            raise RuntimeError(
                f"A criterion must be an instance of PPYoloELoss when using PPYoloETrainingStageSwitchCallback. " f"Got criterion {repr(context.criterion)}"
            )
        context.criterion.use_static_assigner = False

DefaultCheckpointSolver

Implements the default behavior from adaptive_load_state_dict. If the model state dict and checkpoint state dict has no 1:1 matching by name, then default solver uses simple ordered matching. It assumes that order of layers in the checkpoint is the same as in the model and iterates over them simultaneously. If shape of the source and recipient tensors are different, solver raises an error.

Source code in src/super_gradients/training/utils/checkpoint_utils.py
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
class DefaultCheckpointSolver:
    """
    Implements the default behavior from adaptive_load_state_dict.
    If the model state dict and checkpoint state dict has no 1:1 matching by name,
    then default solver uses simple ordered matching.
    It assumes that order of layers in the checkpoint is the same as in the model and
    iterates over them simultaneously.
    If shape of the source and recipient tensors are different, solver raises an error.
    """

    def __call__(self, model_state_dict: Mapping[str, Tensor], checkpoint_state_dict: Mapping[str, Tensor]) -> Mapping[str, Tensor]:
        """
        Map checkpoint state_dict to model state_dict.

        :param model_state_dict: (Mapping[str, Tensor]) A checkpoint state dict
        :param checkpoint_state_dict: (Mapping[str, Tensor]) A model state dict
        :return: (Mapping[str, Tensor]) New checkpoint state dict with keys/values converted to match model state_dict
        """
        new_ckpt_dict = {}
        for (ckpt_key, ckpt_val), (model_key, model_val) in zip(checkpoint_state_dict.items(), model_state_dict.items()):
            if ckpt_val.shape != model_val.shape:
                raise ValueError(f"ckpt layer {ckpt_key} with shape {ckpt_val.shape} does not match {model_key}" f" with shape {model_val.shape} in the model")
            new_ckpt_dict[model_key] = ckpt_val
        return new_ckpt_dict

__call__(model_state_dict, checkpoint_state_dict)

Map checkpoint state_dict to model state_dict.

Parameters:

Name Type Description Default
model_state_dict Mapping[str, Tensor]

(Mapping[str, Tensor]) A checkpoint state dict

required
checkpoint_state_dict Mapping[str, Tensor]

(Mapping[str, Tensor]) A model state dict

required

Returns:

Type Description
Mapping[str, Tensor]

(Mapping[str, Tensor]) New checkpoint state dict with keys/values converted to match model state_dict

Source code in src/super_gradients/training/utils/checkpoint_utils.py
200
201
202
203
204
205
206
207
208
209
210
211
212
213
def __call__(self, model_state_dict: Mapping[str, Tensor], checkpoint_state_dict: Mapping[str, Tensor]) -> Mapping[str, Tensor]:
    """
    Map checkpoint state_dict to model state_dict.

    :param model_state_dict: (Mapping[str, Tensor]) A checkpoint state dict
    :param checkpoint_state_dict: (Mapping[str, Tensor]) A model state dict
    :return: (Mapping[str, Tensor]) New checkpoint state dict with keys/values converted to match model state_dict
    """
    new_ckpt_dict = {}
    for (ckpt_key, ckpt_val), (model_key, model_val) in zip(checkpoint_state_dict.items(), model_state_dict.items()):
        if ckpt_val.shape != model_val.shape:
            raise ValueError(f"ckpt layer {ckpt_key} with shape {ckpt_val.shape} does not match {model_key}" f" with shape {model_val.shape} in the model")
        new_ckpt_dict[model_key] = ckpt_val
    return new_ckpt_dict

MissingPretrainedWeightsException

Bases: Exception

Exception raised by unsupported pretrianed model.

Parameters:

Name Type Description Default
desc

explanation of the error

required
Source code in src/super_gradients/training/utils/checkpoint_utils.py
1537
1538
1539
1540
1541
1542
1543
1544
1545
class MissingPretrainedWeightsException(Exception):
    """Exception raised by unsupported pretrianed model.

    :param desc: explanation of the error
    """

    def __init__(self, desc):
        self.message = "Missing pretrained wights: " + desc
        super().__init__(self.message)

YoloXCheckpointSolver

Implementation of checkpoint solver for old YoloX model checkpoints.

Source code in src/super_gradients/training/utils/checkpoint_utils.py
 216
 217
 218
 219
 220
 221
 222
 223
 224
 225
 226
 227
 228
 229
 230
 231
 232
 233
 234
 235
 236
 237
 238
 239
 240
 241
 242
 243
 244
 245
 246
 247
 248
 249
 250
 251
 252
 253
 254
 255
 256
 257
 258
 259
 260
 261
 262
 263
 264
 265
 266
 267
 268
 269
 270
 271
 272
 273
 274
 275
 276
 277
 278
 279
 280
 281
 282
 283
 284
 285
 286
 287
 288
 289
 290
 291
 292
 293
 294
 295
 296
 297
 298
 299
 300
 301
 302
 303
 304
 305
 306
 307
 308
 309
 310
 311
 312
 313
 314
 315
 316
 317
 318
 319
 320
 321
 322
 323
 324
 325
 326
 327
 328
 329
 330
 331
 332
 333
 334
 335
 336
 337
 338
 339
 340
 341
 342
 343
 344
 345
 346
 347
 348
 349
 350
 351
 352
 353
 354
 355
 356
 357
 358
 359
 360
 361
 362
 363
 364
 365
 366
 367
 368
 369
 370
 371
 372
 373
 374
 375
 376
 377
 378
 379
 380
 381
 382
 383
 384
 385
 386
 387
 388
 389
 390
 391
 392
 393
 394
 395
 396
 397
 398
 399
 400
 401
 402
 403
 404
 405
 406
 407
 408
 409
 410
 411
 412
 413
 414
 415
 416
 417
 418
 419
 420
 421
 422
 423
 424
 425
 426
 427
 428
 429
 430
 431
 432
 433
 434
 435
 436
 437
 438
 439
 440
 441
 442
 443
 444
 445
 446
 447
 448
 449
 450
 451
 452
 453
 454
 455
 456
 457
 458
 459
 460
 461
 462
 463
 464
 465
 466
 467
 468
 469
 470
 471
 472
 473
 474
 475
 476
 477
 478
 479
 480
 481
 482
 483
 484
 485
 486
 487
 488
 489
 490
 491
 492
 493
 494
 495
 496
 497
 498
 499
 500
 501
 502
 503
 504
 505
 506
 507
 508
 509
 510
 511
 512
 513
 514
 515
 516
 517
 518
 519
 520
 521
 522
 523
 524
 525
 526
 527
 528
 529
 530
 531
 532
 533
 534
 535
 536
 537
 538
 539
 540
 541
 542
 543
 544
 545
 546
 547
 548
 549
 550
 551
 552
 553
 554
 555
 556
 557
 558
 559
 560
 561
 562
 563
 564
 565
 566
 567
 568
 569
 570
 571
 572
 573
 574
 575
 576
 577
 578
 579
 580
 581
 582
 583
 584
 585
 586
 587
 588
 589
 590
 591
 592
 593
 594
 595
 596
 597
 598
 599
 600
 601
 602
 603
 604
 605
 606
 607
 608
 609
 610
 611
 612
 613
 614
 615
 616
 617
 618
 619
 620
 621
 622
 623
 624
 625
 626
 627
 628
 629
 630
 631
 632
 633
 634
 635
 636
 637
 638
 639
 640
 641
 642
 643
 644
 645
 646
 647
 648
 649
 650
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
class YoloXCheckpointSolver:
    """
    Implementation of checkpoint solver for old YoloX model checkpoints.
    """

    @classmethod
    def generate_mapping_table(cls) -> Mapping[str, str]:
        """
        Helper method to generate mapping table between olx YoloX checkpoints and the current YoloX layer names.
        :return: A mapping dictionary {checkpoint_key: model_key}
        """
        from super_gradients.common.object_names import Models
        from super_gradients.training import models

        all_mapping_keys = {}
        model_names = [Models.YOLOX_N, Models.YOLOX_T, Models.YOLOX_S, Models.YOLOX_M, Models.YOLOX_L]

        for model_name in model_names:
            model_url = MODEL_URLS[model_name + "_coco"]
            state_dict = load_state_dict_from_url(model_url, progress=True, map_location="cpu")

            model = models.get(model_name, num_classes=80)
            model_state_dict = model.state_dict()
            checkpoint_state_dict = maybe_remove_module_prefix(state_dict["net"])
            new_sd = {
                k: v
                for k, v in checkpoint_state_dict.items()
                if k not in {"stride", "_head.anchors._anchors", "_head.anchors._anchor_grid", "_head.anchors._stride", "_head._modules_list.14.stride"}
            }

            for (model_key, model_value), (checkpoint_key, checkpoint_value) in zip(model_state_dict.items(), new_sd.items()):
                if model_value.size() == checkpoint_value.size() and model_key.split(".")[-1] == checkpoint_key.split(".")[-1]:
                    if checkpoint_key in all_mapping_keys:
                        assert all_mapping_keys[checkpoint_key] == model_key
                    all_mapping_keys[checkpoint_key] = model_key
                else:
                    raise RuntimeError(
                        "Detected mismatch between model and checkpoint state dict keys."
                        f"Model key {model_key} of shape {model_value.size()} does not "
                        f"match checkpoint key {checkpoint_key} of shape {checkpoint_value.size()}"
                    )

        return all_mapping_keys

    def __init__(self):
        # The layers_rename_table below is a result of a manual mapping between the checkpoint keys and the model keys.
        # It was code-generated using YoloXCheckpointSolver.generate_mapping_table() method and tested for
        # correctness with:
        # tests.unit_tests.yolox_unit_test.TestYOLOX.test_yolo_x_checkpoint_solver.
        # tests.unit_tests.test_predict.TestModelPredict.test_detection_models

        self.layers_rename_table = {
            "_backbone._modules_list.0.conv.bn.bias": "_backbone._modules_list.0.bn.bias",
            "_backbone._modules_list.0.conv.bn.num_batches_tracked": "_backbone._modules_list.0.bn.num_batches_tracked",
            "_backbone._modules_list.0.conv.bn.running_mean": "_backbone._modules_list.0.bn.running_mean",
            "_backbone._modules_list.0.conv.bn.running_var": "_backbone._modules_list.0.bn.running_var",
            "_backbone._modules_list.0.conv.bn.weight": "_backbone._modules_list.0.bn.weight",
            "_backbone._modules_list.1.bn.bias": "_backbone._modules_list.1.bn.bias",
            "_backbone._modules_list.1.bn.num_batches_tracked": "_backbone._modules_list.1.bn.num_batches_tracked",
            "_backbone._modules_list.1.bn.running_mean": "_backbone._modules_list.1.bn.running_mean",
            "_backbone._modules_list.1.bn.running_var": "_backbone._modules_list.1.bn.running_var",
            "_backbone._modules_list.1.bn.weight": "_backbone._modules_list.1.bn.weight",
            "_backbone._modules_list.1.conv.bn.bias": "_backbone._modules_list.1.conv.bn.bias",
            "_backbone._modules_list.1.conv.bn.num_batches_tracked": "_backbone._modules_list.1.conv.bn.num_batches_tracked",
            "_backbone._modules_list.1.conv.bn.running_mean": "_backbone._modules_list.1.conv.bn.running_mean",
            "_backbone._modules_list.1.conv.bn.running_var": "_backbone._modules_list.1.conv.bn.running_var",
            "_backbone._modules_list.1.conv.bn.weight": "_backbone._modules_list.1.conv.bn.weight",
            "_backbone._modules_list.1.conv.conv.weight": "_backbone._modules_list.1.conv.conv.weight",
            "_backbone._modules_list.1.conv.weight": "_backbone._modules_list.1.conv.weight",
            "_backbone._modules_list.1.dconv.bn.bias": "_backbone._modules_list.1.dconv.bn.bias",
            "_backbone._modules_list.1.dconv.bn.num_batches_tracked": "_backbone._modules_list.1.dconv.bn.num_batches_tracked",
            "_backbone._modules_list.1.dconv.bn.running_mean": "_backbone._modules_list.1.dconv.bn.running_mean",
            "_backbone._modules_list.1.dconv.bn.running_var": "_backbone._modules_list.1.dconv.bn.running_var",
            "_backbone._modules_list.1.dconv.bn.weight": "_backbone._modules_list.1.dconv.bn.weight",
            "_backbone._modules_list.1.dconv.conv.weight": "_backbone._modules_list.1.dconv.conv.weight",
            "_backbone._modules_list.2.cv1.bn.bias": "_backbone._modules_list.2.conv1.bn.bias",
            "_backbone._modules_list.2.cv1.bn.num_batches_tracked": "_backbone._modules_list.2.conv1.bn.num_batches_tracked",
            "_backbone._modules_list.2.cv1.bn.running_mean": "_backbone._modules_list.2.conv1.bn.running_mean",
            "_backbone._modules_list.2.cv1.bn.running_var": "_backbone._modules_list.2.conv1.bn.running_var",
            "_backbone._modules_list.2.cv1.bn.weight": "_backbone._modules_list.2.conv1.bn.weight",
            "_backbone._modules_list.2.cv1.conv.weight": "_backbone._modules_list.2.conv1.conv.weight",
            "_backbone._modules_list.2.cv2.bn.bias": "_backbone._modules_list.2.conv2.bn.bias",
            "_backbone._modules_list.2.cv2.bn.num_batches_tracked": "_backbone._modules_list.2.conv2.bn.num_batches_tracked",
            "_backbone._modules_list.2.cv2.bn.running_mean": "_backbone._modules_list.2.conv2.bn.running_mean",
            "_backbone._modules_list.2.cv2.bn.running_var": "_backbone._modules_list.2.conv2.bn.running_var",
            "_backbone._modules_list.2.cv2.bn.weight": "_backbone._modules_list.2.conv2.bn.weight",
            "_backbone._modules_list.2.cv2.conv.weight": "_backbone._modules_list.2.conv2.conv.weight",
            "_backbone._modules_list.2.cv3.bn.bias": "_backbone._modules_list.2.conv3.bn.bias",
            "_backbone._modules_list.2.cv3.bn.num_batches_tracked": "_backbone._modules_list.2.conv3.bn.num_batches_tracked",
            "_backbone._modules_list.2.cv3.bn.running_mean": "_backbone._modules_list.2.conv3.bn.running_mean",
            "_backbone._modules_list.2.cv3.bn.running_var": "_backbone._modules_list.2.conv3.bn.running_var",
            "_backbone._modules_list.2.cv3.bn.weight": "_backbone._modules_list.2.conv3.bn.weight",
            "_backbone._modules_list.2.cv3.conv.weight": "_backbone._modules_list.2.conv3.conv.weight",
            "_backbone._modules_list.2.m.0.cv1.bn.bias": "_backbone._modules_list.2.bottlenecks.0.cv1.bn.bias",
            "_backbone._modules_list.2.m.0.cv1.bn.num_batches_tracked": "_backbone._modules_list.2.bottlenecks.0.cv1.bn.num_batches_tracked",
            "_backbone._modules_list.2.m.0.cv1.bn.running_mean": "_backbone._modules_list.2.bottlenecks.0.cv1.bn.running_mean",
            "_backbone._modules_list.2.m.0.cv1.bn.running_var": "_backbone._modules_list.2.bottlenecks.0.cv1.bn.running_var",
            "_backbone._modules_list.2.m.0.cv1.bn.weight": "_backbone._modules_list.2.bottlenecks.0.cv1.bn.weight",
            "_backbone._modules_list.2.m.0.cv1.conv.weight": "_backbone._modules_list.2.bottlenecks.0.cv1.conv.weight",
            "_backbone._modules_list.2.m.0.cv2.bn.bias": "_backbone._modules_list.2.bottlenecks.0.cv2.bn.bias",
            "_backbone._modules_list.2.m.0.cv2.bn.num_batches_tracked": "_backbone._modules_list.2.bottlenecks.0.cv2.bn.num_batches_tracked",
            "_backbone._modules_list.2.m.0.cv2.bn.running_mean": "_backbone._modules_list.2.bottlenecks.0.cv2.bn.running_mean",
            "_backbone._modules_list.2.m.0.cv2.bn.running_var": "_backbone._modules_list.2.bottlenecks.0.cv2.bn.running_var",
            "_backbone._modules_list.2.m.0.cv2.bn.weight": "_backbone._modules_list.2.bottlenecks.0.cv2.bn.weight",
            "_backbone._modules_list.2.m.0.cv2.conv.bn.bias": "_backbone._modules_list.2.bottlenecks.0.cv2.conv.bn.bias",
            "_backbone._modules_list.2.m.0.cv2.conv.bn.num_batches_tracked": "_backbone._modules_list.2.bottlenecks.0.cv2.conv.bn.num_batches_tracked",
            "_backbone._modules_list.2.m.0.cv2.conv.bn.running_mean": "_backbone._modules_list.2.bottlenecks.0.cv2.conv.bn.running_mean",
            "_backbone._modules_list.2.m.0.cv2.conv.bn.running_var": "_backbone._modules_list.2.bottlenecks.0.cv2.conv.bn.running_var",
            "_backbone._modules_list.2.m.0.cv2.conv.bn.weight": "_backbone._modules_list.2.bottlenecks.0.cv2.conv.bn.weight",
            "_backbone._modules_list.2.m.0.cv2.conv.conv.weight": "_backbone._modules_list.2.bottlenecks.0.cv2.conv.conv.weight",
            "_backbone._modules_list.2.m.0.cv2.conv.weight": "_backbone._modules_list.2.bottlenecks.0.cv2.conv.weight",
            "_backbone._modules_list.2.m.0.cv2.dconv.bn.bias": "_backbone._modules_list.2.bottlenecks.0.cv2.dconv.bn.bias",
            "_backbone._modules_list.2.m.0.cv2.dconv.bn.num_batches_tracked": "_backbone._modules_list.2.bottlenecks.0.cv2.dconv.bn.num_batches_tracked",
            "_backbone._modules_list.2.m.0.cv2.dconv.bn.running_mean": "_backbone._modules_list.2.bottlenecks.0.cv2.dconv.bn.running_mean",
            "_backbone._modules_list.2.m.0.cv2.dconv.bn.running_var": "_backbone._modules_list.2.bottlenecks.0.cv2.dconv.bn.running_var",
            "_backbone._modules_list.2.m.0.cv2.dconv.bn.weight": "_backbone._modules_list.2.bottlenecks.0.cv2.dconv.bn.weight",
            "_backbone._modules_list.2.m.0.cv2.dconv.conv.weight": "_backbone._modules_list.2.bottlenecks.0.cv2.dconv.conv.weight",
            "_backbone._modules_list.2.m.1.cv1.bn.bias": "_backbone._modules_list.2.bottlenecks.1.cv1.bn.bias",
            "_backbone._modules_list.2.m.1.cv1.bn.num_batches_tracked": "_backbone._modules_list.2.bottlenecks.1.cv1.bn.num_batches_tracked",
            "_backbone._modules_list.2.m.1.cv1.bn.running_mean": "_backbone._modules_list.2.bottlenecks.1.cv1.bn.running_mean",
            "_backbone._modules_list.2.m.1.cv1.bn.running_var": "_backbone._modules_list.2.bottlenecks.1.cv1.bn.running_var",
            "_backbone._modules_list.2.m.1.cv1.bn.weight": "_backbone._modules_list.2.bottlenecks.1.cv1.bn.weight",
            "_backbone._modules_list.2.m.1.cv1.conv.weight": "_backbone._modules_list.2.bottlenecks.1.cv1.conv.weight",
            "_backbone._modules_list.2.m.1.cv2.bn.bias": "_backbone._modules_list.2.bottlenecks.1.cv2.bn.bias",
            "_backbone._modules_list.2.m.1.cv2.bn.num_batches_tracked": "_backbone._modules_list.2.bottlenecks.1.cv2.bn.num_batches_tracked",
            "_backbone._modules_list.2.m.1.cv2.bn.running_mean": "_backbone._modules_list.2.bottlenecks.1.cv2.bn.running_mean",
            "_backbone._modules_list.2.m.1.cv2.bn.running_var": "_backbone._modules_list.2.bottlenecks.1.cv2.bn.running_var",
            "_backbone._modules_list.2.m.1.cv2.bn.weight": "_backbone._modules_list.2.bottlenecks.1.cv2.bn.weight",
            "_backbone._modules_list.2.m.1.cv2.conv.weight": "_backbone._modules_list.2.bottlenecks.1.cv2.conv.weight",
            "_backbone._modules_list.2.m.2.cv1.bn.bias": "_backbone._modules_list.2.bottlenecks.2.cv1.bn.bias",
            "_backbone._modules_list.2.m.2.cv1.bn.num_batches_tracked": "_backbone._modules_list.2.bottlenecks.2.cv1.bn.num_batches_tracked",
            "_backbone._modules_list.2.m.2.cv1.bn.running_mean": "_backbone._modules_list.2.bottlenecks.2.cv1.bn.running_mean",
            "_backbone._modules_list.2.m.2.cv1.bn.running_var": "_backbone._modules_list.2.bottlenecks.2.cv1.bn.running_var",
            "_backbone._modules_list.2.m.2.cv1.bn.weight": "_backbone._modules_list.2.bottlenecks.2.cv1.bn.weight",
            "_backbone._modules_list.2.m.2.cv1.conv.weight": "_backbone._modules_list.2.bottlenecks.2.cv1.conv.weight",
            "_backbone._modules_list.2.m.2.cv2.bn.bias": "_backbone._modules_list.2.bottlenecks.2.cv2.bn.bias",
            "_backbone._modules_list.2.m.2.cv2.bn.num_batches_tracked": "_backbone._modules_list.2.bottlenecks.2.cv2.bn.num_batches_tracked",
            "_backbone._modules_list.2.m.2.cv2.bn.running_mean": "_backbone._modules_list.2.bottlenecks.2.cv2.bn.running_mean",
            "_backbone._modules_list.2.m.2.cv2.bn.running_var": "_backbone._modules_list.2.bottlenecks.2.cv2.bn.running_var",
            "_backbone._modules_list.2.m.2.cv2.bn.weight": "_backbone._modules_list.2.bottlenecks.2.cv2.bn.weight",
            "_backbone._modules_list.2.m.2.cv2.conv.weight": "_backbone._modules_list.2.bottlenecks.2.cv2.conv.weight",
            "_backbone._modules_list.3.bn.bias": "_backbone._modules_list.3.bn.bias",
            "_backbone._modules_list.3.bn.num_batches_tracked": "_backbone._modules_list.3.bn.num_batches_tracked",
            "_backbone._modules_list.3.bn.running_mean": "_backbone._modules_list.3.bn.running_mean",
            "_backbone._modules_list.3.bn.running_var": "_backbone._modules_list.3.bn.running_var",
            "_backbone._modules_list.3.bn.weight": "_backbone._modules_list.3.bn.weight",
            "_backbone._modules_list.3.conv.bn.bias": "_backbone._modules_list.3.conv.bn.bias",
            "_backbone._modules_list.3.conv.bn.num_batches_tracked": "_backbone._modules_list.3.conv.bn.num_batches_tracked",
            "_backbone._modules_list.3.conv.bn.running_mean": "_backbone._modules_list.3.conv.bn.running_mean",
            "_backbone._modules_list.3.conv.bn.running_var": "_backbone._modules_list.3.conv.bn.running_var",
            "_backbone._modules_list.3.conv.bn.weight": "_backbone._modules_list.3.conv.bn.weight",
            "_backbone._modules_list.3.conv.conv.weight": "_backbone._modules_list.3.conv.conv.weight",
            "_backbone._modules_list.3.conv.weight": "_backbone._modules_list.3.conv.weight",
            "_backbone._modules_list.3.dconv.bn.bias": "_backbone._modules_list.3.dconv.bn.bias",
            "_backbone._modules_list.3.dconv.bn.num_batches_tracked": "_backbone._modules_list.3.dconv.bn.num_batches_tracked",
            "_backbone._modules_list.3.dconv.bn.running_mean": "_backbone._modules_list.3.dconv.bn.running_mean",
            "_backbone._modules_list.3.dconv.bn.running_var": "_backbone._modules_list.3.dconv.bn.running_var",
            "_backbone._modules_list.3.dconv.bn.weight": "_backbone._modules_list.3.dconv.bn.weight",
            "_backbone._modules_list.3.dconv.conv.weight": "_backbone._modules_list.3.dconv.conv.weight",
            "_backbone._modules_list.4.cv1.bn.bias": "_backbone._modules_list.4.conv1.bn.bias",
            "_backbone._modules_list.4.cv1.bn.num_batches_tracked": "_backbone._modules_list.4.conv1.bn.num_batches_tracked",
            "_backbone._modules_list.4.cv1.bn.running_mean": "_backbone._modules_list.4.conv1.bn.running_mean",
            "_backbone._modules_list.4.cv1.bn.running_var": "_backbone._modules_list.4.conv1.bn.running_var",
            "_backbone._modules_list.4.cv1.bn.weight": "_backbone._modules_list.4.conv1.bn.weight",
            "_backbone._modules_list.4.cv1.conv.weight": "_backbone._modules_list.4.conv1.conv.weight",
            "_backbone._modules_list.4.cv2.bn.bias": "_backbone._modules_list.4.conv2.bn.bias",
            "_backbone._modules_list.4.cv2.bn.num_batches_tracked": "_backbone._modules_list.4.conv2.bn.num_batches_tracked",
            "_backbone._modules_list.4.cv2.bn.running_mean": "_backbone._modules_list.4.conv2.bn.running_mean",
            "_backbone._modules_list.4.cv2.bn.running_var": "_backbone._modules_list.4.conv2.bn.running_var",
            "_backbone._modules_list.4.cv2.bn.weight": "_backbone._modules_list.4.conv2.bn.weight",
            "_backbone._modules_list.4.cv2.conv.weight": "_backbone._modules_list.4.conv2.conv.weight",
            "_backbone._modules_list.4.cv3.bn.bias": "_backbone._modules_list.4.conv3.bn.bias",
            "_backbone._modules_list.4.cv3.bn.num_batches_tracked": "_backbone._modules_list.4.conv3.bn.num_batches_tracked",
            "_backbone._modules_list.4.cv3.bn.running_mean": "_backbone._modules_list.4.conv3.bn.running_mean",
            "_backbone._modules_list.4.cv3.bn.running_var": "_backbone._modules_list.4.conv3.bn.running_var",
            "_backbone._modules_list.4.cv3.bn.weight": "_backbone._modules_list.4.conv3.bn.weight",
            "_backbone._modules_list.4.cv3.conv.weight": "_backbone._modules_list.4.conv3.conv.weight",
            "_backbone._modules_list.4.m.0.cv1.bn.bias": "_backbone._modules_list.4.bottlenecks.0.cv1.bn.bias",
            "_backbone._modules_list.4.m.0.cv1.bn.num_batches_tracked": "_backbone._modules_list.4.bottlenecks.0.cv1.bn.num_batches_tracked",
            "_backbone._modules_list.4.m.0.cv1.bn.running_mean": "_backbone._modules_list.4.bottlenecks.0.cv1.bn.running_mean",
            "_backbone._modules_list.4.m.0.cv1.bn.running_var": "_backbone._modules_list.4.bottlenecks.0.cv1.bn.running_var",
            "_backbone._modules_list.4.m.0.cv1.bn.weight": "_backbone._modules_list.4.bottlenecks.0.cv1.bn.weight",
            "_backbone._modules_list.4.m.0.cv1.conv.weight": "_backbone._modules_list.4.bottlenecks.0.cv1.conv.weight",
            "_backbone._modules_list.4.m.0.cv2.bn.bias": "_backbone._modules_list.4.bottlenecks.0.cv2.bn.bias",
            "_backbone._modules_list.4.m.0.cv2.bn.num_batches_tracked": "_backbone._modules_list.4.bottlenecks.0.cv2.bn.num_batches_tracked",
            "_backbone._modules_list.4.m.0.cv2.bn.running_mean": "_backbone._modules_list.4.bottlenecks.0.cv2.bn.running_mean",
            "_backbone._modules_list.4.m.0.cv2.bn.running_var": "_backbone._modules_list.4.bottlenecks.0.cv2.bn.running_var",
            "_backbone._modules_list.4.m.0.cv2.bn.weight": "_backbone._modules_list.4.bottlenecks.0.cv2.bn.weight",
            "_backbone._modules_list.4.m.0.cv2.conv.bn.bias": "_backbone._modules_list.4.bottlenecks.0.cv2.conv.bn.bias",
            "_backbone._modules_list.4.m.0.cv2.conv.bn.num_batches_tracked": "_backbone._modules_list.4.bottlenecks.0.cv2.conv.bn.num_batches_tracked",
            "_backbone._modules_list.4.m.0.cv2.conv.bn.running_mean": "_backbone._modules_list.4.bottlenecks.0.cv2.conv.bn.running_mean",
            "_backbone._modules_list.4.m.0.cv2.conv.bn.running_var": "_backbone._modules_list.4.bottlenecks.0.cv2.conv.bn.running_var",
            "_backbone._modules_list.4.m.0.cv2.conv.bn.weight": "_backbone._modules_list.4.bottlenecks.0.cv2.conv.bn.weight",
            "_backbone._modules_list.4.m.0.cv2.conv.conv.weight": "_backbone._modules_list.4.bottlenecks.0.cv2.conv.conv.weight",
            "_backbone._modules_list.4.m.0.cv2.conv.weight": "_backbone._modules_list.4.bottlenecks.0.cv2.conv.weight",
            "_backbone._modules_list.4.m.0.cv2.dconv.bn.bias": "_backbone._modules_list.4.bottlenecks.0.cv2.dconv.bn.bias",
            "_backbone._modules_list.4.m.0.cv2.dconv.bn.num_batches_tracked": "_backbone._modules_list.4.bottlenecks.0.cv2.dconv.bn.num_batches_tracked",
            "_backbone._modules_list.4.m.0.cv2.dconv.bn.running_mean": "_backbone._modules_list.4.bottlenecks.0.cv2.dconv.bn.running_mean",
            "_backbone._modules_list.4.m.0.cv2.dconv.bn.running_var": "_backbone._modules_list.4.bottlenecks.0.cv2.dconv.bn.running_var",
            "_backbone._modules_list.4.m.0.cv2.dconv.bn.weight": "_backbone._modules_list.4.bottlenecks.0.cv2.dconv.bn.weight",
            "_backbone._modules_list.4.m.0.cv2.dconv.conv.weight": "_backbone._modules_list.4.bottlenecks.0.cv2.dconv.conv.weight",
            "_backbone._modules_list.4.m.1.cv1.bn.bias": "_backbone._modules_list.4.bottlenecks.1.cv1.bn.bias",
            "_backbone._modules_list.4.m.1.cv1.bn.num_batches_tracked": "_backbone._modules_list.4.bottlenecks.1.cv1.bn.num_batches_tracked",
            "_backbone._modules_list.4.m.1.cv1.bn.running_mean": "_backbone._modules_list.4.bottlenecks.1.cv1.bn.running_mean",
            "_backbone._modules_list.4.m.1.cv1.bn.running_var": "_backbone._modules_list.4.bottlenecks.1.cv1.bn.running_var",
            "_backbone._modules_list.4.m.1.cv1.bn.weight": "_backbone._modules_list.4.bottlenecks.1.cv1.bn.weight",
            "_backbone._modules_list.4.m.1.cv1.conv.weight": "_backbone._modules_list.4.bottlenecks.1.cv1.conv.weight",
            "_backbone._modules_list.4.m.1.cv2.bn.bias": "_backbone._modules_list.4.bottlenecks.1.cv2.bn.bias",
            "_backbone._modules_list.4.m.1.cv2.bn.num_batches_tracked": "_backbone._modules_list.4.bottlenecks.1.cv2.bn.num_batches_tracked",
            "_backbone._modules_list.4.m.1.cv2.bn.running_mean": "_backbone._modules_list.4.bottlenecks.1.cv2.bn.running_mean",
            "_backbone._modules_list.4.m.1.cv2.bn.running_var": "_backbone._modules_list.4.bottlenecks.1.cv2.bn.running_var",
            "_backbone._modules_list.4.m.1.cv2.bn.weight": "_backbone._modules_list.4.bottlenecks.1.cv2.bn.weight",
            "_backbone._modules_list.4.m.1.cv2.conv.bn.bias": "_backbone._modules_list.4.bottlenecks.1.cv2.conv.bn.bias",
            "_backbone._modules_list.4.m.1.cv2.conv.bn.num_batches_tracked": "_backbone._modules_list.4.bottlenecks.1.cv2.conv.bn.num_batches_tracked",
            "_backbone._modules_list.4.m.1.cv2.conv.bn.running_mean": "_backbone._modules_list.4.bottlenecks.1.cv2.conv.bn.running_mean",
            "_backbone._modules_list.4.m.1.cv2.conv.bn.running_var": "_backbone._modules_list.4.bottlenecks.1.cv2.conv.bn.running_var",
            "_backbone._modules_list.4.m.1.cv2.conv.bn.weight": "_backbone._modules_list.4.bottlenecks.1.cv2.conv.bn.weight",
            "_backbone._modules_list.4.m.1.cv2.conv.conv.weight": "_backbone._modules_list.4.bottlenecks.1.cv2.conv.conv.weight",
            "_backbone._modules_list.4.m.1.cv2.conv.weight": "_backbone._modules_list.4.bottlenecks.1.cv2.conv.weight",
            "_backbone._modules_list.4.m.1.cv2.dconv.bn.bias": "_backbone._modules_list.4.bottlenecks.1.cv2.dconv.bn.bias",
            "_backbone._modules_list.4.m.1.cv2.dconv.bn.num_batches_tracked": "_backbone._modules_list.4.bottlenecks.1.cv2.dconv.bn.num_batches_tracked",
            "_backbone._modules_list.4.m.1.cv2.dconv.bn.running_mean": "_backbone._modules_list.4.bottlenecks.1.cv2.dconv.bn.running_mean",
            "_backbone._modules_list.4.m.1.cv2.dconv.bn.running_var": "_backbone._modules_list.4.bottlenecks.1.cv2.dconv.bn.running_var",
            "_backbone._modules_list.4.m.1.cv2.dconv.bn.weight": "_backbone._modules_list.4.bottlenecks.1.cv2.dconv.bn.weight",
            "_backbone._modules_list.4.m.1.cv2.dconv.conv.weight": "_backbone._modules_list.4.bottlenecks.1.cv2.dconv.conv.weight",
            "_backbone._modules_list.4.m.2.cv1.bn.bias": "_backbone._modules_list.4.bottlenecks.2.cv1.bn.bias",
            "_backbone._modules_list.4.m.2.cv1.bn.num_batches_tracked": "_backbone._modules_list.4.bottlenecks.2.cv1.bn.num_batches_tracked",
            "_backbone._modules_list.4.m.2.cv1.bn.running_mean": "_backbone._modules_list.4.bottlenecks.2.cv1.bn.running_mean",
            "_backbone._modules_list.4.m.2.cv1.bn.running_var": "_backbone._modules_list.4.bottlenecks.2.cv1.bn.running_var",
            "_backbone._modules_list.4.m.2.cv1.bn.weight": "_backbone._modules_list.4.bottlenecks.2.cv1.bn.weight",
            "_backbone._modules_list.4.m.2.cv1.conv.weight": "_backbone._modules_list.4.bottlenecks.2.cv1.conv.weight",
            "_backbone._modules_list.4.m.2.cv2.bn.bias": "_backbone._modules_list.4.bottlenecks.2.cv2.bn.bias",
            "_backbone._modules_list.4.m.2.cv2.bn.num_batches_tracked": "_backbone._modules_list.4.bottlenecks.2.cv2.bn.num_batches_tracked",
            "_backbone._modules_list.4.m.2.cv2.bn.running_mean": "_backbone._modules_list.4.bottlenecks.2.cv2.bn.running_mean",
            "_backbone._modules_list.4.m.2.cv2.bn.running_var": "_backbone._modules_list.4.bottlenecks.2.cv2.bn.running_var",
            "_backbone._modules_list.4.m.2.cv2.bn.weight": "_backbone._modules_list.4.bottlenecks.2.cv2.bn.weight",
            "_backbone._modules_list.4.m.2.cv2.conv.bn.bias": "_backbone._modules_list.4.bottlenecks.2.cv2.conv.bn.bias",
            "_backbone._modules_list.4.m.2.cv2.conv.bn.num_batches_tracked": "_backbone._modules_list.4.bottlenecks.2.cv2.conv.bn.num_batches_tracked",
            "_backbone._modules_list.4.m.2.cv2.conv.bn.running_mean": "_backbone._modules_list.4.bottlenecks.2.cv2.conv.bn.running_mean",
            "_backbone._modules_list.4.m.2.cv2.conv.bn.running_var": "_backbone._modules_list.4.bottlenecks.2.cv2.conv.bn.running_var",
            "_backbone._modules_list.4.m.2.cv2.conv.bn.weight": "_backbone._modules_list.4.bottlenecks.2.cv2.conv.bn.weight",
            "_backbone._modules_list.4.m.2.cv2.conv.conv.weight": "_backbone._modules_list.4.bottlenecks.2.cv2.conv.conv.weight",
            "_backbone._modules_list.4.m.2.cv2.conv.weight": "_backbone._modules_list.4.bottlenecks.2.cv2.conv.weight",
            "_backbone._modules_list.4.m.2.cv2.dconv.bn.bias": "_backbone._modules_list.4.bottlenecks.2.cv2.dconv.bn.bias",
            "_backbone._modules_list.4.m.2.cv2.dconv.bn.num_batches_tracked": "_backbone._modules_list.4.bottlenecks.2.cv2.dconv.bn.num_batches_tracked",
            "_backbone._modules_list.4.m.2.cv2.dconv.bn.running_mean": "_backbone._modules_list.4.bottlenecks.2.cv2.dconv.bn.running_mean",
            "_backbone._modules_list.4.m.2.cv2.dconv.bn.running_var": "_backbone._modules_list.4.bottlenecks.2.cv2.dconv.bn.running_var",
            "_backbone._modules_list.4.m.2.cv2.dconv.bn.weight": "_backbone._modules_list.4.bottlenecks.2.cv2.dconv.bn.weight",
            "_backbone._modules_list.4.m.2.cv2.dconv.conv.weight": "_backbone._modules_list.4.bottlenecks.2.cv2.dconv.conv.weight",
            "_backbone._modules_list.4.m.3.cv1.bn.bias": "_backbone._modules_list.4.bottlenecks.3.cv1.bn.bias",
            "_backbone._modules_list.4.m.3.cv1.bn.num_batches_tracked": "_backbone._modules_list.4.bottlenecks.3.cv1.bn.num_batches_tracked",
            "_backbone._modules_list.4.m.3.cv1.bn.running_mean": "_backbone._modules_list.4.bottlenecks.3.cv1.bn.running_mean",
            "_backbone._modules_list.4.m.3.cv1.bn.running_var": "_backbone._modules_list.4.bottlenecks.3.cv1.bn.running_var",
            "_backbone._modules_list.4.m.3.cv1.bn.weight": "_backbone._modules_list.4.bottlenecks.3.cv1.bn.weight",
            "_backbone._modules_list.4.m.3.cv1.conv.weight": "_backbone._modules_list.4.bottlenecks.3.cv1.conv.weight",
            "_backbone._modules_list.4.m.3.cv2.bn.bias": "_backbone._modules_list.4.bottlenecks.3.cv2.bn.bias",
            "_backbone._modules_list.4.m.3.cv2.bn.num_batches_tracked": "_backbone._modules_list.4.bottlenecks.3.cv2.bn.num_batches_tracked",
            "_backbone._modules_list.4.m.3.cv2.bn.running_mean": "_backbone._modules_list.4.bottlenecks.3.cv2.bn.running_mean",
            "_backbone._modules_list.4.m.3.cv2.bn.running_var": "_backbone._modules_list.4.bottlenecks.3.cv2.bn.running_var",
            "_backbone._modules_list.4.m.3.cv2.bn.weight": "_backbone._modules_list.4.bottlenecks.3.cv2.bn.weight",
            "_backbone._modules_list.4.m.3.cv2.conv.weight": "_backbone._modules_list.4.bottlenecks.3.cv2.conv.weight",
            "_backbone._modules_list.4.m.4.cv1.bn.bias": "_backbone._modules_list.4.bottlenecks.4.cv1.bn.bias",
            "_backbone._modules_list.4.m.4.cv1.bn.num_batches_tracked": "_backbone._modules_list.4.bottlenecks.4.cv1.bn.num_batches_tracked",
            "_backbone._modules_list.4.m.4.cv1.bn.running_mean": "_backbone._modules_list.4.bottlenecks.4.cv1.bn.running_mean",
            "_backbone._modules_list.4.m.4.cv1.bn.running_var": "_backbone._modules_list.4.bottlenecks.4.cv1.bn.running_var",
            "_backbone._modules_list.4.m.4.cv1.bn.weight": "_backbone._modules_list.4.bottlenecks.4.cv1.bn.weight",
            "_backbone._modules_list.4.m.4.cv1.conv.weight": "_backbone._modules_list.4.bottlenecks.4.cv1.conv.weight",
            "_backbone._modules_list.4.m.4.cv2.bn.bias": "_backbone._modules_list.4.bottlenecks.4.cv2.bn.bias",
            "_backbone._modules_list.4.m.4.cv2.bn.num_batches_tracked": "_backbone._modules_list.4.bottlenecks.4.cv2.bn.num_batches_tracked",
            "_backbone._modules_list.4.m.4.cv2.bn.running_mean": "_backbone._modules_list.4.bottlenecks.4.cv2.bn.running_mean",
            "_backbone._modules_list.4.m.4.cv2.bn.running_var": "_backbone._modules_list.4.bottlenecks.4.cv2.bn.running_var",
            "_backbone._modules_list.4.m.4.cv2.bn.weight": "_backbone._modules_list.4.bottlenecks.4.cv2.bn.weight",
            "_backbone._modules_list.4.m.4.cv2.conv.weight": "_backbone._modules_list.4.bottlenecks.4.cv2.conv.weight",
            "_backbone._modules_list.4.m.5.cv1.bn.bias": "_backbone._modules_list.4.bottlenecks.5.cv1.bn.bias",
            "_backbone._modules_list.4.m.5.cv1.bn.num_batches_tracked": "_backbone._modules_list.4.bottlenecks.5.cv1.bn.num_batches_tracked",
            "_backbone._modules_list.4.m.5.cv1.bn.running_mean": "_backbone._modules_list.4.bottlenecks.5.cv1.bn.running_mean",
            "_backbone._modules_list.4.m.5.cv1.bn.running_var": "_backbone._modules_list.4.bottlenecks.5.cv1.bn.running_var",
            "_backbone._modules_list.4.m.5.cv1.bn.weight": "_backbone._modules_list.4.bottlenecks.5.cv1.bn.weight",
            "_backbone._modules_list.4.m.5.cv1.conv.weight": "_backbone._modules_list.4.bottlenecks.5.cv1.conv.weight",
            "_backbone._modules_list.4.m.5.cv2.bn.bias": "_backbone._modules_list.4.bottlenecks.5.cv2.bn.bias",
            "_backbone._modules_list.4.m.5.cv2.bn.num_batches_tracked": "_backbone._modules_list.4.bottlenecks.5.cv2.bn.num_batches_tracked",
            "_backbone._modules_list.4.m.5.cv2.bn.running_mean": "_backbone._modules_list.4.bottlenecks.5.cv2.bn.running_mean",
            "_backbone._modules_list.4.m.5.cv2.bn.running_var": "_backbone._modules_list.4.bottlenecks.5.cv2.bn.running_var",
            "_backbone._modules_list.4.m.5.cv2.bn.weight": "_backbone._modules_list.4.bottlenecks.5.cv2.bn.weight",
            "_backbone._modules_list.4.m.5.cv2.conv.weight": "_backbone._modules_list.4.bottlenecks.5.cv2.conv.weight",
            "_backbone._modules_list.4.m.6.cv1.bn.bias": "_backbone._modules_list.4.bottlenecks.6.cv1.bn.bias",
            "_backbone._modules_list.4.m.6.cv1.bn.num_batches_tracked": "_backbone._modules_list.4.bottlenecks.6.cv1.bn.num_batches_tracked",
            "_backbone._modules_list.4.m.6.cv1.bn.running_mean": "_backbone._modules_list.4.bottlenecks.6.cv1.bn.running_mean",
            "_backbone._modules_list.4.m.6.cv1.bn.running_var": "_backbone._modules_list.4.bottlenecks.6.cv1.bn.running_var",
            "_backbone._modules_list.4.m.6.cv1.bn.weight": "_backbone._modules_list.4.bottlenecks.6.cv1.bn.weight",
            "_backbone._modules_list.4.m.6.cv1.conv.weight": "_backbone._modules_list.4.bottlenecks.6.cv1.conv.weight",
            "_backbone._modules_list.4.m.6.cv2.bn.bias": "_backbone._modules_list.4.bottlenecks.6.cv2.bn.bias",
            "_backbone._modules_list.4.m.6.cv2.bn.num_batches_tracked": "_backbone._modules_list.4.bottlenecks.6.cv2.bn.num_batches_tracked",
            "_backbone._modules_list.4.m.6.cv2.bn.running_mean": "_backbone._modules_list.4.bottlenecks.6.cv2.bn.running_mean",
            "_backbone._modules_list.4.m.6.cv2.bn.running_var": "_backbone._modules_list.4.bottlenecks.6.cv2.bn.running_var",
            "_backbone._modules_list.4.m.6.cv2.bn.weight": "_backbone._modules_list.4.bottlenecks.6.cv2.bn.weight",
            "_backbone._modules_list.4.m.6.cv2.conv.weight": "_backbone._modules_list.4.bottlenecks.6.cv2.conv.weight",
            "_backbone._modules_list.4.m.7.cv1.bn.bias": "_backbone._modules_list.4.bottlenecks.7.cv1.bn.bias",
            "_backbone._modules_list.4.m.7.cv1.bn.num_batches_tracked": "_backbone._modules_list.4.bottlenecks.7.cv1.bn.num_batches_tracked",
            "_backbone._modules_list.4.m.7.cv1.bn.running_mean": "_backbone._modules_list.4.bottlenecks.7.cv1.bn.running_mean",
            "_backbone._modules_list.4.m.7.cv1.bn.running_var": "_backbone._modules_list.4.bottlenecks.7.cv1.bn.running_var",
            "_backbone._modules_list.4.m.7.cv1.bn.weight": "_backbone._modules_list.4.bottlenecks.7.cv1.bn.weight",
            "_backbone._modules_list.4.m.7.cv1.conv.weight": "_backbone._modules_list.4.bottlenecks.7.cv1.conv.weight",
            "_backbone._modules_list.4.m.7.cv2.bn.bias": "_backbone._modules_list.4.bottlenecks.7.cv2.bn.bias",
            "_backbone._modules_list.4.m.7.cv2.bn.num_batches_tracked": "_backbone._modules_list.4.bottlenecks.7.cv2.bn.num_batches_tracked",
            "_backbone._modules_list.4.m.7.cv2.bn.running_mean": "_backbone._modules_list.4.bottlenecks.7.cv2.bn.running_mean",
            "_backbone._modules_list.4.m.7.cv2.bn.running_var": "_backbone._modules_list.4.bottlenecks.7.cv2.bn.running_var",
            "_backbone._modules_list.4.m.7.cv2.bn.weight": "_backbone._modules_list.4.bottlenecks.7.cv2.bn.weight",
            "_backbone._modules_list.4.m.7.cv2.conv.weight": "_backbone._modules_list.4.bottlenecks.7.cv2.conv.weight",
            "_backbone._modules_list.4.m.8.cv1.bn.bias": "_backbone._modules_list.4.bottlenecks.8.cv1.bn.bias",
            "_backbone._modules_list.4.m.8.cv1.bn.num_batches_tracked": "_backbone._modules_list.4.bottlenecks.8.cv1.bn.num_batches_tracked",
            "_backbone._modules_list.4.m.8.cv1.bn.running_mean": "_backbone._modules_list.4.bottlenecks.8.cv1.bn.running_mean",
            "_backbone._modules_list.4.m.8.cv1.bn.running_var": "_backbone._modules_list.4.bottlenecks.8.cv1.bn.running_var",
            "_backbone._modules_list.4.m.8.cv1.bn.weight": "_backbone._modules_list.4.bottlenecks.8.cv1.bn.weight",
            "_backbone._modules_list.4.m.8.cv1.conv.weight": "_backbone._modules_list.4.bottlenecks.8.cv1.conv.weight",
            "_backbone._modules_list.4.m.8.cv2.bn.bias": "_backbone._modules_list.4.bottlenecks.8.cv2.bn.bias",
            "_backbone._modules_list.4.m.8.cv2.bn.num_batches_tracked": "_backbone._modules_list.4.bottlenecks.8.cv2.bn.num_batches_tracked",
            "_backbone._modules_list.4.m.8.cv2.bn.running_mean": "_backbone._modules_list.4.bottlenecks.8.cv2.bn.running_mean",
            "_backbone._modules_list.4.m.8.cv2.bn.running_var": "_backbone._modules_list.4.bottlenecks.8.cv2.bn.running_var",
            "_backbone._modules_list.4.m.8.cv2.bn.weight": "_backbone._modules_list.4.bottlenecks.8.cv2.bn.weight",
            "_backbone._modules_list.4.m.8.cv2.conv.weight": "_backbone._modules_list.4.bottlenecks.8.cv2.conv.weight",
            "_backbone._modules_list.5.bn.bias": "_backbone._modules_list.5.bn.bias",
            "_backbone._modules_list.5.bn.num_batches_tracked": "_backbone._modules_list.5.bn.num_batches_tracked",
            "_backbone._modules_list.5.bn.running_mean": "_backbone._modules_list.5.bn.running_mean",
            "_backbone._modules_list.5.bn.running_var": "_backbone._modules_list.5.bn.running_var",
            "_backbone._modules_list.5.bn.weight": "_backbone._modules_list.5.bn.weight",
            "_backbone._modules_list.5.conv.bn.bias": "_backbone._modules_list.5.conv.bn.bias",
            "_backbone._modules_list.5.conv.bn.num_batches_tracked": "_backbone._modules_list.5.conv.bn.num_batches_tracked",
            "_backbone._modules_list.5.conv.bn.running_mean": "_backbone._modules_list.5.conv.bn.running_mean",
            "_backbone._modules_list.5.conv.bn.running_var": "_backbone._modules_list.5.conv.bn.running_var",
            "_backbone._modules_list.5.conv.bn.weight": "_backbone._modules_list.5.conv.bn.weight",
            "_backbone._modules_list.5.conv.conv.weight": "_backbone._modules_list.5.conv.conv.weight",
            "_backbone._modules_list.5.conv.weight": "_backbone._modules_list.5.conv.weight",
            "_backbone._modules_list.5.dconv.bn.bias": "_backbone._modules_list.5.dconv.bn.bias",
            "_backbone._modules_list.5.dconv.bn.num_batches_tracked": "_backbone._modules_list.5.dconv.bn.num_batches_tracked",
            "_backbone._modules_list.5.dconv.bn.running_mean": "_backbone._modules_list.5.dconv.bn.running_mean",
            "_backbone._modules_list.5.dconv.bn.running_var": "_backbone._modules_list.5.dconv.bn.running_var",
            "_backbone._modules_list.5.dconv.bn.weight": "_backbone._modules_list.5.dconv.bn.weight",
            "_backbone._modules_list.5.dconv.conv.weight": "_backbone._modules_list.5.dconv.conv.weight",
            "_backbone._modules_list.6.cv1.bn.bias": "_backbone._modules_list.6.conv1.bn.bias",
            "_backbone._modules_list.6.cv1.bn.num_batches_tracked": "_backbone._modules_list.6.conv1.bn.num_batches_tracked",
            "_backbone._modules_list.6.cv1.bn.running_mean": "_backbone._modules_list.6.conv1.bn.running_mean",
            "_backbone._modules_list.6.cv1.bn.running_var": "_backbone._modules_list.6.conv1.bn.running_var",
            "_backbone._modules_list.6.cv1.bn.weight": "_backbone._modules_list.6.conv1.bn.weight",
            "_backbone._modules_list.6.cv1.conv.weight": "_backbone._modules_list.6.conv1.conv.weight",
            "_backbone._modules_list.6.cv2.bn.bias": "_backbone._modules_list.6.conv2.bn.bias",
            "_backbone._modules_list.6.cv2.bn.num_batches_tracked": "_backbone._modules_list.6.conv2.bn.num_batches_tracked",
            "_backbone._modules_list.6.cv2.bn.running_mean": "_backbone._modules_list.6.conv2.bn.running_mean",
            "_backbone._modules_list.6.cv2.bn.running_var": "_backbone._modules_list.6.conv2.bn.running_var",
            "_backbone._modules_list.6.cv2.bn.weight": "_backbone._modules_list.6.conv2.bn.weight",
            "_backbone._modules_list.6.cv2.conv.weight": "_backbone._modules_list.6.conv2.conv.weight",
            "_backbone._modules_list.6.cv3.bn.bias": "_backbone._modules_list.6.conv3.bn.bias",
            "_backbone._modules_list.6.cv3.bn.num_batches_tracked": "_backbone._modules_list.6.conv3.bn.num_batches_tracked",
            "_backbone._modules_list.6.cv3.bn.running_mean": "_backbone._modules_list.6.conv3.bn.running_mean",
            "_backbone._modules_list.6.cv3.bn.running_var": "_backbone._modules_list.6.conv3.bn.running_var",
            "_backbone._modules_list.6.cv3.bn.weight": "_backbone._modules_list.6.conv3.bn.weight",
            "_backbone._modules_list.6.cv3.conv.weight": "_backbone._modules_list.6.conv3.conv.weight",
            "_backbone._modules_list.6.m.0.cv1.bn.bias": "_backbone._modules_list.6.bottlenecks.0.cv1.bn.bias",
            "_backbone._modules_list.6.m.0.cv1.bn.num_batches_tracked": "_backbone._modules_list.6.bottlenecks.0.cv1.bn.num_batches_tracked",
            "_backbone._modules_list.6.m.0.cv1.bn.running_mean": "_backbone._modules_list.6.bottlenecks.0.cv1.bn.running_mean",
            "_backbone._modules_list.6.m.0.cv1.bn.running_var": "_backbone._modules_list.6.bottlenecks.0.cv1.bn.running_var",
            "_backbone._modules_list.6.m.0.cv1.bn.weight": "_backbone._modules_list.6.bottlenecks.0.cv1.bn.weight",
            "_backbone._modules_list.6.m.0.cv1.conv.weight": "_backbone._modules_list.6.bottlenecks.0.cv1.conv.weight",
            "_backbone._modules_list.6.m.0.cv2.bn.bias": "_backbone._modules_list.6.bottlenecks.0.cv2.bn.bias",
            "_backbone._modules_list.6.m.0.cv2.bn.num_batches_tracked": "_backbone._modules_list.6.bottlenecks.0.cv2.bn.num_batches_tracked",
            "_backbone._modules_list.6.m.0.cv2.bn.running_mean": "_backbone._modules_list.6.bottlenecks.0.cv2.bn.running_mean",
            "_backbone._modules_list.6.m.0.cv2.bn.running_var": "_backbone._modules_list.6.bottlenecks.0.cv2.bn.running_var",
            "_backbone._modules_list.6.m.0.cv2.bn.weight": "_backbone._modules_list.6.bottlenecks.0.cv2.bn.weight",
            "_backbone._modules_list.6.m.0.cv2.conv.bn.bias": "_backbone._modules_list.6.bottlenecks.0.cv2.conv.bn.bias",
            "_backbone._modules_list.6.m.0.cv2.conv.bn.num_batches_tracked": "_backbone._modules_list.6.bottlenecks.0.cv2.conv.bn.num_batches_tracked",
            "_backbone._modules_list.6.m.0.cv2.conv.bn.running_mean": "_backbone._modules_list.6.bottlenecks.0.cv2.conv.bn.running_mean",
            "_backbone._modules_list.6.m.0.cv2.conv.bn.running_var": "_backbone._modules_list.6.bottlenecks.0.cv2.conv.bn.running_var",
            "_backbone._modules_list.6.m.0.cv2.conv.bn.weight": "_backbone._modules_list.6.bottlenecks.0.cv2.conv.bn.weight",
            "_backbone._modules_list.6.m.0.cv2.conv.conv.weight": "_backbone._modules_list.6.bottlenecks.0.cv2.conv.conv.weight",
            "_backbone._modules_list.6.m.0.cv2.conv.weight": "_backbone._modules_list.6.bottlenecks.0.cv2.conv.weight",
            "_backbone._modules_list.6.m.0.cv2.dconv.bn.bias": "_backbone._modules_list.6.bottlenecks.0.cv2.dconv.bn.bias",
            "_backbone._modules_list.6.m.0.cv2.dconv.bn.num_batches_tracked": "_backbone._modules_list.6.bottlenecks.0.cv2.dconv.bn.num_batches_tracked",
            "_backbone._modules_list.6.m.0.cv2.dconv.bn.running_mean": "_backbone._modules_list.6.bottlenecks.0.cv2.dconv.bn.running_mean",
            "_backbone._modules_list.6.m.0.cv2.dconv.bn.running_var": "_backbone._modules_list.6.bottlenecks.0.cv2.dconv.bn.running_var",
            "_backbone._modules_list.6.m.0.cv2.dconv.bn.weight": "_backbone._modules_list.6.bottlenecks.0.cv2.dconv.bn.weight",
            "_backbone._modules_list.6.m.0.cv2.dconv.conv.weight": "_backbone._modules_list.6.bottlenecks.0.cv2.dconv.conv.weight",
            "_backbone._modules_list.6.m.1.cv1.bn.bias": "_backbone._modules_list.6.bottlenecks.1.cv1.bn.bias",
            "_backbone._modules_list.6.m.1.cv1.bn.num_batches_tracked": "_backbone._modules_list.6.bottlenecks.1.cv1.bn.num_batches_tracked",
            "_backbone._modules_list.6.m.1.cv1.bn.running_mean": "_backbone._modules_list.6.bottlenecks.1.cv1.bn.running_mean",
            "_backbone._modules_list.6.m.1.cv1.bn.running_var": "_backbone._modules_list.6.bottlenecks.1.cv1.bn.running_var",
            "_backbone._modules_list.6.m.1.cv1.bn.weight": "_backbone._modules_list.6.bottlenecks.1.cv1.bn.weight",
            "_backbone._modules_list.6.m.1.cv1.conv.weight": "_backbone._modules_list.6.bottlenecks.1.cv1.conv.weight",
            "_backbone._modules_list.6.m.1.cv2.bn.bias": "_backbone._modules_list.6.bottlenecks.1.cv2.bn.bias",
            "_backbone._modules_list.6.m.1.cv2.bn.num_batches_tracked": "_backbone._modules_list.6.bottlenecks.1.cv2.bn.num_batches_tracked",
            "_backbone._modules_list.6.m.1.cv2.bn.running_mean": "_backbone._modules_list.6.bottlenecks.1.cv2.bn.running_mean",
            "_backbone._modules_list.6.m.1.cv2.bn.running_var": "_backbone._modules_list.6.bottlenecks.1.cv2.bn.running_var",
            "_backbone._modules_list.6.m.1.cv2.bn.weight": "_backbone._modules_list.6.bottlenecks.1.cv2.bn.weight",
            "_backbone._modules_list.6.m.1.cv2.conv.bn.bias": "_backbone._modules_list.6.bottlenecks.1.cv2.conv.bn.bias",
            "_backbone._modules_list.6.m.1.cv2.conv.bn.num_batches_tracked": "_backbone._modules_list.6.bottlenecks.1.cv2.conv.bn.num_batches_tracked",
            "_backbone._modules_list.6.m.1.cv2.conv.bn.running_mean": "_backbone._modules_list.6.bottlenecks.1.cv2.conv.bn.running_mean",
            "_backbone._modules_list.6.m.1.cv2.conv.bn.running_var": "_backbone._modules_list.6.bottlenecks.1.cv2.conv.bn.running_var",
            "_backbone._modules_list.6.m.1.cv2.conv.bn.weight": "_backbone._modules_list.6.bottlenecks.1.cv2.conv.bn.weight",
            "_backbone._modules_list.6.m.1.cv2.conv.conv.weight": "_backbone._modules_list.6.bottlenecks.1.cv2.conv.conv.weight",
            "_backbone._modules_list.6.m.1.cv2.conv.weight": "_backbone._modules_list.6.bottlenecks.1.cv2.conv.weight",
            "_backbone._modules_list.6.m.1.cv2.dconv.bn.bias": "_backbone._modules_list.6.bottlenecks.1.cv2.dconv.bn.bias",
            "_backbone._modules_list.6.m.1.cv2.dconv.bn.num_batches_tracked": "_backbone._modules_list.6.bottlenecks.1.cv2.dconv.bn.num_batches_tracked",
            "_backbone._modules_list.6.m.1.cv2.dconv.bn.running_mean": "_backbone._modules_list.6.bottlenecks.1.cv2.dconv.bn.running_mean",
            "_backbone._modules_list.6.m.1.cv2.dconv.bn.running_var": "_backbone._modules_list.6.bottlenecks.1.cv2.dconv.bn.running_var",
            "_backbone._modules_list.6.m.1.cv2.dconv.bn.weight": "_backbone._modules_list.6.bottlenecks.1.cv2.dconv.bn.weight",
            "_backbone._modules_list.6.m.1.cv2.dconv.conv.weight": "_backbone._modules_list.6.bottlenecks.1.cv2.dconv.conv.weight",
            "_backbone._modules_list.6.m.2.cv1.bn.bias": "_backbone._modules_list.6.bottlenecks.2.cv1.bn.bias",
            "_backbone._modules_list.6.m.2.cv1.bn.num_batches_tracked": "_backbone._modules_list.6.bottlenecks.2.cv1.bn.num_batches_tracked",
            "_backbone._modules_list.6.m.2.cv1.bn.running_mean": "_backbone._modules_list.6.bottlenecks.2.cv1.bn.running_mean",
            "_backbone._modules_list.6.m.2.cv1.bn.running_var": "_backbone._modules_list.6.bottlenecks.2.cv1.bn.running_var",
            "_backbone._modules_list.6.m.2.cv1.bn.weight": "_backbone._modules_list.6.bottlenecks.2.cv1.bn.weight",
            "_backbone._modules_list.6.m.2.cv1.conv.weight": "_backbone._modules_list.6.bottlenecks.2.cv1.conv.weight",
            "_backbone._modules_list.6.m.2.cv2.bn.bias": "_backbone._modules_list.6.bottlenecks.2.cv2.bn.bias",
            "_backbone._modules_list.6.m.2.cv2.bn.num_batches_tracked": "_backbone._modules_list.6.bottlenecks.2.cv2.bn.num_batches_tracked",
            "_backbone._modules_list.6.m.2.cv2.bn.running_mean": "_backbone._modules_list.6.bottlenecks.2.cv2.bn.running_mean",
            "_backbone._modules_list.6.m.2.cv2.bn.running_var": "_backbone._modules_list.6.bottlenecks.2.cv2.bn.running_var",
            "_backbone._modules_list.6.m.2.cv2.bn.weight": "_backbone._modules_list.6.bottlenecks.2.cv2.bn.weight",
            "_backbone._modules_list.6.m.2.cv2.conv.bn.bias": "_backbone._modules_list.6.bottlenecks.2.cv2.conv.bn.bias",
            "_backbone._modules_list.6.m.2.cv2.conv.bn.num_batches_tracked": "_backbone._modules_list.6.bottlenecks.2.cv2.conv.bn.num_batches_tracked",
            "_backbone._modules_list.6.m.2.cv2.conv.bn.running_mean": "_backbone._modules_list.6.bottlenecks.2.cv2.conv.bn.running_mean",
            "_backbone._modules_list.6.m.2.cv2.conv.bn.running_var": "_backbone._modules_list.6.bottlenecks.2.cv2.conv.bn.running_var",
            "_backbone._modules_list.6.m.2.cv2.conv.bn.weight": "_backbone._modules_list.6.bottlenecks.2.cv2.conv.bn.weight",
            "_backbone._modules_list.6.m.2.cv2.conv.conv.weight": "_backbone._modules_list.6.bottlenecks.2.cv2.conv.conv.weight",
            "_backbone._modules_list.6.m.2.cv2.conv.weight": "_backbone._modules_list.6.bottlenecks.2.cv2.conv.weight",
            "_backbone._modules_list.6.m.2.cv2.dconv.bn.bias": "_backbone._modules_list.6.bottlenecks.2.cv2.dconv.bn.bias",
            "_backbone._modules_list.6.m.2.cv2.dconv.bn.num_batches_tracked": "_backbone._modules_list.6.bottlenecks.2.cv2.dconv.bn.num_batches_tracked",
            "_backbone._modules_list.6.m.2.cv2.dconv.bn.running_mean": "_backbone._modules_list.6.bottlenecks.2.cv2.dconv.bn.running_mean",
            "_backbone._modules_list.6.m.2.cv2.dconv.bn.running_var": "_backbone._modules_list.6.bottlenecks.2.cv2.dconv.bn.running_var",
            "_backbone._modules_list.6.m.2.cv2.dconv.bn.weight": "_backbone._modules_list.6.bottlenecks.2.cv2.dconv.bn.weight",
            "_backbone._modules_list.6.m.2.cv2.dconv.conv.weight": "_backbone._modules_list.6.bottlenecks.2.cv2.dconv.conv.weight",
            "_backbone._modules_list.6.m.3.cv1.bn.bias": "_backbone._modules_list.6.bottlenecks.3.cv1.bn.bias",
            "_backbone._modules_list.6.m.3.cv1.bn.num_batches_tracked": "_backbone._modules_list.6.bottlenecks.3.cv1.bn.num_batches_tracked",
            "_backbone._modules_list.6.m.3.cv1.bn.running_mean": "_backbone._modules_list.6.bottlenecks.3.cv1.bn.running_mean",
            "_backbone._modules_list.6.m.3.cv1.bn.running_var": "_backbone._modules_list.6.bottlenecks.3.cv1.bn.running_var",
            "_backbone._modules_list.6.m.3.cv1.bn.weight": "_backbone._modules_list.6.bottlenecks.3.cv1.bn.weight",
            "_backbone._modules_list.6.m.3.cv1.conv.weight": "_backbone._modules_list.6.bottlenecks.3.cv1.conv.weight",
            "_backbone._modules_list.6.m.3.cv2.bn.bias": "_backbone._modules_list.6.bottlenecks.3.cv2.bn.bias",
            "_backbone._modules_list.6.m.3.cv2.bn.num_batches_tracked": "_backbone._modules_list.6.bottlenecks.3.cv2.bn.num_batches_tracked",
            "_backbone._modules_list.6.m.3.cv2.bn.running_mean": "_backbone._modules_list.6.bottlenecks.3.cv2.bn.running_mean",
            "_backbone._modules_list.6.m.3.cv2.bn.running_var": "_backbone._modules_list.6.bottlenecks.3.cv2.bn.running_var",
            "_backbone._modules_list.6.m.3.cv2.bn.weight": "_backbone._modules_list.6.bottlenecks.3.cv2.bn.weight",
            "_backbone._modules_list.6.m.3.cv2.conv.weight": "_backbone._modules_list.6.bottlenecks.3.cv2.conv.weight",
            "_backbone._modules_list.6.m.4.cv1.bn.bias": "_backbone._modules_list.6.bottlenecks.4.cv1.bn.bias",
            "_backbone._modules_list.6.m.4.cv1.bn.num_batches_tracked": "_backbone._modules_list.6.bottlenecks.4.cv1.bn.num_batches_tracked",
            "_backbone._modules_list.6.m.4.cv1.bn.running_mean": "_backbone._modules_list.6.bottlenecks.4.cv1.bn.running_mean",
            "_backbone._modules_list.6.m.4.cv1.bn.running_var": "_backbone._modules_list.6.bottlenecks.4.cv1.bn.running_var",
            "_backbone._modules_list.6.m.4.cv1.bn.weight": "_backbone._modules_list.6.bottlenecks.4.cv1.bn.weight",
            "_backbone._modules_list.6.m.4.cv1.conv.weight": "_backbone._modules_list.6.bottlenecks.4.cv1.conv.weight",
            "_backbone._modules_list.6.m.4.cv2.bn.bias": "_backbone._modules_list.6.bottlenecks.4.cv2.bn.bias",
            "_backbone._modules_list.6.m.4.cv2.bn.num_batches_tracked": "_backbone._modules_list.6.bottlenecks.4.cv2.bn.num_batches_tracked",
            "_backbone._modules_list.6.m.4.cv2.bn.running_mean": "_backbone._modules_list.6.bottlenecks.4.cv2.bn.running_mean",
            "_backbone._modules_list.6.m.4.cv2.bn.running_var": "_backbone._modules_list.6.bottlenecks.4.cv2.bn.running_var",
            "_backbone._modules_list.6.m.4.cv2.bn.weight": "_backbone._modules_list.6.bottlenecks.4.cv2.bn.weight",
            "_backbone._modules_list.6.m.4.cv2.conv.weight": "_backbone._modules_list.6.bottlenecks.4.cv2.conv.weight",
            "_backbone._modules_list.6.m.5.cv1.bn.bias": "_backbone._modules_list.6.bottlenecks.5.cv1.bn.bias",
            "_backbone._modules_list.6.m.5.cv1.bn.num_batches_tracked": "_backbone._modules_list.6.bottlenecks.5.cv1.bn.num_batches_tracked",
            "_backbone._modules_list.6.m.5.cv1.bn.running_mean": "_backbone._modules_list.6.bottlenecks.5.cv1.bn.running_mean",
            "_backbone._modules_list.6.m.5.cv1.bn.running_var": "_backbone._modules_list.6.bottlenecks.5.cv1.bn.running_var",
            "_backbone._modules_list.6.m.5.cv1.bn.weight": "_backbone._modules_list.6.bottlenecks.5.cv1.bn.weight",
            "_backbone._modules_list.6.m.5.cv1.conv.weight": "_backbone._modules_list.6.bottlenecks.5.cv1.conv.weight",
            "_backbone._modules_list.6.m.5.cv2.bn.bias": "_backbone._modules_list.6.bottlenecks.5.cv2.bn.bias",
            "_backbone._modules_list.6.m.5.cv2.bn.num_batches_tracked": "_backbone._modules_list.6.bottlenecks.5.cv2.bn.num_batches_tracked",
            "_backbone._modules_list.6.m.5.cv2.bn.running_mean": "_backbone._modules_list.6.bottlenecks.5.cv2.bn.running_mean",
            "_backbone._modules_list.6.m.5.cv2.bn.running_var": "_backbone._modules_list.6.bottlenecks.5.cv2.bn.running_var",
            "_backbone._modules_list.6.m.5.cv2.bn.weight": "_backbone._modules_list.6.bottlenecks.5.cv2.bn.weight",
            "_backbone._modules_list.6.m.5.cv2.conv.weight": "_backbone._modules_list.6.bottlenecks.5.cv2.conv.weight",
            "_backbone._modules_list.6.m.6.cv1.bn.bias": "_backbone._modules_list.6.bottlenecks.6.cv1.bn.bias",
            "_backbone._modules_list.6.m.6.cv1.bn.num_batches_tracked": "_backbone._modules_list.6.bottlenecks.6.cv1.bn.num_batches_tracked",
            "_backbone._modules_list.6.m.6.cv1.bn.running_mean": "_backbone._modules_list.6.bottlenecks.6.cv1.bn.running_mean",
            "_backbone._modules_list.6.m.6.cv1.bn.running_var": "_backbone._modules_list.6.bottlenecks.6.cv1.bn.running_var",
            "_backbone._modules_list.6.m.6.cv1.bn.weight": "_backbone._modules_list.6.bottlenecks.6.cv1.bn.weight",
            "_backbone._modules_list.6.m.6.cv1.conv.weight": "_backbone._modules_list.6.bottlenecks.6.cv1.conv.weight",
            "_backbone._modules_list.6.m.6.cv2.bn.bias": "_backbone._modules_list.6.bottlenecks.6.cv2.bn.bias",
            "_backbone._modules_list.6.m.6.cv2.bn.num_batches_tracked": "_backbone._modules_list.6.bottlenecks.6.cv2.bn.num_batches_tracked",
            "_backbone._modules_list.6.m.6.cv2.bn.running_mean": "_backbone._modules_list.6.bottlenecks.6.cv2.bn.running_mean",
            "_backbone._modules_list.6.m.6.cv2.bn.running_var": "_backbone._modules_list.6.bottlenecks.6.cv2.bn.running_var",
            "_backbone._modules_list.6.m.6.cv2.bn.weight": "_backbone._modules_list.6.bottlenecks.6.cv2.bn.weight",
            "_backbone._modules_list.6.m.6.cv2.conv.weight": "_backbone._modules_list.6.bottlenecks.6.cv2.conv.weight",
            "_backbone._modules_list.6.m.7.cv1.bn.bias": "_backbone._modules_list.6.bottlenecks.7.cv1.bn.bias",
            "_backbone._modules_list.6.m.7.cv1.bn.num_batches_tracked": "_backbone._modules_list.6.bottlenecks.7.cv1.bn.num_batches_tracked",
            "_backbone._modules_list.6.m.7.cv1.bn.running_mean": "_backbone._modules_list.6.bottlenecks.7.cv1.bn.running_mean",
            "_backbone._modules_list.6.m.7.cv1.bn.running_var": "_backbone._modules_list.6.bottlenecks.7.cv1.bn.running_var",
            "_backbone._modules_list.6.m.7.cv1.bn.weight": "_backbone._modules_list.6.bottlenecks.7.cv1.bn.weight",
            "_backbone._modules_list.6.m.7.cv1.conv.weight": "_backbone._modules_list.6.bottlenecks.7.cv1.conv.weight",
            "_backbone._modules_list.6.m.7.cv2.bn.bias": "_backbone._modules_list.6.bottlenecks.7.cv2.bn.bias",
            "_backbone._modules_list.6.m.7.cv2.bn.num_batches_tracked": "_backbone._modules_list.6.bottlenecks.7.cv2.bn.num_batches_tracked",
            "_backbone._modules_list.6.m.7.cv2.bn.running_mean": "_backbone._modules_list.6.bottlenecks.7.cv2.bn.running_mean",
            "_backbone._modules_list.6.m.7.cv2.bn.running_var": "_backbone._modules_list.6.bottlenecks.7.cv2.bn.running_var",
            "_backbone._modules_list.6.m.7.cv2.bn.weight": "_backbone._modules_list.6.bottlenecks.7.cv2.bn.weight",
            "_backbone._modules_list.6.m.7.cv2.conv.weight": "_backbone._modules_list.6.bottlenecks.7.cv2.conv.weight",
            "_backbone._modules_list.6.m.8.cv1.bn.bias": "_backbone._modules_list.6.bottlenecks.8.cv1.bn.bias",
            "_backbone._modules_list.6.m.8.cv1.bn.num_batches_tracked": "_backbone._modules_list.6.bottlenecks.8.cv1.bn.num_batches_tracked",
            "_backbone._modules_list.6.m.8.cv1.bn.running_mean": "_backbone._modules_list.6.bottlenecks.8.cv1.bn.running_mean",
            "_backbone._modules_list.6.m.8.cv1.bn.running_var": "_backbone._modules_list.6.bottlenecks.8.cv1.bn.running_var",
            "_backbone._modules_list.6.m.8.cv1.bn.weight": "_backbone._modules_list.6.bottlenecks.8.cv1.bn.weight",
            "_backbone._modules_list.6.m.8.cv1.conv.weight": "_backbone._modules_list.6.bottlenecks.8.cv1.conv.weight",
            "_backbone._modules_list.6.m.8.cv2.bn.bias": "_backbone._modules_list.6.bottlenecks.8.cv2.bn.bias",
            "_backbone._modules_list.6.m.8.cv2.bn.num_batches_tracked": "_backbone._modules_list.6.bottlenecks.8.cv2.bn.num_batches_tracked",
            "_backbone._modules_list.6.m.8.cv2.bn.running_mean": "_backbone._modules_list.6.bottlenecks.8.cv2.bn.running_mean",
            "_backbone._modules_list.6.m.8.cv2.bn.running_var": "_backbone._modules_list.6.bottlenecks.8.cv2.bn.running_var",
            "_backbone._modules_list.6.m.8.cv2.bn.weight": "_backbone._modules_list.6.bottlenecks.8.cv2.bn.weight",
            "_backbone._modules_list.6.m.8.cv2.conv.weight": "_backbone._modules_list.6.bottlenecks.8.cv2.conv.weight",
            "_backbone._modules_list.7.bn.bias": "_backbone._modules_list.7.bn.bias",
            "_backbone._modules_list.7.bn.num_batches_tracked": "_backbone._modules_list.7.bn.num_batches_tracked",
            "_backbone._modules_list.7.bn.running_mean": "_backbone._modules_list.7.bn.running_mean",
            "_backbone._modules_list.7.bn.running_var": "_backbone._modules_list.7.bn.running_var",
            "_backbone._modules_list.7.bn.weight": "_backbone._modules_list.7.bn.weight",
            "_backbone._modules_list.7.conv.bn.bias": "_backbone._modules_list.7.conv.bn.bias",
            "_backbone._modules_list.7.conv.bn.num_batches_tracked": "_backbone._modules_list.7.conv.bn.num_batches_tracked",
            "_backbone._modules_list.7.conv.bn.running_mean": "_backbone._modules_list.7.conv.bn.running_mean",
            "_backbone._modules_list.7.conv.bn.running_var": "_backbone._modules_list.7.conv.bn.running_var",
            "_backbone._modules_list.7.conv.bn.weight": "_backbone._modules_list.7.conv.bn.weight",
            "_backbone._modules_list.7.conv.conv.weight": "_backbone._modules_list.7.conv.conv.weight",
            "_backbone._modules_list.7.conv.weight": "_backbone._modules_list.7.conv.weight",
            "_backbone._modules_list.7.dconv.bn.bias": "_backbone._modules_list.7.dconv.bn.bias",
            "_backbone._modules_list.7.dconv.bn.num_batches_tracked": "_backbone._modules_list.7.dconv.bn.num_batches_tracked",
            "_backbone._modules_list.7.dconv.bn.running_mean": "_backbone._modules_list.7.dconv.bn.running_mean",
            "_backbone._modules_list.7.dconv.bn.running_var": "_backbone._modules_list.7.dconv.bn.running_var",
            "_backbone._modules_list.7.dconv.bn.weight": "_backbone._modules_list.7.dconv.bn.weight",
            "_backbone._modules_list.7.dconv.conv.weight": "_backbone._modules_list.7.dconv.conv.weight",
            "_backbone._modules_list.8.cv1.bn.bias": "_backbone._modules_list.8.cv1.bn.bias",
            "_backbone._modules_list.8.cv1.bn.num_batches_tracked": "_backbone._modules_list.8.cv1.bn.num_batches_tracked",
            "_backbone._modules_list.8.cv1.bn.running_mean": "_backbone._modules_list.8.cv1.bn.running_mean",
            "_backbone._modules_list.8.cv1.bn.running_var": "_backbone._modules_list.8.cv1.bn.running_var",
            "_backbone._modules_list.8.cv1.bn.weight": "_backbone._modules_list.8.cv1.bn.weight",
            "_backbone._modules_list.8.cv1.conv.weight": "_backbone._modules_list.8.cv1.conv.weight",
            "_backbone._modules_list.8.cv2.bn.bias": "_backbone._modules_list.8.cv2.bn.bias",
            "_backbone._modules_list.8.cv2.bn.num_batches_tracked": "_backbone._modules_list.8.cv2.bn.num_batches_tracked",
            "_backbone._modules_list.8.cv2.bn.running_mean": "_backbone._modules_list.8.cv2.bn.running_mean",
            "_backbone._modules_list.8.cv2.bn.running_var": "_backbone._modules_list.8.cv2.bn.running_var",
            "_backbone._modules_list.8.cv2.bn.weight": "_backbone._modules_list.8.cv2.bn.weight",
            "_backbone._modules_list.8.cv2.conv.weight": "_backbone._modules_list.8.cv2.conv.weight",
            "_backbone._modules_list.9.cv1.bn.bias": "_backbone._modules_list.9.conv1.bn.bias",
            "_backbone._modules_list.9.cv1.bn.num_batches_tracked": "_backbone._modules_list.9.conv1.bn.num_batches_tracked",
            "_backbone._modules_list.9.cv1.bn.running_mean": "_backbone._modules_list.9.conv1.bn.running_mean",
            "_backbone._modules_list.9.cv1.bn.running_var": "_backbone._modules_list.9.conv1.bn.running_var",
            "_backbone._modules_list.9.cv1.bn.weight": "_backbone._modules_list.9.conv1.bn.weight",
            "_backbone._modules_list.9.cv1.conv.weight": "_backbone._modules_list.9.conv1.conv.weight",
            "_backbone._modules_list.9.cv2.bn.bias": "_backbone._modules_list.9.conv2.bn.bias",
            "_backbone._modules_list.9.cv2.bn.num_batches_tracked": "_backbone._modules_list.9.conv2.bn.num_batches_tracked",
            "_backbone._modules_list.9.cv2.bn.running_mean": "_backbone._modules_list.9.conv2.bn.running_mean",
            "_backbone._modules_list.9.cv2.bn.running_var": "_backbone._modules_list.9.conv2.bn.running_var",
            "_backbone._modules_list.9.cv2.bn.weight": "_backbone._modules_list.9.conv2.bn.weight",
            "_backbone._modules_list.9.cv2.conv.weight": "_backbone._modules_list.9.conv2.conv.weight",
            "_backbone._modules_list.9.cv3.bn.bias": "_backbone._modules_list.9.conv3.bn.bias",
            "_backbone._modules_list.9.cv3.bn.num_batches_tracked": "_backbone._modules_list.9.conv3.bn.num_batches_tracked",
            "_backbone._modules_list.9.cv3.bn.running_mean": "_backbone._modules_list.9.conv3.bn.running_mean",
            "_backbone._modules_list.9.cv3.bn.running_var": "_backbone._modules_list.9.conv3.bn.running_var",
            "_backbone._modules_list.9.cv3.bn.weight": "_backbone._modules_list.9.conv3.bn.weight",
            "_backbone._modules_list.9.cv3.conv.weight": "_backbone._modules_list.9.conv3.conv.weight",
            "_backbone._modules_list.9.m.0.cv1.bn.bias": "_backbone._modules_list.9.bottlenecks.0.cv1.bn.bias",
            "_backbone._modules_list.9.m.0.cv1.bn.num_batches_tracked": "_backbone._modules_list.9.bottlenecks.0.cv1.bn.num_batches_tracked",
            "_backbone._modules_list.9.m.0.cv1.bn.running_mean": "_backbone._modules_list.9.bottlenecks.0.cv1.bn.running_mean",
            "_backbone._modules_list.9.m.0.cv1.bn.running_var": "_backbone._modules_list.9.bottlenecks.0.cv1.bn.running_var",
            "_backbone._modules_list.9.m.0.cv1.bn.weight": "_backbone._modules_list.9.bottlenecks.0.cv1.bn.weight",
            "_backbone._modules_list.9.m.0.cv1.conv.weight": "_backbone._modules_list.9.bottlenecks.0.cv1.conv.weight",
            "_backbone._modules_list.9.m.0.cv2.bn.bias": "_backbone._modules_list.9.bottlenecks.0.cv2.bn.bias",
            "_backbone._modules_list.9.m.0.cv2.bn.num_batches_tracked": "_backbone._modules_list.9.bottlenecks.0.cv2.bn.num_batches_tracked",
            "_backbone._modules_list.9.m.0.cv2.bn.running_mean": "_backbone._modules_list.9.bottlenecks.0.cv2.bn.running_mean",
            "_backbone._modules_list.9.m.0.cv2.bn.running_var": "_backbone._modules_list.9.bottlenecks.0.cv2.bn.running_var",
            "_backbone._modules_list.9.m.0.cv2.bn.weight": "_backbone._modules_list.9.bottlenecks.0.cv2.bn.weight",
            "_backbone._modules_list.9.m.0.cv2.conv.bn.bias": "_backbone._modules_list.9.bottlenecks.0.cv2.conv.bn.bias",
            "_backbone._modules_list.9.m.0.cv2.conv.bn.num_batches_tracked": "_backbone._modules_list.9.bottlenecks.0.cv2.conv.bn.num_batches_tracked",
            "_backbone._modules_list.9.m.0.cv2.conv.bn.running_mean": "_backbone._modules_list.9.bottlenecks.0.cv2.conv.bn.running_mean",
            "_backbone._modules_list.9.m.0.cv2.conv.bn.running_var": "_backbone._modules_list.9.bottlenecks.0.cv2.conv.bn.running_var",
            "_backbone._modules_list.9.m.0.cv2.conv.bn.weight": "_backbone._modules_list.9.bottlenecks.0.cv2.conv.bn.weight",
            "_backbone._modules_list.9.m.0.cv2.conv.conv.weight": "_backbone._modules_list.9.bottlenecks.0.cv2.conv.conv.weight",
            "_backbone._modules_list.9.m.0.cv2.conv.weight": "_backbone._modules_list.9.bottlenecks.0.cv2.conv.weight",
            "_backbone._modules_list.9.m.0.cv2.dconv.bn.bias": "_backbone._modules_list.9.bottlenecks.0.cv2.dconv.bn.bias",
            "_backbone._modules_list.9.m.0.cv2.dconv.bn.num_batches_tracked": "_backbone._modules_list.9.bottlenecks.0.cv2.dconv.bn.num_batches_tracked",
            "_backbone._modules_list.9.m.0.cv2.dconv.bn.running_mean": "_backbone._modules_list.9.bottlenecks.0.cv2.dconv.bn.running_mean",
            "_backbone._modules_list.9.m.0.cv2.dconv.bn.running_var": "_backbone._modules_list.9.bottlenecks.0.cv2.dconv.bn.running_var",
            "_backbone._modules_list.9.m.0.cv2.dconv.bn.weight": "_backbone._modules_list.9.bottlenecks.0.cv2.dconv.bn.weight",
            "_backbone._modules_list.9.m.0.cv2.dconv.conv.weight": "_backbone._modules_list.9.bottlenecks.0.cv2.dconv.conv.weight",
            "_backbone._modules_list.9.m.1.cv1.bn.bias": "_backbone._modules_list.9.bottlenecks.1.cv1.bn.bias",
            "_backbone._modules_list.9.m.1.cv1.bn.num_batches_tracked": "_backbone._modules_list.9.bottlenecks.1.cv1.bn.num_batches_tracked",
            "_backbone._modules_list.9.m.1.cv1.bn.running_mean": "_backbone._modules_list.9.bottlenecks.1.cv1.bn.running_mean",
            "_backbone._modules_list.9.m.1.cv1.bn.running_var": "_backbone._modules_list.9.bottlenecks.1.cv1.bn.running_var",
            "_backbone._modules_list.9.m.1.cv1.bn.weight": "_backbone._modules_list.9.bottlenecks.1.cv1.bn.weight",
            "_backbone._modules_list.9.m.1.cv1.conv.weight": "_backbone._modules_list.9.bottlenecks.1.cv1.conv.weight",
            "_backbone._modules_list.9.m.1.cv2.bn.bias": "_backbone._modules_list.9.bottlenecks.1.cv2.bn.bias",
            "_backbone._modules_list.9.m.1.cv2.bn.num_batches_tracked": "_backbone._modules_list.9.bottlenecks.1.cv2.bn.num_batches_tracked",
            "_backbone._modules_list.9.m.1.cv2.bn.running_mean": "_backbone._modules_list.9.bottlenecks.1.cv2.bn.running_mean",
            "_backbone._modules_list.9.m.1.cv2.bn.running_var": "_backbone._modules_list.9.bottlenecks.1.cv2.bn.running_var",
            "_backbone._modules_list.9.m.1.cv2.bn.weight": "_backbone._modules_list.9.bottlenecks.1.cv2.bn.weight",
            "_backbone._modules_list.9.m.1.cv2.conv.weight": "_backbone._modules_list.9.bottlenecks.1.cv2.conv.weight",
            "_backbone._modules_list.9.m.2.cv1.bn.bias": "_backbone._modules_list.9.bottlenecks.2.cv1.bn.bias",
            "_backbone._modules_list.9.m.2.cv1.bn.num_batches_tracked": "_backbone._modules_list.9.bottlenecks.2.cv1.bn.num_batches_tracked",
            "_backbone._modules_list.9.m.2.cv1.bn.running_mean": "_backbone._modules_list.9.bottlenecks.2.cv1.bn.running_mean",
            "_backbone._modules_list.9.m.2.cv1.bn.running_var": "_backbone._modules_list.9.bottlenecks.2.cv1.bn.running_var",
            "_backbone._modules_list.9.m.2.cv1.bn.weight": "_backbone._modules_list.9.bottlenecks.2.cv1.bn.weight",
            "_backbone._modules_list.9.m.2.cv1.conv.weight": "_backbone._modules_list.9.bottlenecks.2.cv1.conv.weight",
            "_backbone._modules_list.9.m.2.cv2.bn.bias": "_backbone._modules_list.9.bottlenecks.2.cv2.bn.bias",
            "_backbone._modules_list.9.m.2.cv2.bn.num_batches_tracked": "_backbone._modules_list.9.bottlenecks.2.cv2.bn.num_batches_tracked",
            "_backbone._modules_list.9.m.2.cv2.bn.running_mean": "_backbone._modules_list.9.bottlenecks.2.cv2.bn.running_mean",
            "_backbone._modules_list.9.m.2.cv2.bn.running_var": "_backbone._modules_list.9.bottlenecks.2.cv2.bn.running_var",
            "_backbone._modules_list.9.m.2.cv2.bn.weight": "_backbone._modules_list.9.bottlenecks.2.cv2.bn.weight",
            "_backbone._modules_list.9.m.2.cv2.conv.weight": "_backbone._modules_list.9.bottlenecks.2.cv2.conv.weight",
            "_head._modules_list.0.bn.bias": "_head._modules_list.0.bn.bias",
            "_head._modules_list.0.bn.num_batches_tracked": "_head._modules_list.0.bn.num_batches_tracked",
            "_head._modules_list.0.bn.running_mean": "_head._modules_list.0.bn.running_mean",
            "_head._modules_list.0.bn.running_var": "_head._modules_list.0.bn.running_var",
            "_head._modules_list.0.bn.weight": "_head._modules_list.0.bn.weight",
            "_head._modules_list.0.conv.weight": "_head._modules_list.0.conv.weight",
            "_head._modules_list.10.cv1.bn.bias": "_head._modules_list.10.conv1.bn.bias",
            "_head._modules_list.10.cv1.bn.num_batches_tracked": "_head._modules_list.10.conv1.bn.num_batches_tracked",
            "_head._modules_list.10.cv1.bn.running_mean": "_head._modules_list.10.conv1.bn.running_mean",
            "_head._modules_list.10.cv1.bn.running_var": "_head._modules_list.10.conv1.bn.running_var",
            "_head._modules_list.10.cv1.bn.weight": "_head._modules_list.10.conv1.bn.weight",
            "_head._modules_list.10.cv1.conv.weight": "_head._modules_list.10.conv1.conv.weight",
            "_head._modules_list.10.cv2.bn.bias": "_head._modules_list.10.conv2.bn.bias",
            "_head._modules_list.10.cv2.bn.num_batches_tracked": "_head._modules_list.10.conv2.bn.num_batches_tracked",
            "_head._modules_list.10.cv2.bn.running_mean": "_head._modules_list.10.conv2.bn.running_mean",
            "_head._modules_list.10.cv2.bn.running_var": "_head._modules_list.10.conv2.bn.running_var",
            "_head._modules_list.10.cv2.bn.weight": "_head._modules_list.10.conv2.bn.weight",
            "_head._modules_list.10.cv2.conv.weight": "_head._modules_list.10.conv2.conv.weight",
            "_head._modules_list.10.cv3.bn.bias": "_head._modules_list.10.conv3.bn.bias",
            "_head._modules_list.10.cv3.bn.num_batches_tracked": "_head._modules_list.10.conv3.bn.num_batches_tracked",
            "_head._modules_list.10.cv3.bn.running_mean": "_head._modules_list.10.conv3.bn.running_mean",
            "_head._modules_list.10.cv3.bn.running_var": "_head._modules_list.10.conv3.bn.running_var",
            "_head._modules_list.10.cv3.bn.weight": "_head._modules_list.10.conv3.bn.weight",
            "_head._modules_list.10.cv3.conv.weight": "_head._modules_list.10.conv3.conv.weight",
            "_head._modules_list.10.m.0.cv1.bn.bias": "_head._modules_list.10.bottlenecks.0.cv1.bn.bias",
            "_head._modules_list.10.m.0.cv1.bn.num_batches_tracked": "_head._modules_list.10.bottlenecks.0.cv1.bn.num_batches_tracked",
            "_head._modules_list.10.m.0.cv1.bn.running_mean": "_head._modules_list.10.bottlenecks.0.cv1.bn.running_mean",
            "_head._modules_list.10.m.0.cv1.bn.running_var": "_head._modules_list.10.bottlenecks.0.cv1.bn.running_var",
            "_head._modules_list.10.m.0.cv1.bn.weight": "_head._modules_list.10.bottlenecks.0.cv1.bn.weight",
            "_head._modules_list.10.m.0.cv1.conv.weight": "_head._modules_list.10.bottlenecks.0.cv1.conv.weight",
            "_head._modules_list.10.m.0.cv2.bn.bias": "_head._modules_list.10.bottlenecks.0.cv2.bn.bias",
            "_head._modules_list.10.m.0.cv2.bn.num_batches_tracked": "_head._modules_list.10.bottlenecks.0.cv2.bn.num_batches_tracked",
            "_head._modules_list.10.m.0.cv2.bn.running_mean": "_head._modules_list.10.bottlenecks.0.cv2.bn.running_mean",
            "_head._modules_list.10.m.0.cv2.bn.running_var": "_head._modules_list.10.bottlenecks.0.cv2.bn.running_var",
            "_head._modules_list.10.m.0.cv2.bn.weight": "_head._modules_list.10.bottlenecks.0.cv2.bn.weight",
            "_head._modules_list.10.m.0.cv2.conv.bn.bias": "_head._modules_list.10.bottlenecks.0.cv2.conv.bn.bias",
            "_head._modules_list.10.m.0.cv2.conv.bn.num_batches_tracked": "_head._modules_list.10.bottlenecks.0.cv2.conv.bn.num_batches_tracked",
            "_head._modules_list.10.m.0.cv2.conv.bn.running_mean": "_head._modules_list.10.bottlenecks.0.cv2.conv.bn.running_mean",
            "_head._modules_list.10.m.0.cv2.conv.bn.running_var": "_head._modules_list.10.bottlenecks.0.cv2.conv.bn.running_var",
            "_head._modules_list.10.m.0.cv2.conv.bn.weight": "_head._modules_list.10.bottlenecks.0.cv2.conv.bn.weight",
            "_head._modules_list.10.m.0.cv2.conv.conv.weight": "_head._modules_list.10.bottlenecks.0.cv2.conv.conv.weight",
            "_head._modules_list.10.m.0.cv2.conv.weight": "_head._modules_list.10.bottlenecks.0.cv2.conv.weight",
            "_head._modules_list.10.m.0.cv2.dconv.bn.bias": "_head._modules_list.10.bottlenecks.0.cv2.dconv.bn.bias",
            "_head._modules_list.10.m.0.cv2.dconv.bn.num_batches_tracked": "_head._modules_list.10.bottlenecks.0.cv2.dconv.bn.num_batches_tracked",
            "_head._modules_list.10.m.0.cv2.dconv.bn.running_mean": "_head._modules_list.10.bottlenecks.0.cv2.dconv.bn.running_mean",
            "_head._modules_list.10.m.0.cv2.dconv.bn.running_var": "_head._modules_list.10.bottlenecks.0.cv2.dconv.bn.running_var",
            "_head._modules_list.10.m.0.cv2.dconv.bn.weight": "_head._modules_list.10.bottlenecks.0.cv2.dconv.bn.weight",
            "_head._modules_list.10.m.0.cv2.dconv.conv.weight": "_head._modules_list.10.bottlenecks.0.cv2.dconv.conv.weight",
            "_head._modules_list.10.m.1.cv1.bn.bias": "_head._modules_list.10.bottlenecks.1.cv1.bn.bias",
            "_head._modules_list.10.m.1.cv1.bn.num_batches_tracked": "_head._modules_list.10.bottlenecks.1.cv1.bn.num_batches_tracked",
            "_head._modules_list.10.m.1.cv1.bn.running_mean": "_head._modules_list.10.bottlenecks.1.cv1.bn.running_mean",
            "_head._modules_list.10.m.1.cv1.bn.running_var": "_head._modules_list.10.bottlenecks.1.cv1.bn.running_var",
            "_head._modules_list.10.m.1.cv1.bn.weight": "_head._modules_list.10.bottlenecks.1.cv1.bn.weight",
            "_head._modules_list.10.m.1.cv1.conv.weight": "_head._modules_list.10.bottlenecks.1.cv1.conv.weight",
            "_head._modules_list.10.m.1.cv2.bn.bias": "_head._modules_list.10.bottlenecks.1.cv2.bn.bias",
            "_head._modules_list.10.m.1.cv2.bn.num_batches_tracked": "_head._modules_list.10.bottlenecks.1.cv2.bn.num_batches_tracked",
            "_head._modules_list.10.m.1.cv2.bn.running_mean": "_head._modules_list.10.bottlenecks.1.cv2.bn.running_mean",
            "_head._modules_list.10.m.1.cv2.bn.running_var": "_head._modules_list.10.bottlenecks.1.cv2.bn.running_var",
            "_head._modules_list.10.m.1.cv2.bn.weight": "_head._modules_list.10.bottlenecks.1.cv2.bn.weight",
            "_head._modules_list.10.m.1.cv2.conv.weight": "_head._modules_list.10.bottlenecks.1.cv2.conv.weight",
            "_head._modules_list.10.m.2.cv1.bn.bias": "_head._modules_list.10.bottlenecks.2.cv1.bn.bias",
            "_head._modules_list.10.m.2.cv1.bn.num_batches_tracked": "_head._modules_list.10.bottlenecks.2.cv1.bn.num_batches_tracked",
            "_head._modules_list.10.m.2.cv1.bn.running_mean": "_head._modules_list.10.bottlenecks.2.cv1.bn.running_mean",
            "_head._modules_list.10.m.2.cv1.bn.running_var": "_head._modules_list.10.bottlenecks.2.cv1.bn.running_var",
            "_head._modules_list.10.m.2.cv1.bn.weight": "_head._modules_list.10.bottlenecks.2.cv1.bn.weight",
            "_head._modules_list.10.m.2.cv1.conv.weight": "_head._modules_list.10.bottlenecks.2.cv1.conv.weight",
            "_head._modules_list.10.m.2.cv2.bn.bias": "_head._modules_list.10.bottlenecks.2.cv2.bn.bias",
            "_head._modules_list.10.m.2.cv2.bn.num_batches_tracked": "_head._modules_list.10.bottlenecks.2.cv2.bn.num_batches_tracked",
            "_head._modules_list.10.m.2.cv2.bn.running_mean": "_head._modules_list.10.bottlenecks.2.cv2.bn.running_mean",
            "_head._modules_list.10.m.2.cv2.bn.running_var": "_head._modules_list.10.bottlenecks.2.cv2.bn.running_var",
            "_head._modules_list.10.m.2.cv2.bn.weight": "_head._modules_list.10.bottlenecks.2.cv2.bn.weight",
            "_head._modules_list.10.m.2.cv2.conv.weight": "_head._modules_list.10.bottlenecks.2.cv2.conv.weight",
            "_head._modules_list.11.bn.bias": "_head._modules_list.11.bn.bias",
            "_head._modules_list.11.bn.num_batches_tracked": "_head._modules_list.11.bn.num_batches_tracked",
            "_head._modules_list.11.bn.running_mean": "_head._modules_list.11.bn.running_mean",
            "_head._modules_list.11.bn.running_var": "_head._modules_list.11.bn.running_var",
            "_head._modules_list.11.bn.weight": "_head._modules_list.11.bn.weight",
            "_head._modules_list.11.conv.bn.bias": "_head._modules_list.11.conv.bn.bias",
            "_head._modules_list.11.conv.bn.num_batches_tracked": "_head._modules_list.11.conv.bn.num_batches_tracked",
            "_head._modules_list.11.conv.bn.running_mean": "_head._modules_list.11.conv.bn.running_mean",
            "_head._modules_list.11.conv.bn.running_var": "_head._modules_list.11.conv.bn.running_var",
            "_head._modules_list.11.conv.bn.weight": "_head._modules_list.11.conv.bn.weight",
            "_head._modules_list.11.conv.conv.weight": "_head._modules_list.11.conv.conv.weight",
            "_head._modules_list.11.conv.weight": "_head._modules_list.11.conv.weight",
            "_head._modules_list.11.dconv.bn.bias": "_head._modules_list.11.dconv.bn.bias",
            "_head._modules_list.11.dconv.bn.num_batches_tracked": "_head._modules_list.11.dconv.bn.num_batches_tracked",
            "_head._modules_list.11.dconv.bn.running_mean": "_head._modules_list.11.dconv.bn.running_mean",
            "_head._modules_list.11.dconv.bn.running_var": "_head._modules_list.11.dconv.bn.running_var",
            "_head._modules_list.11.dconv.bn.weight": "_head._modules_list.11.dconv.bn.weight",
            "_head._modules_list.11.dconv.conv.weight": "_head._modules_list.11.dconv.conv.weight",
            "_head._modules_list.13.cv1.bn.bias": "_head._modules_list.13.conv1.bn.bias",
            "_head._modules_list.13.cv1.bn.num_batches_tracked": "_head._modules_list.13.conv1.bn.num_batches_tracked",
            "_head._modules_list.13.cv1.bn.running_mean": "_head._modules_list.13.conv1.bn.running_mean",
            "_head._modules_list.13.cv1.bn.running_var": "_head._modules_list.13.conv1.bn.running_var",
            "_head._modules_list.13.cv1.bn.weight": "_head._modules_list.13.conv1.bn.weight",
            "_head._modules_list.13.cv1.conv.weight": "_head._modules_list.13.conv1.conv.weight",
            "_head._modules_list.13.cv2.bn.bias": "_head._modules_list.13.conv2.bn.bias",
            "_head._modules_list.13.cv2.bn.num_batches_tracked": "_head._modules_list.13.conv2.bn.num_batches_tracked",
            "_head._modules_list.13.cv2.bn.running_mean": "_head._modules_list.13.conv2.bn.running_mean",
            "_head._modules_list.13.cv2.bn.running_var": "_head._modules_list.13.conv2.bn.running_var",
            "_head._modules_list.13.cv2.bn.weight": "_head._modules_list.13.conv2.bn.weight",
            "_head._modules_list.13.cv2.conv.weight": "_head._modules_list.13.conv2.conv.weight",
            "_head._modules_list.13.cv3.bn.bias": "_head._modules_list.13.conv3.bn.bias",
            "_head._modules_list.13.cv3.bn.num_batches_tracked": "_head._modules_list.13.conv3.bn.num_batches_tracked",
            "_head._modules_list.13.cv3.bn.running_mean": "_head._modules_list.13.conv3.bn.running_mean",
            "_head._modules_list.13.cv3.bn.running_var": "_head._modules_list.13.conv3.bn.running_var",
            "_head._modules_list.13.cv3.bn.weight": "_head._modules_list.13.conv3.bn.weight",
            "_head._modules_list.13.cv3.conv.weight": "_head._modules_list.13.conv3.conv.weight",
            "_head._modules_list.13.m.0.cv1.bn.bias": "_head._modules_list.13.bottlenecks.0.cv1.bn.bias",
            "_head._modules_list.13.m.0.cv1.bn.num_batches_tracked": "_head._modules_list.13.bottlenecks.0.cv1.bn.num_batches_tracked",
            "_head._modules_list.13.m.0.cv1.bn.running_mean": "_head._modules_list.13.bottlenecks.0.cv1.bn.running_mean",
            "_head._modules_list.13.m.0.cv1.bn.running_var": "_head._modules_list.13.bottlenecks.0.cv1.bn.running_var",
            "_head._modules_list.13.m.0.cv1.bn.weight": "_head._modules_list.13.bottlenecks.0.cv1.bn.weight",
            "_head._modules_list.13.m.0.cv1.conv.weight": "_head._modules_list.13.bottlenecks.0.cv1.conv.weight",
            "_head._modules_list.13.m.0.cv2.bn.bias": "_head._modules_list.13.bottlenecks.0.cv2.bn.bias",
            "_head._modules_list.13.m.0.cv2.bn.num_batches_tracked": "_head._modules_list.13.bottlenecks.0.cv2.bn.num_batches_tracked",
            "_head._modules_list.13.m.0.cv2.bn.running_mean": "_head._modules_list.13.bottlenecks.0.cv2.bn.running_mean",
            "_head._modules_list.13.m.0.cv2.bn.running_var": "_head._modules_list.13.bottlenecks.0.cv2.bn.running_var",
            "_head._modules_list.13.m.0.cv2.bn.weight": "_head._modules_list.13.bottlenecks.0.cv2.bn.weight",
            "_head._modules_list.13.m.0.cv2.conv.bn.bias": "_head._modules_list.13.bottlenecks.0.cv2.conv.bn.bias",
            "_head._modules_list.13.m.0.cv2.conv.bn.num_batches_tracked": "_head._modules_list.13.bottlenecks.0.cv2.conv.bn.num_batches_tracked",
            "_head._modules_list.13.m.0.cv2.conv.bn.running_mean": "_head._modules_list.13.bottlenecks.0.cv2.conv.bn.running_mean",
            "_head._modules_list.13.m.0.cv2.conv.bn.running_var": "_head._modules_list.13.bottlenecks.0.cv2.conv.bn.running_var",
            "_head._modules_list.13.m.0.cv2.conv.bn.weight": "_head._modules_list.13.bottlenecks.0.cv2.conv.bn.weight",
            "_head._modules_list.13.m.0.cv2.conv.conv.weight": "_head._modules_list.13.bottlenecks.0.cv2.conv.conv.weight",
            "_head._modules_list.13.m.0.cv2.conv.weight": "_head._modules_list.13.bottlenecks.0.cv2.conv.weight",
            "_head._modules_list.13.m.0.cv2.dconv.bn.bias": "_head._modules_list.13.bottlenecks.0.cv2.dconv.bn.bias",
            "_head._modules_list.13.m.0.cv2.dconv.bn.num_batches_tracked": "_head._modules_list.13.bottlenecks.0.cv2.dconv.bn.num_batches_tracked",
            "_head._modules_list.13.m.0.cv2.dconv.bn.running_mean": "_head._modules_list.13.bottlenecks.0.cv2.dconv.bn.running_mean",
            "_head._modules_list.13.m.0.cv2.dconv.bn.running_var": "_head._modules_list.13.bottlenecks.0.cv2.dconv.bn.running_var",
            "_head._modules_list.13.m.0.cv2.dconv.bn.weight": "_head._modules_list.13.bottlenecks.0.cv2.dconv.bn.weight",
            "_head._modules_list.13.m.0.cv2.dconv.conv.weight": "_head._modules_list.13.bottlenecks.0.cv2.dconv.conv.weight",
            "_head._modules_list.13.m.1.cv1.bn.bias": "_head._modules_list.13.bottlenecks.1.cv1.bn.bias",
            "_head._modules_list.13.m.1.cv1.bn.num_batches_tracked": "_head._modules_list.13.bottlenecks.1.cv1.bn.num_batches_tracked",
            "_head._modules_list.13.m.1.cv1.bn.running_mean": "_head._modules_list.13.bottlenecks.1.cv1.bn.running_mean",
            "_head._modules_list.13.m.1.cv1.bn.running_var": "_head._modules_list.13.bottlenecks.1.cv1.bn.running_var",
            "_head._modules_list.13.m.1.cv1.bn.weight": "_head._modules_list.13.bottlenecks.1.cv1.bn.weight",
            "_head._modules_list.13.m.1.cv1.conv.weight": "_head._modules_list.13.bottlenecks.1.cv1.conv.weight",
            "_head._modules_list.13.m.1.cv2.bn.bias": "_head._modules_list.13.bottlenecks.1.cv2.bn.bias",
            "_head._modules_list.13.m.1.cv2.bn.num_batches_tracked": "_head._modules_list.13.bottlenecks.1.cv2.bn.num_batches_tracked",
            "_head._modules_list.13.m.1.cv2.bn.running_mean": "_head._modules_list.13.bottlenecks.1.cv2.bn.running_mean",
            "_head._modules_list.13.m.1.cv2.bn.running_var": "_head._modules_list.13.bottlenecks.1.cv2.bn.running_var",
            "_head._modules_list.13.m.1.cv2.bn.weight": "_head._modules_list.13.bottlenecks.1.cv2.bn.weight",
            "_head._modules_list.13.m.1.cv2.conv.weight": "_head._modules_list.13.bottlenecks.1.cv2.conv.weight",
            "_head._modules_list.13.m.2.cv1.bn.bias": "_head._modules_list.13.bottlenecks.2.cv1.bn.bias",
            "_head._modules_list.13.m.2.cv1.bn.num_batches_tracked": "_head._modules_list.13.bottlenecks.2.cv1.bn.num_batches_tracked",
            "_head._modules_list.13.m.2.cv1.bn.running_mean": "_head._modules_list.13.bottlenecks.2.cv1.bn.running_mean",
            "_head._modules_list.13.m.2.cv1.bn.running_var": "_head._modules_list.13.bottlenecks.2.cv1.bn.running_var",
            "_head._modules_list.13.m.2.cv1.bn.weight": "_head._modules_list.13.bottlenecks.2.cv1.bn.weight",
            "_head._modules_list.13.m.2.cv1.conv.weight": "_head._modules_list.13.bottlenecks.2.cv1.conv.weight",
            "_head._modules_list.13.m.2.cv2.bn.bias": "_head._modules_list.13.bottlenecks.2.cv2.bn.bias",
            "_head._modules_list.13.m.2.cv2.bn.num_batches_tracked": "_head._modules_list.13.bottlenecks.2.cv2.bn.num_batches_tracked",
            "_head._modules_list.13.m.2.cv2.bn.running_mean": "_head._modules_list.13.bottlenecks.2.cv2.bn.running_mean",
            "_head._modules_list.13.m.2.cv2.bn.running_var": "_head._modules_list.13.bottlenecks.2.cv2.bn.running_var",
            "_head._modules_list.13.m.2.cv2.bn.weight": "_head._modules_list.13.bottlenecks.2.cv2.bn.weight",
            "_head._modules_list.13.m.2.cv2.conv.weight": "_head._modules_list.13.bottlenecks.2.cv2.conv.weight",
            "_head._modules_list.14.cls_convs.0.0.bn.bias": "_head._modules_list.14.cls_convs.0.0.bn.bias",
            "_head._modules_list.14.cls_convs.0.0.bn.num_batches_tracked": "_head._modules_list.14.cls_convs.0.0.bn.num_batches_tracked",
            "_head._modules_list.14.cls_convs.0.0.bn.running_mean": "_head._modules_list.14.cls_convs.0.0.bn.running_mean",
            "_head._modules_list.14.cls_convs.0.0.bn.running_var": "_head._modules_list.14.cls_convs.0.0.bn.running_var",
            "_head._modules_list.14.cls_convs.0.0.bn.weight": "_head._modules_list.14.cls_convs.0.0.bn.weight",
            "_head._modules_list.14.cls_convs.0.0.conv.bn.bias": "_head._modules_list.14.cls_convs.0.0.conv.bn.bias",
            "_head._modules_list.14.cls_convs.0.0.conv.bn.num_batches_tracked": "_head._modules_list.14.cls_convs.0.0.conv.bn.num_batches_tracked",
            "_head._modules_list.14.cls_convs.0.0.conv.bn.running_mean": "_head._modules_list.14.cls_convs.0.0.conv.bn.running_mean",
            "_head._modules_list.14.cls_convs.0.0.conv.bn.running_var": "_head._modules_list.14.cls_convs.0.0.conv.bn.running_var",
            "_head._modules_list.14.cls_convs.0.0.conv.bn.weight": "_head._modules_list.14.cls_convs.0.0.conv.bn.weight",
            "_head._modules_list.14.cls_convs.0.0.conv.conv.weight": "_head._modules_list.14.cls_convs.0.0.conv.conv.weight",
            "_head._modules_list.14.cls_convs.0.0.conv.weight": "_head._modules_list.14.cls_convs.0.0.conv.weight",
            "_head._modules_list.14.cls_convs.0.0.dconv.bn.bias": "_head._modules_list.14.cls_convs.0.0.dconv.bn.bias",
            "_head._modules_list.14.cls_convs.0.0.dconv.bn.num_batches_tracked": "_head._modules_list.14.cls_convs.0.0.dconv.bn.num_batches_tracked",
            "_head._modules_list.14.cls_convs.0.0.dconv.bn.running_mean": "_head._modules_list.14.cls_convs.0.0.dconv.bn.running_mean",
            "_head._modules_list.14.cls_convs.0.0.dconv.bn.running_var": "_head._modules_list.14.cls_convs.0.0.dconv.bn.running_var",
            "_head._modules_list.14.cls_convs.0.0.dconv.bn.weight": "_head._modules_list.14.cls_convs.0.0.dconv.bn.weight",
            "_head._modules_list.14.cls_convs.0.0.dconv.conv.weight": "_head._modules_list.14.cls_convs.0.0.dconv.conv.weight",
            "_head._modules_list.14.cls_convs.0.1.bn.bias": "_head._modules_list.14.cls_convs.0.1.bn.bias",
            "_head._modules_list.14.cls_convs.0.1.bn.num_batches_tracked": "_head._modules_list.14.cls_convs.0.1.bn.num_batches_tracked",
            "_head._modules_list.14.cls_convs.0.1.bn.running_mean": "_head._modules_list.14.cls_convs.0.1.bn.running_mean",
            "_head._modules_list.14.cls_convs.0.1.bn.running_var": "_head._modules_list.14.cls_convs.0.1.bn.running_var",
            "_head._modules_list.14.cls_convs.0.1.bn.weight": "_head._modules_list.14.cls_convs.0.1.bn.weight",
            "_head._modules_list.14.cls_convs.0.1.conv.bn.bias": "_head._modules_list.14.cls_convs.0.1.conv.bn.bias",
            "_head._modules_list.14.cls_convs.0.1.conv.bn.num_batches_tracked": "_head._modules_list.14.cls_convs.0.1.conv.bn.num_batches_tracked",
            "_head._modules_list.14.cls_convs.0.1.conv.bn.running_mean": "_head._modules_list.14.cls_convs.0.1.conv.bn.running_mean",
            "_head._modules_list.14.cls_convs.0.1.conv.bn.running_var": "_head._modules_list.14.cls_convs.0.1.conv.bn.running_var",
            "_head._modules_list.14.cls_convs.0.1.conv.bn.weight": "_head._modules_list.14.cls_convs.0.1.conv.bn.weight",
            "_head._modules_list.14.cls_convs.0.1.conv.conv.weight": "_head._modules_list.14.cls_convs.0.1.conv.conv.weight",
            "_head._modules_list.14.cls_convs.0.1.conv.weight": "_head._modules_list.14.cls_convs.0.1.conv.weight",
            "_head._modules_list.14.cls_convs.0.1.dconv.bn.bias": "_head._modules_list.14.cls_convs.0.1.dconv.bn.bias",
            "_head._modules_list.14.cls_convs.0.1.dconv.bn.num_batches_tracked": "_head._modules_list.14.cls_convs.0.1.dconv.bn.num_batches_tracked",
            "_head._modules_list.14.cls_convs.0.1.dconv.bn.running_mean": "_head._modules_list.14.cls_convs.0.1.dconv.bn.running_mean",
            "_head._modules_list.14.cls_convs.0.1.dconv.bn.running_var": "_head._modules_list.14.cls_convs.0.1.dconv.bn.running_var",
            "_head._modules_list.14.cls_convs.0.1.dconv.bn.weight": "_head._modules_list.14.cls_convs.0.1.dconv.bn.weight",
            "_head._modules_list.14.cls_convs.0.1.dconv.conv.weight": "_head._modules_list.14.cls_convs.0.1.dconv.conv.weight",
            "_head._modules_list.14.cls_convs.1.0.bn.bias": "_head._modules_list.14.cls_convs.1.0.bn.bias",
            "_head._modules_list.14.cls_convs.1.0.bn.num_batches_tracked": "_head._modules_list.14.cls_convs.1.0.bn.num_batches_tracked",
            "_head._modules_list.14.cls_convs.1.0.bn.running_mean": "_head._modules_list.14.cls_convs.1.0.bn.running_mean",
            "_head._modules_list.14.cls_convs.1.0.bn.running_var": "_head._modules_list.14.cls_convs.1.0.bn.running_var",
            "_head._modules_list.14.cls_convs.1.0.bn.weight": "_head._modules_list.14.cls_convs.1.0.bn.weight",
            "_head._modules_list.14.cls_convs.1.0.conv.bn.bias": "_head._modules_list.14.cls_convs.1.0.conv.bn.bias",
            "_head._modules_list.14.cls_convs.1.0.conv.bn.num_batches_tracked": "_head._modules_list.14.cls_convs.1.0.conv.bn.num_batches_tracked",
            "_head._modules_list.14.cls_convs.1.0.conv.bn.running_mean": "_head._modules_list.14.cls_convs.1.0.conv.bn.running_mean",
            "_head._modules_list.14.cls_convs.1.0.conv.bn.running_var": "_head._modules_list.14.cls_convs.1.0.conv.bn.running_var",
            "_head._modules_list.14.cls_convs.1.0.conv.bn.weight": "_head._modules_list.14.cls_convs.1.0.conv.bn.weight",
            "_head._modules_list.14.cls_convs.1.0.conv.conv.weight": "_head._modules_list.14.cls_convs.1.0.conv.conv.weight",
            "_head._modules_list.14.cls_convs.1.0.conv.weight": "_head._modules_list.14.cls_convs.1.0.conv.weight",
            "_head._modules_list.14.cls_convs.1.0.dconv.bn.bias": "_head._modules_list.14.cls_convs.1.0.dconv.bn.bias",
            "_head._modules_list.14.cls_convs.1.0.dconv.bn.num_batches_tracked": "_head._modules_list.14.cls_convs.1.0.dconv.bn.num_batches_tracked",
            "_head._modules_list.14.cls_convs.1.0.dconv.bn.running_mean": "_head._modules_list.14.cls_convs.1.0.dconv.bn.running_mean",
            "_head._modules_list.14.cls_convs.1.0.dconv.bn.running_var": "_head._modules_list.14.cls_convs.1.0.dconv.bn.running_var",
            "_head._modules_list.14.cls_convs.1.0.dconv.bn.weight": "_head._modules_list.14.cls_convs.1.0.dconv.bn.weight",
            "_head._modules_list.14.cls_convs.1.0.dconv.conv.weight": "_head._modules_list.14.cls_convs.1.0.dconv.conv.weight",
            "_head._modules_list.14.cls_convs.1.1.bn.bias": "_head._modules_list.14.cls_convs.1.1.bn.bias",
            "_head._modules_list.14.cls_convs.1.1.bn.num_batches_tracked": "_head._modules_list.14.cls_convs.1.1.bn.num_batches_tracked",
            "_head._modules_list.14.cls_convs.1.1.bn.running_mean": "_head._modules_list.14.cls_convs.1.1.bn.running_mean",
            "_head._modules_list.14.cls_convs.1.1.bn.running_var": "_head._modules_list.14.cls_convs.1.1.bn.running_var",
            "_head._modules_list.14.cls_convs.1.1.bn.weight": "_head._modules_list.14.cls_convs.1.1.bn.weight",
            "_head._modules_list.14.cls_convs.1.1.conv.bn.bias": "_head._modules_list.14.cls_convs.1.1.conv.bn.bias",
            "_head._modules_list.14.cls_convs.1.1.conv.bn.num_batches_tracked": "_head._modules_list.14.cls_convs.1.1.conv.bn.num_batches_tracked",
            "_head._modules_list.14.cls_convs.1.1.conv.bn.running_mean": "_head._modules_list.14.cls_convs.1.1.conv.bn.running_mean",
            "_head._modules_list.14.cls_convs.1.1.conv.bn.running_var": "_head._modules_list.14.cls_convs.1.1.conv.bn.running_var",
            "_head._modules_list.14.cls_convs.1.1.conv.bn.weight": "_head._modules_list.14.cls_convs.1.1.conv.bn.weight",
            "_head._modules_list.14.cls_convs.1.1.conv.conv.weight": "_head._modules_list.14.cls_convs.1.1.conv.conv.weight",
            "_head._modules_list.14.cls_convs.1.1.conv.weight": "_head._modules_list.14.cls_convs.1.1.conv.weight",
            "_head._modules_list.14.cls_convs.1.1.dconv.bn.bias": "_head._modules_list.14.cls_convs.1.1.dconv.bn.bias",
            "_head._modules_list.14.cls_convs.1.1.dconv.bn.num_batches_tracked": "_head._modules_list.14.cls_convs.1.1.dconv.bn.num_batches_tracked",
            "_head._modules_list.14.cls_convs.1.1.dconv.bn.running_mean": "_head._modules_list.14.cls_convs.1.1.dconv.bn.running_mean",
            "_head._modules_list.14.cls_convs.1.1.dconv.bn.running_var": "_head._modules_list.14.cls_convs.1.1.dconv.bn.running_var",
            "_head._modules_list.14.cls_convs.1.1.dconv.bn.weight": "_head._modules_list.14.cls_convs.1.1.dconv.bn.weight",
            "_head._modules_list.14.cls_convs.1.1.dconv.conv.weight": "_head._modules_list.14.cls_convs.1.1.dconv.conv.weight",
            "_head._modules_list.14.cls_convs.2.0.bn.bias": "_head._modules_list.14.cls_convs.2.0.bn.bias",
            "_head._modules_list.14.cls_convs.2.0.bn.num_batches_tracked": "_head._modules_list.14.cls_convs.2.0.bn.num_batches_tracked",
            "_head._modules_list.14.cls_convs.2.0.bn.running_mean": "_head._modules_list.14.cls_convs.2.0.bn.running_mean",
            "_head._modules_list.14.cls_convs.2.0.bn.running_var": "_head._modules_list.14.cls_convs.2.0.bn.running_var",
            "_head._modules_list.14.cls_convs.2.0.bn.weight": "_head._modules_list.14.cls_convs.2.0.bn.weight",
            "_head._modules_list.14.cls_convs.2.0.conv.bn.bias": "_head._modules_list.14.cls_convs.2.0.conv.bn.bias",
            "_head._modules_list.14.cls_convs.2.0.conv.bn.num_batches_tracked": "_head._modules_list.14.cls_convs.2.0.conv.bn.num_batches_tracked",
            "_head._modules_list.14.cls_convs.2.0.conv.bn.running_mean": "_head._modules_list.14.cls_convs.2.0.conv.bn.running_mean",
            "_head._modules_list.14.cls_convs.2.0.conv.bn.running_var": "_head._modules_list.14.cls_convs.2.0.conv.bn.running_var",
            "_head._modules_list.14.cls_convs.2.0.conv.bn.weight": "_head._modules_list.14.cls_convs.2.0.conv.bn.weight",
            "_head._modules_list.14.cls_convs.2.0.conv.conv.weight": "_head._modules_list.14.cls_convs.2.0.conv.conv.weight",
            "_head._modules_list.14.cls_convs.2.0.conv.weight": "_head._modules_list.14.cls_convs.2.0.conv.weight",
            "_head._modules_list.14.cls_convs.2.0.dconv.bn.bias": "_head._modules_list.14.cls_convs.2.0.dconv.bn.bias",
            "_head._modules_list.14.cls_convs.2.0.dconv.bn.num_batches_tracked": "_head._modules_list.14.cls_convs.2.0.dconv.bn.num_batches_tracked",
            "_head._modules_list.14.cls_convs.2.0.dconv.bn.running_mean": "_head._modules_list.14.cls_convs.2.0.dconv.bn.running_mean",
            "_head._modules_list.14.cls_convs.2.0.dconv.bn.running_var": "_head._modules_list.14.cls_convs.2.0.dconv.bn.running_var",
            "_head._modules_list.14.cls_convs.2.0.dconv.bn.weight": "_head._modules_list.14.cls_convs.2.0.dconv.bn.weight",
            "_head._modules_list.14.cls_convs.2.0.dconv.conv.weight": "_head._modules_list.14.cls_convs.2.0.dconv.conv.weight",
            "_head._modules_list.14.cls_convs.2.1.bn.bias": "_head._modules_list.14.cls_convs.2.1.bn.bias",
            "_head._modules_list.14.cls_convs.2.1.bn.num_batches_tracked": "_head._modules_list.14.cls_convs.2.1.bn.num_batches_tracked",
            "_head._modules_list.14.cls_convs.2.1.bn.running_mean": "_head._modules_list.14.cls_convs.2.1.bn.running_mean",
            "_head._modules_list.14.cls_convs.2.1.bn.running_var": "_head._modules_list.14.cls_convs.2.1.bn.running_var",
            "_head._modules_list.14.cls_convs.2.1.bn.weight": "_head._modules_list.14.cls_convs.2.1.bn.weight",
            "_head._modules_list.14.cls_convs.2.1.conv.bn.bias": "_head._modules_list.14.cls_convs.2.1.conv.bn.bias",
            "_head._modules_list.14.cls_convs.2.1.conv.bn.num_batches_tracked": "_head._modules_list.14.cls_convs.2.1.conv.bn.num_batches_tracked",
            "_head._modules_list.14.cls_convs.2.1.conv.bn.running_mean": "_head._modules_list.14.cls_convs.2.1.conv.bn.running_mean",
            "_head._modules_list.14.cls_convs.2.1.conv.bn.running_var": "_head._modules_list.14.cls_convs.2.1.conv.bn.running_var",
            "_head._modules_list.14.cls_convs.2.1.conv.bn.weight": "_head._modules_list.14.cls_convs.2.1.conv.bn.weight",
            "_head._modules_list.14.cls_convs.2.1.conv.conv.weight": "_head._modules_list.14.cls_convs.2.1.conv.conv.weight",
            "_head._modules_list.14.cls_convs.2.1.conv.weight": "_head._modules_list.14.cls_convs.2.1.conv.weight",
            "_head._modules_list.14.cls_convs.2.1.dconv.bn.bias": "_head._modules_list.14.cls_convs.2.1.dconv.bn.bias",
            "_head._modules_list.14.cls_convs.2.1.dconv.bn.num_batches_tracked": "_head._modules_list.14.cls_convs.2.1.dconv.bn.num_batches_tracked",
            "_head._modules_list.14.cls_convs.2.1.dconv.bn.running_mean": "_head._modules_list.14.cls_convs.2.1.dconv.bn.running_mean",
            "_head._modules_list.14.cls_convs.2.1.dconv.bn.running_var": "_head._modules_list.14.cls_convs.2.1.dconv.bn.running_var",
            "_head._modules_list.14.cls_convs.2.1.dconv.bn.weight": "_head._modules_list.14.cls_convs.2.1.dconv.bn.weight",
            "_head._modules_list.14.cls_convs.2.1.dconv.conv.weight": "_head._modules_list.14.cls_convs.2.1.dconv.conv.weight",
            "_head._modules_list.14.cls_preds.0.bias": "_head._modules_list.14.cls_preds.0.bias",
            "_head._modules_list.14.cls_preds.0.weight": "_head._modules_list.14.cls_preds.0.weight",
            "_head._modules_list.14.cls_preds.1.bias": "_head._modules_list.14.cls_preds.1.bias",
            "_head._modules_list.14.cls_preds.1.weight": "_head._modules_list.14.cls_preds.1.weight",
            "_head._modules_list.14.cls_preds.2.bias": "_head._modules_list.14.cls_preds.2.bias",
            "_head._modules_list.14.cls_preds.2.weight": "_head._modules_list.14.cls_preds.2.weight",
            "_head._modules_list.14.obj_preds.0.bias": "_head._modules_list.14.obj_preds.0.bias",
            "_head._modules_list.14.obj_preds.0.weight": "_head._modules_list.14.obj_preds.0.weight",
            "_head._modules_list.14.obj_preds.1.bias": "_head._modules_list.14.obj_preds.1.bias",
            "_head._modules_list.14.obj_preds.1.weight": "_head._modules_list.14.obj_preds.1.weight",
            "_head._modules_list.14.obj_preds.2.bias": "_head._modules_list.14.obj_preds.2.bias",
            "_head._modules_list.14.obj_preds.2.weight": "_head._modules_list.14.obj_preds.2.weight",
            "_head._modules_list.14.reg_convs.0.0.bn.bias": "_head._modules_list.14.reg_convs.0.0.bn.bias",
            "_head._modules_list.14.reg_convs.0.0.bn.num_batches_tracked": "_head._modules_list.14.reg_convs.0.0.bn.num_batches_tracked",
            "_head._modules_list.14.reg_convs.0.0.bn.running_mean": "_head._modules_list.14.reg_convs.0.0.bn.running_mean",
            "_head._modules_list.14.reg_convs.0.0.bn.running_var": "_head._modules_list.14.reg_convs.0.0.bn.running_var",
            "_head._modules_list.14.reg_convs.0.0.bn.weight": "_head._modules_list.14.reg_convs.0.0.bn.weight",
            "_head._modules_list.14.reg_convs.0.0.conv.bn.bias": "_head._modules_list.14.reg_convs.0.0.conv.bn.bias",
            "_head._modules_list.14.reg_convs.0.0.conv.bn.num_batches_tracked": "_head._modules_list.14.reg_convs.0.0.conv.bn.num_batches_tracked",
            "_head._modules_list.14.reg_convs.0.0.conv.bn.running_mean": "_head._modules_list.14.reg_convs.0.0.conv.bn.running_mean",
            "_head._modules_list.14.reg_convs.0.0.conv.bn.running_var": "_head._modules_list.14.reg_convs.0.0.conv.bn.running_var",
            "_head._modules_list.14.reg_convs.0.0.conv.bn.weight": "_head._modules_list.14.reg_convs.0.0.conv.bn.weight",
            "_head._modules_list.14.reg_convs.0.0.conv.conv.weight": "_head._modules_list.14.reg_convs.0.0.conv.conv.weight",
            "_head._modules_list.14.reg_convs.0.0.conv.weight": "_head._modules_list.14.reg_convs.0.0.conv.weight",
            "_head._modules_list.14.reg_convs.0.0.dconv.bn.bias": "_head._modules_list.14.reg_convs.0.0.dconv.bn.bias",
            "_head._modules_list.14.reg_convs.0.0.dconv.bn.num_batches_tracked": "_head._modules_list.14.reg_convs.0.0.dconv.bn.num_batches_tracked",
            "_head._modules_list.14.reg_convs.0.0.dconv.bn.running_mean": "_head._modules_list.14.reg_convs.0.0.dconv.bn.running_mean",
            "_head._modules_list.14.reg_convs.0.0.dconv.bn.running_var": "_head._modules_list.14.reg_convs.0.0.dconv.bn.running_var",
            "_head._modules_list.14.reg_convs.0.0.dconv.bn.weight": "_head._modules_list.14.reg_convs.0.0.dconv.bn.weight",
            "_head._modules_list.14.reg_convs.0.0.dconv.conv.weight": "_head._modules_list.14.reg_convs.0.0.dconv.conv.weight",
            "_head._modules_list.14.reg_convs.0.1.bn.bias": "_head._modules_list.14.reg_convs.0.1.bn.bias",
            "_head._modules_list.14.reg_convs.0.1.bn.num_batches_tracked": "_head._modules_list.14.reg_convs.0.1.bn.num_batches_tracked",
            "_head._modules_list.14.reg_convs.0.1.bn.running_mean": "_head._modules_list.14.reg_convs.0.1.bn.running_mean",
            "_head._modules_list.14.reg_convs.0.1.bn.running_var": "_head._modules_list.14.reg_convs.0.1.bn.running_var",
            "_head._modules_list.14.reg_convs.0.1.bn.weight": "_head._modules_list.14.reg_convs.0.1.bn.weight",
            "_head._modules_list.14.reg_convs.0.1.conv.bn.bias": "_head._modules_list.14.reg_convs.0.1.conv.bn.bias",
            "_head._modules_list.14.reg_convs.0.1.conv.bn.num_batches_tracked": "_head._modules_list.14.reg_convs.0.1.conv.bn.num_batches_tracked",
            "_head._modules_list.14.reg_convs.0.1.conv.bn.running_mean": "_head._modules_list.14.reg_convs.0.1.conv.bn.running_mean",
            "_head._modules_list.14.reg_convs.0.1.conv.bn.running_var": "_head._modules_list.14.reg_convs.0.1.conv.bn.running_var",
            "_head._modules_list.14.reg_convs.0.1.conv.bn.weight": "_head._modules_list.14.reg_convs.0.1.conv.bn.weight",
            "_head._modules_list.14.reg_convs.0.1.conv.conv.weight": "_head._modules_list.14.reg_convs.0.1.conv.conv.weight",
            "_head._modules_list.14.reg_convs.0.1.conv.weight": "_head._modules_list.14.reg_convs.0.1.conv.weight",
            "_head._modules_list.14.reg_convs.0.1.dconv.bn.bias": "_head._modules_list.14.reg_convs.0.1.dconv.bn.bias",
            "_head._modules_list.14.reg_convs.0.1.dconv.bn.num_batches_tracked": "_head._modules_list.14.reg_convs.0.1.dconv.bn.num_batches_tracked",
            "_head._modules_list.14.reg_convs.0.1.dconv.bn.running_mean": "_head._modules_list.14.reg_convs.0.1.dconv.bn.running_mean",
            "_head._modules_list.14.reg_convs.0.1.dconv.bn.running_var": "_head._modules_list.14.reg_convs.0.1.dconv.bn.running_var",
            "_head._modules_list.14.reg_convs.0.1.dconv.bn.weight": "_head._modules_list.14.reg_convs.0.1.dconv.bn.weight",
            "_head._modules_list.14.reg_convs.0.1.dconv.conv.weight": "_head._modules_list.14.reg_convs.0.1.dconv.conv.weight",
            "_head._modules_list.14.reg_convs.1.0.bn.bias": "_head._modules_list.14.reg_convs.1.0.bn.bias",
            "_head._modules_list.14.reg_convs.1.0.bn.num_batches_tracked": "_head._modules_list.14.reg_convs.1.0.bn.num_batches_tracked",
            "_head._modules_list.14.reg_convs.1.0.bn.running_mean": "_head._modules_list.14.reg_convs.1.0.bn.running_mean",
            "_head._modules_list.14.reg_convs.1.0.bn.running_var": "_head._modules_list.14.reg_convs.1.0.bn.running_var",
            "_head._modules_list.14.reg_convs.1.0.bn.weight": "_head._modules_list.14.reg_convs.1.0.bn.weight",
            "_head._modules_list.14.reg_convs.1.0.conv.bn.bias": "_head._modules_list.14.reg_convs.1.0.conv.bn.bias",
            "_head._modules_list.14.reg_convs.1.0.conv.bn.num_batches_tracked": "_head._modules_list.14.reg_convs.1.0.conv.bn.num_batches_tracked",
            "_head._modules_list.14.reg_convs.1.0.conv.bn.running_mean": "_head._modules_list.14.reg_convs.1.0.conv.bn.running_mean",
            "_head._modules_list.14.reg_convs.1.0.conv.bn.running_var": "_head._modules_list.14.reg_convs.1.0.conv.bn.running_var",
            "_head._modules_list.14.reg_convs.1.0.conv.bn.weight": "_head._modules_list.14.reg_convs.1.0.conv.bn.weight",
            "_head._modules_list.14.reg_convs.1.0.conv.conv.weight": "_head._modules_list.14.reg_convs.1.0.conv.conv.weight",
            "_head._modules_list.14.reg_convs.1.0.conv.weight": "_head._modules_list.14.reg_convs.1.0.conv.weight",
            "_head._modules_list.14.reg_convs.1.0.dconv.bn.bias": "_head._modules_list.14.reg_convs.1.0.dconv.bn.bias",
            "_head._modules_list.14.reg_convs.1.0.dconv.bn.num_batches_tracked": "_head._modules_list.14.reg_convs.1.0.dconv.bn.num_batches_tracked",
            "_head._modules_list.14.reg_convs.1.0.dconv.bn.running_mean": "_head._modules_list.14.reg_convs.1.0.dconv.bn.running_mean",
            "_head._modules_list.14.reg_convs.1.0.dconv.bn.running_var": "_head._modules_list.14.reg_convs.1.0.dconv.bn.running_var",
            "_head._modules_list.14.reg_convs.1.0.dconv.bn.weight": "_head._modules_list.14.reg_convs.1.0.dconv.bn.weight",
            "_head._modules_list.14.reg_convs.1.0.dconv.conv.weight": "_head._modules_list.14.reg_convs.1.0.dconv.conv.weight",
            "_head._modules_list.14.reg_convs.1.1.bn.bias": "_head._modules_list.14.reg_convs.1.1.bn.bias",
            "_head._modules_list.14.reg_convs.1.1.bn.num_batches_tracked": "_head._modules_list.14.reg_convs.1.1.bn.num_batches_tracked",
            "_head._modules_list.14.reg_convs.1.1.bn.running_mean": "_head._modules_list.14.reg_convs.1.1.bn.running_mean",
            "_head._modules_list.14.reg_convs.1.1.bn.running_var": "_head._modules_list.14.reg_convs.1.1.bn.running_var",
            "_head._modules_list.14.reg_convs.1.1.bn.weight": "_head._modules_list.14.reg_convs.1.1.bn.weight",
            "_head._modules_list.14.reg_convs.1.1.conv.bn.bias": "_head._modules_list.14.reg_convs.1.1.conv.bn.bias",
            "_head._modules_list.14.reg_convs.1.1.conv.bn.num_batches_tracked": "_head._modules_list.14.reg_convs.1.1.conv.bn.num_batches_tracked",
            "_head._modules_list.14.reg_convs.1.1.conv.bn.running_mean": "_head._modules_list.14.reg_convs.1.1.conv.bn.running_mean",
            "_head._modules_list.14.reg_convs.1.1.conv.bn.running_var": "_head._modules_list.14.reg_convs.1.1.conv.bn.running_var",
            "_head._modules_list.14.reg_convs.1.1.conv.bn.weight": "_head._modules_list.14.reg_convs.1.1.conv.bn.weight",
            "_head._modules_list.14.reg_convs.1.1.conv.conv.weight": "_head._modules_list.14.reg_convs.1.1.conv.conv.weight",
            "_head._modules_list.14.reg_convs.1.1.conv.weight": "_head._modules_list.14.reg_convs.1.1.conv.weight",
            "_head._modules_list.14.reg_convs.1.1.dconv.bn.bias": "_head._modules_list.14.reg_convs.1.1.dconv.bn.bias",
            "_head._modules_list.14.reg_convs.1.1.dconv.bn.num_batches_tracked": "_head._modules_list.14.reg_convs.1.1.dconv.bn.num_batches_tracked",
            "_head._modules_list.14.reg_convs.1.1.dconv.bn.running_mean": "_head._modules_list.14.reg_convs.1.1.dconv.bn.running_mean",
            "_head._modules_list.14.reg_convs.1.1.dconv.bn.running_var": "_head._modules_list.14.reg_convs.1.1.dconv.bn.running_var",
            "_head._modules_list.14.reg_convs.1.1.dconv.bn.weight": "_head._modules_list.14.reg_convs.1.1.dconv.bn.weight",
            "_head._modules_list.14.reg_convs.1.1.dconv.conv.weight": "_head._modules_list.14.reg_convs.1.1.dconv.conv.weight",
            "_head._modules_list.14.reg_convs.2.0.bn.bias": "_head._modules_list.14.reg_convs.2.0.bn.bias",
            "_head._modules_list.14.reg_convs.2.0.bn.num_batches_tracked": "_head._modules_list.14.reg_convs.2.0.bn.num_batches_tracked",
            "_head._modules_list.14.reg_convs.2.0.bn.running_mean": "_head._modules_list.14.reg_convs.2.0.bn.running_mean",
            "_head._modules_list.14.reg_convs.2.0.bn.running_var": "_head._modules_list.14.reg_convs.2.0.bn.running_var",
            "_head._modules_list.14.reg_convs.2.0.bn.weight": "_head._modules_list.14.reg_convs.2.0.bn.weight",
            "_head._modules_list.14.reg_convs.2.0.conv.bn.bias": "_head._modules_list.14.reg_convs.2.0.conv.bn.bias",
            "_head._modules_list.14.reg_convs.2.0.conv.bn.num_batches_tracked": "_head._modules_list.14.reg_convs.2.0.conv.bn.num_batches_tracked",
            "_head._modules_list.14.reg_convs.2.0.conv.bn.running_mean": "_head._modules_list.14.reg_convs.2.0.conv.bn.running_mean",
            "_head._modules_list.14.reg_convs.2.0.conv.bn.running_var": "_head._modules_list.14.reg_convs.2.0.conv.bn.running_var",
            "_head._modules_list.14.reg_convs.2.0.conv.bn.weight": "_head._modules_list.14.reg_convs.2.0.conv.bn.weight",
            "_head._modules_list.14.reg_convs.2.0.conv.conv.weight": "_head._modules_list.14.reg_convs.2.0.conv.conv.weight",
            "_head._modules_list.14.reg_convs.2.0.conv.weight": "_head._modules_list.14.reg_convs.2.0.conv.weight",
            "_head._modules_list.14.reg_convs.2.0.dconv.bn.bias": "_head._modules_list.14.reg_convs.2.0.dconv.bn.bias",
            "_head._modules_list.14.reg_convs.2.0.dconv.bn.num_batches_tracked": "_head._modules_list.14.reg_convs.2.0.dconv.bn.num_batches_tracked",
            "_head._modules_list.14.reg_convs.2.0.dconv.bn.running_mean": "_head._modules_list.14.reg_convs.2.0.dconv.bn.running_mean",
            "_head._modules_list.14.reg_convs.2.0.dconv.bn.running_var": "_head._modules_list.14.reg_convs.2.0.dconv.bn.running_var",
            "_head._modules_list.14.reg_convs.2.0.dconv.bn.weight": "_head._modules_list.14.reg_convs.2.0.dconv.bn.weight",
            "_head._modules_list.14.reg_convs.2.0.dconv.conv.weight": "_head._modules_list.14.reg_convs.2.0.dconv.conv.weight",
            "_head._modules_list.14.reg_convs.2.1.bn.bias": "_head._modules_list.14.reg_convs.2.1.bn.bias",
            "_head._modules_list.14.reg_convs.2.1.bn.num_batches_tracked": "_head._modules_list.14.reg_convs.2.1.bn.num_batches_tracked",
            "_head._modules_list.14.reg_convs.2.1.bn.running_mean": "_head._modules_list.14.reg_convs.2.1.bn.running_mean",
            "_head._modules_list.14.reg_convs.2.1.bn.running_var": "_head._modules_list.14.reg_convs.2.1.bn.running_var",
            "_head._modules_list.14.reg_convs.2.1.bn.weight": "_head._modules_list.14.reg_convs.2.1.bn.weight",
            "_head._modules_list.14.reg_convs.2.1.conv.bn.bias": "_head._modules_list.14.reg_convs.2.1.conv.bn.bias",
            "_head._modules_list.14.reg_convs.2.1.conv.bn.num_batches_tracked": "_head._modules_list.14.reg_convs.2.1.conv.bn.num_batches_tracked",
            "_head._modules_list.14.reg_convs.2.1.conv.bn.running_mean": "_head._modules_list.14.reg_convs.2.1.conv.bn.running_mean",
            "_head._modules_list.14.reg_convs.2.1.conv.bn.running_var": "_head._modules_list.14.reg_convs.2.1.conv.bn.running_var",
            "_head._modules_list.14.reg_convs.2.1.conv.bn.weight": "_head._modules_list.14.reg_convs.2.1.conv.bn.weight",
            "_head._modules_list.14.reg_convs.2.1.conv.conv.weight": "_head._modules_list.14.reg_convs.2.1.conv.conv.weight",
            "_head._modules_list.14.reg_convs.2.1.conv.weight": "_head._modules_list.14.reg_convs.2.1.conv.weight",
            "_head._modules_list.14.reg_convs.2.1.dconv.bn.bias": "_head._modules_list.14.reg_convs.2.1.dconv.bn.bias",
            "_head._modules_list.14.reg_convs.2.1.dconv.bn.num_batches_tracked": "_head._modules_list.14.reg_convs.2.1.dconv.bn.num_batches_tracked",
            "_head._modules_list.14.reg_convs.2.1.dconv.bn.running_mean": "_head._modules_list.14.reg_convs.2.1.dconv.bn.running_mean",
            "_head._modules_list.14.reg_convs.2.1.dconv.bn.running_var": "_head._modules_list.14.reg_convs.2.1.dconv.bn.running_var",
            "_head._modules_list.14.reg_convs.2.1.dconv.bn.weight": "_head._modules_list.14.reg_convs.2.1.dconv.bn.weight",
            "_head._modules_list.14.reg_convs.2.1.dconv.conv.weight": "_head._modules_list.14.reg_convs.2.1.dconv.conv.weight",
            "_head._modules_list.14.reg_preds.0.bias": "_head._modules_list.14.reg_preds.0.bias",
            "_head._modules_list.14.reg_preds.0.weight": "_head._modules_list.14.reg_preds.0.weight",
            "_head._modules_list.14.reg_preds.1.bias": "_head._modules_list.14.reg_preds.1.bias",
            "_head._modules_list.14.reg_preds.1.weight": "_head._modules_list.14.reg_preds.1.weight",
            "_head._modules_list.14.reg_preds.2.bias": "_head._modules_list.14.reg_preds.2.bias",
            "_head._modules_list.14.reg_preds.2.weight": "_head._modules_list.14.reg_preds.2.weight",
            "_head._modules_list.14.stems.0.bn.bias": "_head._modules_list.14.stems.0.bn.bias",
            "_head._modules_list.14.stems.0.bn.num_batches_tracked": "_head._modules_list.14.stems.0.bn.num_batches_tracked",
            "_head._modules_list.14.stems.0.bn.running_mean": "_head._modules_list.14.stems.0.bn.running_mean",
            "_head._modules_list.14.stems.0.bn.running_var": "_head._modules_list.14.stems.0.bn.running_var",
            "_head._modules_list.14.stems.0.bn.weight": "_head._modules_list.14.stems.0.bn.weight",
            "_head._modules_list.14.stems.0.conv.weight": "_head._modules_list.14.stems.0.conv.weight",
            "_head._modules_list.14.stems.1.bn.bias": "_head._modules_list.14.stems.1.bn.bias",
            "_head._modules_list.14.stems.1.bn.num_batches_tracked": "_head._modules_list.14.stems.1.bn.num_batches_tracked",
            "_head._modules_list.14.stems.1.bn.running_mean": "_head._modules_list.14.stems.1.bn.running_mean",
            "_head._modules_list.14.stems.1.bn.running_var": "_head._modules_list.14.stems.1.bn.running_var",
            "_head._modules_list.14.stems.1.bn.weight": "_head._modules_list.14.stems.1.bn.weight",
            "_head._modules_list.14.stems.1.conv.weight": "_head._modules_list.14.stems.1.conv.weight",
            "_head._modules_list.14.stems.2.bn.bias": "_head._modules_list.14.stems.2.bn.bias",
            "_head._modules_list.14.stems.2.bn.num_batches_tracked": "_head._modules_list.14.stems.2.bn.num_batches_tracked",
            "_head._modules_list.14.stems.2.bn.running_mean": "_head._modules_list.14.stems.2.bn.running_mean",
            "_head._modules_list.14.stems.2.bn.running_var": "_head._modules_list.14.stems.2.bn.running_var",
            "_head._modules_list.14.stems.2.bn.weight": "_head._modules_list.14.stems.2.bn.weight",
            "_head._modules_list.14.stems.2.conv.weight": "_head._modules_list.14.stems.2.conv.weight",
            "_head._modules_list.3.cv1.bn.bias": "_head._modules_list.3.conv1.bn.bias",
            "_head._modules_list.3.cv1.bn.num_batches_tracked": "_head._modules_list.3.conv1.bn.num_batches_tracked",
            "_head._modules_list.3.cv1.bn.running_mean": "_head._modules_list.3.conv1.bn.running_mean",
            "_head._modules_list.3.cv1.bn.running_var": "_head._modules_list.3.conv1.bn.running_var",
            "_head._modules_list.3.cv1.bn.weight": "_head._modules_list.3.conv1.bn.weight",
            "_head._modules_list.3.cv1.conv.weight": "_head._modules_list.3.conv1.conv.weight",
            "_head._modules_list.3.cv2.bn.bias": "_head._modules_list.3.conv2.bn.bias",
            "_head._modules_list.3.cv2.bn.num_batches_tracked": "_head._modules_list.3.conv2.bn.num_batches_tracked",
            "_head._modules_list.3.cv2.bn.running_mean": "_head._modules_list.3.conv2.bn.running_mean",
            "_head._modules_list.3.cv2.bn.running_var": "_head._modules_list.3.conv2.bn.running_var",
            "_head._modules_list.3.cv2.bn.weight": "_head._modules_list.3.conv2.bn.weight",
            "_head._modules_list.3.cv2.conv.weight": "_head._modules_list.3.conv2.conv.weight",
            "_head._modules_list.3.cv3.bn.bias": "_head._modules_list.3.conv3.bn.bias",
            "_head._modules_list.3.cv3.bn.num_batches_tracked": "_head._modules_list.3.conv3.bn.num_batches_tracked",
            "_head._modules_list.3.cv3.bn.running_mean": "_head._modules_list.3.conv3.bn.running_mean",
            "_head._modules_list.3.cv3.bn.running_var": "_head._modules_list.3.conv3.bn.running_var",
            "_head._modules_list.3.cv3.bn.weight": "_head._modules_list.3.conv3.bn.weight",
            "_head._modules_list.3.cv3.conv.weight": "_head._modules_list.3.conv3.conv.weight",
            "_head._modules_list.3.m.0.cv1.bn.bias": "_head._modules_list.3.bottlenecks.0.cv1.bn.bias",
            "_head._modules_list.3.m.0.cv1.bn.num_batches_tracked": "_head._modules_list.3.bottlenecks.0.cv1.bn.num_batches_tracked",
            "_head._modules_list.3.m.0.cv1.bn.running_mean": "_head._modules_list.3.bottlenecks.0.cv1.bn.running_mean",
            "_head._modules_list.3.m.0.cv1.bn.running_var": "_head._modules_list.3.bottlenecks.0.cv1.bn.running_var",
            "_head._modules_list.3.m.0.cv1.bn.weight": "_head._modules_list.3.bottlenecks.0.cv1.bn.weight",
            "_head._modules_list.3.m.0.cv1.conv.weight": "_head._modules_list.3.bottlenecks.0.cv1.conv.weight",
            "_head._modules_list.3.m.0.cv2.bn.bias": "_head._modules_list.3.bottlenecks.0.cv2.bn.bias",
            "_head._modules_list.3.m.0.cv2.bn.num_batches_tracked": "_head._modules_list.3.bottlenecks.0.cv2.bn.num_batches_tracked",
            "_head._modules_list.3.m.0.cv2.bn.running_mean": "_head._modules_list.3.bottlenecks.0.cv2.bn.running_mean",
            "_head._modules_list.3.m.0.cv2.bn.running_var": "_head._modules_list.3.bottlenecks.0.cv2.bn.running_var",
            "_head._modules_list.3.m.0.cv2.bn.weight": "_head._modules_list.3.bottlenecks.0.cv2.bn.weight",
            "_head._modules_list.3.m.0.cv2.conv.bn.bias": "_head._modules_list.3.bottlenecks.0.cv2.conv.bn.bias",
            "_head._modules_list.3.m.0.cv2.conv.bn.num_batches_tracked": "_head._modules_list.3.bottlenecks.0.cv2.conv.bn.num_batches_tracked",
            "_head._modules_list.3.m.0.cv2.conv.bn.running_mean": "_head._modules_list.3.bottlenecks.0.cv2.conv.bn.running_mean",
            "_head._modules_list.3.m.0.cv2.conv.bn.running_var": "_head._modules_list.3.bottlenecks.0.cv2.conv.bn.running_var",
            "_head._modules_list.3.m.0.cv2.conv.bn.weight": "_head._modules_list.3.bottlenecks.0.cv2.conv.bn.weight",
            "_head._modules_list.3.m.0.cv2.conv.conv.weight": "_head._modules_list.3.bottlenecks.0.cv2.conv.conv.weight",
            "_head._modules_list.3.m.0.cv2.conv.weight": "_head._modules_list.3.bottlenecks.0.cv2.conv.weight",
            "_head._modules_list.3.m.0.cv2.dconv.bn.bias": "_head._modules_list.3.bottlenecks.0.cv2.dconv.bn.bias",
            "_head._modules_list.3.m.0.cv2.dconv.bn.num_batches_tracked": "_head._modules_list.3.bottlenecks.0.cv2.dconv.bn.num_batches_tracked",
            "_head._modules_list.3.m.0.cv2.dconv.bn.running_mean": "_head._modules_list.3.bottlenecks.0.cv2.dconv.bn.running_mean",
            "_head._modules_list.3.m.0.cv2.dconv.bn.running_var": "_head._modules_list.3.bottlenecks.0.cv2.dconv.bn.running_var",
            "_head._modules_list.3.m.0.cv2.dconv.bn.weight": "_head._modules_list.3.bottlenecks.0.cv2.dconv.bn.weight",
            "_head._modules_list.3.m.0.cv2.dconv.conv.weight": "_head._modules_list.3.bottlenecks.0.cv2.dconv.conv.weight",
            "_head._modules_list.3.m.1.cv1.bn.bias": "_head._modules_list.3.bottlenecks.1.cv1.bn.bias",
            "_head._modules_list.3.m.1.cv1.bn.num_batches_tracked": "_head._modules_list.3.bottlenecks.1.cv1.bn.num_batches_tracked",
            "_head._modules_list.3.m.1.cv1.bn.running_mean": "_head._modules_list.3.bottlenecks.1.cv1.bn.running_mean",
            "_head._modules_list.3.m.1.cv1.bn.running_var": "_head._modules_list.3.bottlenecks.1.cv1.bn.running_var",
            "_head._modules_list.3.m.1.cv1.bn.weight": "_head._modules_list.3.bottlenecks.1.cv1.bn.weight",
            "_head._modules_list.3.m.1.cv1.conv.weight": "_head._modules_list.3.bottlenecks.1.cv1.conv.weight",
            "_head._modules_list.3.m.1.cv2.bn.bias": "_head._modules_list.3.bottlenecks.1.cv2.bn.bias",
            "_head._modules_list.3.m.1.cv2.bn.num_batches_tracked": "_head._modules_list.3.bottlenecks.1.cv2.bn.num_batches_tracked",
            "_head._modules_list.3.m.1.cv2.bn.running_mean": "_head._modules_list.3.bottlenecks.1.cv2.bn.running_mean",
            "_head._modules_list.3.m.1.cv2.bn.running_var": "_head._modules_list.3.bottlenecks.1.cv2.bn.running_var",
            "_head._modules_list.3.m.1.cv2.bn.weight": "_head._modules_list.3.bottlenecks.1.cv2.bn.weight",
            "_head._modules_list.3.m.1.cv2.conv.weight": "_head._modules_list.3.bottlenecks.1.cv2.conv.weight",
            "_head._modules_list.3.m.2.cv1.bn.bias": "_head._modules_list.3.bottlenecks.2.cv1.bn.bias",
            "_head._modules_list.3.m.2.cv1.bn.num_batches_tracked": "_head._modules_list.3.bottlenecks.2.cv1.bn.num_batches_tracked",
            "_head._modules_list.3.m.2.cv1.bn.running_mean": "_head._modules_list.3.bottlenecks.2.cv1.bn.running_mean",
            "_head._modules_list.3.m.2.cv1.bn.running_var": "_head._modules_list.3.bottlenecks.2.cv1.bn.running_var",
            "_head._modules_list.3.m.2.cv1.bn.weight": "_head._modules_list.3.bottlenecks.2.cv1.bn.weight",
            "_head._modules_list.3.m.2.cv1.conv.weight": "_head._modules_list.3.bottlenecks.2.cv1.conv.weight",
            "_head._modules_list.3.m.2.cv2.bn.bias": "_head._modules_list.3.bottlenecks.2.cv2.bn.bias",
            "_head._modules_list.3.m.2.cv2.bn.num_batches_tracked": "_head._modules_list.3.bottlenecks.2.cv2.bn.num_batches_tracked",
            "_head._modules_list.3.m.2.cv2.bn.running_mean": "_head._modules_list.3.bottlenecks.2.cv2.bn.running_mean",
            "_head._modules_list.3.m.2.cv2.bn.running_var": "_head._modules_list.3.bottlenecks.2.cv2.bn.running_var",
            "_head._modules_list.3.m.2.cv2.bn.weight": "_head._modules_list.3.bottlenecks.2.cv2.bn.weight",
            "_head._modules_list.3.m.2.cv2.conv.weight": "_head._modules_list.3.bottlenecks.2.cv2.conv.weight",
            "_head._modules_list.4.bn.bias": "_head._modules_list.4.bn.bias",
            "_head._modules_list.4.bn.num_batches_tracked": "_head._modules_list.4.bn.num_batches_tracked",
            "_head._modules_list.4.bn.running_mean": "_head._modules_list.4.bn.running_mean",
            "_head._modules_list.4.bn.running_var": "_head._modules_list.4.bn.running_var",
            "_head._modules_list.4.bn.weight": "_head._modules_list.4.bn.weight",
            "_head._modules_list.4.conv.weight": "_head._modules_list.4.conv.weight",
            "_head._modules_list.7.cv1.bn.bias": "_head._modules_list.7.conv1.bn.bias",
            "_head._modules_list.7.cv1.bn.num_batches_tracked": "_head._modules_list.7.conv1.bn.num_batches_tracked",
            "_head._modules_list.7.cv1.bn.running_mean": "_head._modules_list.7.conv1.bn.running_mean",
            "_head._modules_list.7.cv1.bn.running_var": "_head._modules_list.7.conv1.bn.running_var",
            "_head._modules_list.7.cv1.bn.weight": "_head._modules_list.7.conv1.bn.weight",
            "_head._modules_list.7.cv1.conv.weight": "_head._modules_list.7.conv1.conv.weight",
            "_head._modules_list.7.cv2.bn.bias": "_head._modules_list.7.conv2.bn.bias",
            "_head._modules_list.7.cv2.bn.num_batches_tracked": "_head._modules_list.7.conv2.bn.num_batches_tracked",
            "_head._modules_list.7.cv2.bn.running_mean": "_head._modules_list.7.conv2.bn.running_mean",
            "_head._modules_list.7.cv2.bn.running_var": "_head._modules_list.7.conv2.bn.running_var",
            "_head._modules_list.7.cv2.bn.weight": "_head._modules_list.7.conv2.bn.weight",
            "_head._modules_list.7.cv2.conv.weight": "_head._modules_list.7.conv2.conv.weight",
            "_head._modules_list.7.cv3.bn.bias": "_head._modules_list.7.conv3.bn.bias",
            "_head._modules_list.7.cv3.bn.num_batches_tracked": "_head._modules_list.7.conv3.bn.num_batches_tracked",
            "_head._modules_list.7.cv3.bn.running_mean": "_head._modules_list.7.conv3.bn.running_mean",
            "_head._modules_list.7.cv3.bn.running_var": "_head._modules_list.7.conv3.bn.running_var",
            "_head._modules_list.7.cv3.bn.weight": "_head._modules_list.7.conv3.bn.weight",
            "_head._modules_list.7.cv3.conv.weight": "_head._modules_list.7.conv3.conv.weight",
            "_head._modules_list.7.m.0.cv1.bn.bias": "_head._modules_list.7.bottlenecks.0.cv1.bn.bias",
            "_head._modules_list.7.m.0.cv1.bn.num_batches_tracked": "_head._modules_list.7.bottlenecks.0.cv1.bn.num_batches_tracked",
            "_head._modules_list.7.m.0.cv1.bn.running_mean": "_head._modules_list.7.bottlenecks.0.cv1.bn.running_mean",
            "_head._modules_list.7.m.0.cv1.bn.running_var": "_head._modules_list.7.bottlenecks.0.cv1.bn.running_var",
            "_head._modules_list.7.m.0.cv1.bn.weight": "_head._modules_list.7.bottlenecks.0.cv1.bn.weight",
            "_head._modules_list.7.m.0.cv1.conv.weight": "_head._modules_list.7.bottlenecks.0.cv1.conv.weight",
            "_head._modules_list.7.m.0.cv2.bn.bias": "_head._modules_list.7.bottlenecks.0.cv2.bn.bias",
            "_head._modules_list.7.m.0.cv2.bn.num_batches_tracked": "_head._modules_list.7.bottlenecks.0.cv2.bn.num_batches_tracked",
            "_head._modules_list.7.m.0.cv2.bn.running_mean": "_head._modules_list.7.bottlenecks.0.cv2.bn.running_mean",
            "_head._modules_list.7.m.0.cv2.bn.running_var": "_head._modules_list.7.bottlenecks.0.cv2.bn.running_var",
            "_head._modules_list.7.m.0.cv2.bn.weight": "_head._modules_list.7.bottlenecks.0.cv2.bn.weight",
            "_head._modules_list.7.m.0.cv2.conv.bn.bias": "_head._modules_list.7.bottlenecks.0.cv2.conv.bn.bias",
            "_head._modules_list.7.m.0.cv2.conv.bn.num_batches_tracked": "_head._modules_list.7.bottlenecks.0.cv2.conv.bn.num_batches_tracked",
            "_head._modules_list.7.m.0.cv2.conv.bn.running_mean": "_head._modules_list.7.bottlenecks.0.cv2.conv.bn.running_mean",
            "_head._modules_list.7.m.0.cv2.conv.bn.running_var": "_head._modules_list.7.bottlenecks.0.cv2.conv.bn.running_var",
            "_head._modules_list.7.m.0.cv2.conv.bn.weight": "_head._modules_list.7.bottlenecks.0.cv2.conv.bn.weight",
            "_head._modules_list.7.m.0.cv2.conv.conv.weight": "_head._modules_list.7.bottlenecks.0.cv2.conv.conv.weight",
            "_head._modules_list.7.m.0.cv2.conv.weight": "_head._modules_list.7.bottlenecks.0.cv2.conv.weight",
            "_head._modules_list.7.m.0.cv2.dconv.bn.bias": "_head._modules_list.7.bottlenecks.0.cv2.dconv.bn.bias",
            "_head._modules_list.7.m.0.cv2.dconv.bn.num_batches_tracked": "_head._modules_list.7.bottlenecks.0.cv2.dconv.bn.num_batches_tracked",
            "_head._modules_list.7.m.0.cv2.dconv.bn.running_mean": "_head._modules_list.7.bottlenecks.0.cv2.dconv.bn.running_mean",
            "_head._modules_list.7.m.0.cv2.dconv.bn.running_var": "_head._modules_list.7.bottlenecks.0.cv2.dconv.bn.running_var",
            "_head._modules_list.7.m.0.cv2.dconv.bn.weight": "_head._modules_list.7.bottlenecks.0.cv2.dconv.bn.weight",
            "_head._modules_list.7.m.0.cv2.dconv.conv.weight": "_head._modules_list.7.bottlenecks.0.cv2.dconv.conv.weight",
            "_head._modules_list.7.m.1.cv1.bn.bias": "_head._modules_list.7.bottlenecks.1.cv1.bn.bias",
            "_head._modules_list.7.m.1.cv1.bn.num_batches_tracked": "_head._modules_list.7.bottlenecks.1.cv1.bn.num_batches_tracked",
            "_head._modules_list.7.m.1.cv1.bn.running_mean": "_head._modules_list.7.bottlenecks.1.cv1.bn.running_mean",
            "_head._modules_list.7.m.1.cv1.bn.running_var": "_head._modules_list.7.bottlenecks.1.cv1.bn.running_var",
            "_head._modules_list.7.m.1.cv1.bn.weight": "_head._modules_list.7.bottlenecks.1.cv1.bn.weight",
            "_head._modules_list.7.m.1.cv1.conv.weight": "_head._modules_list.7.bottlenecks.1.cv1.conv.weight",
            "_head._modules_list.7.m.1.cv2.bn.bias": "_head._modules_list.7.bottlenecks.1.cv2.bn.bias",
            "_head._modules_list.7.m.1.cv2.bn.num_batches_tracked": "_head._modules_list.7.bottlenecks.1.cv2.bn.num_batches_tracked",
            "_head._modules_list.7.m.1.cv2.bn.running_mean": "_head._modules_list.7.bottlenecks.1.cv2.bn.running_mean",
            "_head._modules_list.7.m.1.cv2.bn.running_var": "_head._modules_list.7.bottlenecks.1.cv2.bn.running_var",
            "_head._modules_list.7.m.1.cv2.bn.weight": "_head._modules_list.7.bottlenecks.1.cv2.bn.weight",
            "_head._modules_list.7.m.1.cv2.conv.weight": "_head._modules_list.7.bottlenecks.1.cv2.conv.weight",
            "_head._modules_list.7.m.2.cv1.bn.bias": "_head._modules_list.7.bottlenecks.2.cv1.bn.bias",
            "_head._modules_list.7.m.2.cv1.bn.num_batches_tracked": "_head._modules_list.7.bottlenecks.2.cv1.bn.num_batches_tracked",
            "_head._modules_list.7.m.2.cv1.bn.running_mean": "_head._modules_list.7.bottlenecks.2.cv1.bn.running_mean",
            "_head._modules_list.7.m.2.cv1.bn.running_var": "_head._modules_list.7.bottlenecks.2.cv1.bn.running_var",
            "_head._modules_list.7.m.2.cv1.bn.weight": "_head._modules_list.7.bottlenecks.2.cv1.bn.weight",
            "_head._modules_list.7.m.2.cv1.conv.weight": "_head._modules_list.7.bottlenecks.2.cv1.conv.weight",
            "_head._modules_list.7.m.2.cv2.bn.bias": "_head._modules_list.7.bottlenecks.2.cv2.bn.bias",
            "_head._modules_list.7.m.2.cv2.bn.num_batches_tracked": "_head._modules_list.7.bottlenecks.2.cv2.bn.num_batches_tracked",
            "_head._modules_list.7.m.2.cv2.bn.running_mean": "_head._modules_list.7.bottlenecks.2.cv2.bn.running_mean",
            "_head._modules_list.7.m.2.cv2.bn.running_var": "_head._modules_list.7.bottlenecks.2.cv2.bn.running_var",
            "_head._modules_list.7.m.2.cv2.bn.weight": "_head._modules_list.7.bottlenecks.2.cv2.bn.weight",
            "_head._modules_list.7.m.2.cv2.conv.weight": "_head._modules_list.7.bottlenecks.2.cv2.conv.weight",
            "_head._modules_list.8.bn.bias": "_head._modules_list.8.bn.bias",
            "_head._modules_list.8.bn.num_batches_tracked": "_head._modules_list.8.bn.num_batches_tracked",
            "_head._modules_list.8.bn.running_mean": "_head._modules_list.8.bn.running_mean",
            "_head._modules_list.8.bn.running_var": "_head._modules_list.8.bn.running_var",
            "_head._modules_list.8.bn.weight": "_head._modules_list.8.bn.weight",
            "_head._modules_list.8.conv.bn.bias": "_head._modules_list.8.conv.bn.bias",
            "_head._modules_list.8.conv.bn.num_batches_tracked": "_head._modules_list.8.conv.bn.num_batches_tracked",
            "_head._modules_list.8.conv.bn.running_mean": "_head._modules_list.8.conv.bn.running_mean",
            "_head._modules_list.8.conv.bn.running_var": "_head._modules_list.8.conv.bn.running_var",
            "_head._modules_list.8.conv.bn.weight": "_head._modules_list.8.conv.bn.weight",
            "_head._modules_list.8.conv.conv.weight": "_head._modules_list.8.conv.conv.weight",
            "_head._modules_list.8.conv.weight": "_head._modules_list.8.conv.weight",
            "_head._modules_list.8.dconv.bn.bias": "_head._modules_list.8.dconv.bn.bias",
            "_head._modules_list.8.dconv.bn.num_batches_tracked": "_head._modules_list.8.dconv.bn.num_batches_tracked",
            "_head._modules_list.8.dconv.bn.running_mean": "_head._modules_list.8.dconv.bn.running_mean",
            "_head._modules_list.8.dconv.bn.running_var": "_head._modules_list.8.dconv.bn.running_var",
            "_head._modules_list.8.dconv.bn.weight": "_head._modules_list.8.dconv.bn.weight",
            "_head._modules_list.8.dconv.conv.weight": "_head._modules_list.8.dconv.conv.weight",
        }

    def __call__(self, model_state_dict: Mapping[str, Tensor], checkpoint_state_dict: Mapping[str, Tensor]) -> Mapping[str, Tensor]:
        checkpoint_state_dict = self._remove_saved_stride_tensors(checkpoint_state_dict)
        checkpoint_state_dict = self._reshape_old_focus_weights(checkpoint_state_dict)
        checkpoint_state_dict = self._rename_layers(checkpoint_state_dict)
        return checkpoint_state_dict

    def _remove_saved_stride_tensors(self, state_dict):
        exclude_stride_keys = {
            "stride",
            "_head.anchors._anchors",
            "_head.anchors._anchor_grid",
            "_head.anchors._stride",
            "_head._modules_list.14.stride",
        }
        return collections.OrderedDict([(k, v) for k, v in state_dict.items() if k not in exclude_stride_keys])

    def _rename_layers(self, state_dict):
        new_state_dict = collections.OrderedDict()
        for k, v in state_dict.items():
            k = self.layers_rename_table.get(k, k)
            new_state_dict[k] = v
        return new_state_dict

    def _reshape_old_focus_weights(self, state_dict):
        if "_backbone._modules_list.0.conv.conv.weight" in state_dict:
            layer = state_dict["_backbone._modules_list.0.conv.conv.weight"]
            del state_dict["_backbone._modules_list.0.conv.conv.weight"]

            data = torch.zeros((layer.size(0), 3, 6, 6))
            data[:, :, ::2, ::2] = layer.data[:, :3]
            data[:, :, 1::2, ::2] = layer.data[:, 3:6]
            data[:, :, ::2, 1::2] = layer.data[:, 6:9]
            data[:, :, 1::2, 1::2] = layer.data[:, 9:12]
            state_dict["_backbone._modules_list.0.conv.weight"] = data

        return state_dict

    def _yolox_ckpt_solver(self, ckpt_key, ckpt_val, model_key, model_val):
        """
        Helper method for reshaping old pretrained checkpoint's focus weights to 6x6 conv weights.
        """

        if (
            ckpt_val.shape != model_val.shape
            and (ckpt_key == "module._backbone._modules_list.0.conv.conv.weight" or ckpt_key == "_backbone._modules_list.0.conv.conv.weight")
            and model_key == "_backbone._modules_list.0.conv.weight"
        ):
            model_val.data[:, :, ::2, ::2] = ckpt_val.data[:, :3]
            model_val.data[:, :, 1::2, ::2] = ckpt_val.data[:, 3:6]
            model_val.data[:, :, ::2, 1::2] = ckpt_val.data[:, 6:9]
            model_val.data[:, :, 1::2, 1::2] = ckpt_val.data[:, 9:12]
            replacement = model_val
        else:
            replacement = ckpt_val

        return replacement

generate_mapping_table() classmethod

Helper method to generate mapping table between olx YoloX checkpoints and the current YoloX layer names.

Returns:

Type Description
Mapping[str, str]

A mapping dictionary {checkpoint_key: model_key}

Source code in src/super_gradients/training/utils/checkpoint_utils.py
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
@classmethod
def generate_mapping_table(cls) -> Mapping[str, str]:
    """
    Helper method to generate mapping table between olx YoloX checkpoints and the current YoloX layer names.
    :return: A mapping dictionary {checkpoint_key: model_key}
    """
    from super_gradients.common.object_names import Models
    from super_gradients.training import models

    all_mapping_keys = {}
    model_names = [Models.YOLOX_N, Models.YOLOX_T, Models.YOLOX_S, Models.YOLOX_M, Models.YOLOX_L]

    for model_name in model_names:
        model_url = MODEL_URLS[model_name + "_coco"]
        state_dict = load_state_dict_from_url(model_url, progress=True, map_location="cpu")

        model = models.get(model_name, num_classes=80)
        model_state_dict = model.state_dict()
        checkpoint_state_dict = maybe_remove_module_prefix(state_dict["net"])
        new_sd = {
            k: v
            for k, v in checkpoint_state_dict.items()
            if k not in {"stride", "_head.anchors._anchors", "_head.anchors._anchor_grid", "_head.anchors._stride", "_head._modules_list.14.stride"}
        }

        for (model_key, model_value), (checkpoint_key, checkpoint_value) in zip(model_state_dict.items(), new_sd.items()):
            if model_value.size() == checkpoint_value.size() and model_key.split(".")[-1] == checkpoint_key.split(".")[-1]:
                if checkpoint_key in all_mapping_keys:
                    assert all_mapping_keys[checkpoint_key] == model_key
                all_mapping_keys[checkpoint_key] = model_key
            else:
                raise RuntimeError(
                    "Detected mismatch between model and checkpoint state dict keys."
                    f"Model key {model_key} of shape {model_value.size()} does not "
                    f"match checkpoint key {checkpoint_key} of shape {checkpoint_value.size()}"
                )

    return all_mapping_keys

adapt_state_dict_to_fit_model_layer_names(model_state_dict, source_ckpt, exclude=[], solver=None)

Given a model state dict and source checkpoints, the method tries to correct the keys in the model_state_dict to fit the ckpt in order to properly load the weights into the model. If unsuccessful - returns None :param model_state_dict: the model state_dict :param source_ckpt: checkpoint dict :param exclude optional list for excluded layers :param solver: callable with signature (ckpt_key, ckpt_val, model_key, model_val) that returns a desired weight for ckpt_val. :return: renamed checkpoint dict (if possible)

Source code in src/super_gradients/training/utils/checkpoint_utils.py
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
def adapt_state_dict_to_fit_model_layer_names(model_state_dict: dict, source_ckpt: dict, exclude: list = [], solver: callable = None):
    """
    Given a model state dict and source checkpoints, the method tries to correct the keys in the model_state_dict to fit
    the ckpt in order to properly load the weights into the model. If unsuccessful - returns None
        :param model_state_dict:               the model state_dict
        :param source_ckpt:                         checkpoint dict
        :param exclude                  optional list for excluded layers
        :param solver:                  callable with signature (ckpt_key, ckpt_val, model_key, model_val)
                                        that returns a desired weight for ckpt_val.
        :return: renamed checkpoint dict (if possible)
    """
    if solver is None:
        solver = DefaultCheckpointSolver()

    if "net" in source_ckpt.keys():
        source_ckpt = source_ckpt["net"]

    if len(exclude):
        model_state_dict = {k: v for k, v in model_state_dict.items() if not any(x in k for x in exclude)}

    new_ckpt_dict = solver(model_state_dict, source_ckpt)
    return {"net": new_ckpt_dict}

adaptive_load_state_dict(net, state_dict, strict, solver=None)

Adaptively loads state_dict to net, by adapting the state_dict to net's layer names first.

Parameters:

Name Type Description Default
net torch.nn.Module

(nn.Module) to load state_dict to

required
state_dict dict

(dict) Checkpoint state_dict

required
strict Union[bool, StrictLoad]

(StrictLoad) key matching strictness

required
solver

callable with signature (ckpt_key, ckpt_val, model_key, model_val) that returns a desired weight for ckpt_val.

None

Returns:

Type Description
Source code in src/super_gradients/training/utils/checkpoint_utils.py
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
def adaptive_load_state_dict(net: torch.nn.Module, state_dict: dict, strict: Union[bool, StrictLoad], solver=None):
    """
    Adaptively loads state_dict to net, by adapting the state_dict to net's layer names first.
    :param net: (nn.Module) to load state_dict to
    :param state_dict: (dict) Checkpoint state_dict
    :param strict: (StrictLoad) key matching strictness
    :param solver: callable with signature (ckpt_key, ckpt_val, model_key, model_val)
                     that returns a desired weight for ckpt_val.
    :return:
    """
    state_dict = state_dict["net"] if "net" in state_dict else state_dict

    # This is a backward compatibility fix for checkpoints that were saved with DataParallel/DistributedDataParallel wrapper
    # and contains "module." prefix in all keys
    # If all keys start with "module.", then we remove it.
    state_dict = maybe_remove_module_prefix(state_dict)

    try:
        strict_bool = strict if isinstance(strict, bool) else strict != StrictLoad.OFF
        net.load_state_dict(state_dict, strict=strict_bool)
    except (RuntimeError, ValueError, KeyError) as ex:
        if strict == StrictLoad.NO_KEY_MATCHING:
            adapted_state_dict = adapt_state_dict_to_fit_model_layer_names(net.state_dict(), state_dict, solver=solver)
            net.load_state_dict(adapted_state_dict["net"], strict=True)
        elif strict == StrictLoad.KEY_MATCHING:
            transfer_weights(net, state_dict)
        else:
            raise_informative_runtime_error(net.state_dict(), state_dict, ex)

copy_ckpt_to_local_folder(local_ckpt_destination_dir, ckpt_filename, remote_ckpt_source_dir=None, path_src='local', overwrite_local_ckpt=False, load_weights_only=False)

Copy the checkpoint from any supported source to a local destination path :param local_ckpt_destination_dir: destination where the checkpoint will be saved to :param ckpt_filename: ckpt_best.pth Or ckpt_latest.pth :param remote_ckpt_source_dir: Name of the source checkpoint to be loaded (S3 Model ull URL) :param path_src: S3 / url :param overwrite_local_ckpt: determines if checkpoint will be saved in destination dir or in a temp folder

:return: Path to checkpoint
Source code in src/super_gradients/training/utils/checkpoint_utils.py
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
@explicit_params_validation(validation_type="None")
def copy_ckpt_to_local_folder(
    local_ckpt_destination_dir: str,
    ckpt_filename: str,
    remote_ckpt_source_dir: str = None,
    path_src: str = "local",
    overwrite_local_ckpt: bool = False,
    load_weights_only: bool = False,
):
    """
    Copy the checkpoint from any supported source to a local destination path
        :param local_ckpt_destination_dir:  destination where the checkpoint will be saved to
        :param ckpt_filename:         ckpt_best.pth Or ckpt_latest.pth
        :param remote_ckpt_source_dir:       Name of the source checkpoint to be loaded (S3 Model\full URL)
        :param path_src:              S3 / url
        :param overwrite_local_ckpt:  determines if checkpoint will be saved in destination dir or in a temp folder

        :return: Path to checkpoint
    """
    ckpt_file_full_local_path = None

    # IF NOT DEFINED - IT IS SET TO THE TARGET's FOLDER NAME
    remote_ckpt_source_dir = local_ckpt_destination_dir if remote_ckpt_source_dir is None else remote_ckpt_source_dir

    if not overwrite_local_ckpt:
        # CREATE A TEMP FOLDER TO SAVE THE CHECKPOINT TO
        download_ckpt_destination_dir = tempfile.gettempdir()
        print(
            "PLEASE NOTICE - YOU ARE IMPORTING A REMOTE CHECKPOINT WITH overwrite_local_checkpoint = False "
            "-> IT WILL BE REDIRECTED TO A TEMP FOLDER AND DELETED ON MACHINE RESTART"
        )
    else:
        # SAVE THE CHECKPOINT TO MODEL's FOLDER
        download_ckpt_destination_dir = pkg_resources.resource_filename("checkpoints", local_ckpt_destination_dir)

    if path_src.startswith("s3"):
        model_checkpoints_data_interface = ADNNModelRepositoryDataInterfaces(data_connection_location=path_src)
        # DOWNLOAD THE FILE FROM S3 TO THE DESTINATION FOLDER
        ckpt_file_full_local_path = model_checkpoints_data_interface.load_remote_checkpoints_file(
            ckpt_source_remote_dir=remote_ckpt_source_dir,
            ckpt_destination_local_dir=download_ckpt_destination_dir,
            ckpt_file_name=ckpt_filename,
            overwrite_local_checkpoints_file=overwrite_local_ckpt,
        )

        if not load_weights_only:
            # COPY LOG FILES FROM THE REMOTE DIRECTORY TO THE LOCAL ONE ONLY IF LOADING THE CURRENT MODELs CKPT
            model_checkpoints_data_interface.load_all_remote_log_files(
                model_name=remote_ckpt_source_dir, model_checkpoint_local_dir=download_ckpt_destination_dir
            )

    if path_src == "url":
        ckpt_file_full_local_path = download_ckpt_destination_dir + os.path.sep + ckpt_filename
        # DOWNLOAD THE FILE FROM URL TO THE DESTINATION FOLDER
        with wait_for_the_master(get_local_rank()):
            download_url_to_file(remote_ckpt_source_dir, ckpt_file_full_local_path, progress=True)

    return ckpt_file_full_local_path

get_scheduler_state(scheduler)

Wrapper for getting a torch lr scheduler state dict, resolving some issues with CyclicLR (see https://github.com/pytorch/pytorch/pull/91400)

Parameters:

Name Type Description Default
scheduler

torch.optim.lr_scheduler._LRScheduler, the scheduler

required

Returns:

Type Description
Dict[str, Tensor]

the scheduler's state_dict

Source code in src/super_gradients/training/utils/checkpoint_utils.py
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
def get_scheduler_state(scheduler) -> Dict[str, Tensor]:
    """
    Wrapper for getting a torch lr scheduler state dict, resolving some issues with CyclicLR
    (see https://github.com/pytorch/pytorch/pull/91400)
    :param scheduler: torch.optim.lr_scheduler._LRScheduler, the scheduler
    :return:          the scheduler's state_dict
    """
    from super_gradients.training.utils import torch_version_is_greater_or_equal
    from torch.optim.lr_scheduler import CyclicLR

    state = scheduler.state_dict()
    if isinstance(scheduler, CyclicLR) and not torch_version_is_greater_or_equal(2, 0):
        # A check is needed since torch 1.12 does not have the _scale_fn_ref attribute, while other versions do
        if "_scale_fn_ref" in state:
            del state["_scale_fn_ref"]
    return state

load_checkpoint_to_model(net, ckpt_local_path, load_backbone=False, strict=StrictLoad.NO_KEY_MATCHING, load_weights_only=False, load_ema_as_net=False, load_processing_params=False)

Loads the state dict in ckpt_local_path to net and returns the checkpoint's state dict.

Parameters:

Name Type Description Default
net torch.nn.Module

Network to load the checkpoint to

required
ckpt_local_path str

Local path to the checkpoint file

required
load_ema_as_net bool

Will load the EMA inside the checkpoint file to the network when set

False
load_backbone bool

Whether to load the checkpoint as a backbone

False
strict Union[str, StrictLoad]

See super_gradients.common.data_types.enum.strict_load.StrictLoad class documentation for details (default=NO_KEY_MATCHING to suport SG trained checkpoints)

StrictLoad.NO_KEY_MATCHING
load_weights_only bool

Whether to ignore all other entries other then "net".

False
load_processing_params bool

Whether to call set_dataset_processing_params on "processing_params" entry inside the checkpoint file (default=False).

False

Returns:

Type Description
Source code in src/super_gradients/training/utils/checkpoint_utils.py
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
@resolve_param("strict", TypeFactory.from_enum_cls(StrictLoad))
def load_checkpoint_to_model(
    net: torch.nn.Module,
    ckpt_local_path: str,
    load_backbone: bool = False,
    strict: Union[str, StrictLoad] = StrictLoad.NO_KEY_MATCHING,
    load_weights_only: bool = False,
    load_ema_as_net: bool = False,
    load_processing_params: bool = False,
):
    """
    Loads the state dict in ckpt_local_path to net and returns the checkpoint's state dict.

    :param net:                    Network to load the checkpoint to
    :param ckpt_local_path:        Local path to the checkpoint file
    :param load_ema_as_net:        Will load the EMA inside the checkpoint file to the network when set
    :param load_backbone:          Whether to load the checkpoint as a backbone
    :param strict:                 See super_gradients.common.data_types.enum.strict_load.StrictLoad class documentation for details
                                   (default=NO_KEY_MATCHING to suport SG trained checkpoints)
    :param load_weights_only:      Whether to ignore all other entries other then "net".
    :param load_processing_params: Whether to call set_dataset_processing_params on "processing_params" entry inside the
                                   checkpoint file (default=False).
    :return:
    """
    net = unwrap_model(net)

    if load_backbone and not hasattr(net, "backbone"):
        raise ValueError("No backbone attribute in net - Can't load backbone weights")

    # LOAD THE LOCAL CHECKPOINT PATH INTO A state_dict OBJECT
    checkpoint = read_ckpt_state_dict(ckpt_path=ckpt_local_path)

    if load_ema_as_net:
        if "ema_net" not in checkpoint.keys():
            raise ValueError("Can't load ema network- no EMA network stored in checkpoint file")
        else:
            checkpoint["net"] = checkpoint["ema_net"]

    # LOAD THE CHECKPOINTS WEIGHTS TO THE MODEL
    if load_backbone:
        adaptive_load_state_dict(net.backbone, checkpoint, strict)
    else:
        adaptive_load_state_dict(net, checkpoint, strict)

    message_suffix = " checkpoint." if not load_ema_as_net else " EMA checkpoint."
    message_model = "model" if not load_backbone else "model's backbone"
    logger.info("Successfully loaded " + message_model + " weights from " + ckpt_local_path + message_suffix)

    _maybe_load_preprocessing_params(net, checkpoint)

    if load_weights_only or load_backbone:
        # DISCARD ALL THE DATA STORED IN CHECKPOINT OTHER THAN THE WEIGHTS
        [checkpoint.pop(key) for key in list(checkpoint.keys()) if key != "net"]

    return checkpoint

load_pretrained_weights(model, architecture, pretrained_weights)

Loads pretrained weights from the MODEL_URLS dictionary to model

Parameters:

Name Type Description Default
architecture str

name of the model's architecture

required
model torch.nn.Module

model to load pretrinaed weights for

required
pretrained_weights str

name for the pretrianed weights (i.e imagenet)

required

Returns:

Type Description

None

Source code in src/super_gradients/training/utils/checkpoint_utils.py
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
def load_pretrained_weights(model: torch.nn.Module, architecture: str, pretrained_weights: str):
    """
    Loads pretrained weights from the MODEL_URLS dictionary to model

    :param architecture:        name of the model's architecture
    :param model:               model to load pretrinaed weights for
    :param pretrained_weights:  name for the pretrianed weights (i.e imagenet)

    :return:                    None
    """
    from super_gradients.common.object_names import Models

    model_url_key = architecture + "_" + str(pretrained_weights)
    if model_url_key not in MODEL_URLS.keys():
        raise MissingPretrainedWeightsException(model_url_key)

    if pretrained_weights in DATASET_LICENSES:
        logger.warning(
            f":warning: The pre-trained models provided by SuperGradients may have their own licenses or terms and "
            "conditions derived from the dataset used for pre-training.\n It is your responsibility to determine whether you "
            "have permission to use the models for your use case.\n The model you have requested was pre-trained on the "
            f"{pretrained_weights} dataset, published under the following terms: {DATASET_LICENSES[pretrained_weights]}"
        )
    url = MODEL_URLS[model_url_key]

    if architecture in {Models.YOLO_NAS_S, Models.YOLO_NAS_M, Models.YOLO_NAS_L}:
        logger.info(
            "License Notification: YOLO-NAS pre-trained weights are subjected to the specific license terms and conditions detailed in \n"
            "https://github.com/Deci-AI/super-gradients/blob/master/LICENSE.YOLONAS.md\n"
            "By downloading the pre-trained weight files you agree to comply with these terms."
        )
    elif architecture in {Models.YOLO_NAS_POSE_N, Models.YOLO_NAS_POSE_S, Models.YOLO_NAS_POSE_M, Models.YOLO_NAS_POSE_L}:
        logger.info(
            "License Notification: YOLO-NAS-POSE pre-trained weights are subjected to the specific license terms and conditions detailed in \n"
            "https://github.com/Deci-AI/super-gradients/blob/master/LICENSE.YOLONAS-POSE.md\n"
            "By downloading the pre-trained weight files you agree to comply with these terms."
        )

    # Basically this check allows settings pretrained weights from local path using file:///path/to/weights scheme
    # which is a valid URI scheme for local files
    # Supporting local files and file URI allows us modification of pretrained weights dics in unit tests
    if url.startswith("file://") or os.path.exists(url):
        pretrained_state_dict = torch.load(url.replace("file://", ""), map_location="cpu")
    else:
        unique_filename = url.split("https://sghub.deci.ai/models/")[1].replace("/", "_").replace(" ", "_")
        map_location = torch.device("cpu")
        with wait_for_the_master(get_local_rank()):
            pretrained_state_dict = load_state_dict_from_url(url=url, map_location=map_location, file_name=unique_filename)

    _load_weights(architecture, model, pretrained_state_dict)
    _maybe_load_preprocessing_params(model, pretrained_state_dict)

load_pretrained_weights_local(model, architecture, pretrained_weights)

Loads pretrained weights from the MODEL_URLS dictionary to model

Parameters:

Name Type Description Default
architecture str

name of the model's architecture

required
model torch.nn.Module

model to load pretrinaed weights for

required
pretrained_weights str

path tp pretrained weights

required

Returns:

Type Description

None

Source code in src/super_gradients/training/utils/checkpoint_utils.py
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
def load_pretrained_weights_local(model: torch.nn.Module, architecture: str, pretrained_weights: str):
    """
    Loads pretrained weights from the MODEL_URLS dictionary to model
    :param architecture: name of the model's architecture
    :param model: model to load pretrinaed weights for
    :param pretrained_weights: path tp pretrained weights
    :return: None
    """

    map_location = torch.device("cpu")

    pretrained_state_dict = torch.load(pretrained_weights, map_location=map_location)
    _load_weights(architecture, model, pretrained_state_dict)
    _maybe_load_preprocessing_params(model, pretrained_state_dict)

maybe_remove_module_prefix(state_dict, prefix='module.')

Checks is all the keys in state_dict start with prefix and if this is true removes this prefix. This function is intended to drop a "module." prefix from all keys in checkpoint that was saved with DataParallel/DistributedDataParallel wrapper.

Since SG 3.1 we changed this behavior and always unwrap the model before saving the state_dict. However, to keep the compatibility with older checkpoints, we must do the 'cleanup' before loading the state_dict.

Returns:

Type Description
Mapping[str, Tensor]

state_dict: The model state_dict after removing the prefix

Source code in src/super_gradients/training/utils/checkpoint_utils.py
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
def maybe_remove_module_prefix(state_dict: Mapping[str, Tensor], prefix: str = "module.") -> Mapping[str, Tensor]:
    """
    Checks is all the keys in `state_dict` start with `prefix` and if this is true removes this prefix.
    This function is intended to drop a "module." prefix from all keys in checkpoint that was saved
    with DataParallel/DistributedDataParallel wrapper.

    Since SG 3.1 we changed this behavior and always unwrap the model before saving the state_dict.
    However, to keep the compatibility with older checkpoints, we must do the 'cleanup' before loading the state_dict.

    :params: state_dict: The model state_dict
    :params: prefix: (str) prefix to remove. Default is "module."
    :return: state_dict: The model state_dict after removing the prefix

    """
    offset = len(prefix)
    if all([key.startswith(prefix) for key in state_dict.keys()]):
        state_dict = collections.OrderedDict([(key[offset:], value) for key, value in state_dict.items()])
    return state_dict

raise_informative_runtime_error(state_dict, checkpoint, exception_msg)

Given a model state dict and source checkpoints, the method calls "adapt_state_dict_to_fit_model_layer_names" and enhances the exception_msg if loading the checkpoint_dict via the conversion method is possible

Source code in src/super_gradients/training/utils/checkpoint_utils.py
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
def raise_informative_runtime_error(state_dict, checkpoint, exception_msg):
    """
    Given a model state dict and source checkpoints, the method calls "adapt_state_dict_to_fit_model_layer_names"
    and enhances the exception_msg if loading the checkpoint_dict via the conversion method is possible
    """
    try:
        new_ckpt_dict = adapt_state_dict_to_fit_model_layer_names(state_dict, checkpoint)
        temp_file = tempfile.NamedTemporaryFile().name + ".pt"
        torch.save(new_ckpt_dict, temp_file)
        exception_msg = (
            f"\n{'=' * 200}\n{str(exception_msg)} \nconvert ckpt via the utils.adapt_state_dict_to_fit_"
            f"model_layer_names method\na converted checkpoint file was saved in the path {temp_file}\n{'=' * 200}"
        )
    except ValueError as ex:  # IN CASE adapt_state_dict_to_fit_model_layer_names WAS UNSUCCESSFUL
        exception_msg = f"\n{'=' * 200} \nThe checkpoint and model shapes do no fit, e.g.: {ex}\n{'=' * 200}"
    finally:
        raise RuntimeError(exception_msg)

read_ckpt_state_dict(ckpt_path, device='cpu')

Reads a checkpoint state dict from a given path or url

Parameters:

Name Type Description Default
ckpt_path str

Checkpoint path or url

required
device

Target device where tensors should be loaded

'cpu'

Returns:

Type Description
Mapping[str, torch.Tensor]

Checkpoint state dict object

Source code in src/super_gradients/training/utils/checkpoint_utils.py
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
def read_ckpt_state_dict(ckpt_path: str, device="cpu") -> Mapping[str, torch.Tensor]:
    """
    Reads a checkpoint state dict from a given path or url

    :param ckpt_path: Checkpoint path or url
    :param device: Target device where tensors should be loaded
    :return: Checkpoint state dict object
    """

    if ckpt_path.startswith("https://"):
        with wait_for_the_master(get_local_rank()):
            state_dict = load_state_dict_from_url(ckpt_path, progress=False, map_location=device)
        return state_dict
    else:
        if not os.path.exists(ckpt_path):
            raise FileNotFoundError(f"Incorrect Checkpoint path: {ckpt_path} (This should be an absolute path)")

        state_dict = torch.load(ckpt_path, map_location=device)
        return state_dict

transfer_weights(model, model_state_dict)

Copy weights from model_state_dict to model, skipping layers that are incompatible (Having different shape). This method is helpful if you are doing some model surgery and want to load part of the model weights into different model. This function will go over all the layers in model_state_dict and will try to find a matching layer in model and copy the weights into it. If shape will not match, the layer will be skipped.

Parameters:

Name Type Description Default
model nn.Module

Model to load weights into

required
model_state_dict Mapping[str, Tensor]

Model state dict to load weights from

required

Returns:

Type Description
None

None

Source code in src/super_gradients/training/utils/checkpoint_utils.py
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
def transfer_weights(model: nn.Module, model_state_dict: Mapping[str, Tensor]) -> None:
    """
    Copy weights from `model_state_dict` to `model`, skipping layers that are incompatible (Having different shape).
    This method is helpful if you are doing some model surgery and want to load
    part of the model weights into different model.
    This function will go over all the layers in `model_state_dict` and will try to find a matching layer in `model` and
    copy the weights into it. If shape will not match, the layer will be skipped.

    :param model: Model to load weights into
    :param model_state_dict: Model state dict to load weights from
    :return: None
    """

    transfered_weights = 0
    for name, value in model_state_dict.items():
        try:
            model.load_state_dict(collections.OrderedDict([(name, value)]), strict=False)
            transfered_weights += 1
        except RuntimeError:
            pass

    percentage_of_checkpoint = transfered_weights / len(model_state_dict)
    percentage_of_model = transfered_weights / len(model.state_dict())
    logger.debug(
        f"Transfered {transfered_weights} ({(100 * percentage_of_checkpoint):.2f}%) weights from the checkpoint. "
        f"{(100 * percentage_of_model):.2f}% of the model layers were initialized using checkpoint."
    )

BaseDatasetAdapterCollateFN

Bases: ABC

Base Collate function that adapts an input data to SuperGradients format

This is done by applying the adapter logic either before or after the original collate function, depending on whether the adapter was set up on a batch or a sample.

Note that the original collate function (if any) will still be used, but will be wrapped into this class.

Source code in src/super_gradients/training/utils/collate_fn/adapters/base_adapter_collate_fn.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
class BaseDatasetAdapterCollateFN(ABC):
    """Base Collate function that adapts an input data to SuperGradients format

    This is done by applying the adapter logic either before or after the original collate function,
    depending on whether the adapter was set up on a batch or a sample.

    Note that the original collate function (if any) will still be used, but will be wrapped into this class.
    """

    @resolve_param("base_collate_fn", CollateFunctionsFactory())
    def __init__(self, adapter: BaseDatasetAdapter, base_collate_fn: Callable):
        """
        :param adapter:             Dataset adapter to use
        :param base_collate_fn:     Collate function to wrap. If None, the default collate function will be used.
        """
        self._adapt_on_batch = adapter.data_config.is_batch

        self.adapter = adapter
        self._base_collate_fn = base_collate_fn or default_collate

        if isinstance(self._base_collate_fn, type(self)):
            raise RuntimeError(f"You just tried to instantiate {self.__class__.__name__} with a `base_collate_fn` of the same type, which is not supported.")

    def __call__(self, samples: Iterable[SupportedDataType]) -> Tuple[torch.Tensor, torch.Tensor]:

        if self._require_setup:
            # This is required because python `input` is no compatible multiprocessing (e.g. `num_workers > 0`, or `DDP`)
            # And if not `self._require_setup`, the adapter will need to ask at least one question using the python `input`
            raise RuntimeError(
                f"Trying to collate using `{self.__class__.__name__}`, but it was not fully set up yet. Please do one of the following\n"
                f"   - Call `{self.__class__.__name__}(...).setup_adapter(dataloader)` before iterating over the dataloader.\n"
                f"   - or Instantiate `{self.__class__.__name__}(config_path=...)` with `config_path` mapping to the cache file of "
                f"an adapter that was already set up on this data.\n"
            )

        if not self._adapt_on_batch:
            samples = self._adapt_samples(samples=samples)

        batch = self._base_collate_fn(samples)

        if self._adapt_on_batch:
            batch = self._adapt_batch(batch=batch)

        images, targets = batch  # At this point we know it is (images, targets) because the adapter was used - either on samples or batch
        return images, targets

    @property
    def _require_setup(self) -> bool:
        return not self.adapter.data_config.is_completely_initialized

    def _adapt_samples(self, samples: Iterable[SupportedDataType]) -> List[Tuple[torch.Tensor, torch.Tensor]]:
        """Apply the adapter logic to a list of samples. This should be called only if the adapter was NOT setup on a batch.
        :param samples: List of samples to adapt
        :return:        List of (Image, Targets)
        """
        adapted_samples = []
        for sample in samples:
            images, targets = self._adapt(data=sample)  # Will construct batch of 1
            images, targets = images.squeeze(0), targets.squeeze(0)  # Extract the sample
            adapted_samples.append((images, targets))
        return adapted_samples

    def _adapt_batch(self, batch: Sequence[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
        """Apply the adapter logic to a batch. This should be called only if the adapter was setup on a batch.
        :param batch: Batch of samples to adapt
        :return:      Adapted batch (Images, Targets)
        """
        return self._adapt(data=batch)

    def _adapt(self, data: Iterable[SupportedDataType]) -> Tuple[torch.Tensor, torch.Tensor]:
        images, targets = self.adapter.adapt(data)
        images = images.float()  # SG takes float as input
        return images, targets

__init__(adapter, base_collate_fn)

Parameters:

Name Type Description Default
adapter BaseDatasetAdapter

Dataset adapter to use

required
base_collate_fn Callable

Collate function to wrap. If None, the default collate function will be used.

required
Source code in src/super_gradients/training/utils/collate_fn/adapters/base_adapter_collate_fn.py
22
23
24
25
26
27
28
29
30
31
32
33
34
@resolve_param("base_collate_fn", CollateFunctionsFactory())
def __init__(self, adapter: BaseDatasetAdapter, base_collate_fn: Callable):
    """
    :param adapter:             Dataset adapter to use
    :param base_collate_fn:     Collate function to wrap. If None, the default collate function will be used.
    """
    self._adapt_on_batch = adapter.data_config.is_batch

    self.adapter = adapter
    self._base_collate_fn = base_collate_fn or default_collate

    if isinstance(self._base_collate_fn, type(self)):
        raise RuntimeError(f"You just tried to instantiate {self.__class__.__name__} with a `base_collate_fn` of the same type, which is not supported.")

ClassificationDatasetAdapterCollateFN

Bases: BaseDatasetAdapterCollateFN

Classification Collate function that adapts an input data to SuperGradients format

This is done by applying the adapter logic either before or after the original collate function, depending on whether the adapter was set up on a batch or a sample.

Note that the original collate function (if any) will still be used, but will be wrapped into this class.

Source code in src/super_gradients/training/utils/collate_fn/adapters/classification_adapter_collate_fn.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
@register_collate_function()
class ClassificationDatasetAdapterCollateFN(BaseDatasetAdapterCollateFN):
    """Classification Collate function that adapts an input data to SuperGradients format

    This is done by applying the adapter logic either before or after the original collate function,
    depending on whether the adapter was set up on a batch or a sample.

    Note that the original collate function (if any) will still be used, but will be wrapped into this class.
    """

    @resolve_param("base_collate_fn", CollateFunctionsFactory())
    def __init__(self, config: Optional[ClassificationDataConfig] = None, config_path: Optional[str] = None, base_collate_fn: Optional[Callable] = None):
        """
        :param config:          Adapter configuration. Use this if you want to hard code some specificities about your dataset.
                                Mutually exclusive with `config_path`.
        :param config_path:     Adapter cache path. Use this if you want to load and/or save the adapter config from a local path.
                                Mutually exclusive with `config`.
        :param base_collate_fn: Collate function to use. Use this if you .If None, the pytorch default collate function will be used.
        """
        if config and config_path:
            raise ValueError("`config` and `config_path` cannot be set at the same time.")
        elif config is None and config_path:
            adapter = ClassificationDatasetAdapter.from_cache(cache_path=config_path)
        elif config is not None and config_path is None:
            adapter = ClassificationDatasetAdapter(data_config=config)
        else:
            raise ValueError("Please either set `config` or `config_path`.")

        super().__init__(adapter=adapter, base_collate_fn=base_collate_fn or base_collate_fn)

    def __call__(self, samples: Iterable[SupportedDataType]) -> Tuple[torch.Tensor, torch.Tensor]:
        images, targets = super().__call__(samples=samples)  # This already returns a batch of (images, targets)
        images = images / 255
        return images, targets

__init__(config=None, config_path=None, base_collate_fn=None)

Parameters:

Name Type Description Default
config Optional[ClassificationDataConfig]

Adapter configuration. Use this if you want to hard code some specificities about your dataset. Mutually exclusive with config_path.

None
config_path Optional[str]

Adapter cache path. Use this if you want to load and/or save the adapter config from a local path. Mutually exclusive with config.

None
base_collate_fn Optional[Callable]

Collate function to use. Use this if you .If None, the pytorch default collate function will be used.

None
Source code in src/super_gradients/training/utils/collate_fn/adapters/classification_adapter_collate_fn.py
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
@resolve_param("base_collate_fn", CollateFunctionsFactory())
def __init__(self, config: Optional[ClassificationDataConfig] = None, config_path: Optional[str] = None, base_collate_fn: Optional[Callable] = None):
    """
    :param config:          Adapter configuration. Use this if you want to hard code some specificities about your dataset.
                            Mutually exclusive with `config_path`.
    :param config_path:     Adapter cache path. Use this if you want to load and/or save the adapter config from a local path.
                            Mutually exclusive with `config`.
    :param base_collate_fn: Collate function to use. Use this if you .If None, the pytorch default collate function will be used.
    """
    if config and config_path:
        raise ValueError("`config` and `config_path` cannot be set at the same time.")
    elif config is None and config_path:
        adapter = ClassificationDatasetAdapter.from_cache(cache_path=config_path)
    elif config is not None and config_path is None:
        adapter = ClassificationDatasetAdapter(data_config=config)
    else:
        raise ValueError("Please either set `config` or `config_path`.")

    super().__init__(adapter=adapter, base_collate_fn=base_collate_fn or base_collate_fn)

DetectionDatasetAdapterCollateFN

Bases: BaseDatasetAdapterCollateFN

Detection Collate function that adapts an input data to SuperGradients format for YOLOX, YOLONAS and PPYOLOE.

This is done by applying the adapter logic either before or after the original collate function, depending on whether the adapter was set up on a batch or a sample.

Note that the original collate function (if any) will still be used, but will be wrapped into this class.

Source code in src/super_gradients/training/utils/collate_fn/adapters/detection_adapter_collate_fn.py
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
@register_collate_function()
class DetectionDatasetAdapterCollateFN(BaseDatasetAdapterCollateFN):
    """Detection Collate function that adapts an input data to SuperGradients format for YOLOX, YOLONAS and PPYOLOE.

    This is done by applying the adapter logic either before or after the original collate function,
    depending on whether the adapter was set up on a batch or a sample.

    Note that the original collate function (if any) will still be used, but will be wrapped into this class.
    """

    @resolve_param("base_collate_fn", CollateFunctionsFactory())
    def __init__(self, config: Optional[DetectionDataConfig] = None, config_path: Optional[str] = None, base_collate_fn: Optional[Callable] = None):
        """
        :param config:          Adapter configuration. Use this if you want to hard code some specificities about your dataset.
                                Mutually exclusive with `config_path`.
        :param config_path:     Adapter cache path. Use this if you want to load and/or save the adapter config from a local path.
                                Mutually exclusive with `config`.
        :param base_collate_fn: Collate function to use. Use this if you .If None, the pytorch default collate function will be used.
        """
        if config and config_path:
            raise ValueError("`config` and `config_path` cannot be set at the same time.")
        elif config is None and config_path:
            adapter = DetectionDatasetAdapter.from_cache(cache_path=config_path)
        elif config is not None and config_path is None:
            adapter = DetectionDatasetAdapter(data_config=config)
        else:
            raise ValueError("Please either set `config` or `config_path`.")

        logger.info("You are using Detection Adapter. Please note that it was designed specifically for YOLONAS, YOLOX and PPYOLOE.")

        # `DetectionCollateFN()` is the default collate_fn for detection.
        # But if the adapter was used on already collated batches, we don't want to force it.
        base_collate_fn = base_collate_fn or (default_collate if adapter.data_config.is_batch else DetectionCollateFN())
        super().__init__(adapter=adapter, base_collate_fn=base_collate_fn)

    def _adapt_samples(self, samples: Iterable[SupportedDataType]) -> List[Tuple[torch.Tensor, torch.Tensor]]:
        """Apply the adapter logic to a list of samples. This should be called only if the adapter was NOT setup on a batch.
        :param samples: List of samples to adapt
        :return:        List of (Image, Targets)
        """
        from super_gradients.training.utils.detection_utils import xyxy2cxcywh

        adapted_samples = []
        for sample in samples:
            images, targets = self._adapt(sample)  # Will construct batch of 1
            images, targets = images[0], targets[0]  # Extract the sample
            targets[:, 1:] = xyxy2cxcywh(targets[:, 1:])  # Adapter is designed to work on label_cxcywh format (YOLOX, PPYOLOE, YOLONAS)
            adapted_samples.append((images, targets))
        return adapted_samples

    def _adapt_batch(self, batch: Sequence[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
        from super_gradients.training.utils.detection_utils import xyxy2cxcywh

        images, targets = super()._adapt_batch(batch)
        targets = DetectionCollateFN._format_targets(targets)
        targets[:, 2:] = xyxy2cxcywh(targets[:, 2:])  # Adapter is designed to work on label_cxcywh format (YOLOX, PPYOLOE, YOLONAS)
        return images, targets

__init__(config=None, config_path=None, base_collate_fn=None)

Parameters:

Name Type Description Default
config Optional[DetectionDataConfig]

Adapter configuration. Use this if you want to hard code some specificities about your dataset. Mutually exclusive with config_path.

None
config_path Optional[str]

Adapter cache path. Use this if you want to load and/or save the adapter config from a local path. Mutually exclusive with config.

None
base_collate_fn Optional[Callable]

Collate function to use. Use this if you .If None, the pytorch default collate function will be used.

None
Source code in src/super_gradients/training/utils/collate_fn/adapters/detection_adapter_collate_fn.py
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
@resolve_param("base_collate_fn", CollateFunctionsFactory())
def __init__(self, config: Optional[DetectionDataConfig] = None, config_path: Optional[str] = None, base_collate_fn: Optional[Callable] = None):
    """
    :param config:          Adapter configuration. Use this if you want to hard code some specificities about your dataset.
                            Mutually exclusive with `config_path`.
    :param config_path:     Adapter cache path. Use this if you want to load and/or save the adapter config from a local path.
                            Mutually exclusive with `config`.
    :param base_collate_fn: Collate function to use. Use this if you .If None, the pytorch default collate function will be used.
    """
    if config and config_path:
        raise ValueError("`config` and `config_path` cannot be set at the same time.")
    elif config is None and config_path:
        adapter = DetectionDatasetAdapter.from_cache(cache_path=config_path)
    elif config is not None and config_path is None:
        adapter = DetectionDatasetAdapter(data_config=config)
    else:
        raise ValueError("Please either set `config` or `config_path`.")

    logger.info("You are using Detection Adapter. Please note that it was designed specifically for YOLONAS, YOLOX and PPYOLOE.")

    # `DetectionCollateFN()` is the default collate_fn for detection.
    # But if the adapter was used on already collated batches, we don't want to force it.
    base_collate_fn = base_collate_fn or (default_collate if adapter.data_config.is_batch else DetectionCollateFN())
    super().__init__(adapter=adapter, base_collate_fn=base_collate_fn)

ensure_flat_bbox_batch(bbox_batch)

Flatten a batched bounding box tensor and prepend the batch ID to each bounding box. Excludes padding boxes.

Parameters:

Name Type Description Default
bbox_batch torch.Tensor

Bounding box tensor of shape (BS, PaddingSize, 5).

required

Returns:

Type Description
torch.Tensor

Flattened tensor of shape (N, 6), where N <= BS * PaddingSize.

Source code in src/super_gradients/training/utils/collate_fn/adapters/detection_adapter_collate_fn.py
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
def ensure_flat_bbox_batch(bbox_batch: torch.Tensor) -> torch.Tensor:
    """
    Flatten a batched bounding box tensor and prepend the batch ID to each bounding box. Excludes padding boxes.

    :param bbox_batch: Bounding box tensor of shape (BS, PaddingSize, 5).
    :return: Flattened tensor of shape (N, 6), where N <= BS * PaddingSize.
    """

    # Create a tensor of batch IDs
    batch_ids = torch.arange(bbox_batch.size(0), device=bbox_batch.device).unsqueeze(-1)
    batch_ids = batch_ids.repeat(1, bbox_batch.size(1)).reshape(-1, 1)  # Shape: (BS*PaddingSize, 1)

    # Reshape bounding box tensor
    bbox_reshaped = bbox_batch.reshape(-1, 5)  # Shape: (BS*PaddingSize, 5)

    # Concatenate batch IDs and reshaped bounding boxes
    flat_bbox = torch.cat((batch_ids, bbox_reshaped), dim=1)  # Shape: (BS*PaddingSize, 6)

    # Filter out padding boxes (assuming padding boxes have all values zero)
    non_padding_mask = torch.any(flat_bbox[:, 1:] != 0, dim=1)
    flat_bbox = flat_bbox[non_padding_mask]

    return flat_bbox

SegmentationDatasetAdapterCollateFN

Bases: BaseDatasetAdapterCollateFN

Segmentation Collate function that adapts an input data to SuperGradients format

This is done by applying the adapter logic either before or after the original collate function, depending on whether the adapter was set up on a batch or a sample.

Note that the original collate function (if any) will still be used, but will be wrapped into this class.

Source code in src/super_gradients/training/utils/collate_fn/adapters/segmentation_adapter_collate_fn.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
@register_collate_function()
class SegmentationDatasetAdapterCollateFN(BaseDatasetAdapterCollateFN):
    """Segmentation Collate function that adapts an input data to SuperGradients format

    This is done by applying the adapter logic either before or after the original collate function,
    depending on whether the adapter was set up on a batch or a sample.

    Note that the original collate function (if any) will still be used, but will be wrapped into this class.
    """

    @resolve_param("base_collate_fn", CollateFunctionsFactory())
    def __init__(self, config: Optional[SegmentationDataConfig] = None, config_path: Optional[str] = None, base_collate_fn: Optional[Callable] = None):
        """
        :param config:          Adapter configuration. Use this if you want to hard code some specificities about your dataset.
                                Mutually exclusive with `config_path`.
        :param config_path:     Adapter cache path. Use this if you want to load and/or save the adapter config from a local path.
                                Mutually exclusive with `config`.
        :param base_collate_fn: Collate function to use. Use this if you .If None, the pytorch default collate function will be used.
        """
        if config and config_path:
            raise ValueError("`config` and `config_path` cannot be set at the same time.")
        elif config is None and config_path:
            adapter = SegmentationDatasetAdapter.from_cache(cache_path=config_path)
        elif config is not None and config_path is None:
            adapter = SegmentationDatasetAdapter(data_config=config)
        else:
            raise ValueError("Please either set `config` or `config_path`.")

        super().__init__(adapter=adapter, base_collate_fn=base_collate_fn or base_collate_fn)

    def __call__(self, samples: Iterable[SupportedDataType]) -> Tuple[torch.Tensor, torch.Tensor]:
        images, targets = super().__call__(samples=samples)
        return images, targets

__init__(config=None, config_path=None, base_collate_fn=None)

Parameters:

Name Type Description Default
config Optional[SegmentationDataConfig]

Adapter configuration. Use this if you want to hard code some specificities about your dataset. Mutually exclusive with config_path.

None
config_path Optional[str]

Adapter cache path. Use this if you want to load and/or save the adapter config from a local path. Mutually exclusive with config.

None
base_collate_fn Optional[Callable]

Collate function to use. Use this if you .If None, the pytorch default collate function will be used.

None
Source code in src/super_gradients/training/utils/collate_fn/adapters/segmentation_adapter_collate_fn.py
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
@resolve_param("base_collate_fn", CollateFunctionsFactory())
def __init__(self, config: Optional[SegmentationDataConfig] = None, config_path: Optional[str] = None, base_collate_fn: Optional[Callable] = None):
    """
    :param config:          Adapter configuration. Use this if you want to hard code some specificities about your dataset.
                            Mutually exclusive with `config_path`.
    :param config_path:     Adapter cache path. Use this if you want to load and/or save the adapter config from a local path.
                            Mutually exclusive with `config`.
    :param base_collate_fn: Collate function to use. Use this if you .If None, the pytorch default collate function will be used.
    """
    if config and config_path:
        raise ValueError("`config` and `config_path` cannot be set at the same time.")
    elif config is None and config_path:
        adapter = SegmentationDatasetAdapter.from_cache(cache_path=config_path)
    elif config is not None and config_path is None:
        adapter = SegmentationDatasetAdapter(data_config=config)
    else:
        raise ValueError("Please either set `config` or `config_path`.")

    super().__init__(adapter=adapter, base_collate_fn=base_collate_fn or base_collate_fn)

CrowdDetectionCollateFN

Bases: DetectionCollateFN

Collate function for Yolox training with additional_batch_items that includes crowd targets

Source code in src/super_gradients/training/utils/collate_fn/crowd_detection_collate_fn.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
@register_collate_function()
class CrowdDetectionCollateFN(DetectionCollateFN):
    """
    Collate function for Yolox training with additional_batch_items that includes crowd targets
    """

    def __init__(self):
        super().__init__()
        self.expected_item_names = ("image", "targets", "crowd_targets")

    def __call__(self, data) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, torch.Tensor]]:
        try:
            images_batch, labels_batch, crowd_labels_batch = list(zip(*data))
        except (ValueError, TypeError):
            raise DatasetItemsException(data_sample=data[0], collate_type=type(self), expected_item_names=self.expected_item_names)

        return self._format_images(images_batch), self._format_targets(labels_batch), {"crowd_targets": self._format_targets(crowd_labels_batch)}

CrowdDetectionPPYoloECollateFN

Bases: PPYoloECollateFN

Collate function for Yolox training with additional_batch_items that includes crowd targets

Source code in src/super_gradients/training/utils/collate_fn/crowd_detection_ppyoloe_collate_fn.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
@register_collate_function()
class CrowdDetectionPPYoloECollateFN(PPYoloECollateFN):
    """
    Collate function for Yolox training with additional_batch_items that includes crowd targets
    """

    def __init__(
        self, random_resize_sizes: Union[List[int], None] = None, random_resize_modes: Union[List[int], None] = None, random_aspect_ratio: bool = False
    ):
        """
        :param random_resize_sizes: List of sizes to randomly resize the image to. If None, will not resize.
        :param random_resize_modes: List of interpolation modes to randomly resize the image to. If None, will not resize.
        :param random_aspect_ratio: If True, will randomly choose both width and height from random_resize_sizes.
                                    If False, will randomly choose only value which will be the width and height of the images.
        """
        super().__init__(random_resize_sizes, random_resize_modes, random_aspect_ratio)
        self.expected_item_names = ("image", "targets", "crowd_targets")

    def __call__(self, data) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, torch.Tensor]]:
        if self.random_resize_sizes is not None:
            data = self.random_resize(data)

        try:
            images_batch, labels_batch, crowd_labels_batch = list(zip(*data))
        except (ValueError, TypeError):
            raise DatasetItemsException(data_sample=data[0], collate_type=type(self), expected_item_names=self.expected_item_names)

        return self._format_images(images_batch), self._format_targets(labels_batch), {"crowd_targets": self._format_targets(crowd_labels_batch)}

__init__(random_resize_sizes=None, random_resize_modes=None, random_aspect_ratio=False)

Parameters:

Name Type Description Default
random_resize_sizes Union[List[int], None]

List of sizes to randomly resize the image to. If None, will not resize.

None
random_resize_modes Union[List[int], None]

List of interpolation modes to randomly resize the image to. If None, will not resize.

None
random_aspect_ratio bool

If True, will randomly choose both width and height from random_resize_sizes. If False, will randomly choose only value which will be the width and height of the images.

False
Source code in src/super_gradients/training/utils/collate_fn/crowd_detection_ppyoloe_collate_fn.py
16
17
18
19
20
21
22
23
24
25
26
def __init__(
    self, random_resize_sizes: Union[List[int], None] = None, random_resize_modes: Union[List[int], None] = None, random_aspect_ratio: bool = False
):
    """
    :param random_resize_sizes: List of sizes to randomly resize the image to. If None, will not resize.
    :param random_resize_modes: List of interpolation modes to randomly resize the image to. If None, will not resize.
    :param random_aspect_ratio: If True, will randomly choose both width and height from random_resize_sizes.
                                If False, will randomly choose only value which will be the width and height of the images.
    """
    super().__init__(random_resize_sizes, random_resize_modes, random_aspect_ratio)
    self.expected_item_names = ("image", "targets", "crowd_targets")

DetectionCollateFN

Collate function for Yolox training

Source code in src/super_gradients/training/utils/collate_fn/detection_collate_fn.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
@register_collate_function()
class DetectionCollateFN:
    """
    Collate function for Yolox training
    """

    def __init__(self):
        self.expected_item_names = ("image", "targets")

    def __call__(self, data) -> Tuple[torch.Tensor, torch.Tensor]:
        try:
            images_batch, labels_batch = list(zip(*data))
        except (ValueError, TypeError):
            raise DatasetItemsException(data_sample=data[0], collate_type=type(self), expected_item_names=self.expected_item_names)

        return self._format_images(images_batch), self._format_targets(labels_batch)

    @staticmethod
    def _format_images(images_batch: List[Union[torch.Tensor, np.array]]) -> torch.Tensor:
        images_batch = [torch.tensor(img) for img in images_batch]
        images_batch_stack = torch.stack(images_batch, 0)
        if images_batch_stack.shape[3] == 3:
            images_batch_stack = torch.moveaxis(images_batch_stack, -1, 1).float()
        return images_batch_stack

    @staticmethod
    def _format_targets(labels_batch: List[Union[torch.Tensor, np.array]]) -> torch.Tensor:
        """
        Stack a batch id column to targets and concatenate
        :param labels_batch: a list of targets per image (each of arbitrary length)
        :return: one tensor of targets of all images of shape [N, 6], where N is the total number of targets in a batch
                 and the 1st column is batch item index
        """
        labels_batch = [torch.tensor(labels) for labels in labels_batch]
        labels_batch_indexed = []
        for i, labels in enumerate(labels_batch):
            batch_column = labels.new_ones((labels.shape[0], 1)) * i
            labels = torch.cat((batch_column, labels), dim=-1)
            labels_batch_indexed.append(labels)
        return torch.cat(labels_batch_indexed, 0)

PPYoloECollateFN

Bases: DetectionCollateFN

Collate function for PPYoloE training

Source code in src/super_gradients/training/utils/collate_fn/ppyoloe_collate_fn.py
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
@register_collate_function()
class PPYoloECollateFN(DetectionCollateFN):
    """
    Collate function for PPYoloE training
    """

    def __init__(
        self,
        random_resize_sizes: Union[List[int], None] = None,
        random_resize_modes: Union[List[int], None] = None,
        random_aspect_ratio: Union[bool, Tuple[float, float]] = False,
    ):
        """
        :param random_resize_sizes: List of single image size dimensions to use for sampling the output image size.
                                    If None, random resizing will not be applied.
                                    If not None, will randomly sample output shape for entire batch:
                                    [B, C, random.choice(random_resize_sizes), random.choice(random_resize_sizes)]
                                    The values in random_resize_sizes should be compatible with the model.
                                    Example: If the model requires input size to be divisible by 32 then all values in `random_resize_sizes`
                                    should be divisible by 32.

        :param random_resize_modes: List of interpolation modes to randomly resize the image to. If None, will not resize.
                                    Interpolation modes correspond to OpenCV interpolation modes:
                                    0 - INTER_NEAREST
                                    1 - INTER_LINEAR
                                    2 - INTER_CUBIC
                                    3 - INTER_AREA
                                    4 - INTER_LANCZOS4
                                    If None defaults to linear interpolation.

        :param random_aspect_ratio: If True, will randomly choose both width and height from random_resize_sizes.
                                    If False, will randomly choose only value which will be the width and height of the images.
                                    If tuple (min_aspect_ratio, max_aspect_ratio), will guarantee that sampled width and height
                                    satisfy required aspect ratio range.
        """
        super().__init__()
        if random_resize_sizes is not None:
            # All possible combinations
            random_resize_sizes = np.array(list(itertools.product(random_resize_sizes, random_resize_sizes)))  # [N, 2]
            if random_aspect_ratio is False:
                # Leave only square sizes
                random_resize_sizes = random_resize_sizes[random_resize_sizes[:, 0] == random_resize_sizes[:, 1]]
            elif random_aspect_ratio is True:
                # No action needed here
                pass
            elif isinstance(random_aspect_ratio, typing.Iterable):
                min_aspect_ratio, max_aspect_ratio = random_aspect_ratio
                if min_aspect_ratio > max_aspect_ratio:
                    raise ValueError(f"min_aspect_ratio: {min_aspect_ratio} must be smaller than max_aspect_ratio: {max_aspect_ratio}")

                # Leave only size combinations with aspect ratio in the given range
                aspect_ratios = random_resize_sizes[:, 0] / random_resize_sizes[:, 1]
                random_resize_sizes = random_resize_sizes[(aspect_ratios >= min_aspect_ratio) & (aspect_ratios <= max_aspect_ratio)]

                if len(random_resize_sizes) == 0:
                    raise ValueError(
                        f"Given random_aspect_ratio value: {random_aspect_ratio} leaves no valid size combinations. Please adjust random_aspect_ratio range."
                    )
            else:
                raise ValueError(f"Unsupported random_aspect_ratio value: {random_aspect_ratio}")
        self.random_resize_sizes = random_resize_sizes
        self.random_resize_modes = list(random_resize_modes) if random_resize_modes is not None else [1]  # Default to linear interpolation

    def __repr__(self):
        return f"PPYoloECollateFN(random_resize_sizes={self.random_resize_sizes}, random_resize_modes={self.random_resize_modes})"

    def __str__(self):
        return self.__repr__()

    def __call__(self, data) -> Tuple[torch.Tensor, torch.Tensor]:
        if self.random_resize_sizes is not None:
            data = self.random_resize(data)
        return super().__call__(data)

    def random_resize(self, batch):
        target_size = random.choice(self.random_resize_sizes)
        interpolation = random.choice(self.random_resize_modes)
        batch = [self.random_resize_sample(sample, target_size, interpolation) for sample in batch]
        return batch

    def random_resize_sample(self, sample, target_size: Tuple[int, int], interpolation: int):
        if len(sample) == 2:
            image, targets = sample  # TARGETS ARE IN LABEL_CXCYWH
            with_crowd = False
        elif len(sample) == 3:
            image, targets, crowd_targets = sample
            with_crowd = True
        else:
            raise DatasetItemsException(data_sample=sample, collate_type=type(self), expected_item_names=self.expected_item_names)

        target_width, target_height = target_size
        dsize = int(target_width), int(target_height)
        scale_factors = target_height / image.shape[0], target_width / image.shape[1]

        image = cv2.resize(
            image,
            dsize=dsize,
            interpolation=interpolation,
        )

        sy, sx = scale_factors
        targets[:, 1:5] *= np.array([[sx, sy, sx, sy]], dtype=targets.dtype)
        if with_crowd:
            crowd_targets[:, 1:5] *= np.array([[sx, sy, sx, sy]], dtype=targets.dtype)
            return image, targets, crowd_targets

        return image, targets

__init__(random_resize_sizes=None, random_resize_modes=None, random_aspect_ratio=False)

Parameters:

Name Type Description Default
random_resize_sizes Union[List[int], None]

List of single image size dimensions to use for sampling the output image size. If None, random resizing will not be applied. If not None, will randomly sample output shape for entire batch: [B, C, random.choice(random_resize_sizes), random.choice(random_resize_sizes)] The values in random_resize_sizes should be compatible with the model. Example: If the model requires input size to be divisible by 32 then all values in random_resize_sizes should be divisible by 32.

None
random_resize_modes Union[List[int], None]

List of interpolation modes to randomly resize the image to. If None, will not resize. Interpolation modes correspond to OpenCV interpolation modes: 0 - INTER_NEAREST 1 - INTER_LINEAR 2 - INTER_CUBIC 3 - INTER_AREA 4 - INTER_LANCZOS4 If None defaults to linear interpolation.

None
random_aspect_ratio Union[bool, Tuple[float, float]]

If True, will randomly choose both width and height from random_resize_sizes. If False, will randomly choose only value which will be the width and height of the images. If tuple (min_aspect_ratio, max_aspect_ratio), will guarantee that sampled width and height satisfy required aspect ratio range.

False
Source code in src/super_gradients/training/utils/collate_fn/ppyoloe_collate_fn.py
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
def __init__(
    self,
    random_resize_sizes: Union[List[int], None] = None,
    random_resize_modes: Union[List[int], None] = None,
    random_aspect_ratio: Union[bool, Tuple[float, float]] = False,
):
    """
    :param random_resize_sizes: List of single image size dimensions to use for sampling the output image size.
                                If None, random resizing will not be applied.
                                If not None, will randomly sample output shape for entire batch:
                                [B, C, random.choice(random_resize_sizes), random.choice(random_resize_sizes)]
                                The values in random_resize_sizes should be compatible with the model.
                                Example: If the model requires input size to be divisible by 32 then all values in `random_resize_sizes`
                                should be divisible by 32.

    :param random_resize_modes: List of interpolation modes to randomly resize the image to. If None, will not resize.
                                Interpolation modes correspond to OpenCV interpolation modes:
                                0 - INTER_NEAREST
                                1 - INTER_LINEAR
                                2 - INTER_CUBIC
                                3 - INTER_AREA
                                4 - INTER_LANCZOS4
                                If None defaults to linear interpolation.

    :param random_aspect_ratio: If True, will randomly choose both width and height from random_resize_sizes.
                                If False, will randomly choose only value which will be the width and height of the images.
                                If tuple (min_aspect_ratio, max_aspect_ratio), will guarantee that sampled width and height
                                satisfy required aspect ratio range.
    """
    super().__init__()
    if random_resize_sizes is not None:
        # All possible combinations
        random_resize_sizes = np.array(list(itertools.product(random_resize_sizes, random_resize_sizes)))  # [N, 2]
        if random_aspect_ratio is False:
            # Leave only square sizes
            random_resize_sizes = random_resize_sizes[random_resize_sizes[:, 0] == random_resize_sizes[:, 1]]
        elif random_aspect_ratio is True:
            # No action needed here
            pass
        elif isinstance(random_aspect_ratio, typing.Iterable):
            min_aspect_ratio, max_aspect_ratio = random_aspect_ratio
            if min_aspect_ratio > max_aspect_ratio:
                raise ValueError(f"min_aspect_ratio: {min_aspect_ratio} must be smaller than max_aspect_ratio: {max_aspect_ratio}")

            # Leave only size combinations with aspect ratio in the given range
            aspect_ratios = random_resize_sizes[:, 0] / random_resize_sizes[:, 1]
            random_resize_sizes = random_resize_sizes[(aspect_ratios >= min_aspect_ratio) & (aspect_ratios <= max_aspect_ratio)]

            if len(random_resize_sizes) == 0:
                raise ValueError(
                    f"Given random_aspect_ratio value: {random_aspect_ratio} leaves no valid size combinations. Please adjust random_aspect_ratio range."
                )
        else:
            raise ValueError(f"Unsupported random_aspect_ratio value: {random_aspect_ratio}")
    self.random_resize_sizes = random_resize_sizes
    self.random_resize_modes = list(random_resize_modes) if random_resize_modes is not None else [1]  # Default to linear interpolation

AccessCounterMixin

Implements access counting mechanism for configuration settings (dicts/lists). It is achieved by wrapping underlying config and override getitem, getattr methods to catch read operations and increments access counter for each property.

Source code in src/super_gradients/training/utils/config_utils.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
class AccessCounterMixin:
    """
    Implements access counting mechanism for configuration settings (dicts/lists).
    It is achieved by wrapping underlying config and override __getitem__, __getattr__ methods to catch read operations
    and increments access counter for each property.
    """

    _access_counter: Mapping[str, int]
    _prefix: str  # Prefix string

    def maybe_wrap_as_counter(self, value, key, count_usage: bool = True):
        """
        Return an attribute value optionally wrapped as access counter adapter to trace read counts.

        :param value: Attribute value
        :param key: Attribute name
        :param count_usage: Whether increment usage count for given attribute. Default is True.

        :return: wrapped value
        """
        key_with_prefix = self._prefix + str(key)
        if count_usage:
            self._access_counter[key_with_prefix] += 1
        if isinstance(value, Mapping):
            return AccessCounterDict(value, access_counter=self._access_counter, prefix=key_with_prefix + ".")
        if isinstance(value, Iterable) and not isinstance(value, str):
            return AccessCounterList(value, access_counter=self._access_counter, prefix=key_with_prefix + ".")
        return value

    @property
    def access_counter(self):
        return self._access_counter

    @abc.abstractmethod
    def get_all_params(self) -> Set[str]:
        raise NotImplementedError()

    def get_used_params(self) -> Set[str]:
        used_params = {k for (k, v) in self._access_counter.items() if v > 0}
        return used_params

    def get_unused_params(self) -> Set[str]:
        unused_params = self.get_all_params() - self.get_used_params()
        return unused_params

    def __copy__(self):
        cls = self.__class__
        result = cls.__new__(cls)
        result.__dict__.update(self.__dict__)
        return result

    def __deepcopy__(self, memo):
        cls = self.__class__
        result = cls.__new__(cls)
        memo[id(self)] = result
        for k, v in self.__dict__.items():
            setattr(result, k, deepcopy(v, memo))
        return result

maybe_wrap_as_counter(value, key, count_usage=True)

Return an attribute value optionally wrapped as access counter adapter to trace read counts.

Parameters:

Name Type Description Default
value

Attribute value

required
key

Attribute name

required
count_usage bool

Whether increment usage count for given attribute. Default is True.

True

Returns:

Type Description

wrapped value

Source code in src/super_gradients/training/utils/config_utils.py
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
def maybe_wrap_as_counter(self, value, key, count_usage: bool = True):
    """
    Return an attribute value optionally wrapped as access counter adapter to trace read counts.

    :param value: Attribute value
    :param key: Attribute name
    :param count_usage: Whether increment usage count for given attribute. Default is True.

    :return: wrapped value
    """
    key_with_prefix = self._prefix + str(key)
    if count_usage:
        self._access_counter[key_with_prefix] += 1
    if isinstance(value, Mapping):
        return AccessCounterDict(value, access_counter=self._access_counter, prefix=key_with_prefix + ".")
    if isinstance(value, Iterable) and not isinstance(value, str):
        return AccessCounterList(value, access_counter=self._access_counter, prefix=key_with_prefix + ".")
    return value

raise_if_unused_params(config)

A helper function to check whether all confuration parameters were used on given block of code. Motivation to have this check is to ensure there were no typo or outdated configuration parameters. It at least one of config parameters was not used, this function will raise an UnusedConfigParamException exception. Example usage:

from super_gradients.training.utils import raise_if_unused_params

with raise_if_unused_params(some_config) as some_config: do_something_with_config(some_config)

Parameters:

Name Type Description Default
config Union[HpmStruct, DictConfig, ListConfig, Mapping, list, tuple]

A config to check

required

Returns:

Type Description
ConfigInspector

An instance of ConfigInspector

Source code in src/super_gradients/training/utils/config_utils.py
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
def raise_if_unused_params(config: Union[HpmStruct, DictConfig, ListConfig, Mapping, list, tuple]) -> ConfigInspector:
    """
    A helper function to check whether all confuration parameters were used on given block of code. Motivation to have
    this check is to ensure there were no typo or outdated configuration parameters.
    It at least one of config parameters was not used, this function will raise an UnusedConfigParamException exception.
    Example usage:

    >>> from super_gradients.training.utils import raise_if_unused_params
    >>>
    >>> with raise_if_unused_params(some_config) as some_config:
    >>>    do_something_with_config(some_config)
    >>>

    :param config: A config to check
    :return: An instance of ConfigInspector
    """
    if isinstance(config, HpmStruct):
        wrapper_cls = AccessCounterHpmStruct
    elif isinstance(config, (Mapping, DictConfig)):
        wrapper_cls = AccessCounterDict
    elif isinstance(config, (list, tuple, ListConfig)):
        wrapper_cls = AccessCounterList
    else:
        raise RuntimeError(f"Unsupported type. Root configuration object must be a mapping or list. Got type {type(config)}")

    return ConfigInspector(wrapper_cls(config), unused_params_action="raise")

warn_if_unused_params(config)

A helper function to check whether all confuration parameters were used on given block of code. Motivation to have this check is to ensure there were no typo or outdated configuration parameters. It at least one of config parameters was not used, this function will emit warning. Example usage:

from super_gradients.training.utils import warn_if_unused_params

with warn_if_unused_params(some_config) as some_config: do_something_with_config(some_config)

Parameters:

Name Type Description Default
config

A config to check

required

Returns:

Type Description

An instance of ConfigInspector

Source code in src/super_gradients/training/utils/config_utils.py
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
def warn_if_unused_params(config):
    """
    A helper function to check whether all confuration parameters were used on given block of code. Motivation to have
    this check is to ensure there were no typo or outdated configuration parameters.
    It at least one of config parameters was not used, this function will emit warning.
    Example usage:

    >>> from super_gradients.training.utils import warn_if_unused_params
    >>>
    >>> with warn_if_unused_params(some_config) as some_config:
    >>>    do_something_with_config(some_config)
    >>>

    :param config: A config to check
    :return: An instance of ConfigInspector
    """
    if isinstance(config, HpmStruct):
        wrapper_cls = AccessCounterHpmStruct
    elif isinstance(config, (Mapping, DictConfig)):
        wrapper_cls = AccessCounterDict
    elif isinstance(config, (list, tuple, ListConfig)):
        wrapper_cls = AccessCounterList
    else:
        raise RuntimeError("Unsupported type. Root configuration object must be a mapping or list.")

    return ConfigInspector(wrapper_cls(config), unused_params_action="warn")

wrap_with_warning(cls, message)

Emits a warning when target class of function is called.

from super_gradients.training.utils.deprecated_utils import wrap_with_warning from super_gradients.training.utils.callbacks import LinearEpochLRWarmup, LinearBatchLRWarmup

LR_WARMUP_CLS_DICT = { "linear": wrap_with_warning( LinearEpochLRWarmup, message=f"Parameter linear has been made deprecated and will be removed in the next SG release. Please use linear_epoch instead", ), 'linear_epoch`': LinearEpochLRWarmup, }

Parameters:

Name Type Description Default
cls Callable

A class or function to wrap

required
message str

A message to emit when this class is called

required

Returns:

Type Description
Any

A factory method that returns wrapped class

Source code in src/super_gradients/training/utils/deprecated_utils.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
def wrap_with_warning(cls: Callable, message: str) -> Any:
    """
    Emits a warning when target class of function is called.

    >>> from super_gradients.training.utils.deprecated_utils import wrap_with_warning
    >>> from super_gradients.training.utils.callbacks import LinearEpochLRWarmup, LinearBatchLRWarmup
    >>>
    >>> LR_WARMUP_CLS_DICT = {
    >>>     "linear": wrap_with_warning(
    >>>         LinearEpochLRWarmup,
    >>>         message=f"Parameter `linear` has been made deprecated and will be removed in the next SG release. Please use `linear_epoch` instead",
    >>>     ),
    >>>     'linear_epoch`': LinearEpochLRWarmup,
    >>> }

    :param cls: A class or function to wrap
    :param message: A message to emit when this class is called
    :return: A factory method that returns wrapped class
    """

    def _inner_fn(*args, **kwargs):
        logger.warning(message)
        return cls(*args, **kwargs)

    return _inner_fn

Anchors

A wrapper function to hold the anchors used by detection models such as Yolo

Source code in src/super_gradients/training/utils/detection_utils.py
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
class Anchors:
    """
    A wrapper function to hold the anchors used by detection models such as Yolo
    """

    def __init__(self, anchors_list: List[List], strides: List[int]):
        """
        :param anchors_list: of the shape [[w1,h1,w2,h2,w3,h3], [w4,h4,w5,h5,w6,h6] .... where each sublist holds
            the width and height of the anchors of a specific detection layer.
            i.e. for a model with 3 detection layers, each containing 5 anchors the format will be a of 3 sublists of 10 numbers each
            The width and height are in pixels (not relative to image size)
        :param strides: a list containing the stride of the layers from which the detection heads are fed.
            i.e. if the firs detection head is connected to the backbone after the input dimensions were reduces by 8, the first number will be 8
        """
        super().__init__()

        self.__anchors_list = anchors_list
        self.__strides = tuple(strides)

        self._check_all_lists(anchors_list)
        self._check_all_len_equal_and_even(anchors_list)

        self._stride = np.array(strides, dtype=np.float32)
        anchors = np.array(anchors_list, dtype=np.float32).reshape((len(anchors_list), -1, 2))
        self._anchors = anchors / self._stride.reshape((-1, 1, 1))
        self._anchor_grid = anchors.copy().reshape(len(anchors_list), 1, -1, 1, 1, 2)

    @staticmethod
    def _check_all_lists(anchors: list) -> bool:
        for a in anchors:
            if not isinstance(a, (list, ListConfig)):
                raise RuntimeError("All objects of anchors_list must be lists")

    @staticmethod
    def _check_all_len_equal_and_even(anchors: list) -> bool:
        len_of_first = len(anchors[0])
        for a in anchors:
            if len(a) % 2 == 1 or len(a) != len_of_first:
                raise RuntimeError("All objects of anchors_list must be of the same even length")

    @property
    def stride(self) -> np.ndarray:
        return self._stride

    @property
    def anchors(self) -> np.ndarray:
        return self._anchors

    @property
    def anchor_grid(self) -> np.ndarray:
        return self._anchor_grid

    @property
    def detection_layers_num(self) -> int:
        return self._anchors.shape[0]

    @property
    def num_anchors(self) -> int:
        return self._anchors.shape[1]

    def __repr__(self):
        return f"anchors_list: {self.__anchors_list} strides: {self.__strides}"

__init__(anchors_list, strides)

Parameters:

Name Type Description Default
anchors_list List[List]

of the shape [[w1,h1,w2,h2,w3,h3], [w4,h4,w5,h5,w6,h6] .... where each sublist holds the width and height of the anchors of a specific detection layer. i.e. for a model with 3 detection layers, each containing 5 anchors the format will be a of 3 sublists of 10 numbers each The width and height are in pixels (not relative to image size)

required
strides List[int]

a list containing the stride of the layers from which the detection heads are fed. i.e. if the firs detection head is connected to the backbone after the input dimensions were reduces by 8, the first number will be 8

required
Source code in src/super_gradients/training/utils/detection_utils.py
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
def __init__(self, anchors_list: List[List], strides: List[int]):
    """
    :param anchors_list: of the shape [[w1,h1,w2,h2,w3,h3], [w4,h4,w5,h5,w6,h6] .... where each sublist holds
        the width and height of the anchors of a specific detection layer.
        i.e. for a model with 3 detection layers, each containing 5 anchors the format will be a of 3 sublists of 10 numbers each
        The width and height are in pixels (not relative to image size)
    :param strides: a list containing the stride of the layers from which the detection heads are fed.
        i.e. if the firs detection head is connected to the backbone after the input dimensions were reduces by 8, the first number will be 8
    """
    super().__init__()

    self.__anchors_list = anchors_list
    self.__strides = tuple(strides)

    self._check_all_lists(anchors_list)
    self._check_all_len_equal_and_even(anchors_list)

    self._stride = np.array(strides, dtype=np.float32)
    anchors = np.array(anchors_list, dtype=np.float32).reshape((len(anchors_list), -1, 2))
    self._anchors = anchors / self._stride.reshape((-1, 1, 1))
    self._anchor_grid = anchors.copy().reshape(len(anchors_list), 1, -1, 1, 1, 2)

DetectionMatching

Bases: ABC

DetectionMatching is an abstract base class that defines the interface for matching detections in object detection models. It includes methods for computing targets for both regular and crowd scenarios, as well as getting thresholds for matching.

Source code in src/super_gradients/training/utils/detection_utils.py
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
class DetectionMatching(ABC):
    """
    DetectionMatching is an abstract base class that defines the interface for matching detections
    in object detection models. It includes methods for computing targets for both regular and crowd
    scenarios, as well as getting thresholds for matching.
    """

    @abstractmethod
    def get_thresholds(self) -> torch.Tensor:
        """
        Abstract method to get the thresholds used for detection matching.

        :return: (torch.Tensor) The thresholds used in the matching process.
        """
        pass

    @abstractmethod
    def compute_targets(
        self,
        preds_box_xyxy: torch.Tensor,
        preds_cls: torch.Tensor,
        targets_box_xyxy: torch.Tensor,
        targets_cls: torch.Tensor,
        preds_matched: torch.Tensor,
        targets_matched: torch.Tensor,
        preds_idx_to_use: torch.Tensor,
    ) -> torch.Tensor:
        """
        Abstract method to compute targets for regular scenarios.

        :param preds_box_xyxy: (torch.Tensor) Predicted bounding boxes in XYXY format.
        :param preds_cls: (torch.Tensor) Predicted classes.
        :param targets_box_xyxy: (torch.Tensor) Target bounding boxes in XYXY format.
        :param targets_cls: (torch.Tensor) Target classes.
        :param preds_matched: (torch.Tensor) Tensor indicating which predictions are matched.
        :param targets_matched: (torch.Tensor) Tensor indicating which targets are matched.
        :param preds_idx_to_use: (torch.Tensor) Indices of predictions to use.
        :return: (torch.Tensor) Computed targets.
        """
        pass

    @abstractmethod
    def compute_crowd_targets(
        self,
        preds_box_xyxy: torch.Tensor,
        preds_cls: torch.Tensor,
        crowd_targets_cls: torch.Tensor,
        crowd_target_box_xyxy: torch.Tensor,
        preds_matched: torch.Tensor,
        preds_to_ignore: torch.Tensor,
        preds_idx_to_use: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Abstract method to compute targets for crowd scenarios.

        :param preds_box_xyxy: (torch.Tensor) Predicted bounding boxes in XYXY format.
        :param preds_cls: (torch.Tensor) Predicted classes.
        :param crowd_targets_cls: (torch.Tensor) Crowd target classes.
        :param crowd_target_box_xyxy: (torch.Tensor) Crowd target bounding boxes in XYXY format.
        :param preds_matched: (torch.Tensor) Tensor indicating which predictions are matched.
        :param preds_to_ignore: (torch.Tensor) Tensor indicating which predictions to ignore.
        :param preds_idx_to_use: (torch.Tensor) Indices of predictions to use.
        :return: (Tuple[torch.Tensor, torch.Tensor]) Computed targets for crowd scenarios.
        """
        pass

compute_crowd_targets(preds_box_xyxy, preds_cls, crowd_targets_cls, crowd_target_box_xyxy, preds_matched, preds_to_ignore, preds_idx_to_use) abstractmethod

Abstract method to compute targets for crowd scenarios.

Parameters:

Name Type Description Default
preds_box_xyxy torch.Tensor

(torch.Tensor) Predicted bounding boxes in XYXY format.

required
preds_cls torch.Tensor

(torch.Tensor) Predicted classes.

required
crowd_targets_cls torch.Tensor

(torch.Tensor) Crowd target classes.

required
crowd_target_box_xyxy torch.Tensor

(torch.Tensor) Crowd target bounding boxes in XYXY format.

required
preds_matched torch.Tensor

(torch.Tensor) Tensor indicating which predictions are matched.

required
preds_to_ignore torch.Tensor

(torch.Tensor) Tensor indicating which predictions to ignore.

required
preds_idx_to_use torch.Tensor

(torch.Tensor) Indices of predictions to use.

required

Returns:

Type Description
Tuple[torch.Tensor, torch.Tensor]

(Tuple[torch.Tensor, torch.Tensor]) Computed targets for crowd scenarios.

Source code in src/super_gradients/training/utils/detection_utils.py
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
@abstractmethod
def compute_crowd_targets(
    self,
    preds_box_xyxy: torch.Tensor,
    preds_cls: torch.Tensor,
    crowd_targets_cls: torch.Tensor,
    crowd_target_box_xyxy: torch.Tensor,
    preds_matched: torch.Tensor,
    preds_to_ignore: torch.Tensor,
    preds_idx_to_use: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Abstract method to compute targets for crowd scenarios.

    :param preds_box_xyxy: (torch.Tensor) Predicted bounding boxes in XYXY format.
    :param preds_cls: (torch.Tensor) Predicted classes.
    :param crowd_targets_cls: (torch.Tensor) Crowd target classes.
    :param crowd_target_box_xyxy: (torch.Tensor) Crowd target bounding boxes in XYXY format.
    :param preds_matched: (torch.Tensor) Tensor indicating which predictions are matched.
    :param preds_to_ignore: (torch.Tensor) Tensor indicating which predictions to ignore.
    :param preds_idx_to_use: (torch.Tensor) Indices of predictions to use.
    :return: (Tuple[torch.Tensor, torch.Tensor]) Computed targets for crowd scenarios.
    """
    pass

compute_targets(preds_box_xyxy, preds_cls, targets_box_xyxy, targets_cls, preds_matched, targets_matched, preds_idx_to_use) abstractmethod

Abstract method to compute targets for regular scenarios.

Parameters:

Name Type Description Default
preds_box_xyxy torch.Tensor

(torch.Tensor) Predicted bounding boxes in XYXY format.

required
preds_cls torch.Tensor

(torch.Tensor) Predicted classes.

required
targets_box_xyxy torch.Tensor

(torch.Tensor) Target bounding boxes in XYXY format.

required
targets_cls torch.Tensor

(torch.Tensor) Target classes.

required
preds_matched torch.Tensor

(torch.Tensor) Tensor indicating which predictions are matched.

required
targets_matched torch.Tensor

(torch.Tensor) Tensor indicating which targets are matched.

required
preds_idx_to_use torch.Tensor

(torch.Tensor) Indices of predictions to use.

required

Returns:

Type Description
torch.Tensor

(torch.Tensor) Computed targets.

Source code in src/super_gradients/training/utils/detection_utils.py
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
@abstractmethod
def compute_targets(
    self,
    preds_box_xyxy: torch.Tensor,
    preds_cls: torch.Tensor,
    targets_box_xyxy: torch.Tensor,
    targets_cls: torch.Tensor,
    preds_matched: torch.Tensor,
    targets_matched: torch.Tensor,
    preds_idx_to_use: torch.Tensor,
) -> torch.Tensor:
    """
    Abstract method to compute targets for regular scenarios.

    :param preds_box_xyxy: (torch.Tensor) Predicted bounding boxes in XYXY format.
    :param preds_cls: (torch.Tensor) Predicted classes.
    :param targets_box_xyxy: (torch.Tensor) Target bounding boxes in XYXY format.
    :param targets_cls: (torch.Tensor) Target classes.
    :param preds_matched: (torch.Tensor) Tensor indicating which predictions are matched.
    :param targets_matched: (torch.Tensor) Tensor indicating which targets are matched.
    :param preds_idx_to_use: (torch.Tensor) Indices of predictions to use.
    :return: (torch.Tensor) Computed targets.
    """
    pass

get_thresholds() abstractmethod

Abstract method to get the thresholds used for detection matching.

Returns:

Type Description
torch.Tensor

(torch.Tensor) The thresholds used in the matching process.

Source code in src/super_gradients/training/utils/detection_utils.py
820
821
822
823
824
825
826
827
@abstractmethod
def get_thresholds(self) -> torch.Tensor:
    """
    Abstract method to get the thresholds used for detection matching.

    :return: (torch.Tensor) The thresholds used in the matching process.
    """
    pass

DetectionPostPredictionCallback

Bases: ABC, nn.Module

Source code in src/super_gradients/training/utils/detection_utils.py
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
class DetectionPostPredictionCallback(ABC, nn.Module):
    def __init__(self) -> None:
        super().__init__()

    @abstractmethod
    def forward(self, x, device: str = None):
        """

        :param x:       the output of your model
        :param device:  (Deprecated) Not used anymore, exists only for sake of keeping the same interface as in the parent class.
                        Will be removed in the SG 3.7.0.
                        A device parameter in case we want to move tensors to a specific device.
        :return:        a list with length batch_size, each item in the list is a detections
                        with shape: nx6 (x1, y1, x2, y2, confidence, class) where x and y are in range [0,1]
        """
        raise NotImplementedError

forward(x, device=None) abstractmethod

Parameters:

Name Type Description Default
x

the output of your model

required
device str

(Deprecated) Not used anymore, exists only for sake of keeping the same interface as in the parent class. Will be removed in the SG 3.7.0. A device parameter in case we want to move tensors to a specific device.

None

Returns:

Type Description

a list with length batch_size, each item in the list is a detections with shape: nx6 (x1, y1, x2, y2, confidence, class) where x and y are in range [0,1]

Source code in src/super_gradients/training/utils/detection_utils.py
217
218
219
220
221
222
223
224
225
226
227
228
@abstractmethod
def forward(self, x, device: str = None):
    """

    :param x:       the output of your model
    :param device:  (Deprecated) Not used anymore, exists only for sake of keeping the same interface as in the parent class.
                    Will be removed in the SG 3.7.0.
                    A device parameter in case we want to move tensors to a specific device.
    :return:        a list with length batch_size, each item in the list is a detections
                    with shape: nx6 (x1, y1, x2, y2, confidence, class) where x and y are in range [0,1]
    """
    raise NotImplementedError

DetectionTargetsFormat

Bases: Enum

Enum class for the different detection output formats

When NORMALIZED is not specified- the type refers to unnormalized image coordinates (of the bboxes).

For example: LABEL_NORMALIZED_XYXY means [class_idx,x1,y1,x2,y2]

Source code in src/super_gradients/training/utils/detection_utils.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
class DetectionTargetsFormat(Enum):
    """
    Enum class for the different detection output formats

    When NORMALIZED is not specified- the type refers to unnormalized image coordinates (of the bboxes).

    For example:
    LABEL_NORMALIZED_XYXY means [class_idx,x1,y1,x2,y2]
    """

    LABEL_XYXY = "LABEL_XYXY"
    XYXY_LABEL = "XYXY_LABEL"
    LABEL_NORMALIZED_XYXY = "LABEL_NORMALIZED_XYXY"
    NORMALIZED_XYXY_LABEL = "NORMALIZED_XYXY_LABEL"
    LABEL_CXCYWH = "LABEL_CXCYWH"
    CXCYWH_LABEL = "CXCYWH_LABEL"
    LABEL_NORMALIZED_CXCYWH = "LABEL_NORMALIZED_CXCYWH"
    NORMALIZED_CXCYWH_LABEL = "NORMALIZED_CXCYWH_LABEL"

DetectionVisualization

Source code in src/super_gradients/training/utils/detection_utils.py
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
class DetectionVisualization:
    @staticmethod
    def _generate_color_mapping(num_classes: int) -> List[Tuple[int]]:
        """
        Generate a unique BGR color for each class
        """

        return generate_color_mapping(num_classes=num_classes)

    @staticmethod
    def draw_box_title(
        color_mapping: List[Tuple[int]],
        class_names: List[str],
        box_thickness: Optional[int],
        image_np: np.ndarray,
        x1: int,
        y1: int,
        x2: int,
        y2: int,
        class_id: int,
        pred_conf: float = None,
        bbox_prefix: str = "",
    ):
        """
        Draw a rectangle with class name, confidence on the image
        :param color_mapping: A list of N RGB colors for each class
        :param class_names: A list of N class names
        :param box_thickness: Thickness of the bounding box (in pixels)
        :param image_np: Image in RGB format (H, W, C) where to draw the bounding box
        :param x1: X coordinate of the top left corner of the bounding box
        :param y1: Y coordinate of the top left corner of the bounding box
        :param x2: X coordinate of the bottom right corner of the bounding box
        :param y2: Y coordinate of the bottom right corner of the bounding box
        :param class_id: A corresponding class id
        :param pred_conf: Class confidence score (optional)
        :param bbox_prefix: Prefix to add to the title of the bounding boxes
        """
        color = color_mapping[class_id]
        class_name = class_names[class_id]

        title = class_name
        if bbox_prefix:
            title = f"{bbox_prefix} {class_name}"
        if pred_conf is not None:
            title = f"{title} {str(round(pred_conf, 2))}"

        image_np = draw_bbox(image=image_np, title=title, x1=x1, y1=y1, x2=x2, y2=y2, box_thickness=box_thickness, color=color)
        return image_np

    @staticmethod
    def _visualize_image(
        image_np: np.ndarray,
        pred_boxes: np.ndarray,
        target_boxes: np.ndarray,
        class_names: List[str],
        box_thickness: Optional[int],
        gt_alpha: float,
        image_scale: float,
        checkpoint_dir: str,
        image_name: str,
    ):
        return DetectionVisualization.visualize_image(
            image_np=image_np,
            pred_boxes=pred_boxes,
            target_boxes=target_boxes,
            class_names=class_names,
            box_thickness=box_thickness,
            gt_alpha=gt_alpha,
            image_scale=image_scale,
            checkpoint_dir=checkpoint_dir,
            image_name=image_name,
        )

    @staticmethod
    def visualize_image(
        image_np: np.ndarray,
        class_names: List[str],
        target_boxes: Optional[np.ndarray] = None,
        pred_boxes: Optional[np.ndarray] = None,
        box_thickness: Optional[int] = 2,
        gt_alpha: float = 0.6,
        image_scale: float = 1.0,
        checkpoint_dir: Optional[str] = None,
        image_name: Optional[str] = None,
    ):
        image_np = cv2.resize(image_np, (0, 0), fx=image_scale, fy=image_scale, interpolation=cv2.INTER_NEAREST)
        color_mapping = DetectionVisualization._generate_color_mapping(len(class_names))

        if pred_boxes is not None:
            # Draw predictions
            pred_boxes[:, :4] *= image_scale
            for xyxy_score_label in pred_boxes:
                image_np = DetectionVisualization.draw_box_title(
                    color_mapping=color_mapping,
                    class_names=class_names,
                    box_thickness=box_thickness,
                    image_np=image_np,
                    x1=int(xyxy_score_label[0]),
                    y1=int(xyxy_score_label[1]),
                    x2=int(xyxy_score_label[2]),
                    y2=int(xyxy_score_label[3]),
                    class_id=int(xyxy_score_label[5]),
                    pred_conf=float(xyxy_score_label[4]),
                    bbox_prefix="[Pred]" if target_boxes is not None else "",  # If we have TARGETS, we want to add a prefix to distinguish.
                )

        if target_boxes is not None:
            # If gt_alpha is set, we will show it as a transparent overlay.
            if gt_alpha is not None:
                # Transparent overlay of ground truth boxes
                image_with_targets = np.zeros_like(image_np, np.uint8)
            else:
                image_with_targets = image_np

            for label_xyxy in target_boxes:
                image_with_targets = DetectionVisualization.draw_box_title(
                    color_mapping=color_mapping,
                    class_names=class_names,
                    box_thickness=box_thickness,
                    image_np=image_with_targets,
                    x1=int(label_xyxy[1]),
                    y1=int(label_xyxy[2]),
                    x2=int(label_xyxy[3]),
                    y2=int(label_xyxy[4]),
                    class_id=int(label_xyxy[0]),
                    bbox_prefix="[GT]" if pred_boxes is not None else "",  # If we have PREDICTIONS, we want to add a prefix to distinguish.
                )

            if gt_alpha is not None:
                # Transparent overlay of ground truth boxes
                mask = image_with_targets.astype(bool)
                image_np[mask] = cv2.addWeighted(image_np, 1 - gt_alpha, image_with_targets, gt_alpha, 0)[mask]
            else:
                image_np = image_with_targets

        if checkpoint_dir is None:
            return image_np
        else:
            pathlib.Path(checkpoint_dir).mkdir(parents=True, exist_ok=True)
            cv2.imwrite(os.path.join(checkpoint_dir, str(image_name) + ".jpg"), image_np)

    @staticmethod
    def _scaled_ccwh_to_xyxy(target_boxes: np.ndarray, h: int, w: int, image_scale: float) -> np.ndarray:
        """
        Modifies target_boxes inplace
        :param target_boxes:    (c1, c2, w, h) boxes in [0, 1] range
        :param h:               image height
        :param w:               image width
        :param image_scale:     desired scale for the boxes w.r.t. w and h
        :return:                targets in (x1, y1, x2, y2) format
                                in range [0, w * self.image_scale] [0, h * self.image_scale]
        """
        # unscale
        target_boxes[:, 2:] *= np.array([[w, h, w, h]])

        # x1 = c1 - w // 2; y1 = c2 - h // 2
        target_boxes[:, 2] -= target_boxes[:, 4] // 2
        target_boxes[:, 3] -= target_boxes[:, 5] // 2
        # x2 = w + x1; y2 = h + y1
        target_boxes[:, 4] += target_boxes[:, 2]
        target_boxes[:, 5] += target_boxes[:, 3]

        target_boxes[:, 2:] *= image_scale
        target_boxes = target_boxes.astype(int)
        return target_boxes

    @staticmethod
    def visualize_batch(
        image_tensor: torch.Tensor,
        pred_boxes: List[torch.Tensor],
        target_boxes: torch.Tensor,
        batch_name: Union[int, str],
        class_names: List[str],
        checkpoint_dir: str = None,
        undo_preprocessing_func: Callable[[torch.Tensor], np.ndarray] = undo_image_preprocessing,
        box_thickness: Optional[int] = None,
        image_scale: float = 1.0,
        gt_alpha: float = 0.4,
    ):
        """
        A helper function to visualize detections predicted by a network:
        saves images into a given path with a name that is {batch_name}_{imade_idx_in_the_batch}.jpg, one batch per call.
        Colors are generated on the fly: uniformly sampled from color wheel to support all given classes.

        Adjustable:
            * Ground truth box transparency;
            * Box width;
            * Image size (larger or smaller than what's provided)

        :param image_tensor:            rgb images, (B, H, W, 3)
        :param pred_boxes:              boxes after NMS for each image in a batch, each (Num_boxes, 6),
                                        values on dim 1 are: x1, y1, x2, y2, confidence, class
        :param target_boxes:            (Num_targets, 6), values on dim 1 are: image id in a batch, class, cx cy w h
                                        (coordinates scaled to [0, 1])
        :param batch_name:              id of the current batch to use for image naming

        :param class_names:             names of all classes, each on its own index
        :param checkpoint_dir:          a path where images with boxes will be saved. if None, the result images will
                                        be returns as a list of numpy image arrays

        :param undo_preprocessing_func: a function to convert preprocessed images tensor into a batch of cv2-like images
        :param box_thickness:           box line thickness in px
        :param image_scale:             scale of an image w.r.t. given image size,
                                        e.g. incoming images are (320x320), use scale = 2. to preview in (640x640)
        :param gt_alpha:                a value in [0., 1.] transparency on ground truth boxes,
                                        0 for invisible, 1 for fully opaque
        """
        image_np = undo_preprocessing_func(image_tensor.detach())
        targets = DetectionVisualization._scaled_ccwh_to_xyxy(target_boxes.detach().cpu().numpy().copy(), *image_np.shape[1:3], image_scale)
        if pred_boxes is None:
            pred_boxes = [None for _ in range(image_np.shape[0])]

        out_images = []
        for i in range(image_np.shape[0]):
            preds = pred_boxes[i].detach().cpu().numpy() if pred_boxes[i] is not None else np.empty((0, 6))
            targets_cur = targets[targets[:, 0] == i]

            image_name = "_".join([str(batch_name), str(i)])
            res_image = DetectionVisualization._visualize_image(
                image_np=image_np[i],
                pred_boxes=preds,
                target_boxes=targets_cur,
                class_names=class_names,
                box_thickness=box_thickness,
                gt_alpha=gt_alpha,
                image_scale=image_scale,
                checkpoint_dir=checkpoint_dir,
                image_name=image_name,
            )
            if res_image is not None:
                out_images.append(res_image)

        return out_images

draw_box_title(color_mapping, class_names, box_thickness, image_np, x1, y1, x2, y2, class_id, pred_conf=None, bbox_prefix='') staticmethod

Draw a rectangle with class name, confidence on the image

Parameters:

Name Type Description Default
color_mapping List[Tuple[int]]

A list of N RGB colors for each class

required
class_names List[str]

A list of N class names

required
box_thickness Optional[int]

Thickness of the bounding box (in pixels)

required
image_np np.ndarray

Image in RGB format (H, W, C) where to draw the bounding box

required
x1 int

X coordinate of the top left corner of the bounding box

required
y1 int

Y coordinate of the top left corner of the bounding box

required
x2 int

X coordinate of the bottom right corner of the bounding box

required
y2 int

Y coordinate of the bottom right corner of the bounding box

required
class_id int

A corresponding class id

required
pred_conf float

Class confidence score (optional)

None
bbox_prefix str

Prefix to add to the title of the bounding boxes

''
Source code in src/super_gradients/training/utils/detection_utils.py
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
@staticmethod
def draw_box_title(
    color_mapping: List[Tuple[int]],
    class_names: List[str],
    box_thickness: Optional[int],
    image_np: np.ndarray,
    x1: int,
    y1: int,
    x2: int,
    y2: int,
    class_id: int,
    pred_conf: float = None,
    bbox_prefix: str = "",
):
    """
    Draw a rectangle with class name, confidence on the image
    :param color_mapping: A list of N RGB colors for each class
    :param class_names: A list of N class names
    :param box_thickness: Thickness of the bounding box (in pixels)
    :param image_np: Image in RGB format (H, W, C) where to draw the bounding box
    :param x1: X coordinate of the top left corner of the bounding box
    :param y1: Y coordinate of the top left corner of the bounding box
    :param x2: X coordinate of the bottom right corner of the bounding box
    :param y2: Y coordinate of the bottom right corner of the bounding box
    :param class_id: A corresponding class id
    :param pred_conf: Class confidence score (optional)
    :param bbox_prefix: Prefix to add to the title of the bounding boxes
    """
    color = color_mapping[class_id]
    class_name = class_names[class_id]

    title = class_name
    if bbox_prefix:
        title = f"{bbox_prefix} {class_name}"
    if pred_conf is not None:
        title = f"{title} {str(round(pred_conf, 2))}"

    image_np = draw_bbox(image=image_np, title=title, x1=x1, y1=y1, x2=x2, y2=y2, box_thickness=box_thickness, color=color)
    return image_np

visualize_batch(image_tensor, pred_boxes, target_boxes, batch_name, class_names, checkpoint_dir=None, undo_preprocessing_func=undo_image_preprocessing, box_thickness=None, image_scale=1.0, gt_alpha=0.4) staticmethod

A helper function to visualize detections predicted by a network: saves images into a given path with a name that is {batch_name}_{imade_idx_in_the_batch}.jpg, one batch per call. Colors are generated on the fly: uniformly sampled from color wheel to support all given classes.

Adjustable: * Ground truth box transparency; * Box width; * Image size (larger or smaller than what's provided)

Parameters:

Name Type Description Default
image_tensor torch.Tensor

rgb images, (B, H, W, 3)

required
pred_boxes List[torch.Tensor]

boxes after NMS for each image in a batch, each (Num_boxes, 6), values on dim 1 are: x1, y1, x2, y2, confidence, class

required
target_boxes torch.Tensor

(Num_targets, 6), values on dim 1 are: image id in a batch, class, cx cy w h (coordinates scaled to [0, 1])

required
batch_name Union[int, str]

id of the current batch to use for image naming

required
class_names List[str]

names of all classes, each on its own index

required
checkpoint_dir str

a path where images with boxes will be saved. if None, the result images will be returns as a list of numpy image arrays

None
undo_preprocessing_func Callable[[torch.Tensor], np.ndarray]

a function to convert preprocessed images tensor into a batch of cv2-like images

undo_image_preprocessing
box_thickness Optional[int]

box line thickness in px

None
image_scale float

scale of an image w.r.t. given image size, e.g. incoming images are (320x320), use scale = 2. to preview in (640x640)

1.0
gt_alpha float

a value in [0., 1.] transparency on ground truth boxes, 0 for invisible, 1 for fully opaque

0.4
Source code in src/super_gradients/training/utils/detection_utils.py
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
@staticmethod
def visualize_batch(
    image_tensor: torch.Tensor,
    pred_boxes: List[torch.Tensor],
    target_boxes: torch.Tensor,
    batch_name: Union[int, str],
    class_names: List[str],
    checkpoint_dir: str = None,
    undo_preprocessing_func: Callable[[torch.Tensor], np.ndarray] = undo_image_preprocessing,
    box_thickness: Optional[int] = None,
    image_scale: float = 1.0,
    gt_alpha: float = 0.4,
):
    """
    A helper function to visualize detections predicted by a network:
    saves images into a given path with a name that is {batch_name}_{imade_idx_in_the_batch}.jpg, one batch per call.
    Colors are generated on the fly: uniformly sampled from color wheel to support all given classes.

    Adjustable:
        * Ground truth box transparency;
        * Box width;
        * Image size (larger or smaller than what's provided)

    :param image_tensor:            rgb images, (B, H, W, 3)
    :param pred_boxes:              boxes after NMS for each image in a batch, each (Num_boxes, 6),
                                    values on dim 1 are: x1, y1, x2, y2, confidence, class
    :param target_boxes:            (Num_targets, 6), values on dim 1 are: image id in a batch, class, cx cy w h
                                    (coordinates scaled to [0, 1])
    :param batch_name:              id of the current batch to use for image naming

    :param class_names:             names of all classes, each on its own index
    :param checkpoint_dir:          a path where images with boxes will be saved. if None, the result images will
                                    be returns as a list of numpy image arrays

    :param undo_preprocessing_func: a function to convert preprocessed images tensor into a batch of cv2-like images
    :param box_thickness:           box line thickness in px
    :param image_scale:             scale of an image w.r.t. given image size,
                                    e.g. incoming images are (320x320), use scale = 2. to preview in (640x640)
    :param gt_alpha:                a value in [0., 1.] transparency on ground truth boxes,
                                    0 for invisible, 1 for fully opaque
    """
    image_np = undo_preprocessing_func(image_tensor.detach())
    targets = DetectionVisualization._scaled_ccwh_to_xyxy(target_boxes.detach().cpu().numpy().copy(), *image_np.shape[1:3], image_scale)
    if pred_boxes is None:
        pred_boxes = [None for _ in range(image_np.shape[0])]

    out_images = []
    for i in range(image_np.shape[0]):
        preds = pred_boxes[i].detach().cpu().numpy() if pred_boxes[i] is not None else np.empty((0, 6))
        targets_cur = targets[targets[:, 0] == i]

        image_name = "_".join([str(batch_name), str(i)])
        res_image = DetectionVisualization._visualize_image(
            image_np=image_np[i],
            pred_boxes=preds,
            target_boxes=targets_cur,
            class_names=class_names,
            box_thickness=box_thickness,
            gt_alpha=gt_alpha,
            image_scale=image_scale,
            checkpoint_dir=checkpoint_dir,
            image_name=image_name,
        )
        if res_image is not None:
            out_images.append(res_image)

    return out_images

DistanceMatching

Bases: DetectionMatching

DistanceMatching is a subclass of DetectionMatching that uses a distance metric for matching detections in object detection models.

Source code in src/super_gradients/training/utils/detection_utils.py
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
class DistanceMatching(DetectionMatching):
    """
    DistanceMatching is a subclass of DetectionMatching that uses a distance metric
    for matching detections in object detection models.
    """

    def __init__(self, distance_metric, distance_thresholds: torch.Tensor):
        """
        Initializes the DistanceMatching instance with a distance metric and distance thresholds.

        :param distance_metric: The distance metric to be used for matching.
        :param distance_thresholds: (torch.Tensor) The distance thresholds for matching.
        """
        self.distance_metric = distance_metric
        self.distance_thresholds = distance_thresholds

    def get_thresholds(self) -> torch.Tensor:
        """
        Returns the distance thresholds used for detection matching.

        :return: (torch.Tensor) The distance thresholds.
        """
        return torch.tensor(self.distance_thresholds)

    def compute_targets(
        self,
        preds_box_xyxy: torch.Tensor,
        preds_cls: torch.Tensor,
        targets_box_xyxy: torch.Tensor,
        targets_cls: torch.Tensor,
        preds_matched: torch.Tensor,
        targets_matched: torch.Tensor,
        preds_idx_to_use: torch.Tensor,
    ) -> torch.Tensor:
        """
        Computes the matching targets based on the distance metric for regular scenarios.

        :param preds_box_xyxy: (torch.Tensor) Predicted bounding boxes in XYXY format.
        :param preds_cls: (torch.Tensor) Predicted classes.
        :param targets_box_xyxy: (torch.Tensor) Target bounding boxes in XYXY format.
        :param targets_cls: (torch.Tensor) Target classes.
        :param preds_matched: (torch.Tensor) Tensor indicating which predictions are matched.
        :param targets_matched: (torch.Tensor) Tensor indicating which targets are matched.
        :param preds_idx_to_use: (torch.Tensor) Indices of predictions to use.
        :return: (torch.Tensor) Computed matching targets.
        """
        # Calculate the distances between targets and predictions using the current metric
        # shape = (n_preds x n_targets)
        distances = self.distance_metric.calculate_distance(preds_box_xyxy[preds_idx_to_use], targets_box_xyxy)

        # Invalidate distances when class labels don't match
        cls_mismatch = preds_cls[preds_idx_to_use].view(-1, 1) != targets_cls.view(1, -1)
        distances[cls_mismatch] = float("inf")  # or max(distance_thresholds) + 1

        # Sort distances
        sorted_distances, target_sorted = distances.sort(stable=True)

        # Identify all pairs that are within the max distance threshold
        candidate_pairs = (sorted_distances < max(self.distance_thresholds)).nonzero(as_tuple=False)
        for pred_selected_i, target_sorted_i in candidate_pairs:
            pred_i = preds_idx_to_use[pred_selected_i]
            target_i = target_sorted[pred_selected_i, target_sorted_i]

            distance_thresholds_tensor = torch.tensor(self.distance_thresholds, device=distances.device)
            is_distance_below_threshold = sorted_distances[pred_selected_i, target_sorted_i] < distance_thresholds_tensor
            are_candidates_free = torch.logical_and(~preds_matched[pred_i, :], ~targets_matched[target_i, :])
            are_candidates_good = torch.logical_and(is_distance_below_threshold, are_candidates_free)

            targets_matched[target_i, are_candidates_good] = True
            preds_matched[pred_i, are_candidates_good] = True

            if targets_matched.all():
                break

        return preds_matched

    def compute_crowd_targets(
        self,
        preds_box_xyxy: torch.Tensor,
        preds_cls: torch.Tensor,
        crowd_targets_cls: torch.Tensor,
        crowd_target_box_xyxy: torch.Tensor,
        preds_matched: torch.Tensor,
        preds_to_ignore: torch.Tensor,
        preds_idx_to_use: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Computes the matching targets based on the distance metric for crowd scenarios.

        :param preds_box_xyxy: (torch.Tensor) Predicted bounding boxes in XYXY format.
        :param preds_cls: (torch.Tensor) Predicted classes.
        :param crowd_targets_cls: (torch.Tensor) Crowd target classes.
        :param crowd_target_box_xyxy: (torch.Tensor) Crowd target bounding boxes in XYXY format.
        :param preds_matched: (torch.Tensor) Tensor indicating which predictions are matched.
        :param preds_to_ignore: (torch.Tensor) Tensor indicating which predictions to ignore.
        :param preds_idx_to_use: (torch.Tensor) Indices of predictions to use.
        :return: (Tuple[torch.Tensor, torch.Tensor]) Computed matching targets for crowd scenarios.
        """
        cls_mismatch_crowd = preds_cls[preds_idx_to_use].view(-1, 1) != crowd_targets_cls.view(1, -1)

        # Iterate over each distance metric and its corresponding threshold
        distances = self.distance_metric.calculate_distance(preds_box_xyxy[preds_idx_to_use], crowd_target_box_xyxy)
        distances[cls_mismatch_crowd] = float("inf")

        best_distance, _ = distances.min(1)
        is_matching_with_crowd = best_distance.view(-1, 1) < torch.tensor(self.distance_thresholds, device=distances.device).view(1, -1)

        preds_to_ignore[preds_idx_to_use] = torch.logical_or(preds_to_ignore[preds_idx_to_use], is_matching_with_crowd)

        return preds_matched, preds_to_ignore

__init__(distance_metric, distance_thresholds)

Initializes the DistanceMatching instance with a distance metric and distance thresholds.

Parameters:

Name Type Description Default
distance_metric

The distance metric to be used for matching.

required
distance_thresholds torch.Tensor

(torch.Tensor) The distance thresholds for matching.

required
Source code in src/super_gradients/training/utils/detection_utils.py
1014
1015
1016
1017
1018
1019
1020
1021
1022
def __init__(self, distance_metric, distance_thresholds: torch.Tensor):
    """
    Initializes the DistanceMatching instance with a distance metric and distance thresholds.

    :param distance_metric: The distance metric to be used for matching.
    :param distance_thresholds: (torch.Tensor) The distance thresholds for matching.
    """
    self.distance_metric = distance_metric
    self.distance_thresholds = distance_thresholds

compute_crowd_targets(preds_box_xyxy, preds_cls, crowd_targets_cls, crowd_target_box_xyxy, preds_matched, preds_to_ignore, preds_idx_to_use)

Computes the matching targets based on the distance metric for crowd scenarios.

Parameters:

Name Type Description Default
preds_box_xyxy torch.Tensor

(torch.Tensor) Predicted bounding boxes in XYXY format.

required
preds_cls torch.Tensor

(torch.Tensor) Predicted classes.

required
crowd_targets_cls torch.Tensor

(torch.Tensor) Crowd target classes.

required
crowd_target_box_xyxy torch.Tensor

(torch.Tensor) Crowd target bounding boxes in XYXY format.

required
preds_matched torch.Tensor

(torch.Tensor) Tensor indicating which predictions are matched.

required
preds_to_ignore torch.Tensor

(torch.Tensor) Tensor indicating which predictions to ignore.

required
preds_idx_to_use torch.Tensor

(torch.Tensor) Indices of predictions to use.

required

Returns:

Type Description
Tuple[torch.Tensor, torch.Tensor]

(Tuple[torch.Tensor, torch.Tensor]) Computed matching targets for crowd scenarios.

Source code in src/super_gradients/training/utils/detection_utils.py
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
def compute_crowd_targets(
    self,
    preds_box_xyxy: torch.Tensor,
    preds_cls: torch.Tensor,
    crowd_targets_cls: torch.Tensor,
    crowd_target_box_xyxy: torch.Tensor,
    preds_matched: torch.Tensor,
    preds_to_ignore: torch.Tensor,
    preds_idx_to_use: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Computes the matching targets based on the distance metric for crowd scenarios.

    :param preds_box_xyxy: (torch.Tensor) Predicted bounding boxes in XYXY format.
    :param preds_cls: (torch.Tensor) Predicted classes.
    :param crowd_targets_cls: (torch.Tensor) Crowd target classes.
    :param crowd_target_box_xyxy: (torch.Tensor) Crowd target bounding boxes in XYXY format.
    :param preds_matched: (torch.Tensor) Tensor indicating which predictions are matched.
    :param preds_to_ignore: (torch.Tensor) Tensor indicating which predictions to ignore.
    :param preds_idx_to_use: (torch.Tensor) Indices of predictions to use.
    :return: (Tuple[torch.Tensor, torch.Tensor]) Computed matching targets for crowd scenarios.
    """
    cls_mismatch_crowd = preds_cls[preds_idx_to_use].view(-1, 1) != crowd_targets_cls.view(1, -1)

    # Iterate over each distance metric and its corresponding threshold
    distances = self.distance_metric.calculate_distance(preds_box_xyxy[preds_idx_to_use], crowd_target_box_xyxy)
    distances[cls_mismatch_crowd] = float("inf")

    best_distance, _ = distances.min(1)
    is_matching_with_crowd = best_distance.view(-1, 1) < torch.tensor(self.distance_thresholds, device=distances.device).view(1, -1)

    preds_to_ignore[preds_idx_to_use] = torch.logical_or(preds_to_ignore[preds_idx_to_use], is_matching_with_crowd)

    return preds_matched, preds_to_ignore

compute_targets(preds_box_xyxy, preds_cls, targets_box_xyxy, targets_cls, preds_matched, targets_matched, preds_idx_to_use)

Computes the matching targets based on the distance metric for regular scenarios.

Parameters:

Name Type Description Default
preds_box_xyxy torch.Tensor

(torch.Tensor) Predicted bounding boxes in XYXY format.

required
preds_cls torch.Tensor

(torch.Tensor) Predicted classes.

required
targets_box_xyxy torch.Tensor

(torch.Tensor) Target bounding boxes in XYXY format.

required
targets_cls torch.Tensor

(torch.Tensor) Target classes.

required
preds_matched torch.Tensor

(torch.Tensor) Tensor indicating which predictions are matched.

required
targets_matched torch.Tensor

(torch.Tensor) Tensor indicating which targets are matched.

required
preds_idx_to_use torch.Tensor

(torch.Tensor) Indices of predictions to use.

required

Returns:

Type Description
torch.Tensor

(torch.Tensor) Computed matching targets.

Source code in src/super_gradients/training/utils/detection_utils.py
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
def compute_targets(
    self,
    preds_box_xyxy: torch.Tensor,
    preds_cls: torch.Tensor,
    targets_box_xyxy: torch.Tensor,
    targets_cls: torch.Tensor,
    preds_matched: torch.Tensor,
    targets_matched: torch.Tensor,
    preds_idx_to_use: torch.Tensor,
) -> torch.Tensor:
    """
    Computes the matching targets based on the distance metric for regular scenarios.

    :param preds_box_xyxy: (torch.Tensor) Predicted bounding boxes in XYXY format.
    :param preds_cls: (torch.Tensor) Predicted classes.
    :param targets_box_xyxy: (torch.Tensor) Target bounding boxes in XYXY format.
    :param targets_cls: (torch.Tensor) Target classes.
    :param preds_matched: (torch.Tensor) Tensor indicating which predictions are matched.
    :param targets_matched: (torch.Tensor) Tensor indicating which targets are matched.
    :param preds_idx_to_use: (torch.Tensor) Indices of predictions to use.
    :return: (torch.Tensor) Computed matching targets.
    """
    # Calculate the distances between targets and predictions using the current metric
    # shape = (n_preds x n_targets)
    distances = self.distance_metric.calculate_distance(preds_box_xyxy[preds_idx_to_use], targets_box_xyxy)

    # Invalidate distances when class labels don't match
    cls_mismatch = preds_cls[preds_idx_to_use].view(-1, 1) != targets_cls.view(1, -1)
    distances[cls_mismatch] = float("inf")  # or max(distance_thresholds) + 1

    # Sort distances
    sorted_distances, target_sorted = distances.sort(stable=True)

    # Identify all pairs that are within the max distance threshold
    candidate_pairs = (sorted_distances < max(self.distance_thresholds)).nonzero(as_tuple=False)
    for pred_selected_i, target_sorted_i in candidate_pairs:
        pred_i = preds_idx_to_use[pred_selected_i]
        target_i = target_sorted[pred_selected_i, target_sorted_i]

        distance_thresholds_tensor = torch.tensor(self.distance_thresholds, device=distances.device)
        is_distance_below_threshold = sorted_distances[pred_selected_i, target_sorted_i] < distance_thresholds_tensor
        are_candidates_free = torch.logical_and(~preds_matched[pred_i, :], ~targets_matched[target_i, :])
        are_candidates_good = torch.logical_and(is_distance_below_threshold, are_candidates_free)

        targets_matched[target_i, are_candidates_good] = True
        preds_matched[pred_i, are_candidates_good] = True

        if targets_matched.all():
            break

    return preds_matched

get_thresholds()

Returns the distance thresholds used for detection matching.

Returns:

Type Description
torch.Tensor

(torch.Tensor) The distance thresholds.

Source code in src/super_gradients/training/utils/detection_utils.py
1024
1025
1026
1027
1028
1029
1030
def get_thresholds(self) -> torch.Tensor:
    """
    Returns the distance thresholds used for detection matching.

    :return: (torch.Tensor) The distance thresholds.
    """
    return torch.tensor(self.distance_thresholds)

EuclideanDistance

Bases: DistanceMetric

Source code in src/super_gradients/training/utils/detection_utils.py
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
class EuclideanDistance(DistanceMetric):
    def calculate_distance(self, predicted: torch.Tensor, target: torch.Tensor):
        """
        Calculate the Euclidean distance (L2 distance) between the centers of preds_box and targets_box.

        :param predicted: (N, 4) tensor for N predicted bounding boxes (x1, y1, x2, y2)
        :param target: (M, 4) tensor for M target bounding boxes (x1, y1, x2, y2)

        :return: (N, M) tensor representing pairwise euclidean distances
        """
        # Calculate the centers of the bounding boxes
        centers1 = (predicted[:, :2] + predicted[:, 2:]) / 2
        centers2 = (target[:, :2] + target[:, 2:]) / 2

        # Calculate squared differences
        diff = centers1.view(-1, 1, 2) - centers2.view(1, -1, 2)
        dist_sq = (diff**2).sum(dim=2)
        dist = torch.sqrt(dist_sq)

        return dist

calculate_distance(predicted, target)

Calculate the Euclidean distance (L2 distance) between the centers of preds_box and targets_box.

Parameters:

Name Type Description Default
predicted torch.Tensor

(N, 4) tensor for N predicted bounding boxes (x1, y1, x2, y2)

required
target torch.Tensor

(M, 4) tensor for M target bounding boxes (x1, y1, x2, y2)

required

Returns:

Type Description

(N, M) tensor representing pairwise euclidean distances

Source code in src/super_gradients/training/utils/detection_utils.py
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
def calculate_distance(self, predicted: torch.Tensor, target: torch.Tensor):
    """
    Calculate the Euclidean distance (L2 distance) between the centers of preds_box and targets_box.

    :param predicted: (N, 4) tensor for N predicted bounding boxes (x1, y1, x2, y2)
    :param target: (M, 4) tensor for M target bounding boxes (x1, y1, x2, y2)

    :return: (N, M) tensor representing pairwise euclidean distances
    """
    # Calculate the centers of the bounding boxes
    centers1 = (predicted[:, :2] + predicted[:, 2:]) / 2
    centers2 = (target[:, :2] + target[:, 2:]) / 2

    # Calculate squared differences
    diff = centers1.view(-1, 1, 2) - centers2.view(1, -1, 2)
    dist_sq = (diff**2).sum(dim=2)
    dist = torch.sqrt(dist_sq)

    return dist

IoUMatching

Bases: DetectionMatching

IoUMatching is a subclass of DetectionMatching that uses Intersection over Union (IoU) for matching detections in object detection models.

Source code in src/super_gradients/training/utils/detection_utils.py
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
class IoUMatching(DetectionMatching):
    """
    IoUMatching is a subclass of DetectionMatching that uses Intersection over Union (IoU)
    for matching detections in object detection models.
    """

    def __init__(self, iou_thresholds: torch.Tensor):
        """
        Initializes the IoUMatching instance with IoU thresholds.

        :param iou_thresholds: (torch.Tensor) The IoU thresholds for matching.
        """
        self.iou_thresholds = iou_thresholds

    def get_thresholds(self) -> torch.Tensor:
        """
        Returns the IoU thresholds used for detection matching.

        :return: (torch.Tensor) The IoU thresholds.
        """
        return self.iou_thresholds

    def compute_targets(
        self,
        preds_box_xyxy: torch.Tensor,
        preds_cls: torch.Tensor,
        targets_box_xyxy: torch.Tensor,
        targets_cls: torch.Tensor,
        preds_matched: torch.Tensor,
        targets_matched: torch.Tensor,
        preds_idx_to_use: torch.Tensor,
    ) -> torch.Tensor:
        """
        Computes the matching targets based on IoU for regular scenarios.

        :param preds_box_xyxy: (torch.Tensor) Predicted bounding boxes in XYXY format.
        :param preds_cls: (torch.Tensor) Predicted classes.
        :param targets_box_xyxy: (torch.Tensor) Target bounding boxes in XYXY format.
        :param targets_cls: (torch.Tensor) Target classes.
        :param preds_matched: (torch.Tensor) Tensor indicating which predictions are matched.
        :param targets_matched: (torch.Tensor) Tensor indicating which targets are matched.
        :param preds_idx_to_use: (torch.Tensor) Indices of predictions to use.
        :return: (torch.Tensor) Computed matching targets.
        """
        # shape = (n_preds x n_targets)
        iou = box_iou(preds_box_xyxy[preds_idx_to_use], targets_box_xyxy)

        # Fill IoU values at index (i, j) with 0 when the prediction (i) and target(j) are of different class
        # Filling with 0 is equivalent to ignore these values since with want IoU > iou_threshold > 0
        cls_mismatch = preds_cls[preds_idx_to_use].view(-1, 1) != targets_cls.view(1, -1)
        iou[cls_mismatch] = 0

        # The matching priority is first detection confidence and then IoU value.
        # The detection is already sorted by confidence in NMS, so here for each prediction we order the targets by iou.
        sorted_iou, target_sorted = iou.sort(descending=True, stable=True)

        # Only iterate over IoU values higher than min threshold to speed up the process
        for pred_selected_i, target_sorted_i in (sorted_iou > self.iou_thresholds[0]).nonzero(as_tuple=False):
            # pred_selected_i and target_sorted_i are relative to filters/sorting, so we extract their absolute indexes
            pred_i = preds_idx_to_use[pred_selected_i]
            target_i = target_sorted[pred_selected_i, target_sorted_i]

            # Vector[j], True when IoU(pred_i, target_i) is above the (j)th threshold
            is_iou_above_threshold = sorted_iou[pred_selected_i, target_sorted_i] > self.iou_thresholds

            # Vector[j], True when both pred_i and target_i are not matched yet for the (j)th threshold
            are_candidates_free = torch.logical_and(~preds_matched[pred_i, :], ~targets_matched[target_i, :])

            # Vector[j], True when (pred_i, target_i) can be matched for the (j)th threshold
            are_candidates_good = torch.logical_and(is_iou_above_threshold, are_candidates_free)

            # For every threshold (j) where target_i and pred_i can be matched together ( are_candidates_good[j]==True )
            # fill the matching placeholders with True
            targets_matched[target_i, are_candidates_good] = True
            preds_matched[pred_i, are_candidates_good] = True

            # When all the targets are matched with a prediction for every IoU Threshold, stop.
            if targets_matched.all():
                break

        return preds_matched

    def compute_crowd_targets(
        self,
        preds_box_xyxy: torch.Tensor,
        preds_cls: torch.Tensor,
        crowd_targets_cls: torch.Tensor,
        crowd_target_box_xyxy: torch.Tensor,
        preds_matched: torch.Tensor,
        preds_to_ignore: torch.Tensor,
        preds_idx_to_use: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Computes the matching targets based on IoU for crowd scenarios.

        :param preds_box_xyxy: (torch.Tensor) Predicted bounding boxes in XYXY format.
        :param preds_cls: (torch.Tensor) Predicted classes.
        :param crowd_targets_cls: (torch.Tensor) Crowd target classes.
        :param crowd_target_box_xyxy: (torch.Tensor) Crowd target bounding boxes in XYXY format.
        :param preds_matched: (torch.Tensor) Tensor indicating which predictions are matched.
        :param preds_to_ignore: (torch.Tensor) Tensor indicating which predictions to ignore.
        :param preds_idx_to_use: (torch.Tensor) Indices of predictions to use.
        :return: (Tuple[torch.Tensor, torch.Tensor]) Computed matching targets for crowd scenarios.
        """
        # Crowd targets can be matched with many predictions.
        # Therefore, for every prediction we just need to check if it has IoA large enough with any crowd target.

        # shape = (n_preds_to_use x n_crowd_targets)
        ioa = crowd_ioa(preds_box_xyxy[preds_idx_to_use], crowd_target_box_xyxy)

        # Fill IoA values at index (i, j) with 0 when the prediction (i) and target(j) are of different class
        # Filling with 0 is equivalent to ignore these values since with want IoA > threshold > 0
        cls_mismatch = preds_cls[preds_idx_to_use].view(-1, 1) != crowd_targets_cls.view(1, -1)
        ioa[cls_mismatch] = 0

        # For each prediction, we keep it's highest score with any crowd target (of same class)
        # shape = (n_preds_to_use)
        best_ioa, _ = ioa.max(1)

        # If a prediction has IoA higher than threshold (with any target of same class), then there is a match
        # shape = (n_preds_to_use x iou_thresholds)
        is_matching_with_crowd = best_ioa.view(-1, 1) > self.iou_thresholds.view(1, -1)

        preds_to_ignore[preds_idx_to_use] = torch.logical_or(preds_to_ignore[preds_idx_to_use], is_matching_with_crowd)

        return preds_matched, preds_to_ignore

__init__(iou_thresholds)

Initializes the IoUMatching instance with IoU thresholds.

Parameters:

Name Type Description Default
iou_thresholds torch.Tensor

(torch.Tensor) The IoU thresholds for matching.

required
Source code in src/super_gradients/training/utils/detection_utils.py
886
887
888
889
890
891
892
def __init__(self, iou_thresholds: torch.Tensor):
    """
    Initializes the IoUMatching instance with IoU thresholds.

    :param iou_thresholds: (torch.Tensor) The IoU thresholds for matching.
    """
    self.iou_thresholds = iou_thresholds

compute_crowd_targets(preds_box_xyxy, preds_cls, crowd_targets_cls, crowd_target_box_xyxy, preds_matched, preds_to_ignore, preds_idx_to_use)

Computes the matching targets based on IoU for crowd scenarios.

Parameters:

Name Type Description Default
preds_box_xyxy torch.Tensor

(torch.Tensor) Predicted bounding boxes in XYXY format.

required
preds_cls torch.Tensor

(torch.Tensor) Predicted classes.

required
crowd_targets_cls torch.Tensor

(torch.Tensor) Crowd target classes.

required
crowd_target_box_xyxy torch.Tensor

(torch.Tensor) Crowd target bounding boxes in XYXY format.

required
preds_matched torch.Tensor

(torch.Tensor) Tensor indicating which predictions are matched.

required
preds_to_ignore torch.Tensor

(torch.Tensor) Tensor indicating which predictions to ignore.

required
preds_idx_to_use torch.Tensor

(torch.Tensor) Indices of predictions to use.

required

Returns:

Type Description
Tuple[torch.Tensor, torch.Tensor]

(Tuple[torch.Tensor, torch.Tensor]) Computed matching targets for crowd scenarios.

Source code in src/super_gradients/training/utils/detection_utils.py
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
def compute_crowd_targets(
    self,
    preds_box_xyxy: torch.Tensor,
    preds_cls: torch.Tensor,
    crowd_targets_cls: torch.Tensor,
    crowd_target_box_xyxy: torch.Tensor,
    preds_matched: torch.Tensor,
    preds_to_ignore: torch.Tensor,
    preds_idx_to_use: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Computes the matching targets based on IoU for crowd scenarios.

    :param preds_box_xyxy: (torch.Tensor) Predicted bounding boxes in XYXY format.
    :param preds_cls: (torch.Tensor) Predicted classes.
    :param crowd_targets_cls: (torch.Tensor) Crowd target classes.
    :param crowd_target_box_xyxy: (torch.Tensor) Crowd target bounding boxes in XYXY format.
    :param preds_matched: (torch.Tensor) Tensor indicating which predictions are matched.
    :param preds_to_ignore: (torch.Tensor) Tensor indicating which predictions to ignore.
    :param preds_idx_to_use: (torch.Tensor) Indices of predictions to use.
    :return: (Tuple[torch.Tensor, torch.Tensor]) Computed matching targets for crowd scenarios.
    """
    # Crowd targets can be matched with many predictions.
    # Therefore, for every prediction we just need to check if it has IoA large enough with any crowd target.

    # shape = (n_preds_to_use x n_crowd_targets)
    ioa = crowd_ioa(preds_box_xyxy[preds_idx_to_use], crowd_target_box_xyxy)

    # Fill IoA values at index (i, j) with 0 when the prediction (i) and target(j) are of different class
    # Filling with 0 is equivalent to ignore these values since with want IoA > threshold > 0
    cls_mismatch = preds_cls[preds_idx_to_use].view(-1, 1) != crowd_targets_cls.view(1, -1)
    ioa[cls_mismatch] = 0

    # For each prediction, we keep it's highest score with any crowd target (of same class)
    # shape = (n_preds_to_use)
    best_ioa, _ = ioa.max(1)

    # If a prediction has IoA higher than threshold (with any target of same class), then there is a match
    # shape = (n_preds_to_use x iou_thresholds)
    is_matching_with_crowd = best_ioa.view(-1, 1) > self.iou_thresholds.view(1, -1)

    preds_to_ignore[preds_idx_to_use] = torch.logical_or(preds_to_ignore[preds_idx_to_use], is_matching_with_crowd)

    return preds_matched, preds_to_ignore

compute_targets(preds_box_xyxy, preds_cls, targets_box_xyxy, targets_cls, preds_matched, targets_matched, preds_idx_to_use)

Computes the matching targets based on IoU for regular scenarios.

Parameters:

Name Type Description Default
preds_box_xyxy torch.Tensor

(torch.Tensor) Predicted bounding boxes in XYXY format.

required
preds_cls torch.Tensor

(torch.Tensor) Predicted classes.

required
targets_box_xyxy torch.Tensor

(torch.Tensor) Target bounding boxes in XYXY format.

required
targets_cls torch.Tensor

(torch.Tensor) Target classes.

required
preds_matched torch.Tensor

(torch.Tensor) Tensor indicating which predictions are matched.

required
targets_matched torch.Tensor

(torch.Tensor) Tensor indicating which targets are matched.

required
preds_idx_to_use torch.Tensor

(torch.Tensor) Indices of predictions to use.

required

Returns:

Type Description
torch.Tensor

(torch.Tensor) Computed matching targets.

Source code in src/super_gradients/training/utils/detection_utils.py
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
def compute_targets(
    self,
    preds_box_xyxy: torch.Tensor,
    preds_cls: torch.Tensor,
    targets_box_xyxy: torch.Tensor,
    targets_cls: torch.Tensor,
    preds_matched: torch.Tensor,
    targets_matched: torch.Tensor,
    preds_idx_to_use: torch.Tensor,
) -> torch.Tensor:
    """
    Computes the matching targets based on IoU for regular scenarios.

    :param preds_box_xyxy: (torch.Tensor) Predicted bounding boxes in XYXY format.
    :param preds_cls: (torch.Tensor) Predicted classes.
    :param targets_box_xyxy: (torch.Tensor) Target bounding boxes in XYXY format.
    :param targets_cls: (torch.Tensor) Target classes.
    :param preds_matched: (torch.Tensor) Tensor indicating which predictions are matched.
    :param targets_matched: (torch.Tensor) Tensor indicating which targets are matched.
    :param preds_idx_to_use: (torch.Tensor) Indices of predictions to use.
    :return: (torch.Tensor) Computed matching targets.
    """
    # shape = (n_preds x n_targets)
    iou = box_iou(preds_box_xyxy[preds_idx_to_use], targets_box_xyxy)

    # Fill IoU values at index (i, j) with 0 when the prediction (i) and target(j) are of different class
    # Filling with 0 is equivalent to ignore these values since with want IoU > iou_threshold > 0
    cls_mismatch = preds_cls[preds_idx_to_use].view(-1, 1) != targets_cls.view(1, -1)
    iou[cls_mismatch] = 0

    # The matching priority is first detection confidence and then IoU value.
    # The detection is already sorted by confidence in NMS, so here for each prediction we order the targets by iou.
    sorted_iou, target_sorted = iou.sort(descending=True, stable=True)

    # Only iterate over IoU values higher than min threshold to speed up the process
    for pred_selected_i, target_sorted_i in (sorted_iou > self.iou_thresholds[0]).nonzero(as_tuple=False):
        # pred_selected_i and target_sorted_i are relative to filters/sorting, so we extract their absolute indexes
        pred_i = preds_idx_to_use[pred_selected_i]
        target_i = target_sorted[pred_selected_i, target_sorted_i]

        # Vector[j], True when IoU(pred_i, target_i) is above the (j)th threshold
        is_iou_above_threshold = sorted_iou[pred_selected_i, target_sorted_i] > self.iou_thresholds

        # Vector[j], True when both pred_i and target_i are not matched yet for the (j)th threshold
        are_candidates_free = torch.logical_and(~preds_matched[pred_i, :], ~targets_matched[target_i, :])

        # Vector[j], True when (pred_i, target_i) can be matched for the (j)th threshold
        are_candidates_good = torch.logical_and(is_iou_above_threshold, are_candidates_free)

        # For every threshold (j) where target_i and pred_i can be matched together ( are_candidates_good[j]==True )
        # fill the matching placeholders with True
        targets_matched[target_i, are_candidates_good] = True
        preds_matched[pred_i, are_candidates_good] = True

        # When all the targets are matched with a prediction for every IoU Threshold, stop.
        if targets_matched.all():
            break

    return preds_matched

get_thresholds()

Returns the IoU thresholds used for detection matching.

Returns:

Type Description
torch.Tensor

(torch.Tensor) The IoU thresholds.

Source code in src/super_gradients/training/utils/detection_utils.py
894
895
896
897
898
899
900
def get_thresholds(self) -> torch.Tensor:
    """
    Returns the IoU thresholds used for detection matching.

    :return: (torch.Tensor) The IoU thresholds.
    """
    return self.iou_thresholds

IouThreshold

Bases: tuple, Enum

Source code in src/super_gradients/training/utils/detection_utils.py
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
class IouThreshold(tuple, Enum):
    MAP_05 = (0.5, 0.5)
    MAP_05_TO_095 = (0.5, 0.95)

    def is_range(self):
        return self[0] != self[1]

    def to_tensor(self):
        if self.is_range():
            return self.from_bounds(self[0], self[1], step=0.05)
        else:
            return torch.tensor([self[0]])

    @classmethod
    def from_bounds(cls, low: float, high: float, step: float = 0.05) -> torch.Tensor:
        """
        Create a tensor with values from low (including) to high (including) with a given step size.
        :param low: Lower bound
        :param high: Upper bound
        :param step: Step size
        :return: Tensor of [low, low + step, low + 2 * step, ..., high]
        """
        n_iou_thresh = int(round((high - low) / step)) + 1
        return torch.linspace(low, high, n_iou_thresh)

from_bounds(low, high, step=0.05) classmethod

Create a tensor with values from low (including) to high (including) with a given step size.

Parameters:

Name Type Description Default
low float

Lower bound

required
high float

Upper bound

required
step float

Step size

0.05

Returns:

Type Description
torch.Tensor

Tensor of [low, low + step, low + 2 * step, ..., high]

Source code in src/super_gradients/training/utils/detection_utils.py
244
245
246
247
248
249
250
251
252
253
254
@classmethod
def from_bounds(cls, low: float, high: float, step: float = 0.05) -> torch.Tensor:
    """
    Create a tensor with values from low (including) to high (including) with a given step size.
    :param low: Lower bound
    :param high: Upper bound
    :param step: Step size
    :return: Tensor of [low, low + step, low + 2 * step, ..., high]
    """
    n_iou_thresh = int(round((high - low) / step)) + 1
    return torch.linspace(low, high, n_iou_thresh)

ManhattanDistance

Bases: DistanceMetric

Source code in src/super_gradients/training/utils/detection_utils.py
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
class ManhattanDistance(DistanceMetric):
    def calculate_distance(self, predicted: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        """
        Calculate the Manhattan distance (L1 distance) between the centers of preds_box and targets_box.

        :param predicted: (N, 4) tensor for N predicted bounding boxes (x1, y1, x2, y2)
        :param target: (M, 4) tensor for M target bounding boxes (x1, y1, x2, y2)

        :return: (N, M) tensor representing pairwise Manhattan distances
        """
        # Calculate the centers of the bounding boxes
        centers1 = (predicted[:, :2] + predicted[:, 2:]) / 2  # (N, 2)
        centers2 = (target[:, :2] + target[:, 2:]) / 2  # (M, 2)

        # Calculate absolute differences
        diff = centers1.view(-1, 1, 2) - centers2.view(1, -1, 2)
        abs_diff = torch.abs(diff).sum(dim=2)

        return abs_diff

calculate_distance(predicted, target)

Calculate the Manhattan distance (L1 distance) between the centers of preds_box and targets_box.

Parameters:

Name Type Description Default
predicted torch.Tensor

(N, 4) tensor for N predicted bounding boxes (x1, y1, x2, y2)

required
target torch.Tensor

(M, 4) tensor for M target bounding boxes (x1, y1, x2, y2)

required

Returns:

Type Description
torch.Tensor

(N, M) tensor representing pairwise Manhattan distances

Source code in src/super_gradients/training/utils/detection_utils.py
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
def calculate_distance(self, predicted: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
    """
    Calculate the Manhattan distance (L1 distance) between the centers of preds_box and targets_box.

    :param predicted: (N, 4) tensor for N predicted bounding boxes (x1, y1, x2, y2)
    :param target: (M, 4) tensor for M target bounding boxes (x1, y1, x2, y2)

    :return: (N, M) tensor representing pairwise Manhattan distances
    """
    # Calculate the centers of the bounding boxes
    centers1 = (predicted[:, :2] + predicted[:, 2:]) / 2  # (N, 2)
    centers2 = (target[:, :2] + target[:, 2:]) / 2  # (M, 2)

    # Calculate absolute differences
    diff = centers1.view(-1, 1, 2) - centers2.view(1, -1, 2)
    abs_diff = torch.abs(diff).sum(dim=2)

    return abs_diff

NMS_Type

Bases: str, Enum

Type of non max suppression algorithm that can be used for post processing detection

Source code in src/super_gradients/training/utils/detection_utils.py
393
394
395
396
397
398
399
class NMS_Type(str, Enum):
    """
    Type of non max suppression algorithm that can be used for post processing detection
    """

    ITERATIVE = "iterative"
    MATRIX = "matrix"

adjust_box_anns(bbox, scale_ratio, padw, padh, w_max, h_max)

Adjusts the bbox annotations of rescaled, padded image.

Parameters:

Name Type Description Default
bbox

(np.array) bbox to modify.

required
scale_ratio

(float) scale ratio between rescale output image and original one.

required
padw

(int) width padding size.

required
padh

(int) height padding size.

required
w_max

(int) width border.

required
h_max

(int) height border

required

Returns:

Type Description

modified bbox (np.array)

Source code in src/super_gradients/training/utils/detection_utils.py
771
772
773
774
775
776
777
778
779
780
781
782
783
784
def adjust_box_anns(bbox, scale_ratio, padw, padh, w_max, h_max):
    """
    Adjusts the bbox annotations of rescaled, padded image.

    :param bbox: (np.array) bbox to modify.
    :param scale_ratio: (float) scale ratio between rescale output image and original one.
    :param padw: (int) width padding size.
    :param padh: (int) height padding size.
    :param w_max: (int) width border.
    :param h_max: (int) height border
    :return: modified bbox (np.array)
    """
    scaled_bboxes = bbox * scale_ratio + np.array([[padw, padh, padw, padh]])
    return change_bbox_bounds_for_image_size_inplace(scaled_bboxes, img_shape=(h_max, w_max))

box_iou(box1, box2)

Return intersection-over-union (Jaccard index) of boxes. Both sets of boxes are expected to be in (x1, y1, x2, y2) format.

Parameters:

Name Type Description Default
box1 torch.Tensor

Tensor of shape [N, 4]

required
box2 torch.Tensor

Tensor of shape [M, 4]

required

Returns:

Type Description
torch.Tensor

iou, Tensor of shape [N, M]: the NxM matrix containing the pairwise IoU values for every element in boxes1 and boxes2

Source code in src/super_gradients/training/utils/detection_utils.py
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
def box_iou(box1: torch.Tensor, box2: torch.Tensor) -> torch.Tensor:
    # https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
    """
    Return intersection-over-union (Jaccard index) of boxes.
    Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
    :param box1: Tensor of shape [N, 4]
    :param box2: Tensor of shape [M, 4]
    :return:     iou, Tensor of shape [N, M]: the NxM matrix containing the pairwise IoU values for every element in boxes1 and boxes2
    """

    def box_area(box):
        # box = 4xn
        return (box[2] - box[0]) * (box[3] - box[1])

    area1 = box_area(box1.T)
    area2 = box_area(box2.T)

    # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
    inter = (torch.min(box1[:, None, 2:], box2[:, 2:]) - torch.max(box1[:, None, :2], box2[:, :2])).clamp(0).prod(2)
    return inter / (area1[:, None] + area2 - inter)  # iou = inter / (area1 + area2 - inter)

calc_bbox_iou_matrix(pred)

calculate iou for every pair of boxes in the boxes vector

Parameters:

Name Type Description Default
pred torch.Tensor

a 3-dimensional tensor containing all boxes for a batch of images [N, num_boxes, 4], where each box format is [x1,y1,x2,y2]

required

Returns:

Type Description

a 3-dimensional matrix where M_i_j_k is the iou of box j and box k of the i'th image in the batch

Source code in src/super_gradients/training/utils/detection_utils.py
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
def calc_bbox_iou_matrix(pred: torch.Tensor):
    """
    calculate iou for every pair of boxes in the boxes vector
    :param pred: a 3-dimensional tensor containing all boxes for a batch of images [N, num_boxes, 4], where
                 each box format is [x1,y1,x2,y2]
    :return: a 3-dimensional matrix where M_i_j_k is the iou of box j and box k of the i'th image in the batch
    """
    box = pred[:, :, :4]  #
    b1_x1, b1_y1 = box[:, :, 0].unsqueeze(1), box[:, :, 1].unsqueeze(1)
    b1_x2, b1_y2 = box[:, :, 2].unsqueeze(1), box[:, :, 3].unsqueeze(1)

    b2_x1 = b1_x1.transpose(2, 1)
    b2_x2 = b1_x2.transpose(2, 1)
    b2_y1 = b1_y1.transpose(2, 1)
    b2_y2 = b1_y2.transpose(2, 1)
    intersection_area = (torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)).clamp(0) * (torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1)).clamp(0)
    # Union Area
    w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1
    w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1
    union_area = (w1 * h1 + 1e-16) + w2 * h2 - intersection_area
    ious = intersection_area / union_area
    return ious

calculate_bbox_iou_matrix(box1, box2, x1y1x2y2=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-09)

calculate iou matrix containing the iou of every couple iuo(i,j) where i is in box1 and j is in box2

Parameters:

Name Type Description Default
box1

a 2D tensor of boxes (shape N x 4)

required
box2

a 2D tensor of boxes (shape M x 4)

required
x1y1x2y2

boxes format is x1y1x2y2 (True) or xywh where xy is the center (False)

True

Returns:

Type Description

a 2D iou matrix (shape NxM)

Source code in src/super_gradients/training/utils/detection_utils.py
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
def calculate_bbox_iou_matrix(box1, box2, x1y1x2y2=True, GIoU: bool = False, DIoU=False, CIoU=False, eps=1e-9):
    """
    calculate iou matrix containing the iou of every couple iuo(i,j) where i is in box1 and j is in box2
    :param box1: a 2D tensor of boxes (shape N x 4)
    :param box2: a 2D tensor of boxes (shape M x 4)
    :param x1y1x2y2: boxes format is x1y1x2y2 (True) or xywh where xy is the center (False)
    :return: a 2D iou matrix (shape NxM)
    """
    if box1.dim() > 1:
        box1 = box1.T

    # Get the coordinates of bounding boxes
    if x1y1x2y2:  # x1, y1, x2, y2 = box1
        b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3]
        b2_x1, b2_y1, b2_x2, b2_y2 = box2[:, 0], box2[:, 1], box2[:, 2], box2[:, 3]
    else:  # x, y, w, h = box1
        b1_x1, b1_x2 = box1[0] - box1[2] / 2, box1[0] + box1[2] / 2
        b1_y1, b1_y2 = box1[1] - box1[3] / 2, box1[1] + box1[3] / 2
        b2_x1, b2_x2 = box2[:, 0] - box2[:, 2] / 2, box2[:, 0] + box2[:, 2] / 2
        b2_y1, b2_y2 = box2[:, 1] - box2[:, 3] / 2, box2[:, 1] + box2[:, 3] / 2

    b1_x1, b1_y1, b1_x2, b1_y2 = b1_x1.unsqueeze(1), b1_y1.unsqueeze(1), b1_x2.unsqueeze(1), b1_y2.unsqueeze(1)

    return _iou(CIoU, DIoU, GIoU, b1_x1, b1_x2, b1_y1, b1_y2, b2_x1, b2_x2, b2_y1, b2_y2, eps)

change_bbox_bounds_for_image_size(boxes, img_shape, inplace=True)

Clips bboxes to image boundaries. The function may operate both in- and on a copy of the input which is controlled by the inplace parameter. It exists for backward compatibility and will be removed in the SG 3.8.0 and this method will not modify the input. An inplace version of this method is available as change_bbox_bounds_for_image_size_inplace.

Parameters:

Name Type Description Default
bboxes

(np.ndarray) Input bounding boxes in XYXY format of [..., 4] shape

required
img_shape Tuple[int, int]

Tuple[int,int] of image shape (height, width).

required
inplace

(bool) If True, the function operates in-place. Otherwise, it returns a modified copy. If True this will trigger a deprecated warning to inform the user to use change_bbox_bounds_for_image_size_inplace instead.

True

Returns:

Type Description
np.ndarray

(np.ndarray)clipped bboxes in XYXY format of [..., 4] shape

Source code in src/super_gradients/training/utils/detection_utils.py
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
def change_bbox_bounds_for_image_size(boxes: np.ndarray, img_shape: Tuple[int, int], inplace=True) -> np.ndarray:
    """
    Clips bboxes to image boundaries.
    The function may operate both in- and on a copy of the input which is controlled by the inplace parameter.
    It exists for backward compatibility and will be removed in the SG 3.8.0 and this method will not modify the input.
    An inplace version of this method is available as change_bbox_bounds_for_image_size_inplace.

    :param bboxes:     (np.ndarray) Input bounding boxes in XYXY format of [..., 4] shape
    :param img_shape:  Tuple[int,int] of image shape (height, width).
    :param inplace:    (bool) If True, the function operates in-place. Otherwise, it returns a modified copy.
                       If True this will trigger a deprecated warning to inform the user to use
                       change_bbox_bounds_for_image_size_inplace instead.
    :return:           (np.ndarray)clipped bboxes in XYXY format of [..., 4] shape
    """
    if not inplace:
        boxes = boxes.copy()
    else:
        deprecate_param(
            deprecated_param_name="inplace",
            deprecated_since="3.7.0",
            removed_from="3.8.0",
            reason="For in-place operation, use change_bbox_bounds_for_image_size_inplace",
        )
    return change_bbox_bounds_for_image_size_inplace(boxes, img_shape)

change_bbox_bounds_for_image_size_inplace(boxes, img_shape)

Clips bboxes to image boundaries. The function operates in-place.

Parameters:

Name Type Description Default
bboxes

(np.ndarray) Input bounding boxes in XYXY format of [..., 4] shape

required
img_shape Tuple[int, int]

Tuple[int,int] of image shape (height, width).

required

Returns:

Type Description
np.ndarray

(np.ndarray)clipped bboxes in XYXY format of [..., 4] shape

Source code in src/super_gradients/training/utils/detection_utils.py
174
175
176
177
178
179
180
181
182
183
184
def change_bbox_bounds_for_image_size_inplace(boxes: np.ndarray, img_shape: Tuple[int, int]) -> np.ndarray:
    """
    Clips bboxes to image boundaries. The function operates in-place.

    :param bboxes:     (np.ndarray) Input bounding boxes in XYXY format of [..., 4] shape
    :param img_shape:  Tuple[int,int] of image shape (height, width).
    :return:           (np.ndarray)clipped bboxes in XYXY format of [..., 4] shape
    """
    boxes[..., [0, 2]] = boxes[..., [0, 2]].clip(min=0, max=img_shape[1])
    boxes[..., [1, 3]] = boxes[..., [1, 3]].clip(min=0, max=img_shape[0])
    return boxes

compute_box_area(box)

Compute the area of one or many boxes.

Parameters:

Name Type Description Default
box torch.Tensor

One or many boxes, shape = (4, ?), each box in format (x1, y1, x2, y2)

required

Returns:

Type Description
torch.Tensor

Area of every box, shape = (1, ?)

Source code in src/super_gradients/training/utils/detection_utils.py
787
788
789
790
791
792
793
794
def compute_box_area(box: torch.Tensor) -> torch.Tensor:
    """
    Compute the area of one or many boxes.
    :param box: One or many boxes, shape = (4, ?), each box in format (x1, y1, x2, y2)
    :return: Area of every box, shape = (1, ?)
    """
    # box = 4xn
    return (box[2] - box[0]) * (box[3] - box[1])

compute_detection_matching(output, targets, height, width, denormalize_targets, device, iou_thresholds=None, crowd_targets=None, top_k=100, return_on_cpu=True, matching_strategy=None)

Match predictions (NMS output) and the targets (ground truth) with respect to IoU and confidence score.

Parameters:

Name Type Description Default
output List[torch.Tensor]

list (of length batch_size) of Tensors of shape (num_predictions, 6) format: (x1, y1, x2, y2, confidence, class_label) where x1,y1,x2,y2 are according to image size

required
targets torch.Tensor

targets for all images of shape (total_num_targets, 6) format: (index, label, x, y, w, h, ) where x,y,w,h are in range [0,1]

required
height int

dimensions of the image

required
width int

dimensions of the image

required
iou_thresholds torch.Tensor

Threshold to compute the mAP

None
device str

Device

required
crowd_targets Optional[torch.Tensor]

crowd targets for all images of shape (total_num_crowd_targets, 6) format: (index, label, x, y, w, h) where x,y,w,h are in range [0,1]

None
top_k int

Number of predictions to keep per class, ordered by confidence score

100
denormalize_targets bool

If True, denormalize the targets and crowd_targets

required
return_on_cpu bool

If True, the output will be returned on "CPU", otherwise it will be returned on "device"

True
matching_strategy DetectionMatching

Method to match predictions to ground truth targets, IoU, distance based

None

Returns:

Type Description
List[Tuple]

list of the following tensors, for every image: :preds_matched: Tensor of shape (num_img_predictions, n_thresholds) True when prediction (i) is matched with a target with respect to the (j)th IoU threshold :preds_to_ignore: Tensor of shape (num_img_predictions, n_thresholds) True when prediction (i) is matched with a crowd target with respect to the (j)th IoU threshold :preds_scores: Tensor of shape (num_img_predictions), confidence score for every prediction :preds_cls: Tensor of shape (num_img_predictions), predicted class for every prediction :targets_cls: Tensor of shape (num_img_targets), ground truth class for every target

Source code in src/super_gradients/training/utils/detection_utils.py
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
def compute_detection_matching(
    output: List[torch.Tensor],
    targets: torch.Tensor,
    height: int,
    width: int,
    denormalize_targets: bool,
    device: str,
    iou_thresholds: torch.Tensor = None,
    crowd_targets: Optional[torch.Tensor] = None,
    top_k: int = 100,
    return_on_cpu: bool = True,
    matching_strategy: DetectionMatching = None,
) -> List[Tuple]:
    """
    Match predictions (NMS output) and the targets (ground truth) with respect to IoU and confidence score.
    :param output:          list (of length batch_size) of Tensors of shape (num_predictions, 6)
                            format:     (x1, y1, x2, y2, confidence, class_label) where x1,y1,x2,y2 are according to image size
    :param targets:         targets for all images of shape (total_num_targets, 6)
                            format:     (index, label, x, y, w, h, ) where x,y,w,h are in range [0,1]
    :param height:          dimensions of the image
    :param width:           dimensions of the image
    :param iou_thresholds:  Threshold to compute the mAP
    :param device:          Device
    :param crowd_targets:   crowd targets for all images of shape (total_num_crowd_targets, 6)
                            format:     (index, label, x, y, w, h) where x,y,w,h are in range [0,1]
    :param top_k:           Number of predictions to keep per class, ordered by confidence score
    :param denormalize_targets: If True, denormalize the targets and crowd_targets
    :param return_on_cpu:   If True, the output will be returned on "CPU", otherwise it will be returned on "device"
    :param matching_strategy: Method to match predictions to ground truth targets, IoU, distance based

    :return:                list of the following tensors, for every image:
        :preds_matched:     Tensor of shape (num_img_predictions, n_thresholds)
                            True when prediction (i) is matched with a target with respect to the (j)th IoU threshold
        :preds_to_ignore:   Tensor of shape (num_img_predictions, n_thresholds)
                            True when prediction (i) is matched with a crowd target with respect to the (j)th IoU threshold
        :preds_scores:      Tensor of shape (num_img_predictions), confidence score for every prediction
        :preds_cls:         Tensor of shape (num_img_predictions), predicted class for every prediction
        :targets_cls:       Tensor of shape (num_img_targets), ground truth class for every target
    """
    if matching_strategy is None:
        raise ValueError("matching_strategy must not be None")
    if isinstance(matching_strategy, IoUMatching) and iou_thresholds is None:
        raise ValueError("iou_thresholds is required for IoU matching strategy")

    output = map(lambda tensor: None if tensor is None else tensor.to(device), output)
    thresholds = matching_strategy.get_thresholds()
    targets, thresholds = targets.to(device), thresholds.to(device)

    # If crowd_targets is not provided, we patch it with an empty tensor
    crowd_targets = torch.zeros(size=(0, 6), device=device) if crowd_targets is None else crowd_targets.to(device)

    batch_metrics = []
    for img_i, img_preds in enumerate(output):
        # If img_preds is None (not prediction for this image), we patch it with an empty tensor
        img_preds = img_preds if img_preds is not None else torch.zeros(size=(0, 6), device=device)
        img_targets = targets[targets[:, 0] == img_i, 1:]
        img_crowd_targets = crowd_targets[crowd_targets[:, 0] == img_i, 1:]

        img_matching_tensors = compute_img_detection_matching(
            preds=img_preds,
            targets=img_targets,
            crowd_targets=img_crowd_targets,
            denormalize_targets=denormalize_targets,
            height=height,
            width=width,
            iou_thresholds=iou_thresholds,
            device=device,
            top_k=top_k,
            return_on_cpu=return_on_cpu,
            matching_strategy=matching_strategy,
        )
        batch_metrics.append(img_matching_tensors)

    return batch_metrics

compute_detection_metrics(preds_matched, preds_to_ignore, preds_scores, preds_cls, targets_cls, device, recall_thresholds=None, score_threshold=0.1, calc_best_score_thresholds=None)

Compute the list of precision, recall, MaP and f1 for every recall IoU threshold and for every class.

Parameters:

Name Type Description Default
preds_matched torch.Tensor

Tensor of shape (num_predictions, n_iou_thresholds) True when prediction (i) is matched with a target with respect to the (j)th IoU threshold

required
preds_scores torch.Tensor

Tensor of shape (num_predictions), confidence score for every prediction

required
preds_cls torch.Tensor

Tensor of shape (num_predictions), predicted class for every prediction

required
targets_cls torch.Tensor

Tensor of shape (num_targets), ground truth class for every target box to be detected

required
recall_thresholds Optional[torch.Tensor]

Recall thresholds used to compute MaP.

None
score_threshold Optional[float]

Minimum confidence score to consider a prediction for the computation of precision, recall and f1 (not MaP)

0.1
device str

Device

required
calc_best_score_thresholds bool

(Deprecated) If True, the best confidence score threshold is computed for each class This parameter is deprecated and ignore. Function always compute best threshold.

None

Returns:

Type Description
Tuple

:ap, precision, recall, f1: Tensors of shape (n_class, nb_iou_thrs) :unique_classes: Vector with all unique target classes :best_score_threshold: torch.float with the best overall score threshold if calc_best_score_thresholds is True else None :best_score_threshold_per_cls: Array that stores the best score threshold for each class , if calc_best_score_thresholds is True else None

Source code in src/super_gradients/training/utils/detection_utils.py
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
def compute_detection_metrics(
    preds_matched: torch.Tensor,
    preds_to_ignore: torch.Tensor,
    preds_scores: torch.Tensor,
    preds_cls: torch.Tensor,
    targets_cls: torch.Tensor,
    device: str,
    recall_thresholds: Optional[torch.Tensor] = None,
    score_threshold: Optional[float] = 0.1,
    calc_best_score_thresholds: bool = None,
) -> Tuple:
    """
    Compute the list of precision, recall, MaP and f1 for every recall IoU threshold and for every class.

    :param preds_matched:      Tensor of shape (num_predictions, n_iou_thresholds)
                                    True when prediction (i) is matched with a target with respect to the (j)th IoU threshold
    :param preds_to_ignore     Tensor of shape (num_predictions, n_iou_thresholds)
                                    True when prediction (i) is matched with a crowd target with respect to the (j)th IoU threshold
    :param preds_scores:       Tensor of shape (num_predictions), confidence score for every prediction
    :param preds_cls:          Tensor of shape (num_predictions), predicted class for every prediction
    :param targets_cls:        Tensor of shape (num_targets), ground truth class for every target box to be detected
    :param recall_thresholds:   Recall thresholds used to compute MaP.
    :param score_threshold:    Minimum confidence score to consider a prediction for the computation of
                                    precision, recall and f1 (not MaP)
    :param device:             Device
    :param calc_best_score_thresholds: (Deprecated) If True, the best confidence score threshold is computed for each class
                                       This parameter is deprecated and ignore. Function always compute best threshold.

    :return:
        :ap, precision, recall, f1: Tensors of shape (n_class, nb_iou_thrs)
        :unique_classes:            Vector with all unique target classes
        :best_score_threshold:      torch.float with the best overall score threshold if calc_best_score_thresholds
                                    is True else None
        :best_score_threshold_per_cls: Array that stores the best score threshold for each class , if
                                            calc_best_score_thresholds is True else None

    """
    if calc_best_score_thresholds is not None:
        warnings.warn(
            "calc_best_score_thresholds argument is deprecated and will be removed in SG 3.8.0.\n"
            "Best score threhsold is always computed by compute_detection_metrics since SG 3.6.0.\n"
            "Please update your code and remove explicitely passing calc_best_score_thresholds.\n"
        )

    preds_matched, preds_to_ignore = preds_matched.to(device), preds_to_ignore.to(device)
    preds_scores, preds_cls, targets_cls = preds_scores.to(device), preds_cls.to(device), targets_cls.to(device)

    recall_thresholds = torch.linspace(0, 1, 101, device=device) if recall_thresholds is None else recall_thresholds.to(device)

    unique_classes = torch.unique(targets_cls).long()

    n_class, nb_iou_thrs = len(unique_classes), preds_matched.shape[-1]

    ap = torch.zeros((n_class, nb_iou_thrs), device=device)
    precision = torch.zeros((n_class, nb_iou_thrs), device=device)
    recall = torch.zeros((n_class, nb_iou_thrs), device=device)

    nb_score_thrs = len(recall_thresholds)
    all_score_thresholds = torch.linspace(0, 1, nb_score_thrs, device=device)
    f1_per_class_per_threshold = torch.zeros((n_class, nb_score_thrs), device=device)
    best_score_threshold_per_cls = torch.zeros(n_class, device=device)

    for cls_i, class_value in enumerate(unique_classes):
        cls_preds_idx, cls_targets_idx = (preds_cls == class_value), (targets_cls == class_value)
        cls_ap, cls_precision, cls_recall, cls_f1_per_threshold, cls_best_score_threshold = compute_detection_metrics_per_cls(
            preds_matched=preds_matched[cls_preds_idx],
            preds_to_ignore=preds_to_ignore[cls_preds_idx],
            preds_scores=preds_scores[cls_preds_idx],
            n_targets=cls_targets_idx.sum(),
            recall_thresholds=recall_thresholds,
            score_threshold=score_threshold,
            device=device,
        )
        ap[cls_i, :] = cls_ap
        precision[cls_i, :] = cls_precision
        recall[cls_i, :] = cls_recall

        f1_per_class_per_threshold[cls_i, :] = cls_f1_per_threshold
        best_score_threshold_per_cls[cls_i] = cls_best_score_threshold

    f1 = 2 * precision * recall / (precision + recall + 1e-16)

    mean_f1_across_classes = torch.mean(f1_per_class_per_threshold, dim=0)
    best_score_threshold = all_score_thresholds[torch.argmax(mean_f1_across_classes)]

    return ap, precision, recall, f1, unique_classes, best_score_threshold, best_score_threshold_per_cls

compute_detection_metrics_per_cls(preds_matched, preds_to_ignore, preds_scores, n_targets, recall_thresholds, score_threshold, device, calc_best_score_thresholds=None)

Compute the list of precision, recall and MaP of a given class for every recall threshold.

Parameters:

Name Type Description Default
preds_matched torch.Tensor

Tensor of shape (num_predictions, n_thresholds) True when prediction (i) is matched with a target with respect to the(j)th threshold

required
preds_scores torch.Tensor

Tensor of shape (num_predictions), confidence score for every prediction

required
n_targets int

Number of target boxes of this class

required
recall_thresholds torch.Tensor

Tensor of shape (max_n_rec_thresh) list of recall thresholds used to compute MaP

required
score_threshold float

Minimum confidence score to consider a prediction for the computation of precision and recall (not MaP)

required
device str

Device

required
nb_score_thrs

Number of score thresholds to consider when calc_best_score_thresholds is True

required
calc_best_score_thresholds

(Deprecated) If True, the best confidence score threshold is computed for each class This parameter is deprecated and ignore. Function always compute best threshold.

None

Returns:

Type Description

:ap, precision, recall: Tensors of shape (nb_thrs) :mean_f1_per_threshold: Tensor of shape (nb_score_thresholds) if calc_best_score_thresholds is True else None :best_score_threshold: torch.float if calc_best_score_thresholds is True else None

Source code in src/super_gradients/training/utils/detection_utils.py
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
def compute_detection_metrics_per_cls(
    preds_matched: torch.Tensor,
    preds_to_ignore: torch.Tensor,
    preds_scores: torch.Tensor,
    n_targets: int,
    recall_thresholds: torch.Tensor,
    score_threshold: float,
    device: str,
    calc_best_score_thresholds=None,
):
    """
    Compute the list of precision, recall and MaP of a given class for every recall threshold.

    :param preds_matched:      Tensor of shape (num_predictions, n_thresholds)
                                    True when prediction (i) is matched with a target
                                    with respect to the(j)th threshold
    :param preds_to_ignore     Tensor of shape (num_predictions, n_thresholds)
                                    True when prediction (i) is matched with a crowd target
                                    with respect to the (j)th threshold
    :param preds_scores:       Tensor of shape (num_predictions), confidence score for every prediction
    :param n_targets:          Number of target boxes of this class
    :param recall_thresholds:  Tensor of shape (max_n_rec_thresh) list of recall thresholds used to compute MaP
    :param score_threshold:    Minimum confidence score to consider a prediction for the computation of
                                    precision and recall (not MaP)
    :param device:             Device
    :param nb_score_thrs:              Number of score thresholds to consider when calc_best_score_thresholds is True
    :param calc_best_score_thresholds: (Deprecated) If True, the best confidence score threshold is computed for each class
                                       This parameter is deprecated and ignore. Function always compute best threshold.
    :return:
        :ap, precision, recall:     Tensors of shape (nb_thrs)
        :mean_f1_per_threshold:     Tensor of shape (nb_score_thresholds) if calc_best_score_thresholds is True else None
        :best_score_threshold:      torch.float if calc_best_score_thresholds is True else None
    """
    if calc_best_score_thresholds is not None:
        warnings.warn(
            "calc_best_score_thresholds argument is deprecated and will be removed in SG 3.8.0.\n"
            "Best score threhsold is always computed by compute_detection_metrics since SG 3.6.0.\n"
            "Please update your code and remove explicitely passing calc_best_score_thresholds.\n"
        )

    nb_iou_thrs = preds_matched.shape[-1]
    nb_score_thrs = len(recall_thresholds)

    mean_f1_per_threshold = torch.zeros(nb_score_thrs, device=device)
    best_score_threshold = torch.tensor(0.0, dtype=torch.float, device=device)

    tps = preds_matched
    fps = torch.logical_and(torch.logical_not(preds_matched), torch.logical_not(preds_to_ignore))

    if len(tps) == 0:
        return (
            torch.zeros(nb_iou_thrs, device=device),
            torch.zeros(nb_iou_thrs, device=device),
            torch.zeros(nb_iou_thrs, device=device),
            mean_f1_per_threshold,
            best_score_threshold,
        )

    # Sort by decreasing score
    dtype = torch.uint8 if preds_scores.is_cuda and preds_scores.dtype is torch.bool else preds_scores.dtype
    sort_ind = torch.argsort(preds_scores.to(dtype), descending=True)
    tps = tps[sort_ind, :]
    fps = fps[sort_ind, :]
    preds_scores = preds_scores[sort_ind].contiguous()

    # Rolling sum over the predictions
    rolling_tps = torch.cumsum(tps, axis=0, dtype=torch.float)
    rolling_fps = torch.cumsum(fps, axis=0, dtype=torch.float)

    rolling_recalls = rolling_tps / n_targets
    rolling_precisions = rolling_tps / (rolling_tps + rolling_fps + torch.finfo(torch.float64).eps)

    # Reversed cummax to only have decreasing values
    rolling_precisions = rolling_precisions.flip(0).cummax(0).values.flip(0)

    # ==================
    # RECALL & PRECISION

    # We want the rolling precision/recall at index i so that: preds_scores[i-1] >= score_threshold > preds_scores[i]
    # Note: torch.searchsorted works on increasing sequence and preds_scores is decreasing, so we work with "-"
    # Note2: right=True due to negation
    lowest_score_above_threshold = torch.searchsorted(-preds_scores, -score_threshold, right=True)

    if lowest_score_above_threshold == 0:  # Here score_threshold > preds_scores[0], so no pred is above the threshold
        recall = torch.zeros(nb_iou_thrs, device=device)
        precision = torch.zeros(nb_iou_thrs, device=device)  # the precision is not really defined when no pred but we need to give it a value
    else:
        recall = rolling_recalls[lowest_score_above_threshold - 1]
        precision = rolling_precisions[lowest_score_above_threshold - 1]

    # ==================
    # BEST CONFIDENCE SCORE THRESHOLD PER CLASS
    all_score_thresholds = torch.linspace(0, 1, nb_score_thrs, device=device)

    # We want the rolling precision/recall at index i so that: preds_scores[i-1] > score_threshold >= preds_scores[i]
    # Note: torch.searchsorted works on increasing sequence and preds_scores is decreasing, so we work with "-"
    lowest_scores_above_thresholds = torch.searchsorted(-preds_scores, -all_score_thresholds, right=True)

    # When score_threshold > preds_scores[0], then no pred is above the threshold, so we pad with zeros
    rolling_recalls_padded = torch.cat((torch.zeros(1, nb_iou_thrs, device=device), rolling_recalls), dim=0)
    rolling_precisions_padded = torch.cat((torch.zeros(1, nb_iou_thrs, device=device), rolling_precisions), dim=0)

    # shape = (n_score_thresholds, nb_iou_thrs)
    recalls_per_threshold = torch.index_select(input=rolling_recalls_padded, dim=0, index=lowest_scores_above_thresholds)
    precisions_per_threshold = torch.index_select(input=rolling_precisions_padded, dim=0, index=lowest_scores_above_thresholds)

    # shape (n_score_thresholds, nb_iou_thrs)
    f1_per_threshold = 2 * recalls_per_threshold * precisions_per_threshold / (recalls_per_threshold + precisions_per_threshold + 1e-16)
    mean_f1_per_threshold = torch.mean(f1_per_threshold, dim=1)  # average over iou thresholds
    best_score_threshold = all_score_thresholds[torch.argmax(mean_f1_per_threshold)]

    # ==================
    # AVERAGE PRECISION

    # shape = (nb_iou_thrs, n_recall_thresholds)
    recall_thresholds = recall_thresholds.view(1, -1).repeat(nb_iou_thrs, 1)

    # We want the index i so that: rolling_recalls[i-1] < recall_thresholds[k] <= rolling_recalls[i]
    # Note:  when recall_thresholds[k] > max(rolling_recalls), i = len(rolling_recalls)
    # Note2: we work with transpose (.T) to apply torch.searchsorted on first dim instead of the last one
    recall_threshold_idx = torch.searchsorted(rolling_recalls.T.contiguous(), recall_thresholds, right=False).T

    # When recall_thresholds[k] > max(rolling_recalls), rolling_precisions[i] is not defined, and we want precision = 0
    rolling_precisions = torch.cat((rolling_precisions, torch.zeros(1, nb_iou_thrs, device=device)), dim=0)

    # shape = (n_recall_thresholds, nb_iou_thrs)
    sampled_precision_points = torch.gather(input=rolling_precisions, index=recall_threshold_idx, dim=0)

    # Average over the recall_thresholds
    ap = sampled_precision_points.mean(0)

    return ap, precision, recall, mean_f1_per_threshold, best_score_threshold

compute_img_detection_matching(preds, targets, crowd_targets, height, width, device, denormalize_targets, iou_thresholds=None, top_k=100, return_on_cpu=True, matching_strategy=None)

Match predictions (NMS output) and the targets (ground truth) with respect to metric and confidence score for a given image.

Parameters:

Name Type Description Default
preds torch.Tensor

Tensor of shape (num_img_predictions, 6) format: (x1, y1, x2, y2, confidence, class_label) where x1,y1,x2,y2 are according to image size

required
targets torch.Tensor

targets for this image of shape (num_img_targets, 6) format: (label, cx, cy, w, h) where cx,cy,w,h

required
height int

dimensions of the image

required
width int

dimensions of the image

required
device str required
crowd_targets torch.Tensor

crowd targets for all images of shape (total_num_crowd_targets, 6) format: (index, x, y, w, h) where x,y,w,h are in range [0,1]

required
iou_thresholds torch.Tensor

Threshold to compute the mAP

None
top_k int

Number of predictions to keep per class, ordered by confidence score

100
denormalize_targets bool

If True, denormalize the targets and crowd_targets

required
return_on_cpu bool

If True, the output will be returned on "CPU", otherwise it will be returned on "device"

True
matching_strategy DetectionMatching

Method to match predictions to ground truth targets: IoU, distance based

None

Returns:

Type Description
Tuple

:preds_matched: Tensor of shape (num_img_predictions, n_thresholds) True when prediction (i) is matched with a target with respect to the (j)th threshold :preds_to_ignore: Tensor of shape (num_img_predictions, n_thresholds) True when prediction (i) is matched with a crowd target with respect to the (j)th threshold :preds_scores: Tensor of shape (num_img_predictions), confidence score for every prediction :preds_cls: Tensor of shape (num_img_predictions), predicted class for every prediction :targets_cls: Tensor of shape (num_img_targets), ground truth class for every target

Source code in src/super_gradients/training/utils/detection_utils.py
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
def compute_img_detection_matching(
    preds: torch.Tensor,
    targets: torch.Tensor,
    crowd_targets: torch.Tensor,
    height: int,
    width: int,
    device: str,
    denormalize_targets: bool,
    iou_thresholds: torch.Tensor = None,
    top_k: int = 100,
    return_on_cpu: bool = True,
    matching_strategy: DetectionMatching = None,
) -> Tuple:
    """
    Match predictions (NMS output) and the targets (ground truth) with respect to metric and confidence score
    for a given image.
    :param preds:           Tensor of shape (num_img_predictions, 6)
                            format:     (x1, y1, x2, y2, confidence, class_label) where x1,y1,x2,y2 are according to image size
    :param targets:         targets for this image of shape (num_img_targets, 6)
                            format:     (label, cx, cy, w, h) where cx,cy,w,h
    :param height:          dimensions of the image
    :param width:           dimensions of the image
    :param device:
    :param crowd_targets:   crowd targets for all images of shape (total_num_crowd_targets, 6)
                            format:     (index, x, y, w, h) where x,y,w,h are in range [0,1]
    :param iou_thresholds:  Threshold to compute the mAP
    :param top_k:           Number of predictions to keep per class, ordered by confidence score
    :param device:          Device
    :param denormalize_targets: If True, denormalize the targets and crowd_targets
    :param return_on_cpu:   If True, the output will be returned on "CPU", otherwise it will be returned on "device"
    :param matching_strategy: Method to match predictions to ground truth targets: IoU, distance based

    :return:
        :preds_matched:     Tensor of shape (num_img_predictions, n_thresholds)
                                True when prediction (i) is matched with a target with respect to the (j)th threshold
        :preds_to_ignore:   Tensor of shape (num_img_predictions, n_thresholds)
                                True when prediction (i) is matched with a crowd target with respect to the (j)th threshold
        :preds_scores:      Tensor of shape (num_img_predictions), confidence score for every prediction
        :preds_cls:         Tensor of shape (num_img_predictions), predicted class for every prediction
        :targets_cls:       Tensor of shape (num_img_targets), ground truth class for every target
    """
    num_thresholds = len(matching_strategy.get_thresholds())

    if preds is None or len(preds) == 0:
        if return_on_cpu:
            device = "cpu"
        preds_matched = torch.zeros((0, num_thresholds), dtype=torch.bool, device=device)
        preds_to_ignore = torch.zeros((0, num_thresholds), dtype=torch.bool, device=device)
        preds_scores = torch.tensor([], dtype=torch.float32, device=device)
        preds_cls = torch.tensor([], dtype=torch.float32, device=device)
        targets_cls = targets[:, 0].to(device=device)
        return preds_matched, preds_to_ignore, preds_scores, preds_cls, targets_cls

    preds_matched = torch.zeros(len(preds), num_thresholds, dtype=torch.bool, device=preds.device)
    targets_matched = torch.zeros(len(targets), num_thresholds, dtype=torch.bool, device=preds.device)
    preds_to_ignore = torch.zeros(len(preds), num_thresholds, dtype=torch.bool, device=preds.device)

    preds_cls, preds_box, preds_scores = preds[:, -1], preds[:, 0:4], preds[:, 4]
    targets_cls, targets_box = targets[:, 0], targets[:, 1:5]
    crowd_targets_cls, crowd_target_box = crowd_targets[:, 0], crowd_targets[:, 1:5]

    # Ignore all but the predictions that were top_k for their class
    preds_idx_to_use = get_top_k_idx_per_cls(preds_scores, preds_cls, top_k)
    preds_to_ignore[:, :] = True
    preds_to_ignore[preds_idx_to_use] = False

    if len(targets) > 0 or len(crowd_targets) > 0:
        # CHANGE bboxes TO FIT THE IMAGE SIZE
        change_bbox_bounds_for_image_size_inplace(preds, (height, width))

        targets_box = cxcywh2xyxy(targets_box)
        crowd_target_box = cxcywh2xyxy(crowd_target_box)

        if denormalize_targets:
            targets_box[:, [0, 2]] *= width
            targets_box[:, [1, 3]] *= height
            crowd_target_box[:, [0, 2]] *= width
            crowd_target_box[:, [1, 3]] *= height

        if len(targets) > 0:
            preds_matched = matching_strategy.compute_targets(preds_box, preds_cls, targets_box, targets_cls, preds_matched, targets_matched, preds_idx_to_use)

        if len(crowd_targets) > 0:
            preds_matched, preds_to_ignore = matching_strategy.compute_crowd_targets(
                preds_box, preds_cls, crowd_targets_cls, crowd_target_box, preds_matched, preds_to_ignore, preds_idx_to_use
            )

    if return_on_cpu:
        preds_matched = preds_matched.to("cpu")
        preds_to_ignore = preds_to_ignore.to("cpu")
        preds_scores = preds_scores.to("cpu")
        preds_cls = preds_cls.to("cpu")
        targets_cls = targets_cls.to("cpu")

    return preds_matched, preds_to_ignore, preds_scores, preds_cls, targets_cls

convert_cxcywh_bbox_to_xyxy(input_bbox)

Converts bounding box format from [cx, cy, w, h] to [x1, y1, x2, y2] :param input_bbox: input bbox either 2-dimensional (for all boxes of a single image) or 3-dimensional (for boxes of a batch of images) :return: Converted bbox in same dimensions as the original

Source code in src/super_gradients/training/utils/detection_utils.py
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
def convert_cxcywh_bbox_to_xyxy(input_bbox: torch.Tensor):
    """
    Converts bounding box format from [cx, cy, w, h] to [x1, y1, x2, y2]
        :param input_bbox:  input bbox either 2-dimensional (for all boxes of a single image) or 3-dimensional (for
                            boxes of a batch of images)
        :return:            Converted bbox in same dimensions as the original
    """
    need_squeeze = False
    # the input is always processed as a batch. in case it not a batch, it is unsqueezed, process and than squeeze back.
    if input_bbox.dim() < 3:
        need_squeeze = True
        input_bbox = input_bbox.unsqueeze(0)

    converted_bbox = torch.zeros_like(input_bbox) if isinstance(input_bbox, torch.Tensor) else np.zeros_like(input_bbox)
    converted_bbox[:, :, 0] = input_bbox[:, :, 0] - input_bbox[:, :, 2] / 2
    converted_bbox[:, :, 1] = input_bbox[:, :, 1] - input_bbox[:, :, 3] / 2
    converted_bbox[:, :, 2] = input_bbox[:, :, 0] + input_bbox[:, :, 2] / 2
    converted_bbox[:, :, 3] = input_bbox[:, :, 1] + input_bbox[:, :, 3] / 2

    # squeeze back if needed
    if need_squeeze:
        converted_bbox = converted_bbox[0]

    return converted_bbox

crowd_ioa(det_box, crowd_box)

Return intersection-over-detection_area of boxes, used for crowd ground truths. Both sets of boxes are expected to be in (x1, y1, x2, y2) format.

Parameters:

Name Type Description Default
det_box torch.Tensor

Tensor of shape [N, 4]

required
crowd_box torch.Tensor

Tensor of shape [M, 4]

required

Returns:

Type Description
torch.Tensor

crowd_ioa, Tensor of shape [N, M]: the NxM matrix containing the pairwise IoA values for every element in det_box and crowd_box

Source code in src/super_gradients/training/utils/detection_utils.py
797
798
799
800
801
802
803
804
805
806
807
808
809
810
def crowd_ioa(det_box: torch.Tensor, crowd_box: torch.Tensor) -> torch.Tensor:
    """
    Return intersection-over-detection_area of boxes, used for crowd ground truths.
    Both sets of boxes are expected to be in (x1, y1, x2, y2) format.

    :param det_box:     Tensor of shape [N, 4]
    :param crowd_box:   Tensor of shape [M, 4]
    :return: crowd_ioa, Tensor of shape [N, M]: the NxM matrix containing the pairwise IoA values for every element in det_box and crowd_box
    """
    det_area = compute_box_area(det_box.T)

    # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
    inter = (torch.min(det_box[:, None, 2:], crowd_box[:, 2:]) - torch.max(det_box[:, None, :2], crowd_box[:, :2])).clamp(0).prod(2)
    return inter / det_area[:, None]  # crowd_ioa = inter / det_area

cxcywh2xyxy(bboxes)

Transforms bboxes from centerized xy wh format to xyxy format

Parameters:

Name Type Description Default
bboxes

array, shaped (nboxes, 4)

required

Returns:

Type Description

modified bboxes

Source code in src/super_gradients/training/utils/detection_utils.py
725
726
727
728
729
730
731
732
733
734
735
def cxcywh2xyxy(bboxes):
    """
    Transforms bboxes from centerized xy wh format to xyxy format
    :param bboxes: array, shaped (nboxes, 4)
    :return: modified bboxes
    """
    bboxes[:, 1] = bboxes[:, 1] - bboxes[:, 3] * 0.5
    bboxes[:, 0] = bboxes[:, 0] - bboxes[:, 2] * 0.5
    bboxes[:, 3] = bboxes[:, 3] + bboxes[:, 1]
    bboxes[:, 2] = bboxes[:, 2] + bboxes[:, 0]
    return bboxes

get_class_index_in_target(target_format)

Get the label of a given target

Parameters:

Name Type Description Default
target_format DetectionTargetsFormat

Representation of the target (ex: LABEL_XYXY)

required

Returns:

Type Description
int

Position of the class id in a bbox ex: 0 if bbox of format label_xyxy | -1 if bbox of format xyxy_label

Source code in src/super_gradients/training/utils/detection_utils.py
42
43
44
45
46
47
48
49
50
51
52
53
54
def get_class_index_in_target(target_format: DetectionTargetsFormat) -> int:
    """Get the label of a given target
    :param target_format:   Representation of the target (ex: LABEL_XYXY)
    :return:                Position of the class id in a bbox
                                ex: 0 if bbox of format label_xyxy | -1 if bbox of format xyxy_label
    """
    format_split = target_format.value.split("_")
    if format_split[0] == "LABEL":
        return 0
    elif format_split[-1] == "LABEL":
        return -1
    else:
        raise NotImplementedError(f"No implementation to find index of LABEL in {target_format.value}")

get_mosaic_coordinate(mosaic_index, xc, yc, w, h, input_h, input_w)

Returns the mosaic coordinates of final mosaic image according to mosaic image index.

Parameters:

Name Type Description Default
mosaic_index

(int) mosaic image index

required
xc

(int) center x coordinate of the entire mosaic grid.

required
yc

(int) center y coordinate of the entire mosaic grid.

required
w

(int) width of bbox

required
h

(int) height of bbox

required
input_h

(int) image input height (should be 1/2 of the final mosaic output image height).

required
input_w

(int) image input width (should be 1/2 of the final mosaic output image width).

required

Returns:

Type Description

(x1, y1, x2, y2), (x1s, y1s, x2s, y2s) where (x1, y1, x2, y2) are the coordinates in the final mosaic output image, and (x1s, y1s, x2s, y2s) are the coordinates in the placed image.

Source code in src/super_gradients/training/utils/detection_utils.py
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
def get_mosaic_coordinate(mosaic_index, xc, yc, w, h, input_h, input_w):
    """
    Returns the mosaic coordinates of final mosaic image according to mosaic image index.

    :param mosaic_index: (int) mosaic image index
    :param xc: (int) center x coordinate of the entire mosaic grid.
    :param yc: (int) center y coordinate of the entire mosaic grid.
    :param w: (int) width of bbox
    :param h: (int) height of bbox
    :param input_h: (int) image input height (should be 1/2 of the final mosaic output image height).
    :param input_w: (int) image input width (should be 1/2 of the final mosaic output image width).
    :return: (x1, y1, x2, y2), (x1s, y1s, x2s, y2s) where (x1, y1, x2, y2) are the coordinates in the final mosaic
        output image, and (x1s, y1s, x2s, y2s) are the coordinates in the placed image.
    """
    # index0 to top left part of image
    if mosaic_index == 0:
        x1, y1, x2, y2 = max(xc - w, 0), max(yc - h, 0), xc, yc
        small_coord = w - (x2 - x1), h - (y2 - y1), w, h
    # index1 to top right part of image
    elif mosaic_index == 1:
        x1, y1, x2, y2 = xc, max(yc - h, 0), min(xc + w, input_w * 2), yc
        small_coord = 0, h - (y2 - y1), min(w, x2 - x1), h
    # index2 to bottom left part of image
    elif mosaic_index == 2:
        x1, y1, x2, y2 = max(xc - w, 0), yc, xc, min(input_h * 2, yc + h)
        small_coord = w - (x2 - x1), 0, w, min(y2 - y1, h)
    # index2 to bottom right part of image
    elif mosaic_index == 3:
        x1, y1, x2, y2 = xc, yc, min(xc + w, input_w * 2), min(input_h * 2, yc + h)  # noqa
        small_coord = 0, 0, min(w, x2 - x1), min(y2 - y1, h)
    return (x1, y1, x2, y2), small_coord

get_top_k_idx_per_cls(preds_scores, preds_cls, top_k)

Get the indexes of all the top k predictions for every class

Parameters:

Name Type Description Default
preds_scores torch.Tensor

The confidence scores, vector of shape (n_pred)

required
preds_cls torch.Tensor

The predicted class, vector of shape (n_pred)

required
top_k int

Number of predictions to keep per class, ordered by confidence score

required

Returns:

Type Description

Indexes of the top k predictions. length <= (k * n_unique_class)

Source code in src/super_gradients/training/utils/detection_utils.py
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
def get_top_k_idx_per_cls(preds_scores: torch.Tensor, preds_cls: torch.Tensor, top_k: int):
    """Get the indexes of all the top k predictions for every class

    :param preds_scores:   The confidence scores, vector of shape (n_pred)
    :param preds_cls:      The predicted class, vector of shape (n_pred)
    :param top_k:          Number of predictions to keep per class, ordered by confidence score

    :return top_k_idx:     Indexes of the top k predictions. length <= (k * n_unique_class)
    """
    n_unique_cls = torch.max(preds_cls)
    mask = preds_cls.view(-1, 1) == torch.arange(n_unique_cls + 1, device=preds_scores.device).view(1, -1)
    preds_scores_per_cls = preds_scores.view(-1, 1) * mask

    sorted_scores_per_cls, sorting_idx = preds_scores_per_cls.sort(0, descending=True)
    idx_with_satisfying_scores = sorted_scores_per_cls[:top_k, :].nonzero(as_tuple=False)
    top_k_idx = sorting_idx[idx_with_satisfying_scores.split(1, dim=1)]
    return top_k_idx.view(-1)

matrix_non_max_suppression(pred, conf_thres=0.1, kernel='gaussian', sigma=3.0, max_num_of_detections=500, class_agnostic_nms=False)

Performs Matrix Non-Maximum Suppression (NMS) on inference results https://arxiv.org/pdf/1912.04488.pdf

Parameters:

Name Type Description Default
pred

Raw model prediction (in test mode) - a Tensor of shape [batch, num_predictions, 85] where each item format is (x, y, w, h, object_conf, class_conf, ... 80 classes score ...)

required
conf_thres float

Threshold under which prediction are discarded

0.1
kernel str

Type of kernel to use ['gaussian', 'linear']

'gaussian'
sigma float

Sigma for the gaussian kernel

3.0
max_num_of_detections int

Maximum number of boxes to output

500

Returns:

Type Description
List[torch.Tensor]

Detections list with shape (x1, y1, x2, y2, object_conf, class_conf, class)

Source code in src/super_gradients/training/utils/detection_utils.py
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
def matrix_non_max_suppression(
    pred, conf_thres: float = 0.1, kernel: str = "gaussian", sigma: float = 3.0, max_num_of_detections: int = 500, class_agnostic_nms: bool = False
) -> List[torch.Tensor]:
    """Performs Matrix Non-Maximum Suppression (NMS) on inference results https://arxiv.org/pdf/1912.04488.pdf

    :param pred:        Raw model prediction (in test mode) - a Tensor of shape [batch, num_predictions, 85]
                        where each item format is (x, y, w, h, object_conf, class_conf, ... 80 classes score ...)
    :param conf_thres:  Threshold under which prediction are discarded
    :param kernel:      Type of kernel to use ['gaussian', 'linear']
    :param sigma:       Sigma for the gaussian kernel
    :param max_num_of_detections: Maximum number of boxes to output

    :return: Detections list with shape (x1, y1, x2, y2, object_conf, class_conf, class)
    """
    # MULTIPLY CONF BY CLASS CONF TO GET COMBINED CONFIDENCE
    class_conf, class_pred = pred[:, :, 5:].max(2)
    pred[:, :, 4] *= class_conf

    # BOX (CENTER X, CENTER Y, WIDTH, HEIGHT) TO (X1, Y1, X2, Y2)
    pred[:, :, :4] = convert_cxcywh_bbox_to_xyxy(pred[:, :, :4])

    # DETECTIONS ORDERED AS (x1y1x2y2, obj_conf, class_conf, class_pred)
    pred = torch.cat((pred[:, :, :5], class_pred.unsqueeze(2)), 2)

    # SORT DETECTIONS BY DECREASING CONFIDENCE SCORES
    sort_ind = (-pred[:, :, 4]).argsort()
    pred = torch.stack([pred[i, sort_ind[i]] for i in range(pred.shape[0])])[:, 0:max_num_of_detections]

    ious = calc_bbox_iou_matrix(pred)

    ious = ious.triu(1)

    if not class_agnostic_nms:
        # CREATE A LABELS MASK, WE WANT ONLY BOXES WITH THE SAME LABEL TO AFFECT EACH OTHER
        labels = pred[:, :, 5:]
        labeles_matrix = (labels == labels.transpose(2, 1)).float().triu(1)
        ious *= labeles_matrix

    ious_cmax, _ = ious.max(1)
    ious_cmax = ious_cmax.unsqueeze(2).repeat(1, 1, max_num_of_detections)

    if kernel == "gaussian":
        decay_matrix = torch.exp(-1 * sigma * (ious**2))
        compensate_matrix = torch.exp(-1 * sigma * (ious_cmax**2))
        decay, _ = (decay_matrix / compensate_matrix).min(dim=1)
    else:
        decay = (1 - ious) / (1 - ious_cmax)
        decay, _ = decay.min(dim=1)

    pred[:, :, 4] *= decay

    output = [pred[i, pred[i, :, 4] > conf_thres] for i in range(pred.shape[0])]

    return output

non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, multi_label_per_box=True, with_confidence=False, class_agnostic_nms=False)

Performs Non-Maximum Suppression (NMS) on inference results

Parameters:

Name Type Description Default
prediction

raw model prediction. Should be a list of Tensors of shape (cx, cy, w, h, confidence, cls0, cls1, ...)

required
conf_thres

below the confidence threshold - prediction are discarded

0.1
iou_thres

IoU threshold for the nms algorithm

0.6
multi_label_per_box bool

controls whether to decode multiple labels per box. True - each anchor can produce multiple labels of different classes that pass confidence threshold check (default). False - each anchor can produce only one label of the class with the highest score.

True
with_confidence bool

whether to multiply objectness score with class score. usually valid for Yolo models only.

False
class_agnostic_nms bool

indicates how boxes of different classes will be treated during NMS True - NMS will be performed on all classes together. False - NMS will be performed on each class separately (default).

False

Returns:

Type Description

detections with shape nx6 (x1, y1, x2, y2, object_conf, class_conf, class)

Source code in src/super_gradients/training/utils/detection_utils.py
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
def non_max_suppression(
    prediction, conf_thres=0.1, iou_thres=0.6, multi_label_per_box: bool = True, with_confidence: bool = False, class_agnostic_nms: bool = False
):
    """
    Performs Non-Maximum Suppression (NMS) on inference results

    :param prediction: raw model prediction. Should be a list of Tensors of shape (cx, cy, w, h, confidence, cls0, cls1, ...)
    :param conf_thres: below the confidence threshold - prediction are discarded
    :param iou_thres: IoU threshold for the nms algorithm
    :param multi_label_per_box: controls whether to decode multiple labels per box.
                                True - each anchor can produce multiple labels of different classes
                                       that pass confidence threshold check (default).
                                False - each anchor can produce only one label of the class with the highest score.
    :param with_confidence: whether to multiply objectness score with class score.
                            usually valid for Yolo models only.
    :param class_agnostic_nms: indicates how boxes of different classes will be treated during NMS
                               True - NMS will be performed on all classes together.
                               False - NMS will be performed on each class separately (default).
    :return: detections with shape nx6 (x1, y1, x2, y2, object_conf, class_conf, class)

    """
    candidates_above_thres = prediction[..., 4] > conf_thres  # filter by confidence
    output = [None] * prediction.shape[0]

    for image_idx, pred in enumerate(prediction):
        pred = pred[candidates_above_thres[image_idx]]  # confident

        if not pred.shape[0]:  # If none remain process next image
            continue

        if with_confidence:
            pred[:, 5:] *= pred[:, 4:5]  # multiply objectness score with class score

        box = convert_cxcywh_bbox_to_xyxy(pred[:, :4])  # cxcywh to xyxy

        # Detections matrix nx6 (xyxy, conf, cls)
        if multi_label_per_box:  # try for all good confidence classes
            i, j = (pred[:, 5:] > conf_thres).nonzero(as_tuple=False).T
            pred = torch.cat((box[i], pred[i, j + 5, None], j[:, None].float()), 1)

        else:  # best class only
            conf, j = pred[:, 5:].max(1, keepdim=True)
            pred = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]

        if not pred.shape[0]:  # If none remain process next image
            continue

        # Apply torch batched NMS algorithm
        boxes, scores, cls_idx = pred[:, :4], pred[:, 4], pred[:, 5]
        if class_agnostic_nms:
            idx_to_keep = torchvision.ops.boxes.nms(boxes, scores, iou_thres)
        else:
            idx_to_keep = torchvision.ops.boxes.batched_nms(boxes, scores, cls_idx, iou_thres)
        output[image_idx] = pred[idx_to_keep]

    return output

undo_image_preprocessing(im_tensor)

Parameters:

Name Type Description Default
im_tensor torch.Tensor

images in a batch after preprocessing for inference, RGB, (B, C, H, W)

required

Returns:

Type Description
np.ndarray

images in a batch in cv2 format, BGR, (B, H, W, C)

Source code in src/super_gradients/training/utils/detection_utils.py
402
403
404
405
406
407
408
409
410
def undo_image_preprocessing(im_tensor: torch.Tensor) -> np.ndarray:
    """
    :param im_tensor: images in a batch after preprocessing for inference, RGB, (B, C, H, W)
    :return:          images in a batch in cv2 format, BGR, (B, H, W, C)
    """
    im_np = im_tensor.cpu().numpy()
    im_np = im_np[:, ::-1, :, :].transpose(0, 2, 3, 1)
    im_np *= 255.0
    return np.ascontiguousarray(im_np, dtype=np.uint8)

xyxy2cxcywh(bboxes)

Transforms bboxes from xyxy format to centerized xy wh format

Parameters:

Name Type Description Default
bboxes

array, shaped (nboxes, 4)

required

Returns:

Type Description

modified bboxes

Source code in src/super_gradients/training/utils/detection_utils.py
712
713
714
715
716
717
718
719
720
721
722
def xyxy2cxcywh(bboxes):
    """
    Transforms bboxes from xyxy format to centerized xy wh format
    :param bboxes: array, shaped (nboxes, 4)
    :return: modified bboxes
    """
    bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 0]
    bboxes[:, 3] = bboxes[:, 3] - bboxes[:, 1]
    bboxes[:, 0] = bboxes[:, 0] + bboxes[:, 2] * 0.5
    bboxes[:, 1] = bboxes[:, 1] + bboxes[:, 3] * 0.5
    return bboxes

DDPNotSetupException

Bases: Exception

Exception raised when DDP setup is required but was not done

Source code in src/super_gradients/training/utils/distributed_training_utils.py
370
371
372
373
374
375
376
377
378
379
380
381
class DDPNotSetupException(Exception):
    """Exception raised when DDP setup is required but was not done"""

    def __init__(self):
        self.message = (
            "Your environment was not setup correctly for DDP.\n"
            "Please run at the beginning of your script:\n"
            ">>> from super_gradients.training.utils.distributed_training_utils import setup_device'\n"
            ">>> from super_gradients.common.data_types.enum import MultiGPUMode\n"
            ">>> setup_device(multi_gpu=MultiGPUMode.DISTRIBUTED_DATA_PARALLEL, num_gpus=...)"
        )
        super().__init__(self.message)

compute_precise_bn_stats(model, loader, precise_bn_batch_size, num_gpus)

Parameters:

Name Type Description Default
model nn.Module

The model being trained (ie: Trainer.net)

required
loader torch.utils.data.DataLoader

Training dataloader (ie: Trainer.train_loader)

required
precise_bn_batch_size int

The effective batch size we want to calculate the batchnorm on. For example, if we are training a model on 8 gpus, with a batch of 128 on each gpu, a good rule of thumb would be to give it 8192 (ie: effective_batch_size * num_gpus = batch_per_gpu * num_gpus * num_gpus). If precise_bn_batch_size is not provided in the training_params, the latter heuristic will be taken. param num_gpus: The number of gpus we are training on

required
Source code in src/super_gradients/training/utils/distributed_training_utils.py
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
@torch.no_grad()
def compute_precise_bn_stats(model: nn.Module, loader: torch.utils.data.DataLoader, precise_bn_batch_size: int, num_gpus: int):
    """
    :param model:                   The model being trained (ie: Trainer.net)
    :param loader:                  Training dataloader (ie: Trainer.train_loader)
    :param precise_bn_batch_size:   The effective batch size we want to calculate the batchnorm on. For example, if we are training a model
                                    on 8 gpus, with a batch of 128 on each gpu, a good rule of thumb would be to give it 8192
                                    (ie: effective_batch_size * num_gpus = batch_per_gpu * num_gpus * num_gpus).
                                    If precise_bn_batch_size is not provided in the training_params, the latter heuristic
                                    will be taken.
    param num_gpus:                 The number of gpus we are training on
    """

    # Compute the number of minibatches to use
    num_iter = int(precise_bn_batch_size / (loader.batch_size * num_gpus)) if precise_bn_batch_size else num_gpus
    num_iter = min(num_iter, len(loader))

    # Retrieve the BN layers
    bns = [m for m in model.modules() if isinstance(m, torch.nn.BatchNorm2d)]

    # Initialize BN stats storage for computing mean(mean(batch)) and mean(var(batch))
    running_means = [torch.zeros_like(bn.running_mean) for bn in bns]
    running_vars = [torch.zeros_like(bn.running_var) for bn in bns]

    # Remember momentum values
    momentums = [bn.momentum for bn in bns]

    # Set momentum to 1.0 to compute BN stats that only reflect the current batch
    for bn in bns:
        bn.momentum = 1.0

    # Average the BN stats for each BN layer over the batches
    for inputs, _labels in itertools.islice(loader, num_iter):
        model(inputs.cuda())
        for i, bn in enumerate(bns):
            running_means[i] += bn.running_mean / num_iter
            running_vars[i] += bn.running_var / num_iter

    # Sync BN stats across GPUs (no reduction if 1 GPU used)
    running_means = scaled_all_reduce(running_means, num_gpus=num_gpus)
    running_vars = scaled_all_reduce(running_vars, num_gpus=num_gpus)

    # Set BN stats and restore original momentum values
    for i, bn in enumerate(bns):
        bn.running_mean = running_means[i]
        bn.running_var = running_vars[i]
        bn.momentum = momentums[i]

distributed_all_reduce_tensor_average(tensor, n)

This method performs a reduce operation on multiple nodes running distributed training It first sums all of the results and then divides the summation

Parameters:

Name Type Description Default
tensor

The tensor to perform the reduce operation for

required
n

Number of nodes

required

Returns:

Type Description

Averaged tensor from all of the nodes

Source code in src/super_gradients/training/utils/distributed_training_utils.py
33
34
35
36
37
38
39
40
41
42
43
44
def distributed_all_reduce_tensor_average(tensor, n):
    """
    This method performs a reduce operation on multiple nodes running distributed training
    It first sums all of the results and then divides the summation
    :param tensor:  The tensor to perform the reduce operation for
    :param n:  Number of nodes
    :return:   Averaged tensor from all of the nodes
    """
    rt = tensor.clone()
    torch.distributed.all_reduce(rt, op=torch.distributed.ReduceOp.SUM)
    rt /= n
    return rt

get_gpu_mem_utilization()

GPU memory managed by the caching allocator in bytes for a given device.

Source code in src/super_gradients/training/utils/distributed_training_utils.py
360
361
362
363
364
365
366
367
def get_gpu_mem_utilization():
    """GPU memory managed by the caching allocator in bytes for a given device."""

    # Workaround to work on any torch version
    if hasattr(torch.cuda, "memory_reserved"):
        return torch.cuda.memory_reserved()
    else:
        return torch.cuda.memory_cached()

initialize_ddp()

Initialize Distributed Data Parallel

Important note: (1) in distributed training it is customary to specify learning rates and batch sizes per GPU. Whatever learning rate and schedule you specify will be applied to the each GPU individually. Since gradients are passed and summed (reduced) from all to all GPUs, the effective batch size is the batch you specify times the number of GPUs. In the literature there are several "best practices" to set learning rates and schedules for large batch sizes.

Source code in src/super_gradients/training/utils/distributed_training_utils.py
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
def initialize_ddp():
    """
    Initialize Distributed Data Parallel

    Important note: (1) in distributed training it is customary to specify learning rates and batch sizes per GPU.
    Whatever learning rate and schedule you specify will be applied to the each GPU individually.
    Since gradients are passed and summed (reduced) from all to all GPUs, the effective batch size is the
    batch you specify times the number of GPUs. In the literature there are several "best practices" to set
    learning rates and schedules for large batch sizes.
    """

    if device_config.assigned_rank > 0:
        mute_current_process()

    logger.info("Distributed training starting...")
    if not torch.distributed.is_initialized():
        backend = "gloo" if os.name == "nt" else "nccl"
        torch.distributed.init_process_group(backend=backend, init_method="env://")
    torch.cuda.set_device(device_config.assigned_rank)

    if torch.distributed.get_rank() == 0:
        logger.info(f"Training in distributed mode... with {str(torch.distributed.get_world_size())} GPUs")
    device_config.device = "cuda:%d" % device_config.assigned_rank

maybe_all_gather_as_list(inputs)

When in DDP- gathers inputs from all processes. When not in DDP - returns the single-element list of [input].

Parameters:

Name Type Description Default
image

np.ndarray, the local rank's tensor to gather

required

Returns:

Type Description
List

np.ndarray, the output image as described above

Source code in src/super_gradients/training/utils/distributed_training_utils.py
417
418
419
420
421
422
423
424
425
426
427
428
429
430
def maybe_all_gather_as_list(inputs) -> List:
    """
    When in DDP- gathers inputs from all processes.
    When not in DDP - returns the single-element list of [input].

    :param image: np.ndarray, the local rank's tensor to gather

    :return: np.ndarray, the output image as described above
    """
    if is_distributed():
        output_container = [None for _ in range(get_world_size())]
        all_gather_object(output_container, inputs)
        return output_container
    return [inputs]

maybe_all_gather_np_images(image)

When in DDP- gathers images (as np.ndarray objects) from all processes. Returns the concatenated np.array across dim=0. When not in DDP - returns the input tensor.

Parameters:

Name Type Description Default
image np.ndarray

np.ndarray, the local rank's tensor to gather

required

Returns:

Type Description
np.ndarray

np.ndarray, the output image as described above

Source code in src/super_gradients/training/utils/distributed_training_utils.py
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
def maybe_all_gather_np_images(image: np.ndarray) -> np.ndarray:
    """
    When in DDP- gathers images (as np.ndarray objects) from all processes.
     Returns the concatenated np.array across dim=0.
    When not in DDP - returns the input tensor.

    :param image: np.ndarray, the local rank's tensor to gather

    :return: np.ndarray, the output image as described above
    """
    if is_distributed():
        rank = get_rank()
        output_container = [None for _ in range(get_world_size())]
        all_gather_object(output_container, image)
        if rank == 0:
            image = np.concatenate(output_container, 0)
    return image

maybe_all_reduce_tensor_average(tensor)

When in DDP- mean-reduces tensor from all devices. When not in DDP - returns the input tensor.

Parameters:

Name Type Description Default
tensor torch.Tensor

tensor to (maybe) reduce

required

Returns:

Type Description
torch.Tensor
Source code in src/super_gradients/training/utils/distributed_training_utils.py
384
385
386
387
388
389
390
391
392
393
394
395
def maybe_all_reduce_tensor_average(tensor: torch.Tensor) -> torch.Tensor:
    """
    When in DDP- mean-reduces tensor from all devices.
    When not in DDP - returns the input tensor.

    :param tensor:tensor to (maybe) reduce
    :return:
    """
    if is_distributed():
        # .to_dense() is required to ensure we can do maybe_all_reduce_tensor_average(some_vector[3])
        tensor = distributed_all_reduce_tensor_average(tensor=tensor.to_dense(), n=torch.distributed.get_world_size())
    return tensor

reduce_results_tuple_for_ddp(validation_results_tuple, device)

Gather all validation tuples from the various devices and average them

Source code in src/super_gradients/training/utils/distributed_training_utils.py
47
48
49
50
51
52
53
54
55
56
57
def reduce_results_tuple_for_ddp(validation_results_tuple, device):
    """Gather all validation tuples from the various devices and average them"""
    validation_results_list = list(validation_results_tuple)
    for i, validation_result in enumerate(validation_results_list):
        if torch.is_tensor(validation_result):
            validation_result = validation_result.clone().detach()
        else:
            validation_result = torch.tensor(validation_result)
        validation_results_list[i] = distributed_all_reduce_tensor_average(tensor=validation_result.to(device), n=torch.distributed.get_world_size())
    validation_results_tuple = tuple(validation_results_list)
    return validation_results_tuple

restart_script_with_ddp(num_gpus=None)

Launch the same script as the one that was launched (i.e. the command used to start the current process is re-used) but on subprocesses (i.e. with DDP).

Parameters:

Name Type Description Default
num_gpus int

How many gpu's you want to run the script on. If not specified, every available device will be used.

None
Source code in src/super_gradients/training/utils/distributed_training_utils.py
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
@record
def restart_script_with_ddp(num_gpus: int = None):
    """Launch the same script as the one that was launched (i.e. the command used to start the current process is re-used) but on subprocesses (i.e. with DDP).

    :param num_gpus: How many gpu's you want to run the script on. If not specified, every available device will be used.
    """
    ddp_port = find_free_port()

    # Get the value fom recipe if specified, otherwise take all available devices.
    num_gpus = num_gpus if num_gpus is not None else torch.cuda.device_count()
    if num_gpus > torch.cuda.device_count():
        raise ValueError(f"You specified num_gpus={num_gpus} but only {torch.cuda.device_count()} GPU's are available")

    logger.info(
        "Launching DDP with:\n"
        f"   - ddp_port = {ddp_port}\n"
        f"   - num_gpus = {num_gpus}/{torch.cuda.device_count()} available\n"
        "-------------------------------------\n"
    )

    config = LaunchConfig(
        nproc_per_node=num_gpus,
        min_nodes=1,
        max_nodes=1,
        run_id="sg_initiated",
        role="default",
        rdzv_endpoint=f"127.0.0.1:{ddp_port}",
        rdzv_backend="static",
        rdzv_configs={"rank": 0, "timeout": 900},
        rdzv_timeout=-1,
        max_restarts=0,
        monitor_interval=5,
        start_method="spawn",
        log_dir=None,
        redirects=Std.NONE,
        tee=Std.NONE,
        metrics_cfg={},
    )

    elastic_launch(config=config, entrypoint=sys.executable)(*sys.argv, *EXTRA_ARGS)

    # The code below should actually never be reached as the process will be in a loop inside elastic_launch until any subprocess crashes.
    sys.exit(0)

scaled_all_reduce(tensors, num_gpus)

Performs the scaled all_reduce operation on the provided tensors. The input tensors are modified in-place. Currently supports only the sum reduction operator. The reduced values are scaled by the inverse size of the process group (equivalent to num_gpus).

Source code in src/super_gradients/training/utils/distributed_training_utils.py
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
def scaled_all_reduce(tensors: torch.Tensor, num_gpus: int):
    """
    Performs the scaled all_reduce operation on the provided tensors.
    The input tensors are modified in-place.
    Currently supports only the sum
    reduction operator.
    The reduced values are scaled by the inverse size of the
    process group (equivalent to num_gpus).
    """
    # There is no need for reduction in the single-proc case
    if num_gpus == 1:
        return tensors

    # Queue the reductions
    reductions = []
    for tensor in tensors:
        reduction = torch.distributed.all_reduce(tensor, async_op=True)
        reductions.append(reduction)

    # Wait for reductions to finish
    for reduction in reductions:
        reduction.wait()

    # Scale the results
    for tensor in tensors:
        tensor.mul_(1.0 / num_gpus)
    return tensors

setup_cpu(multi_gpu=MultiGPUMode.AUTO, num_gpus=None)

Parameters:

Name Type Description Default
multi_gpu MultiGPUMode

DDP, DP, Off or AUTO

MultiGPUMode.AUTO
num_gpus int

Number of GPU's to use.

None
Source code in src/super_gradients/training/utils/distributed_training_utils.py
211
212
213
214
215
216
217
218
219
220
221
222
223
def setup_cpu(multi_gpu: MultiGPUMode = MultiGPUMode.AUTO, num_gpus: int = None):
    """
    :param multi_gpu:    DDP, DP, Off or AUTO
    :param num_gpus:     Number of GPU's to use.
    """
    if multi_gpu not in (MultiGPUMode.OFF, MultiGPUMode.AUTO, None):
        raise ValueError(f"device='cpu' and multi_gpu={multi_gpu} are not compatible together.")

    if num_gpus not in (0, None):
        raise ValueError(f"device='cpu' and num_gpus={num_gpus} are not compatible together.")

    device_config.device = "cpu"
    device_config.multi_gpu = MultiGPUMode.OFF

setup_device(multi_gpu=None, num_gpus=None, device='cuda')

If required, launch ddp subprocesses.

Parameters:

Name Type Description Default
multi_gpu MultiGPUMode

DDP, DP, Off or AUTO

None
num_gpus int

Number of GPU's to use. When None, use all available devices on DDP or only one device on DP/OFF.

None
device str

The device you want to use ('cpu' or 'cuda') If you only set num_gpus, your device will be set up according to the following logic: - setup_device(num_gpus=0) => gpu_mode='OFF' and device='cpu' - setup_device(num_gpus=1) => gpu_mode='OFF' and device='gpu' - setup_device(num_gpus>=2) => gpu_mode='DDP' and device='gpu' - setup_device(num_gpus=-1) => gpu_mode='DDP' and device='gpu' and num_gpus=<N-AVAILABLE-GPUs>

'cuda'
Source code in src/super_gradients/training/utils/distributed_training_utils.py
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
@resolve_param("multi_gpu", TypeFactory(MultiGPUMode.dict()))
def setup_device(multi_gpu: MultiGPUMode = None, num_gpus: int = None, device: str = "cuda"):
    """
    If required, launch ddp subprocesses.
    :param multi_gpu:   DDP, DP, Off or AUTO
    :param num_gpus:    Number of GPU's to use. When None, use all available devices on DDP or only one device on DP/OFF.
    :param device:      The device you want to use ('cpu' or 'cuda')

    If you only set num_gpus, your device will be set up according to the following logic:
        - `setup_device(num_gpus=0)`  => `gpu_mode='OFF'` and `device='cpu'`
        - `setup_device(num_gpus=1)`  => `gpu_mode='OFF'` and `device='gpu'`
        - `setup_device(num_gpus>=2)` => `gpu_mode='DDP'` and `device='gpu'`
        - `setup_device(num_gpus=-1)` => `gpu_mode='DDP'` and `device='gpu'` and `num_gpus=<N-AVAILABLE-GPUs>`

    """
    init_trainer()

    # When launching with torch.distributed.launch or torchrun, multi_gpu might not be set to DDP (since we are not using the recipe params)
    # To avoid any issue we force multi_gpu to be DDP if the current process is ddp subprocess. We also set num_gpus, device to run smoothly.
    if not is_launched_using_sg() and is_distributed():
        multi_gpu, num_gpus, device = MultiGPUMode.DISTRIBUTED_DATA_PARALLEL, None, "cuda"

    if device is None:
        device = "cuda"

    if device == "cuda" and not torch.cuda.is_available():
        logger.warning("CUDA device is not available on your device... Moving to CPU.")
        multi_gpu, num_gpus, device = MultiGPUMode.OFF, 0, "cpu"

    if device == "cpu":
        setup_cpu(multi_gpu, num_gpus)
    elif device == "cuda":
        setup_gpu(multi_gpu, num_gpus)
    else:
        raise ValueError(f"Only valid values for device are: 'cpu' and 'cuda'. Received: '{device}'")

setup_gpu(multi_gpu=MultiGPUMode.AUTO, num_gpus=None)

If required, launch ddp subprocesses.

Parameters:

Name Type Description Default
multi_gpu MultiGPUMode

DDP, DP, Off or AUTO

MultiGPUMode.AUTO
num_gpus int

Number of GPU's to use. When None, use all available devices on DDP or only one device on DP/OFF.

None
Source code in src/super_gradients/training/utils/distributed_training_utils.py
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
def setup_gpu(multi_gpu: MultiGPUMode = MultiGPUMode.AUTO, num_gpus: int = None):
    """
    If required, launch ddp subprocesses.
    :param multi_gpu:    DDP, DP, Off or AUTO
    :param num_gpus:     Number of GPU's to use. When None, use all available devices on DDP or only one device on DP/OFF.
    """

    if num_gpus == 0:
        raise ValueError("device='cuda' and num_gpus=0 are not compatible together.")

    multi_gpu, num_gpus = _resolve_gpu_params(multi_gpu=multi_gpu, num_gpus=num_gpus)

    device_config.device = "cuda"
    device_config.multi_gpu = multi_gpu
    device_config.num_gpus = num_gpus

    if is_distributed():
        initialize_ddp()
    elif multi_gpu == MultiGPUMode.DISTRIBUTED_DATA_PARALLEL:
        restart_script_with_ddp(num_gpus=num_gpus)

setup_gpu_mode(gpu_mode=MultiGPUMode.OFF, num_gpus=None)

[DEPRECATED in favor of setup_device] If required, launch ddp subprocesses.

Parameters:

Name Type Description Default
gpu_mode MultiGPUMode

DDP, DP, Off or AUTO

MultiGPUMode.OFF
num_gpus int

Number of GPU's to use. When None, use all available devices on DDP or only one device on DP/OFF.

None
Source code in src/super_gradients/training/utils/distributed_training_utils.py
165
166
167
168
169
170
171
def setup_gpu_mode(gpu_mode: MultiGPUMode = MultiGPUMode.OFF, num_gpus: int = None):
    """[DEPRECATED in favor of setup_device] If required, launch ddp subprocesses.
    :param gpu_mode:    DDP, DP, Off or AUTO
    :param num_gpus:    Number of GPU's to use. When None, use all available devices on DDP or only one device on DP/OFF.
    """
    logger.warning("setup_gpu_mode is now deprecated in favor of setup_device")
    setup_device(multi_gpu=gpu_mode, num_gpus=num_gpus)

wait_for_the_master(local_rank)

Make all processes waiting for the master to do some task.

Source code in src/super_gradients/training/utils/distributed_training_utils.py
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
@contextmanager
def wait_for_the_master(local_rank: int):
    """
    Make all processes waiting for the master to do some task.
    """
    if local_rank > 0:
        dist.barrier()
    yield
    if local_rank == 0:
        if not dist.is_available():
            return
        if not dist.is_initialized():
            return
        else:
            dist.barrier()

EarlyStop

Bases: PhaseCallback

Callback to monitor a metric and stop training when it stops improving. Inspired by pytorch_lightning.callbacks.early_stopping and tf.keras.callbacks.EarlyStopping

Source code in src/super_gradients/training/utils/early_stopping.py
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
@register_callback(Callbacks.EARLY_STOP)
class EarlyStop(PhaseCallback):
    """
    Callback to monitor a metric and stop training when it stops improving.
    Inspired by pytorch_lightning.callbacks.early_stopping and tf.keras.callbacks.EarlyStopping
    """

    mode_dict = {"min": torch.lt, "max": torch.gt}
    supported_phases = (Phase.VALIDATION_EPOCH_END, Phase.TRAIN_EPOCH_END)

    def __init__(
        self,
        phase: Phase,
        monitor: str,
        mode: str = "min",
        min_delta: float = 0.0,
        patience: int = 3,
        check_finite: bool = True,
        threshold: Optional[float] = None,
        verbose: bool = False,
        strict: bool = True,
    ):
        """

        :param phase: Callback phase event.
        :param monitor: name of the metric to be monitored.
        :param mode: one of 'min', 'max'. In 'min' mode, training will stop when the quantity
           monitored has stopped decreasing and in 'max' mode it will stop when the quantity
           monitored has stopped increasing.
        :param min_delta: minimum change in the monitored quantity to qualify as an improvement, i.e. an absolute
           change of less than `min_delta`, will count as no improvement.
        :param patience: number of checks with no improvement after which training will be stopped.
            One check happens after every phase event.
        :param check_finite: When set ``True``, stops training when the monitor becomes NaN or infinite.
        :param threshold: Stop training immediately once the monitored quantity reaches this threshold. For mode 'min'
            stops training when below threshold, For mode 'max' stops training when above threshold.
        :param verbose: If `True` print logs.
        :param strict: whether to crash the training if `monitor` is not found in the metrics.
        """
        super(EarlyStop, self).__init__(phase)

        if phase not in self.supported_phases:
            raise ValueError(f"EarlyStop doesn't support phase: {phase}, " f"excepted {', '.join([str(x) for x in self.supported_phases])}")
        self.phase = phase
        self.monitor_key = monitor
        self.min_delta = min_delta
        self.patience = patience
        self.mode = mode
        self.check_finite = check_finite
        self.threshold = threshold
        self.verbose = verbose
        self.strict = strict

        self.wait_count = 0

        if self.mode not in self.mode_dict:
            raise Exception(f"`mode` can be {', '.join(self.mode_dict.keys())}, got {self.mode}")
        self.monitor_op = self.mode_dict[self.mode]
        self.min_delta *= 1 if self.monitor_op == torch.gt else -1

        torch_inf = torch.tensor(np.Inf)
        self.best_score = torch_inf if self.monitor_op == torch.lt else -torch_inf

    def _get_metric_value(self, metrics_dict):
        if self.monitor_key not in metrics_dict.keys():
            msg = f"Can't find EarlyStop monitor {self.monitor_key} in metrics_dict: {metrics_dict.keys()}"
            exception_cls = RuntimeError if self.strict else MissingMonitorKeyException
            raise exception_cls(msg)
        return metrics_dict[self.monitor_key]

    def _check_for_early_stop(self, current: torch.Tensor):
        should_stop = False

        # check if current value is Nan or inf
        if self.check_finite and not torch.isfinite(current):
            should_stop = True
            reason = (
                f"Monitored metric {self.monitor_key} = {current} is not finite." f" Previous best value was {self.best_score:.3f}. Signaling Trainer to stop."
            )

        # check if current value reached threshold value
        elif self.threshold is not None and self.monitor_op(current, self.threshold):
            should_stop = True
            reason = "Stopping threshold reached:" f" {self.monitor_key} = {current} {self.monitor_op} {self.threshold}." " Signaling Trainer to stop."

        # check if current is an improvement of monitor_key metric.
        elif self.monitor_op(current - self.min_delta, self.best_score.to(current.device)):
            should_stop = False
            if torch.isfinite(self.best_score):
                reason = (
                    f"Metric {self.monitor_key} improved by {abs(self.best_score - current):.3f} >="
                    f" min_delta = {abs(self.min_delta)}. New best score: {current:.3f}"
                )
            else:
                reason = f"Metric {self.monitor_key} improved. New best score: {current:.3f}"
            self.best_score = current
            self.wait_count = 0

        # no improvement in monitor_key metric, check if wait_count is bigger than patience.
        else:
            self.wait_count += 1
            reason = f"Monitored metric {self.monitor_key} did not improve in the last {self.wait_count} records."
            if self.wait_count >= self.patience:
                should_stop = True
                reason += f" Best score: {self.best_score:.3f}. Signaling Trainer to stop."

        return reason, should_stop

    def __call__(self, context: PhaseContext):
        try:
            current = self._get_metric_value(context.metrics_dict)
        except MissingMonitorKeyException as e:
            logger.warning(e)
            return

        if not isinstance(current, torch.Tensor):
            current = torch.tensor(current)

        reason, self.should_stop = self._check_for_early_stop(current)

        # log reason message, and signal early stop if should_stop=True.
        if self.should_stop:
            self._signal_early_stop(context, reason)

        elif self.verbose:
            logger.info(reason)

    def _signal_early_stop(self, context: PhaseContext, reason: str):
        logger.info(reason)
        context.update_context(stop_training=True)

__init__(phase, monitor, mode='min', min_delta=0.0, patience=3, check_finite=True, threshold=None, verbose=False, strict=True)

Parameters:

Name Type Description Default
phase Phase

Callback phase event.

required
monitor str

name of the metric to be monitored.

required
mode str

one of 'min', 'max'. In 'min' mode, training will stop when the quantity monitored has stopped decreasing and in 'max' mode it will stop when the quantity monitored has stopped increasing.

'min'
min_delta float

minimum change in the monitored quantity to qualify as an improvement, i.e. an absolute change of less than min_delta, will count as no improvement.

0.0
patience int

number of checks with no improvement after which training will be stopped. One check happens after every phase event.

3
check_finite bool

When set True, stops training when the monitor becomes NaN or infinite.

True
threshold Optional[float]

Stop training immediately once the monitored quantity reaches this threshold. For mode 'min' stops training when below threshold, For mode 'max' stops training when above threshold.

None
verbose bool

If True print logs.

False
strict bool

whether to crash the training if monitor is not found in the metrics.

True
Source code in src/super_gradients/training/utils/early_stopping.py
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
def __init__(
    self,
    phase: Phase,
    monitor: str,
    mode: str = "min",
    min_delta: float = 0.0,
    patience: int = 3,
    check_finite: bool = True,
    threshold: Optional[float] = None,
    verbose: bool = False,
    strict: bool = True,
):
    """

    :param phase: Callback phase event.
    :param monitor: name of the metric to be monitored.
    :param mode: one of 'min', 'max'. In 'min' mode, training will stop when the quantity
       monitored has stopped decreasing and in 'max' mode it will stop when the quantity
       monitored has stopped increasing.
    :param min_delta: minimum change in the monitored quantity to qualify as an improvement, i.e. an absolute
       change of less than `min_delta`, will count as no improvement.
    :param patience: number of checks with no improvement after which training will be stopped.
        One check happens after every phase event.
    :param check_finite: When set ``True``, stops training when the monitor becomes NaN or infinite.
    :param threshold: Stop training immediately once the monitored quantity reaches this threshold. For mode 'min'
        stops training when below threshold, For mode 'max' stops training when above threshold.
    :param verbose: If `True` print logs.
    :param strict: whether to crash the training if `monitor` is not found in the metrics.
    """
    super(EarlyStop, self).__init__(phase)

    if phase not in self.supported_phases:
        raise ValueError(f"EarlyStop doesn't support phase: {phase}, " f"excepted {', '.join([str(x) for x in self.supported_phases])}")
    self.phase = phase
    self.monitor_key = monitor
    self.min_delta = min_delta
    self.patience = patience
    self.mode = mode
    self.check_finite = check_finite
    self.threshold = threshold
    self.verbose = verbose
    self.strict = strict

    self.wait_count = 0

    if self.mode not in self.mode_dict:
        raise Exception(f"`mode` can be {', '.join(self.mode_dict.keys())}, got {self.mode}")
    self.monitor_op = self.mode_dict[self.mode]
    self.min_delta *= 1 if self.monitor_op == torch.gt else -1

    torch_inf = torch.tensor(np.Inf)
    self.best_score = torch_inf if self.monitor_op == torch.lt else -torch_inf

MissingMonitorKeyException

Bases: Exception

Exception raised for missing monitor key in metrics_dict.

Source code in src/super_gradients/training/utils/early_stopping.py
146
147
148
149
150
151
class MissingMonitorKeyException(Exception):
    """
    Exception raised for missing monitor key in metrics_dict.
    """

    pass

KDModelEMA

Bases: ModelEMA

Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models Keep a moving average of everything in the model state_dict (parameters and buffers). This is intended to allow functionality like https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage A smoothed version of the weights is necessary for some training schemes to perform well. This class is sensitive where it is initialized in the sequence of model init, GPU assignment and distributed training wrappers.

Source code in src/super_gradients/training/utils/ema.py
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
class KDModelEMA(ModelEMA):
    """Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models
    Keep a moving average of everything in the model state_dict (parameters and buffers).
    This is intended to allow functionality like
    https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
    A smoothed version of the weights is necessary for some training schemes to perform well.
    This class is sensitive where it is initialized in the sequence of model init,
    GPU assignment and distributed training wrappers.
    """

    def __init__(self, kd_model: KDModule, decay: float, decay_function: IDecayFunction):
        """
        Init the EMA
        :param kd_model: KDModule, the training Knowledge distillation model to construct the EMA model by
                    IMPORTANT: WHEN THE APPLICATION OF EMA ONLY ON A SUBSET OF ATTRIBUTES IS DESIRED, WRAP THE NN.MODULE
                    AS SgModule AND OVERWRITE get_include_attributes() AND get_exclude_attributes() AS DESIRED.
        :param decay: the maximum decay value. as the training process advances, the decay will climb towards this value
                      until the EMA_t+1 = EMA_t * decay + TRAINING_MODEL * (1- decay)
        :param beta: the exponent coefficient. The higher the beta, the sooner in the training the decay will saturate to
                     its final value. beta=15 is ~40% of the training process.
        """
        # Only work on the student (we don't want to update and to have a duplicate of the teacher)
        super().__init__(model=unwrap_model(kd_model).student, decay=decay, decay_function=decay_function)

        # Overwrite current ema attribute with combination of the student model EMA (current self.ema)
        # with already the instantiated teacher, to have the final KD EMA
        self.ema = KDModule(
            arch_params=unwrap_model(kd_model).arch_params,
            student=self.ema,
            teacher=unwrap_model(kd_model).teacher,
            run_teacher_on_eval=unwrap_model(kd_model).run_teacher_on_eval,
        )

__init__(kd_model, decay, decay_function)

Init the EMA

Parameters:

Name Type Description Default
kd_model KDModule

KDModule, the training Knowledge distillation model to construct the EMA model by IMPORTANT: WHEN THE APPLICATION OF EMA ONLY ON A SUBSET OF ATTRIBUTES IS DESIRED, WRAP THE NN.MODULE AS SgModule AND OVERWRITE get_include_attributes() AND get_exclude_attributes() AS DESIRED.

required
decay float

the maximum decay value. as the training process advances, the decay will climb towards this value until the EMA_t+1 = EMA_t * decay + TRAINING_MODEL * (1- decay)

required
beta

the exponent coefficient. The higher the beta, the sooner in the training the decay will saturate to its final value. beta=15 is ~40% of the training process.

required
Source code in src/super_gradients/training/utils/ema.py
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
def __init__(self, kd_model: KDModule, decay: float, decay_function: IDecayFunction):
    """
    Init the EMA
    :param kd_model: KDModule, the training Knowledge distillation model to construct the EMA model by
                IMPORTANT: WHEN THE APPLICATION OF EMA ONLY ON A SUBSET OF ATTRIBUTES IS DESIRED, WRAP THE NN.MODULE
                AS SgModule AND OVERWRITE get_include_attributes() AND get_exclude_attributes() AS DESIRED.
    :param decay: the maximum decay value. as the training process advances, the decay will climb towards this value
                  until the EMA_t+1 = EMA_t * decay + TRAINING_MODEL * (1- decay)
    :param beta: the exponent coefficient. The higher the beta, the sooner in the training the decay will saturate to
                 its final value. beta=15 is ~40% of the training process.
    """
    # Only work on the student (we don't want to update and to have a duplicate of the teacher)
    super().__init__(model=unwrap_model(kd_model).student, decay=decay, decay_function=decay_function)

    # Overwrite current ema attribute with combination of the student model EMA (current self.ema)
    # with already the instantiated teacher, to have the final KD EMA
    self.ema = KDModule(
        arch_params=unwrap_model(kd_model).arch_params,
        student=self.ema,
        teacher=unwrap_model(kd_model).teacher,
        run_teacher_on_eval=unwrap_model(kd_model).run_teacher_on_eval,
    )

ModelEMA

Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models Keep a moving average of everything in the model state_dict (parameters and buffers). This is intended to allow functionality like https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage A smoothed version of the weights is necessary for some training schemes to perform well. This class is sensitive where it is initialized in the sequence of model init, GPU assignment and distributed training wrappers.

Source code in src/super_gradients/training/utils/ema.py
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
class ModelEMA:
    """Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models
    Keep a moving average of everything in the model state_dict (parameters and buffers).
    This is intended to allow functionality like
    https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
    A smoothed version of the weights is necessary for some training schemes to perform well.
    This class is sensitive where it is initialized in the sequence of model init,
    GPU assignment and distributed training wrappers.
    """

    def __init__(self, model: nn.Module, decay: float, decay_function: IDecayFunction):
        """
        Init the EMA
        :param model: Union[SgModule, nn.Module], the training model to construct the EMA model by
                    IMPORTANT: WHEN THE APPLICATION OF EMA ONLY ON A SUBSET OF ATTRIBUTES IS DESIRED, WRAP THE NN.MODULE
                    AS SgModule AND OVERWRITE get_include_attributes() AND get_exclude_attributes() AS DESIRED.
        :param decay: the maximum decay value. as the training process advances, the decay will climb towards this value
                      until the EMA_t+1 = EMA_t * decay + TRAINING_MODEL * (1- decay)
        :param beta: the exponent coefficient. The higher the beta, the sooner in the training the decay will saturate to
                     its final value. beta=15 is ~40% of the training process.
        """
        # Create EMA
        model = unwrap_model(model)
        self.ema = deepcopy(model)
        self.ema.eval()
        self.decay = decay
        self.decay_function = decay_function

        """"
        we hold a list of model attributes (not wights and biases) which we would like to include in each
        attribute update or exclude from each update. a SgModule declare these attribute using
        get_include_attributes and get_exclude_attributes functions. for a nn.Module which is not a SgModule
        all non-private (not starting with '_') attributes will be updated (and only them).
        """
        if isinstance(model, SgModule):
            self.include_attributes = model.get_include_attributes()
            self.exclude_attributes = model.get_exclude_attributes()
        else:
            warnings.warn("Warning: EMA should be used with SgModule instance. All attributes of the model will be " "included in EMA")
            self.include_attributes = []
            self.exclude_attributes = []
        for p in self.ema.parameters():
            p.requires_grad_(False)

    @classmethod
    def from_params(cls, model: nn.Module, decay_type: str = None, decay: float = None, **kwargs):
        if decay is None:
            logger.warning(
                "Parameter `decay` is not specified for EMA params. Please specify `decay` parameter explicitly in your config:\n"
                "ema: True\n"
                "ema_params: \n"
                "  decay: 0.9999\n"
                "  decay_type: exp\n"
                "  beta: 15\n"
                "Will default to decay: 0.9999\n"
                "In the next major release of SG this warning will become an error."
            )
            decay = 0.9999

        if "exp_activation" in kwargs:
            logger.warning(
                "Parameter `exp_activation` is deprecated for EMA model. Please update your config to use decay_type: str (constant|exp|threshold) instead:\n"
                "ema: True\n"
                "ema_params: \n"
                "  decay: 0.9999\n"
                "  decay_type: exp # Equivalent to exp_activation: True\n"
                "  beta: 15\n"
                "\n"
                "ema: True\n"
                "ema_params: \n"
                "  decay: 0.9999\n"
                "  decay_type: constant # Equivalent to exp_activation: False\n"
                "\n"
                "In the next major release of SG this warning will become an error."
            )
            decay_type = "exp" if bool(kwargs.pop("exp_activation")) else "constant"

        if decay_type is None:
            logger.warning(
                "Parameter decay_type is not specified for EMA model. Please specify decay_type parameter explicitly in your config:\n"
                "ema: True\n"
                "ema_params: \n"
                "  decay: 0.9999\n"
                "  decay_type: constant|exp|threshold\n"
                "Will default to `exp` decay with beta = 15\n"
                "In the next major release of SG this warning will become an error."
            )
            decay_type = "exp"
            if "beta" not in kwargs:
                kwargs["beta"] = 15

        try:
            decay_cls = EMA_DECAY_FUNCTIONS[decay_type]
        except KeyError:
            raise UnknownTypeException(decay_type, list(EMA_DECAY_FUNCTIONS.keys()))

        decay_function = decay_cls(**kwargs)
        return cls(model, decay, decay_function)

    def update(self, model, step: int, total_steps: int):
        """
        Update the state of the EMA model.

        :param model: Current training model
        :param step: Current training step
        :param total_steps: Total training steps
        """
        # Update EMA parameters
        model = unwrap_model(model)
        with torch.no_grad():
            decay = self.decay_function(self.decay, step, total_steps)

            for ema_v, model_v in zip(self.ema.state_dict().values(), model.state_dict().values()):
                if ema_v.dtype.is_floating_point:
                    ema_v.copy_(ema_v * decay + (1.0 - decay) * model_v.detach())

    def update_attr(self, model):
        """
        This function updates model attributes (not weight and biases) from original model to the ema model.
        attributes of the original model, such as anchors and grids (of detection models), may be crucial to the
        model operation and need to be updated.
        If include_attributes and exclude_attributes lists were not defined, all non-private (not starting with '_')
        attributes will be updated (and only them).
        :param model: the source model
        """
        copy_attr(self.ema, unwrap_model(model), self.include_attributes, self.exclude_attributes)

decay_function = decay_function instance-attribute

" we hold a list of model attributes (not wights and biases) which we would like to include in each attribute update or exclude from each update. a SgModule declare these attribute using get_include_attributes and get_exclude_attributes functions. for a nn.Module which is not a SgModule all non-private (not starting with '_') attributes will be updated (and only them).

__init__(model, decay, decay_function)

Init the EMA

Parameters:

Name Type Description Default
model nn.Module

Union[SgModule, nn.Module], the training model to construct the EMA model by IMPORTANT: WHEN THE APPLICATION OF EMA ONLY ON A SUBSET OF ATTRIBUTES IS DESIRED, WRAP THE NN.MODULE AS SgModule AND OVERWRITE get_include_attributes() AND get_exclude_attributes() AS DESIRED.

required
decay float

the maximum decay value. as the training process advances, the decay will climb towards this value until the EMA_t+1 = EMA_t * decay + TRAINING_MODEL * (1- decay)

required
beta

the exponent coefficient. The higher the beta, the sooner in the training the decay will saturate to its final value. beta=15 is ~40% of the training process.

required
Source code in src/super_gradients/training/utils/ema.py
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
def __init__(self, model: nn.Module, decay: float, decay_function: IDecayFunction):
    """
    Init the EMA
    :param model: Union[SgModule, nn.Module], the training model to construct the EMA model by
                IMPORTANT: WHEN THE APPLICATION OF EMA ONLY ON A SUBSET OF ATTRIBUTES IS DESIRED, WRAP THE NN.MODULE
                AS SgModule AND OVERWRITE get_include_attributes() AND get_exclude_attributes() AS DESIRED.
    :param decay: the maximum decay value. as the training process advances, the decay will climb towards this value
                  until the EMA_t+1 = EMA_t * decay + TRAINING_MODEL * (1- decay)
    :param beta: the exponent coefficient. The higher the beta, the sooner in the training the decay will saturate to
                 its final value. beta=15 is ~40% of the training process.
    """
    # Create EMA
    model = unwrap_model(model)
    self.ema = deepcopy(model)
    self.ema.eval()
    self.decay = decay
    self.decay_function = decay_function

    """"
    we hold a list of model attributes (not wights and biases) which we would like to include in each
    attribute update or exclude from each update. a SgModule declare these attribute using
    get_include_attributes and get_exclude_attributes functions. for a nn.Module which is not a SgModule
    all non-private (not starting with '_') attributes will be updated (and only them).
    """
    if isinstance(model, SgModule):
        self.include_attributes = model.get_include_attributes()
        self.exclude_attributes = model.get_exclude_attributes()
    else:
        warnings.warn("Warning: EMA should be used with SgModule instance. All attributes of the model will be " "included in EMA")
        self.include_attributes = []
        self.exclude_attributes = []
    for p in self.ema.parameters():
        p.requires_grad_(False)

update(model, step, total_steps)

Update the state of the EMA model.

Parameters:

Name Type Description Default
model

Current training model

required
step int

Current training step

required
total_steps int

Total training steps

required
Source code in src/super_gradients/training/utils/ema.py
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
def update(self, model, step: int, total_steps: int):
    """
    Update the state of the EMA model.

    :param model: Current training model
    :param step: Current training step
    :param total_steps: Total training steps
    """
    # Update EMA parameters
    model = unwrap_model(model)
    with torch.no_grad():
        decay = self.decay_function(self.decay, step, total_steps)

        for ema_v, model_v in zip(self.ema.state_dict().values(), model.state_dict().values()):
            if ema_v.dtype.is_floating_point:
                ema_v.copy_(ema_v * decay + (1.0 - decay) * model_v.detach())

update_attr(model)

This function updates model attributes (not weight and biases) from original model to the ema model. attributes of the original model, such as anchors and grids (of detection models), may be crucial to the model operation and need to be updated. If include_attributes and exclude_attributes lists were not defined, all non-private (not starting with '_') attributes will be updated (and only them).

Parameters:

Name Type Description Default
model

the source model

required
Source code in src/super_gradients/training/utils/ema.py
143
144
145
146
147
148
149
150
151
152
def update_attr(self, model):
    """
    This function updates model attributes (not weight and biases) from original model to the ema model.
    attributes of the original model, such as anchors and grids (of detection models), may be crucial to the
    model operation and need to be updated.
    If include_attributes and exclude_attributes lists were not defined, all non-private (not starting with '_')
    attributes will be updated (and only them).
    :param model: the source model
    """
    copy_attr(self.ema, unwrap_model(model), self.include_attributes, self.exclude_attributes)

ConstantDecay

Bases: IDecayFunction

Constant decay schedule.

Source code in src/super_gradients/training/utils/ema_decay_schedules.py
26
27
28
29
30
31
32
33
34
35
class ConstantDecay(IDecayFunction):
    """
    Constant decay schedule.
    """

    def __init__(self, **kwargs):
        pass

    def __call__(self, decay: float, step: int, total_steps: int) -> float:
        return decay

ExpDecay

Bases: IDecayFunction

Gradually increase EMA decay from 0.1 to the maximum value using following formula: decay * (1 - math.exp(-x * self.beta))

Source code in src/super_gradients/training/utils/ema_decay_schedules.py
50
51
52
53
54
55
56
57
58
59
60
61
class ExpDecay(IDecayFunction):
    """
    Gradually increase EMA decay from 0.1 to the maximum value using following formula: decay * (1 - math.exp(-x * self.beta))

    """

    def __init__(self, beta: float, **kwargs):
        self.beta = beta

    def __call__(self, decay: float, step, total_steps: int) -> float:
        x = step / total_steps
        return decay * (1 - np.exp(-x * self.beta))

IDecayFunction

Interface for EMA decay schedule. The decay schedule is a function of the maximum decay value and training progress. Usually it gradually increase EMA from to the maximum value. The exact ramp-up schedule is defined by the concrete implementation.

Source code in src/super_gradients/training/utils/ema_decay_schedules.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
class IDecayFunction:
    """
    Interface for EMA decay schedule. The decay schedule is a function of the maximum decay value and training progress.
    Usually it gradually increase EMA from to the maximum value. The exact ramp-up schedule is defined by the concrete
    implementation.
    """

    @abstractmethod
    def __call__(self, decay: float, step: int, total_steps: int) -> float:
        """

        :param decay: The maximum decay value.
        :param step: Current training step. The unit-range training percentage can be obtained by `step / total_steps`.
        :param total_steps:  Total number of training steps.
        :return: Computed decay value for a given step.
        """
        pass

__call__(decay, step, total_steps) abstractmethod

Parameters:

Name Type Description Default
decay float

The maximum decay value.

required
step int

Current training step. The unit-range training percentage can be obtained by step / total_steps.

required
total_steps int

Total number of training steps.

required

Returns:

Type Description
float

Computed decay value for a given step.

Source code in src/super_gradients/training/utils/ema_decay_schedules.py
14
15
16
17
18
19
20
21
22
23
@abstractmethod
def __call__(self, decay: float, step: int, total_steps: int) -> float:
    """

    :param decay: The maximum decay value.
    :param step: Current training step. The unit-range training percentage can be obtained by `step / total_steps`.
    :param total_steps:  Total number of training steps.
    :return: Computed decay value for a given step.
    """
    pass

ThresholdDecay

Bases: IDecayFunction

Gradually increase EMA decay from 0.1 to the maximum value using following formula: min(decay, (1 + step) / (10 + step))

Source code in src/super_gradients/training/utils/ema_decay_schedules.py
38
39
40
41
42
43
44
45
46
47
class ThresholdDecay(IDecayFunction):
    """
    Gradually increase EMA decay from 0.1 to the maximum value using following formula: min(decay, (1 + step) / (10 + step))
    """

    def __init__(self, **kwargs):
        pass

    def __call__(self, decay: float, step, total_steps: int) -> float:
        return np.minimum(decay, (1 + step) / (10 + step))

fuse_conv_bn(model, replace_bn_with_identity=False)

Fuses consecutive nn.Conv2d and nn.BatchNorm2d layers recursively inplace in all of the model

Parameters:

Name Type Description Default
replace_bn_with_identity bool

if set to true, bn will be replaced with identity. otherwise, bn will be removed

False
model nn.Module

the target model

required

Returns:

Type Description

the number of fuses executed

Source code in src/super_gradients/training/utils/export_utils.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
def fuse_conv_bn(model: nn.Module, replace_bn_with_identity: bool = False):
    """
    Fuses consecutive nn.Conv2d and nn.BatchNorm2d layers recursively inplace in all of the model
    :param replace_bn_with_identity: if set to true, bn will be replaced with identity. otherwise, bn will be removed
    :param model: the target model
    :return: the number of fuses executed
    """
    children = list(model.named_children())
    counter = 0
    for i in range(len(children) - 1):
        if isinstance(children[i][1], torch.nn.Conv2d) and isinstance(children[i + 1][1], torch.nn.BatchNorm2d):
            setattr(model, children[i][0], torch.nn.utils.fuse_conv_bn_eval(children[i][1], children[i + 1][1]))
            if replace_bn_with_identity:
                setattr(model, children[i + 1][0], nn.Identity())
            else:
                delattr(model, children[i + 1][0])
            counter += 1
    for child_name, child in children:
        counter += fuse_conv_bn(child, replace_bn_with_identity)

    return counter

infer_image_shape_from_model(model)

Infer the image shape from the model. This function takes the preprocessing parameters if they are available and gets the input image shape from them. If the preprocessing parameters are not available, the function returns None

Parameters:

Name Type Description Default
model Union[nn.Module, HasPredict] required

Returns:

Type Description
Optional[Tuple[int, int]]

A tuple of (height, width) or None

Source code in src/super_gradients/training/utils/export_utils.py
49
50
51
52
53
54
55
56
57
58
59
60
61
62
def infer_image_shape_from_model(model: Union[nn.Module, HasPredict]) -> Optional[Tuple[int, int]]:
    """
    Infer the image shape from the model. This function takes the preprocessing parameters if they are available
    and gets the input image shape from them. If the preprocessing parameters are not available, the function returns None
    :param model:
    :return: A tuple of (height, width) or None
    """
    model = unwrap_model(model)
    if isinstance(model, HasPredict):
        processing = model.get_processing_params()
        if processing is not None:
            shape = processing.infer_image_input_shape()
            return shape
    return None

get_input_output_shapes(batch_size, input_dims, output_dims)

Returns input/output shapes for single/multiple input/s output/s

Source code in src/super_gradients/training/utils/get_model_stats.py
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
def get_input_output_shapes(batch_size: int, input_dims: Union[list, tuple], output_dims: Union[list, tuple]):
    """
    Returns input/output shapes for single/multiple input/s output/s
    """
    if isinstance(input_dims[0], list):
        input_shape = [i.size() for i in input_dims[0] if i is not None]
    else:
        input_shape = list(input_dims[0].size())
    input_shape[0] = batch_size
    if isinstance(output_dims, (list, tuple)):
        output_shape = [[-1] + list(o.size())[1:] for o in output_dims if o is not None]
    else:
        output_shape = list(output_dims.size())
        output_shape[0] = batch_size
    return input_shape, output_shape

get_model_stats(model, input_dims, high_verbosity=True, batch_size=1, device='cuda', dtypes=None, iterations=100)

return the model summary as a string The block(type) column represents the lines (layers) above :param dtypes: The input types (list of inputs types) :param high_verbosity: prints layer by layer information

Source code in src/super_gradients/training/utils/get_model_stats.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
def get_model_stats(
    model: nn.Module,
    input_dims: Union[list, tuple],
    high_verbosity: bool = True,
    batch_size: int = 1,
    device: str = "cuda",  # noqa: C901
    dtypes=None,
    iterations: int = 100,
):
    """
    return the model summary as a string
    The block(type) column represents the lines (layers) above
        :param dtypes:          The input types (list of inputs types)
        :param high_verbosity:  prints layer by layer information
    """
    dtypes = dtypes or [torch.FloatTensor] * len(input_dims)

    def register_hook(module):
        """
        add a hook (all the desirable actions) for every layer that is not nn.Sequential/nn.ModuleList
        """

        def hook(module, input, output):

            class_name = str(module.__class__).split(".")[-1].split("'")[0]
            module_idx = len(summary)

            m_key = f"{class_name}-{module_idx + 1}"
            summary[m_key] = OrderedDict()

            # block_name refers to all layers that contains other layers
            if len(module._modules) != 0:
                summary[m_key]["block_name"] = class_name

            summary[m_key]["inference_time"] = np.round(timer.stop(), 3)
            timer.start()

            summary[m_key]["gpu_occupation"] = (round(torch.cuda.memory_allocated(0) / 1024**3, 2), "GB") if torch.cuda.is_available() else [0]
            summary[m_key]["gpu_cached_memory"] = (round(torch.cuda.memory_reserved(0) / 1024**3, 2), "GB") if torch.cuda.is_available() else [0]

            summary[m_key]["input_shape"], summary[m_key]["output_shape"] = get_input_output_shapes(batch_size=batch_size, input_dims=input, output_dims=output)

            params = 0
            if hasattr(module, "weight") and hasattr(module.weight, "size"):
                params += torch.prod(torch.LongTensor(list(module.weight.size())))
                summary[m_key]["trainable"] = module.weight.requires_grad
            if hasattr(module, "bias") and hasattr(module.bias, "size"):
                params += torch.prod(torch.LongTensor(list(module.bias.size())))
            summary[m_key]["nb_params"] = params

        if not isinstance(module, nn.Sequential) and not isinstance(module, nn.ModuleList):
            hooks.append(module.register_forward_hook(hook))

    # multiple inputs to the network
    if isinstance(input_dims, tuple):
        input_dims = [input_dims]

    x = [torch.rand(batch_size, *input_dim).type(dtype).to(device=device) for input_dim, dtype in zip(input_dims, dtypes)]

    summary_list = []
    with torch.no_grad():
        for i in range(iterations + 10):
            # create properties
            summary = OrderedDict()
            hooks = []

            # register hook
            model.apply(register_hook)

            timer = Timer(device=device)
            timer.start()
            # make a forward pass
            model(*x)

            # remove these hooks
            for h in hooks:
                h.remove()

            # we start counting from the 10th iteration for warmup
            if i >= 10:
                summary_list.append(summary)

    summary = _average_inference_time(summary_list=summary_list, summary=summary, divisor=iterations)

    return _convert_summary_dict_to_string(summary=summary, high_verbosity=high_verbosity, input_dims=input_dims, batch_size=batch_size, device=device)

check_image_typing(image)

Check if the given object respects typing of image.

Parameters:

Name Type Description Default
image ImageSource

Image to check.

required

Returns:

Type Description
bool

True if the object is an image, False otherwise.

Source code in src/super_gradients/training/utils/media/image.py
174
175
176
177
178
179
180
181
182
183
184
def check_image_typing(image: ImageSource) -> bool:
    """Check if the given object respects typing of image.
    :param image: Image to check.
    :return: True if the object is an image, False otherwise.
    """
    if isinstance(image, get_args(SingleImageSource)):
        return True
    elif isinstance(image, list):
        return all([isinstance(image_item, get_args(SingleImageSource)) for image_item in image])
    else:
        return False

generate_image_loader(images)

Generator that loads images one at a time.

Supported types include: - str: A string representing either an image or an URL. - numpy.ndarray: A numpy array representing the image - torch.Tensor: A PyTorch tensor representing the image - PIL.Image.Image: A PIL Image object - List: A list of images of any of the above types.

Parameters:

Name Type Description Default
images Union[List[ImageSource], ImageSource]

Single image or a list of images of supported types.

required

Returns:

Type Description
Iterable[np.ndarray]

Generator of images as numpy arrays (H, W, C). If loaded from string, the image will be returned as RGB.

Source code in src/super_gradients/training/utils/media/image.py
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
def generate_image_loader(images: Union[List[ImageSource], ImageSource]) -> Iterable[np.ndarray]:
    """Generator that loads images one at a time.

    Supported types include:
        - str:              A string representing either an image or an URL.
        - numpy.ndarray:    A numpy array representing the image
        - torch.Tensor:     A PyTorch tensor representing the image
        - PIL.Image.Image:  A PIL Image object
        - List:             A list of images of any of the above types.

    :param images:  Single image or a list of images of supported types.
    :return:        Generator of images as numpy arrays (H, W, C). If loaded from string, the image will be returned as RGB.
    """
    if isinstance(images, str) and os.path.isdir(images):
        images_paths = list_images_in_folder(images)
        for image_path in images_paths:
            yield load_image(image=image_path)
    elif _is_4d_array(images):
        warnings.warn(
            "It seems you are using predict() with 4D array as input. "
            "Please note we cannot track whether the input was already normalized or not. "
            "You will get incorrect results if you feed batches from train/validation dataloader that were already normalized."
            "Please check https://docs.deci.ai/super-gradients/latest/documentation/source/ModelPredictions.html for more details."
        )
        for image in images:
            yield load_image(image=image)
    elif _is_list_of_images(images=images):
        warnings.warn("It seems you are using predict() with batch input")
        for image in images:
            yield load_image(image=image)
    else:
        yield load_image(image=images)

is_image(filename)

Check if the given file name refers to image.

Parameters:

Name Type Description Default
filename str

The filename to check.

required

Returns:

Type Description
bool

True if the file is an image, False otherwise.

Source code in src/super_gradients/training/utils/media/image.py
187
188
189
190
191
192
193
def is_image(filename: str) -> bool:
    """Check if the given file name refers to image.

    :param filename:    The filename to check.
    :return:            True if the file is an image, False otherwise.
    """
    return filename.split(".")[-1].lower() in IMG_EXTENSIONS

is_url(url)

Check if the given string is a URL.

Parameters:

Name Type Description Default
url str

String to check.

required
Source code in src/super_gradients/training/utils/media/image.py
152
153
154
155
156
157
158
159
160
def is_url(url: str) -> bool:
    """Check if the given string is a URL.
    :param url:  String to check.
    """
    try:
        result = urlparse(url)
        return all([result.scheme, result.netloc, result.path])
    except Exception:
        return False

list_images_in_folder(directory)

List all the images in a directory.

Parameters:

Name Type Description Default
directory str

The path to the directory containing the images.

required

Returns:

Type Description
List[str]

A list of image file names.

Source code in src/super_gradients/training/utils/media/image.py
79
80
81
82
83
84
85
86
def list_images_in_folder(directory: str) -> List[str]:
    """List all the images in a directory.
    :param directory: The path to the directory containing the images.
    :return: A list of image file names.
    """
    files = os.listdir(directory)
    images_paths = [os.path.join(directory, f) for f in files if is_image(f)]
    return images_paths

load_image(image, input_image_channels=3)

Load a single image and return it as a numpy arrays (H, W, C).

Supported image types include: - numpy.ndarray: A numpy array representing the image - torch.Tensor: A PyTorch tensor representing the image - PIL.Image.Image: A PIL Image object - str: A string representing either a local file path or a URL to an image

Parameters:

Name Type Description Default
image ImageSource

Single image of supported types.

required
input_image_channels int

Number of channels that model expects as input. This value helps to infer the layout of the input image array. As of now this argument has default value of 3, but in future it will become mandatory.

3

Returns:

Type Description
np.ndarray

Image as numpy arrays (H, W, C). If loaded from string, the image will be returned as RGB.

Source code in src/super_gradients/training/utils/media/image.py
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
def load_image(image: ImageSource, input_image_channels: int = 3) -> np.ndarray:
    """Load a single image and return it as a numpy arrays (H, W, C).

    Supported image types include:
        - numpy.ndarray:    A numpy array representing the image
        - torch.Tensor:     A PyTorch tensor representing the image
        - PIL.Image.Image:  A PIL Image object
        - str:              A string representing either a local file path or a URL to an image

    :param image: Single image of supported types.
    :param input_image_channels: Number of channels that model expects as input.
                                 This value helps to infer the layout of the input image array.
                                 As of now this argument has default value of 3, but in future it will become mandatory.

    :return:      Image as numpy arrays (H, W, C). If loaded from string, the image will be returned as RGB.
    """
    if isinstance(image, np.ndarray):
        if image.ndim != 3:
            raise ValueError(f"Unsupported image shape: {image.shape}. This function only supports 3-dimensional images.")
        if image.shape[0] == input_image_channels:
            image = np.ascontiguousarray(image.transpose((1, 2, 0)))
        elif image.shape[2] == input_image_channels:
            pass
        else:
            raise ValueError(f"Cannot infer image layout (HWC or CHW) for image of shape {image.shape} while C is {input_image_channels}")

        return image
    elif isinstance(image, torch.Tensor):
        image = image.detach().cpu().numpy()
        return load_image(image=image, input_image_channels=input_image_channels)
    elif isinstance(image, PIL.Image.Image):
        return load_np_image_from_pil(image)
    elif isinstance(image, str):
        image = load_pil_image_from_str(image_str=image)
        return load_np_image_from_pil(image)
    else:
        raise ValueError(f"Unsupported image type: {type(image)}")

load_images(images)

Load a single image or a list of images and return them as a list of numpy arrays.

Supported types include: - str: A string representing either an image or an URL. - numpy.ndarray: A numpy array representing the image - torch.Tensor: A PyTorch tensor representing the image - PIL.Image.Image: A PIL Image object - List: A list of images of any of the above types.

Parameters:

Name Type Description Default
images Union[List[ImageSource], ImageSource]

Single image or a list of images of supported types.

required

Returns:

Type Description
List[np.ndarray]

List of images as numpy arrays. If loaded from string, the image will be returned as RGB.

Source code in src/super_gradients/training/utils/media/image.py
21
22
23
24
25
26
27
28
29
30
31
32
33
34
def load_images(images: Union[List[ImageSource], ImageSource]) -> List[np.ndarray]:
    """Load a single image or a list of images and return them as a list of numpy arrays.

    Supported types include:
        - str:              A string representing either an image or an URL.
        - numpy.ndarray:    A numpy array representing the image
        - torch.Tensor:     A PyTorch tensor representing the image
        - PIL.Image.Image:  A PIL Image object
        - List:             A list of images of any of the above types.

    :param images:  Single image or a list of images of supported types.
    :return:        List of images as numpy arrays. If loaded from string, the image will be returned as RGB.
    """
    return [image for image in generate_image_loader(images=images)]

load_np_image_from_pil(image)

Convert a PIL image to numpy array in RGB format.

Source code in src/super_gradients/training/utils/media/image.py
128
129
130
def load_np_image_from_pil(image: PIL.Image.Image) -> np.ndarray:
    """Convert a PIL image to numpy array in RGB format."""
    return np.asarray(image.convert("RGB"))

load_pil_image_from_str(image_str)

Load an image based on a string (local file path or URL).

Source code in src/super_gradients/training/utils/media/image.py
133
134
135
136
137
138
139
140
141
def load_pil_image_from_str(image_str: str) -> PIL.Image.Image:
    """Load an image based on a string (local file path or URL)."""

    if is_url(image_str):
        response = requests.get(image_str, stream=True)
        response.raise_for_status()
        return PIL.Image.open(io.BytesIO(response.content))
    else:
        return PIL.Image.open(image_str)

save_image(image, path)

Save a numpy array as an image.

Parameters:

Name Type Description Default
image np.ndarray

Image to save, (H, W, C), RGB.

required
path str

Path to save the image to.

required
Source code in src/super_gradients/training/utils/media/image.py
144
145
146
147
148
149
def save_image(image: np.ndarray, path: str) -> None:
    """Save a numpy array as an image.
    :param image:  Image to save, (H, W, C), RGB.
    :param path:   Path to save the image to.
    """
    Image.fromarray(image).save(path)

show_image(image)

Show an image using matplotlib.

Parameters:

Name Type Description Default
image np.ndarray

Image to show in (H, W, C), RGB.

required
Source code in src/super_gradients/training/utils/media/image.py
163
164
165
166
167
168
169
170
171
def show_image(image: np.ndarray) -> None:
    """Show an image using matplotlib.
    :param image: Image to show in (H, W, C), RGB.
    """
    plt.figure(figsize=(image.shape[1] / 100.0, image.shape[0] / 100.0), dpi=100)
    plt.imshow(image, interpolation="nearest")
    plt.axis("off")
    plt.tight_layout()
    plt.show()

FPSCounter

Class for calculating the FPS of a video stream.

Source code in src/super_gradients/training/utils/media/stream.py
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
class FPSCounter:
    """Class for calculating the FPS of a video stream."""

    def __init__(self, update_frequency: Optional[float] = None):
        """Create a new FPSCounter object.

        :param update_frequency: Minimum time (in seconds) between updates to the FPS counter.
                                 If None, the counter is updated every frame.
        """
        self._update_frequency = update_frequency

        self._start_time = time.time()
        self._frame_count = 0
        self._fps = 0.0

    def _update_fps(self, elapsed_time, current_time) -> None:
        """Compute new value of FPS and reset the counter."""
        self._fps = self._frame_count / elapsed_time
        self._start_time = current_time
        self._frame_count = 0

    @property
    def fps(self) -> float:
        """Current FPS value."""

        self._frame_count += 1
        current_time, elapsed_time = time.time(), time.time() - self._start_time

        if self._update_frequency is None or elapsed_time > self._update_frequency:
            self._update_fps(elapsed_time=elapsed_time, current_time=current_time)

        return self._fps

fps: float property

Current FPS value.

__init__(update_frequency=None)

Create a new FPSCounter object.

Parameters:

Name Type Description Default
update_frequency Optional[float]

Minimum time (in seconds) between updates to the FPS counter. If None, the counter is updated every frame.

None
Source code in src/super_gradients/training/utils/media/stream.py
108
109
110
111
112
113
114
115
116
117
118
def __init__(self, update_frequency: Optional[float] = None):
    """Create a new FPSCounter object.

    :param update_frequency: Minimum time (in seconds) between updates to the FPS counter.
                             If None, the counter is updated every frame.
    """
    self._update_frequency = update_frequency

    self._start_time = time.time()
    self._frame_count = 0
    self._fps = 0.0

WebcamStreaming

Stream video from a webcam. Press 'q' to quit the streaming.

Parameters:

Name Type Description Default
window_name str

Name of the window to display the video stream.

''
frame_processing_fn Optional[Callable[[np.ndarray], np.ndarray]]

Function to apply to each frame before displaying it. If None, frames are displayed as is.

None
capture int

ID of the video capture device to use. Default is cv2.CAP_ANY (which selects the first available device).

cv2.CAP_ANY
fps_update_frequency Optional[float]

Minimum time (in seconds) between updates to the FPS counter. If None, the counter is updated every frame.

None
Source code in src/super_gradients/training/utils/media/stream.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
class WebcamStreaming:
    """Stream video from a webcam. Press 'q' to quit the streaming.

    :param window_name:          Name of the window to display the video stream.
    :param frame_processing_fn:  Function to apply to each frame before displaying it.
                                 If None, frames are displayed as is.
    :param capture:              ID of the video capture device to use.
                                 Default is cv2.CAP_ANY (which selects the first available device).
    :param fps_update_frequency: Minimum time (in seconds) between updates to the FPS counter.
                                 If None, the counter is updated every frame.
    """

    def __init__(
        self,
        window_name: str = "",
        frame_processing_fn: Optional[Callable[[np.ndarray], np.ndarray]] = None,
        capture: int = cv2.CAP_ANY,
        fps_update_frequency: Optional[float] = None,
    ):
        self.window_name = window_name
        self.frame_processing_fn = frame_processing_fn
        self.cap = cv2.VideoCapture(capture)
        if not self.cap.isOpened():
            message = "Could not open video capture device. Please check whether you have the webcam connected."
            if sys.platform == "darwin":
                message += " On macOS, you may need to grant the terminal access to the webcam in System Preferences."
                message += " Check https://stackoverflow.com/search?q=OpenCV+macOS+camera+access for more information."
            elif sys.platform == "nt":
                message += " On Windows, you may need to grant the terminal access to the webcam in the settings."
                message += " Check https://support.microsoft.com/en-us/windows/manage-app-permissions-for-your-camera-in-windows-87ebc757-1f87-7bbf-84b5-0686afb6ca6b#WindowsVersion=Windows_11 for more information."  # noqa
            raise ValueError(message)

        self._fps_counter = FPSCounter(update_frequency=fps_update_frequency)

    def run(self) -> None:
        """Start streaming video from the webcam and displaying it in a window.

        Press 'q' to quit the streaming.
        """
        while not self._stop() and self._display_single_frame():
            pass

    def _display_single_frame(self) -> bool:
        """Read a single frame from the video capture device, apply any specified frame processing,
        and display the resulting frame in the window.

        Also updates the FPS counter and displays it in the frame.
        """
        _ret, frame = self.cap.read()
        if not _ret or frame is None:
            logger.warning("Could not read frame from video capture device.")
            return False

        if self.frame_processing_fn:
            # Convert the frame to RGB since this is the format expected
            # by the predict function
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            frame = self.frame_processing_fn(frame)
            frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)

        _write_fps_to_frame(frame, self.fps)
        cv2.imshow(self.window_name, frame)
        return _ret

    def _stop(self) -> bool:
        """Stopping condition for the streaming."""
        return cv2.waitKey(1) & 0xFF == ord("q")

    @property
    def fps(self) -> float:
        return self._fps_counter.fps

    def __del__(self):
        """Release the video capture device and close the window."""
        self.cap.release()
        cv2.destroyAllWindows()

__del__()

Release the video capture device and close the window.

Source code in src/super_gradients/training/utils/media/stream.py
86
87
88
89
def __del__(self):
    """Release the video capture device and close the window."""
    self.cap.release()
    cv2.destroyAllWindows()

run()

Start streaming video from the webcam and displaying it in a window.

Press 'q' to quit the streaming.

Source code in src/super_gradients/training/utils/media/stream.py
48
49
50
51
52
53
54
def run(self) -> None:
    """Start streaming video from the webcam and displaying it in a window.

    Press 'q' to quit the streaming.
    """
    while not self._stop() and self._display_single_frame():
        pass

includes_video_extension(file_path)

Check if a file includes a video extension.

Parameters:

Name Type Description Default
file_path str

Path to the video file.

required

Returns:

Type Description
bool

True if the file includes a video extension.

Source code in src/super_gradients/training/utils/media/video.py
218
219
220
221
222
223
def includes_video_extension(file_path: str) -> bool:
    """Check if a file includes a video extension.
    :param file_path:   Path to the video file.
    :return:            True if the file includes a video extension.
    """
    return isinstance(file_path, str) and file_path.lower().endswith(VIDEO_EXTENSIONS)

lazy_load_video(file_path, max_frames=None)

Open a video file and returns a generator which yields frames.

Parameters:

Name Type Description Default
file_path str

Path to the video file.

required
max_frames Optional[int]

Optional, maximum number of frames to extract.

None

Returns:

Type Description
Tuple[Iterator[np.ndarray], int, int]
  • Generator yielding frames representing the video, each in (H, W, C), RGB. - Frames per Second (FPS). - Amount of frames in video.
Source code in src/super_gradients/training/utils/media/video.py
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
def lazy_load_video(file_path: str, max_frames: Optional[int] = None) -> Tuple[Iterator[np.ndarray], int, int]:
    """Open a video file and returns a generator which yields frames.

    :param file_path:   Path to the video file.
    :param max_frames:  Optional, maximum number of frames to extract.
    :return:
                - Generator yielding frames representing the video, each in (H, W, C), RGB.
                - Frames per Second (FPS).
                - Amount of frames in video.
    """
    cap = _open_video(file_path)
    fps = cap.get(cv2.CAP_PROP_FPS)
    frames = _lazy_extract_frames(cap, max_frames)
    num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    return frames, fps, num_frames

load_video(file_path, max_frames=None)

Open a video file and extract each frame into numpy array.

Parameters:

Name Type Description Default
file_path str

Path to the video file.

required
max_frames Optional[int]

Optional, maximum number of frames to extract.

None

Returns:

Type Description
Tuple[List[np.ndarray], int]
  • Frames representing the video, each in (H, W, C), RGB. - Frames per Second (FPS).
Source code in src/super_gradients/training/utils/media/video.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
def load_video(file_path: str, max_frames: Optional[int] = None) -> Tuple[List[np.ndarray], int]:
    """Open a video file and extract each frame into numpy array.

    :param file_path:   Path to the video file.
    :param max_frames:  Optional, maximum number of frames to extract.
    :return:
                - Frames representing the video, each in (H, W, C), RGB.
                - Frames per Second (FPS).
    """
    cap = _open_video(file_path)
    frames = _extract_frames(cap, max_frames)
    fps = cap.get(cv2.CAP_PROP_FPS)
    cap.release()
    return frames, fps

save_gif(output_path, frames, fps)

Save a video locally in .gif format. Safe for generator of frames object.

Parameters:

Name Type Description Default
output_path str

Where the video will be saved

required
frames Iterable[np.ndarray]

Frames representing the video, each in (H, W, C), RGB. Note that all the frames are expected to have the same shape.

required
fps int

Frames per second

required
Source code in src/super_gradients/training/utils/media/video.py
119
120
121
122
123
124
125
126
127
128
129
130
131
def save_gif(output_path: str, frames: Iterable[np.ndarray], fps: int) -> None:
    """Save a video locally in .gif format. Safe for generator of frames object.

    :param output_path: Where the video will be saved
    :param frames:      Frames representing the video, each in (H, W, C), RGB. Note that all the frames are expected to have the same shape.
    :param fps:         Frames per second
    """
    frame_iter_obj = iter(frames)
    pil_frames_iter_obj = map(PIL.Image.fromarray, frame_iter_obj)

    first_frame = next(pil_frames_iter_obj)

    first_frame.save(output_path, save_all=True, append_images=pil_frames_iter_obj, duration=int(1000 / fps), loop=0)

save_mp4(output_path, frames, fps)

Save a video locally in .mp4 format. Safe for generator of frames object.

Parameters:

Name Type Description Default
output_path str

Where the video will be saved

required
frames Iterable[np.ndarray]

Frames representing the video, each in (H, W, C), RGB. Note that all the frames are expected to have the same shape.

required
fps int

Frames per second

required
Source code in src/super_gradients/training/utils/media/video.py
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
def save_mp4(output_path: str, frames: Iterable[np.ndarray], fps: int) -> None:
    """Save a video locally in .mp4 format. Safe for generator of frames object.

    :param output_path: Where the video will be saved
    :param frames:      Frames representing the video, each in (H, W, C), RGB. Note that all the frames are expected to have the same shape.
    :param fps:         Frames per second
    """
    video_height, video_width, video_writer = None, None, None

    for frame in frames:
        if video_height is None:
            video_height, video_width = frame.shape[:2]
            video_writer = cv2.VideoWriter(
                output_path,
                cv2.VideoWriter_fourcc(*"mp4v"),
                fps,
                (video_width, video_height),
            )
        _validate_frame(frame, video_height, video_width)
        video_writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))

    video_writer.release()

save_video(output_path, frames, fps)

Save a video locally. Depending on the extension, the video will be saved as a .mp4 file or as a .gif file.

Parameters:

Name Type Description Default
output_path str

Where the video will be saved

required
frames List[np.ndarray]

Frames representing the video, each in (H, W, C), RGB. Note that all the frames are expected to have the same shape.

required
fps int

Frames per second

required
Source code in src/super_gradients/training/utils/media/video.py
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
def save_video(output_path: str, frames: List[np.ndarray], fps: int) -> None:
    """Save a video locally. Depending on the extension, the video will be saved as a .mp4 file or as a .gif file.

    :param output_path: Where the video will be saved
    :param frames:      Frames representing the video, each in (H, W, C), RGB. Note that all the frames are expected to have the same shape.
    :param fps:         Frames per second
    """
    if not includes_video_extension(output_path):
        logger.info(f'Output path "{output_path}" does not have a video extension, and therefore will be saved as {output_path}.mp4')
        output_path += ".mp4"

    if check_is_gif(output_path):
        save_gif(output_path, frames, fps)
    else:
        save_mp4(output_path, frames, fps)

show_video_from_disk(video_path, window_name='Prediction')

Display a video from disk using OpenCV.

Parameters:

Name Type Description Default
video_path str

Path to the video file.

required
window_name str

Name of the window to display the video

'Prediction'
Source code in src/super_gradients/training/utils/media/video.py
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
def show_video_from_disk(video_path: str, window_name: str = "Prediction"):
    """Display a video from disk using OpenCV.

    :param video_path:   Path to the video file.
    :param window_name:  Name of the window to display the video
    """
    cap = _open_video(video_path)
    fps = cap.get(cv2.CAP_PROP_FPS)

    while cap.isOpened():
        ret, frame = cap.read()

        if ret:
            # Display the frame
            cv2.imshow(window_name, frame)

            # Wait for the specified number of milliseconds before displaying the next frame
            if cv2.waitKey(int(1000 / fps)) & 0xFF == ord("q"):
                break
        else:
            break

    # Release the VideoCapture object and destroy the window
    cap.release()
    cv2.destroyAllWindows()
    cv2.waitKey(1)

show_video_from_frames(frames, fps, window_name='Prediction')

Display a video from a list of frames using OpenCV.

Parameters:

Name Type Description Default
frames List[np.ndarray]

Frames representing the video, each in (H, W, C), RGB. Note that all the frames are expected to have the same shape.

required
fps float

Frames per second

required
window_name str

Name of the window to display the video

'Prediction'
Source code in src/super_gradients/training/utils/media/video.py
203
204
205
206
207
208
209
210
211
212
213
214
215
def show_video_from_frames(frames: List[np.ndarray], fps: float, window_name: str = "Prediction") -> None:
    """Display a video from a list of frames using OpenCV.

    :param frames:      Frames representing the video, each in (H, W, C), RGB. Note that all the frames are expected to have the same shape.
    :param fps:         Frames per second
    :param window_name:  Name of the window to display the video
    """
    for frame in frames:
        frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
        cv2.imshow(window_name, frame)
        cv2.waitKey(int(1000 / fps))
    cv2.destroyAllWindows()
    cv2.waitKey(1)

build_optimizer(net, lr, training_params)

Wrapper function for initializing the optimizer :param net: the nn_module to build the optimizer for :param lr: initial learning rate :param training_params: training_parameters

Source code in src/super_gradients/training/utils/optimizer_utils.py
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
def build_optimizer(net: nn.Module, lr: float, training_params) -> optim.Optimizer:
    """
    Wrapper function for initializing the optimizer
        :param net: the nn_module to build the optimizer for
        :param lr: initial learning rate
        :param training_params: training_parameters
    """
    if is_model_wrapped(net):
        raise ValueError("Argument net for build_optimizer must be an unwrapped model. " "Please use build_optimizer(unwrap_model(net), ...).")
    if isinstance(training_params.optimizer, str):
        optimizer_cls = OptimizersTypeFactory().get(training_params.optimizer)
    else:
        optimizer_cls = training_params.optimizer
    optimizer_params = OPTIMIZERS_DEFAULT_PARAMS[optimizer_cls].copy() if optimizer_cls in OPTIMIZERS_DEFAULT_PARAMS.keys() else dict()
    optimizer_params.update(**training_params.optimizer_params)
    training_params.optimizer_params = optimizer_params

    weight_decay = get_param(training_params.optimizer_params, "weight_decay", 0.0)
    # OPTIMIZER PARAM GROUPS ARE SET USING DEFAULT OR MODEL SPECIFIC INIT
    if hasattr(net, "initialize_param_groups") or hasattr(net, "update_param_groups"):
        warnings.warn(
            "initialize_param_groups and update_param_groups usages are deprecated since 3.4.0, will be removed in "
            "3.5.0 and have no effect. \n "
            "Assign different learning rates by passing a mapping of layer name prefixes to lr values through "
            "initial_lr training hyperparameter (i.e initial_lr={'backbone': 0.01, 'default':0.1})",
            DeprecationWarning,
        )
    if training_params.finetune:
        if not isinstance(net, SupportsFineTune):
            warnings.warn(
                "training hyperparameter finetune=True but will have no effect. get_finetune_lr_dict is not implemented for this model, which is required."
            )
        elif not isinstance(lr, float):
            raise RuntimeError("When training with fine_tune=True, initial_lr must be a scalar.")
        lr = net.get_finetune_lr_dict(lr)
        logger.info(f"Training with finetune=True: setting initial_lr to predefined mapping {lr}")
        training_params.initial_lr = lr

    net_named_params = initialize_param_groups(net, lr)

    if training_params.zero_weight_decay_on_bias_and_bn:
        optimizer_training_params = separate_zero_wd_params_groups_for_optimizer(net, net_named_params, weight_decay)

    else:
        # Overwrite groups to include params instead of named params
        for ind_group, param_group in enumerate(net_named_params):
            param_group["params"] = [param[1] for param in list(param_group["named_params"])]
            del param_group["named_params"]
            net_named_params[ind_group] = param_group
        optimizer_training_params = net_named_params

    # CREATE AN OPTIMIZER OBJECT AND INITIALIZE IT
    optimizer = optimizer_cls(optimizer_training_params, **training_params.optimizer_params)

    return optimizer

get_initial_lr_from_optimizer(optimizer)

Returns Initial learning rate as:

float - learning rate value when passed as a scalar Dictionary where keys are group names and values are the learning rates. For example {"default": 0.01, "head": 0.1}

Does so by iterating over the optmizer.param_groups and extracting the "lr" vaules. If the optimizer was intiialized with .parameters() and not named_paramters(), names will be assigned to the optimizer parameter groups by index.

Parameters:

Name Type Description Default
optimizer torch.optim.Optimizer

torch.optim.Optimizer, The optimizer to extract the lrs from.

required

Returns:

Type Description
Union[Dict[str, float], float]

initial_lr as described above.

Source code in src/super_gradients/training/utils/optimizer_utils.py
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
def get_initial_lr_from_optimizer(optimizer: torch.optim.Optimizer) -> Union[Dict[str, float], float]:
    """
    Returns Initial learning rate as:

    float - learning rate value when passed as a scalar
    Dictionary where keys are group names and values are the learning rates.
    For example {"default": 0.01, "head": 0.1}

    Does so by iterating over the optmizer.param_groups and extracting the "lr" vaules.
    If the optimizer was intiialized with .parameters() and not named_paramters(), names will be assigned to the
     optimizer parameter groups by index.

    :param optimizer: torch.optim.Optimizer, The optimizer to extract the lrs from.
    :return: initial_lr as described above.
    """
    if "name" not in optimizer.param_groups[0].keys():
        optimizer = name_optimizer_param_groups_inplace(optimizer)
    if len(optimizer.param_groups) == 1:
        initial_lr = optimizer.param_groups[0]["lr"]
    else:
        initial_lr = {group["name"]: group["lr"] for group in optimizer.param_groups}
    return initial_lr

initialize_param_groups(model, lr)

Custom param groups for training with specified learning rates for each group in the model.

Parameters:

Name Type Description Default
model nn.Module

nn.Module model.

required
lr Union[float, Dict[str, float]]

Dictionary where keys are group names and values are the learning rates, or a learning rate value when passed as a scalar.

required

Returns:

Type Description
List[Dict]

List of param groups.

Source code in src/super_gradients/training/utils/optimizer_utils.py
181
182
183
184
185
186
187
188
189
190
191
192
193
def initialize_param_groups(model: nn.Module, lr: Union[float, Dict[str, float]]) -> List[Dict]:
    """
    Custom param groups for training with specified learning rates for each group in the model.
    :param model: nn.Module model.
    :param lr: Dictionary where keys are group names and values are the learning rates,
     or a learning rate value when passed as a scalar.
    :return: List of param groups.
    """
    if isinstance(lr, float) or isinstance(lr, int):
        model_named_params = [{"named_params": model.named_parameters(), "lr": lr, "name": "default"}]
    else:
        model_named_params = separate_lr_groups(model, lr)
    return model_named_params

name_optimizer_param_groups_inplace(optimizer)

Convert an optimizer's param_groups to use named parameters, modifying it in place.

Parameters:

Name Type Description Default
optimizer torch.optim.Optimizer

torch.optim.Optimizer, The optimizer to be converted. Returns: torch.optim.Optimizer: The same optimizer with modified param_groups.

required
Source code in src/super_gradients/training/utils/optimizer_utils.py
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
def name_optimizer_param_groups_inplace(optimizer: torch.optim.Optimizer) -> torch.optim.Optimizer:
    """
    Convert an optimizer's param_groups to use named parameters, modifying it in place.

    :param optimizer: torch.optim.Optimizer, The optimizer to be converted.

    Returns:
        torch.optim.Optimizer: The same optimizer with modified param_groups.
    """

    named_parameters = list(optimizer.param_groups[0]["params"])
    num_param_groups = len(optimizer.param_groups)
    group_name = [f"group_{i}" for i in range(num_param_groups)] if num_param_groups > 1 else "default"

    for i, param_group in enumerate(optimizer.param_groups):
        param_group["params"] = named_parameters
        param_group["name"] = group_name if num_param_groups == 1 else group_name[i]

    return optimizer

separate_lr_groups(model, lr_dict)

Separate parameters based on specified learning rates for each group in the model.

Parameters:

Name Type Description Default
model nn.Module

nn.Module model.

required
lr_dict Dict[str, float]

Dictionary where keys are group names and values are the learning rates.

required

Returns:

Type Description
List[Dict]

List of param groups with named_parameters and corresponding learning rates.

Source code in src/super_gradients/training/utils/optimizer_utils.py
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
def separate_lr_groups(model: nn.Module, lr_dict: Dict[str, float]) -> List[Dict]:
    """
    Separate parameters based on specified learning rates for each group in the model.
    :param model: nn.Module model.
    :param lr_dict: Dictionary where keys are group names and values are the learning rates.
    :return: List of param groups with named_parameters and corresponding learning rates.
    """
    param_groups = []
    default_lr = lr_dict.get("default", None)
    if default_lr is None:
        raise RuntimeError("When passing initial_lr as dictionary, must pass 'default'.")
    group_names = set(lr_dict.keys()) - {"default"}

    for group_name in group_names:
        lr = lr_dict[group_name]
        named_params = [(name, param) for name, param in model.named_parameters() if name.startswith(group_name)]

        if lr == 0:
            for name, param in named_params:
                param.requires_grad = False  # Freeze the layer
        else:
            param_groups.append({"named_params": named_params, "lr": lr, "name": group_name})

    default_named_params = [
        (name, param) for name, param in model.named_parameters() if all(name.startswith(group) is False for group in group_names) and param.requires_grad
    ]
    if default_named_params:
        if default_lr != 0:
            param_groups.append({"named_params": default_named_params, "lr": default_lr, "name": "default"})
        else:
            for name, param in default_named_params:
                param.requires_grad = False  # Freeze the layer

    return param_groups

separate_zero_wd_params_groups_for_optimizer(module, net_named_params, weight_decay)

separate param groups for batchnorm and biases and others with weight decay. return list of param groups in format required by torch Optimizer classes. bias + BN with weight decay=0 and the rest with the given weight decay :param module: train net module. :param net_named_params: list of params groups, output of SgModule.initialize_param_groups :param weight_decay: value to set for the non BN and bias parameters

Source code in src/super_gradients/training/utils/optimizer_utils.py
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
def separate_zero_wd_params_groups_for_optimizer(module: nn.Module, net_named_params, weight_decay: float):
    """
    separate param groups for batchnorm and biases and others with weight decay. return list of param groups in format
     required by torch Optimizer classes.
    bias + BN with weight decay=0 and the rest with the given weight decay
        :param module: train net module.
        :param net_named_params: list of params groups, output of SgModule.initialize_param_groups
        :param weight_decay: value to set for the non BN and bias parameters
    """
    # FIXME - replace usage of ids addresses to find batchnorm and biases params.
    #  This solution iterate 2 times over module parameters, find a way to iterate only one time.
    no_decay_ids = _get_no_decay_param_ids(module)
    # split param groups for optimizer
    optimizer_param_groups = []
    for param_group in net_named_params:
        no_decay_params = []
        decay_params = []
        for name, param in param_group["named_params"]:
            if id(param) in no_decay_ids:
                no_decay_params.append(param)
            else:
                decay_params.append(param)
        # append two param groups from the original param group, with and without weight decay.
        extra_optim_params = {key: param_group[key] for key in param_group if key not in ["named_params", "weight_decay"]}
        optimizer_param_groups.append({"params": no_decay_params, "weight_decay": 0.0, **extra_optim_params})
        optimizer_param_groups.append({"params": decay_params, "weight_decay": weight_decay, **extra_optim_params})

    return optimizer_param_groups

This implementation is taken from timm's github: https://github.com/rwightman/pytorch-image-models/blob/master/timm/optim/lamb.py

PyTorch Lamb optimizer w/ behaviour similar to NVIDIA FusedLamb

This optimizer code was adapted from the following (starting with latest) * https://github.com/HabanaAI/Model-References/blob/2b435114fe8e31f159b1d3063b8280ae37af7423/PyTorch/nlp/bert/pretraining/lamb.py * https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/LanguageModeling/Transformer-XL/pytorch/lamb.py * https://github.com/cybertronai/pytorch-lamb

Use FusedLamb if you can (GPU). The reason for including this variant of Lamb is to have a version that is similar in behaviour to APEX FusedLamb if you aren't using NVIDIA GPUs or cannot install/use APEX.

In addition to some cleanup, this Lamb impl has been modified to support PyTorch XLA and has been tested on TPU.

Original copyrights for above sources are below.

Modifications Copyright 2021 Ross Wightman

Copyright (c) 2021, Habana Labs Ltd. All rights reserved.

Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at

https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.

MIT License

Copyright (c) 2019 cybertronai

Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

Lamb

Bases: Optimizer

Implements a pure pytorch variant of FuseLAMB (NvLamb variant) optimizer from apex.optimizers.FusedLAMB reference: https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/LanguageModeling/Transformer-XL/pytorch/lamb.py

LAMB was proposed in Large Batch Optimization for Deep Learning: Training BERT in 76 minutes_.

Arguments: params (iterable): iterable of parameters to optimize or dicts defining parameter groups. lr (float, optional): learning rate. (default: 1e-3) betas (Tuple[float, float], optional): coefficients used for computing running averages of gradient and its norm. (default: (0.9, 0.999)) eps (float, optional): term added to the denominator to improve numerical stability. (default: 1e-8) weight_decay (float, optional): weight decay (L2 penalty) (default: 0) grad_averaging (bool, optional): whether apply (1-beta2) to grad when calculating running averages of gradient. (default: True) max_grad_norm (float, optional): value used to clip global grad norm (default: 1.0) trust_clip (bool): enable LAMBC trust ratio clipping (default: False) always_adapt (boolean, optional): Apply adaptive learning rate to 0.0 weight decay parameter (default: False)

.. _Large Batch Optimization for Deep Learning - Training BERT in 76 minutes: https://arxiv.org/abs/1904.00962 .. _On the Convergence of Adam and Beyond: https://openreview.net/forum?id=ryQu7f-RZ

Source code in src/super_gradients/training/utils/optimizers/lamb.py
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
@register_optimizer(Optimizers.LAMB)
class Lamb(Optimizer):
    """Implements a pure pytorch variant of FuseLAMB (NvLamb variant) optimizer from apex.optimizers.FusedLAMB
    reference: https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/LanguageModeling/Transformer-XL/pytorch/lamb.py

    LAMB was proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_.

    Arguments:
        params (iterable): iterable of parameters to optimize or dicts defining parameter groups.
        lr (float, optional): learning rate. (default: 1e-3)
        betas (Tuple[float, float], optional): coefficients used for computing
            running averages of gradient and its norm. (default: (0.9, 0.999))
        eps (float, optional): term added to the denominator to improve
            numerical stability. (default: 1e-8)
        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
        grad_averaging (bool, optional): whether apply (1-beta2) to grad when
            calculating running averages of gradient. (default: True)
        max_grad_norm (float, optional): value used to clip global grad norm (default: 1.0)
        trust_clip (bool): enable LAMBC trust ratio clipping (default: False)
        always_adapt (boolean, optional): Apply adaptive learning rate to 0.0
            weight decay parameter (default: False)

    .. _Large Batch Optimization for Deep Learning - Training BERT in 76 minutes:
        https://arxiv.org/abs/1904.00962
    .. _On the Convergence of Adam and Beyond:
        https://openreview.net/forum?id=ryQu7f-RZ
    """

    def __init__(
        self,
        params: Union[Iterable[torch.Tensor], Iterable[dict]],
        lr: float = 1e-3,
        bias_correction: bool = True,
        betas: Tuple[float, float] = (0.9, 0.999),
        eps: float = 1e-6,
        weight_decay: float = 0.01,
        grad_averaging: bool = True,
        max_grad_norm: float = 1.0,
        trust_clip: bool = False,
        always_adapt: bool = False,
    ):
        defaults = dict(
            lr=lr,
            bias_correction=bias_correction,
            betas=betas,
            eps=eps,
            weight_decay=weight_decay,
            grad_averaging=grad_averaging,
            max_grad_norm=max_grad_norm,
            trust_clip=trust_clip,
            always_adapt=always_adapt,
        )
        super().__init__(params, defaults)

    @torch.no_grad()
    def step(self, closure: Optional[callable] = None) -> torch.Tensor:
        """Performs a single optimization step.
        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        device = self.param_groups[0]["params"][0].device
        one_tensor = torch.tensor(1.0, device=device)  # because torch.where doesn't handle scalars correctly
        global_grad_norm = torch.zeros(1, device=device)
        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None:
                    continue
                grad = p.grad
                if grad.is_sparse:
                    raise RuntimeError("Lamb does not support sparse gradients, consider SparseAdam instad.")
                global_grad_norm.add_(grad.pow(2).sum())

        global_grad_norm = torch.sqrt(global_grad_norm)
        # FIXME it'd be nice to remove explicit tensor conversion of scalars when torch.where promotes
        # scalar types properly https://github.com/pytorch/pytorch/issues/9190
        max_grad_norm = torch.tensor(self.defaults["max_grad_norm"], device=device)
        clip_global_grad_norm = torch.where(global_grad_norm > max_grad_norm, global_grad_norm / max_grad_norm, one_tensor)

        for group in self.param_groups:
            bias_correction = 1 if group["bias_correction"] else 0
            beta1, beta2 = group["betas"]
            grad_averaging = 1 if group["grad_averaging"] else 0
            beta3 = 1 - beta1 if grad_averaging else 1.0

            # assume same step across group now to simplify things
            # per parameter step can be easily support by making it tensor, or pass list into kernel
            if "step" in group:
                group["step"] += 1
            else:
                group["step"] = 1

            if bias_correction:
                bias_correction1 = 1 - beta1 ** group["step"]
                bias_correction2 = 1 - beta2 ** group["step"]
            else:
                bias_correction1, bias_correction2 = 1.0, 1.0

            for p in group["params"]:
                if p.grad is None:
                    continue
                grad = p.grad.div_(clip_global_grad_norm)
                state = self.state[p]

                # State initialization
                if len(state) == 0:
                    # Exponential moving average of gradient valuesa
                    state["exp_avg"] = torch.zeros_like(p)
                    # Exponential moving average of squared gradient values
                    state["exp_avg_sq"] = torch.zeros_like(p)

                exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]

                # Decay the first and second moment running average coefficient
                exp_avg.mul_(beta1).add_(grad, alpha=beta3)  # m_t
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)  # v_t

                denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group["eps"])
                update = (exp_avg / bias_correction1).div_(denom)

                weight_decay = group["weight_decay"]
                if weight_decay != 0:
                    update.add_(p, alpha=weight_decay)

                if weight_decay != 0 or group["always_adapt"]:
                    # Layer-wise LR adaptation. By default, skip adaptation on parameters that are
                    # excluded from weight decay, unless always_adapt == True, then always enabled.
                    w_norm = p.norm(2.0)
                    g_norm = update.norm(2.0)
                    # FIXME nested where required since logical and/or not working in PT XLA
                    trust_ratio = torch.where(
                        w_norm > 0,
                        torch.where(g_norm > 0, w_norm / g_norm, one_tensor),
                        one_tensor,
                    )
                    if group["trust_clip"]:
                        # LAMBC trust clipping, upper bound fixed at one
                        trust_ratio = torch.minimum(trust_ratio, one_tensor)
                    update.mul_(trust_ratio)

                p.add_(update, alpha=-group["lr"])

        return loss

step(closure=None)

Performs a single optimization step. Arguments: closure (callable, optional): A closure that reevaluates the model and returns the loss.

Source code in src/super_gradients/training/utils/optimizers/lamb.py
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
@torch.no_grad()
def step(self, closure: Optional[callable] = None) -> torch.Tensor:
    """Performs a single optimization step.
    Arguments:
        closure (callable, optional): A closure that reevaluates the model
            and returns the loss.
    """
    loss = None
    if closure is not None:
        with torch.enable_grad():
            loss = closure()

    device = self.param_groups[0]["params"][0].device
    one_tensor = torch.tensor(1.0, device=device)  # because torch.where doesn't handle scalars correctly
    global_grad_norm = torch.zeros(1, device=device)
    for group in self.param_groups:
        for p in group["params"]:
            if p.grad is None:
                continue
            grad = p.grad
            if grad.is_sparse:
                raise RuntimeError("Lamb does not support sparse gradients, consider SparseAdam instad.")
            global_grad_norm.add_(grad.pow(2).sum())

    global_grad_norm = torch.sqrt(global_grad_norm)
    # FIXME it'd be nice to remove explicit tensor conversion of scalars when torch.where promotes
    # scalar types properly https://github.com/pytorch/pytorch/issues/9190
    max_grad_norm = torch.tensor(self.defaults["max_grad_norm"], device=device)
    clip_global_grad_norm = torch.where(global_grad_norm > max_grad_norm, global_grad_norm / max_grad_norm, one_tensor)

    for group in self.param_groups:
        bias_correction = 1 if group["bias_correction"] else 0
        beta1, beta2 = group["betas"]
        grad_averaging = 1 if group["grad_averaging"] else 0
        beta3 = 1 - beta1 if grad_averaging else 1.0

        # assume same step across group now to simplify things
        # per parameter step can be easily support by making it tensor, or pass list into kernel
        if "step" in group:
            group["step"] += 1
        else:
            group["step"] = 1

        if bias_correction:
            bias_correction1 = 1 - beta1 ** group["step"]
            bias_correction2 = 1 - beta2 ** group["step"]
        else:
            bias_correction1, bias_correction2 = 1.0, 1.0

        for p in group["params"]:
            if p.grad is None:
                continue
            grad = p.grad.div_(clip_global_grad_norm)
            state = self.state[p]

            # State initialization
            if len(state) == 0:
                # Exponential moving average of gradient valuesa
                state["exp_avg"] = torch.zeros_like(p)
                # Exponential moving average of squared gradient values
                state["exp_avg_sq"] = torch.zeros_like(p)

            exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]

            # Decay the first and second moment running average coefficient
            exp_avg.mul_(beta1).add_(grad, alpha=beta3)  # m_t
            exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)  # v_t

            denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group["eps"])
            update = (exp_avg / bias_correction1).div_(denom)

            weight_decay = group["weight_decay"]
            if weight_decay != 0:
                update.add_(p, alpha=weight_decay)

            if weight_decay != 0 or group["always_adapt"]:
                # Layer-wise LR adaptation. By default, skip adaptation on parameters that are
                # excluded from weight decay, unless always_adapt == True, then always enabled.
                w_norm = p.norm(2.0)
                g_norm = update.norm(2.0)
                # FIXME nested where required since logical and/or not working in PT XLA
                trust_ratio = torch.where(
                    w_norm > 0,
                    torch.where(g_norm > 0, w_norm / g_norm, one_tensor),
                    one_tensor,
                )
                if group["trust_clip"]:
                    # LAMBC trust clipping, upper bound fixed at one
                    trust_ratio = torch.minimum(trust_ratio, one_tensor)
                update.mul_(trust_ratio)

            p.add_(update, alpha=-group["lr"])

    return loss

PyTorch implementation of the Lion optimizer. Code adopted from: https://github.com/google/automl/blob/master/lion/lion_pytorch.py

Lion

Bases: Optimizer

Implements Lion algorithm. Generaly, it is recommended to divide lr used by AdamW by 10 and multiply the weight decay by 10.

Source code in src/super_gradients/training/utils/optimizers/lion.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
@register_optimizer(Optimizers.LION)
class Lion(Optimizer):
    r"""Implements Lion algorithm.
    Generaly, it is recommended to divide lr used by AdamW by 10 and multiply the weight decay by 10.
    """

    def __init__(
        self,
        params: Union[Iterable[torch.Tensor], Iterable[dict]],
        lr: float = 1e-4,
        betas: Tuple[float, float] = (0.9, 0.99),
        weight_decay: float = 0.0,
    ):
        """
        Initialize the hyperparameters.

        :param params:          Iterable of parameters to optimize or dicts defining parameter groups
        :param lr:              Learning rate (default: 1e-4)
        :param betas:           Coefficients used for computing running averages of gradient and its square (default: (0.9, 0.99))
        :param weight_decay:    Weight decay coefficient (default: 0)
        """

        if not 0.0 <= lr:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
        defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay)
        super().__init__(params, defaults)

    @torch.no_grad()
    def step(self, closure: Optional[callable] = None) -> torch.Tensor:
        """
        Perform a single optimization step.

        :param closure: A closure that reevaluates the model and returns the loss.
        :return: Loss.
        """
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None:
                    continue

                # Perform stepweight decay
                p.data.mul_(1 - group["lr"] * group["weight_decay"])

                grad = p.grad
                state = self.state[p]
                # State initialization
                if len(state) == 0:
                    # Exponential moving average of gradient values
                    state["exp_avg"] = torch.zeros_like(p)

                exp_avg = state["exp_avg"]
                beta1, beta2 = group["betas"]

                # Weight update
                update = exp_avg * beta1 + grad * (1 - beta1)
                p.add_(torch.sign(update), alpha=-group["lr"])
                # Decay the momentum running average coefficient
                exp_avg.mul_(beta2).add_(grad, alpha=1 - beta2)

        return loss

__init__(params, lr=0.0001, betas=(0.9, 0.99), weight_decay=0.0)

Initialize the hyperparameters.

Parameters:

Name Type Description Default
params Union[Iterable[torch.Tensor], Iterable[dict]]

Iterable of parameters to optimize or dicts defining parameter groups

required
lr float

Learning rate (default: 1e-4)

0.0001
betas Tuple[float, float]

Coefficients used for computing running averages of gradient and its square (default: (0.9, 0.99))

(0.9, 0.99)
weight_decay float

Weight decay coefficient (default: 0)

0.0
Source code in src/super_gradients/training/utils/optimizers/lion.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
def __init__(
    self,
    params: Union[Iterable[torch.Tensor], Iterable[dict]],
    lr: float = 1e-4,
    betas: Tuple[float, float] = (0.9, 0.99),
    weight_decay: float = 0.0,
):
    """
    Initialize the hyperparameters.

    :param params:          Iterable of parameters to optimize or dicts defining parameter groups
    :param lr:              Learning rate (default: 1e-4)
    :param betas:           Coefficients used for computing running averages of gradient and its square (default: (0.9, 0.99))
    :param weight_decay:    Weight decay coefficient (default: 0)
    """

    if not 0.0 <= lr:
        raise ValueError("Invalid learning rate: {}".format(lr))
    if not 0.0 <= betas[0] < 1.0:
        raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
    if not 0.0 <= betas[1] < 1.0:
        raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
    defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay)
    super().__init__(params, defaults)

step(closure=None)

Perform a single optimization step.

Parameters:

Name Type Description Default
closure Optional[callable]

A closure that reevaluates the model and returns the loss.

None

Returns:

Type Description
torch.Tensor

Loss.

Source code in src/super_gradients/training/utils/optimizers/lion.py
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
@torch.no_grad()
def step(self, closure: Optional[callable] = None) -> torch.Tensor:
    """
    Perform a single optimization step.

    :param closure: A closure that reevaluates the model and returns the loss.
    :return: Loss.
    """
    loss = None
    if closure is not None:
        with torch.enable_grad():
            loss = closure()

    for group in self.param_groups:
        for p in group["params"]:
            if p.grad is None:
                continue

            # Perform stepweight decay
            p.data.mul_(1 - group["lr"] * group["weight_decay"])

            grad = p.grad
            state = self.state[p]
            # State initialization
            if len(state) == 0:
                # Exponential moving average of gradient values
                state["exp_avg"] = torch.zeros_like(p)

            exp_avg = state["exp_avg"]
            beta1, beta2 = group["betas"]

            # Weight update
            update = exp_avg * beta1 + grad * (1 - beta1)
            p.add_(torch.sign(update), alpha=-group["lr"])
            # Decay the momentum running average coefficient
            exp_avg.mul_(beta2).add_(grad, alpha=1 - beta2)

    return loss

RMSpropTF

Bases: Optimizer

Implements RMSprop algorithm (TensorFlow style epsilon) NOTE: This is a direct cut-and-paste of PyTorch RMSprop with eps applied before sqrt and a few other modifications to closer match Tensorflow for matching hyper-params. Noteworthy changes include: 1. Epsilon applied inside square-root 2. square_avg initialized to ones 3. LR scaling of update accumulated in momentum buffer Proposed by G. Hinton in his course <http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf>. The centered version first appears in Generating Sequences With Recurrent Neural Networks <https://arxiv.org/pdf/1308.0850v5.pdf>.

Source code in src/super_gradients/training/utils/optimizers/rmsprop_tf.py
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
@register_optimizer(Optimizers.RMS_PROP_TF)
class RMSpropTF(Optimizer):
    """Implements RMSprop algorithm (TensorFlow style epsilon)
    NOTE: This is a direct cut-and-paste of PyTorch RMSprop with eps applied before sqrt
    and a few other modifications to closer match Tensorflow for matching hyper-params.
    Noteworthy changes include:
    1. Epsilon applied inside square-root
    2. square_avg initialized to ones
    3. LR scaling of update accumulated in momentum buffer
    Proposed by G. Hinton in his
    `course <http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf>`_.
    The centered version first appears in `Generating Sequences
    With Recurrent Neural Networks <https://arxiv.org/pdf/1308.0850v5.pdf>`_."""

    def __init__(
        self,
        params: Union[Iterable[torch.Tensor], Iterable[dict]],
        lr: float = 1e-2,
        alpha: float = 0.9,
        eps: float = 1e-10,
        weight_decay: float = 0,
        momentum: float = 0.0,
        centered: bool = False,
        decoupled_decay: bool = False,
        lr_in_momentum: bool = True,
    ):
        """RMSprop optimizer that follows the tf's RMSprop characteristics
        :param params (iterable): iterable of parameters to optimize or dicts defining parameter groups.
        :param lr (float, optional): learning rate
        :param momentum (float, optional): momentum factor
        :param alpha (float, optional): smoothing (decay) constant
        :param eps (float, optional): term added to the denominator to improve numerical stability
        :param centered (bool, optional) : if ``True``, compute the centered RMSProp, the gradient is normalized by an
         estimation of its variance
        :param weight_decay (float, optional): weight decay (L2 penalty)
        :param decoupled_decay (bool, optional): decoupled weight decay as per https://arxiv.org/abs/1711.05101
        :param lr_in_momentum (bool, optional): learning rate scaling is included in the momentum buffer update as per
         defaults in Tensorflow
        """
        if not 0.0 <= lr:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if not 0.0 <= eps:
            raise ValueError("Invalid epsilon value: {}".format(eps))
        if not 0.0 <= momentum:
            raise ValueError("Invalid momentum value: {}".format(momentum))
        if not 0.0 <= weight_decay:
            raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
        if not 0.0 <= alpha:
            raise ValueError("Invalid alpha value: {}".format(alpha))

        defaults = dict(
            lr=lr,
            momentum=momentum,
            alpha=alpha,
            eps=eps,
            centered=centered,
            weight_decay=weight_decay,
            decoupled_decay=decoupled_decay,
            lr_in_momentum=lr_in_momentum,
        )
        super(RMSpropTF, self).__init__(params, defaults)

    def __setstate__(self, state):
        super(RMSpropTF, self).__setstate__(state)
        for group in self.param_groups:
            group.setdefault("momentum", 0)
            group.setdefault("centered", False)

    def step(self, closure: Optional[callable] = None) -> torch.Tensor:  # noqa: C901
        """Performs a single optimization step.
        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None:
                    continue
                grad = p.grad.data
                if grad.is_sparse:
                    raise RuntimeError("RMSprop does not support sparse gradients")
                state = self.state[p]

                # State initialization
                if len(state) == 0:
                    state["step"] = 0
                    state["square_avg"] = torch.ones_like(p.data)  # PyTorch inits to zero
                    if group["momentum"] > 0:
                        state["momentum_buffer"] = torch.zeros_like(p.data)
                    if group["centered"]:
                        state["grad_avg"] = torch.zeros_like(p.data)

                square_avg = state["square_avg"]
                one_minus_alpha = 1.0 - group["alpha"]

                state["step"] += 1

                if group["weight_decay"] != 0:
                    if "decoupled_decay" in group and group["decoupled_decay"]:
                        p.data.add_(-group["weight_decay"], p.data)
                    else:
                        grad = grad.add(group["weight_decay"], p.data)

                # Tensorflow order of ops for updating squared avg
                square_avg.add_(one_minus_alpha, grad.pow(2) - square_avg)
                # square_avg.mul_(alpha).addcmul_(1 - alpha, grad, grad)  # PyTorch original

                if group["centered"]:
                    grad_avg = state["grad_avg"]
                    grad_avg.add_(one_minus_alpha, grad - grad_avg)
                    # grad_avg.mul_(alpha).add_(1 - alpha, grad)  # PyTorch original
                    avg = square_avg.addcmul(-1, grad_avg, grad_avg).add(group["eps"]).sqrt_()  # eps moved in sqrt
                else:
                    avg = square_avg.add(group["eps"]).sqrt_()  # eps moved in sqrt

                if group["momentum"] > 0:
                    buf = state["momentum_buffer"]
                    # Tensorflow accumulates the LR scaling in the momentum buffer
                    if "lr_in_momentum" in group and group["lr_in_momentum"]:
                        buf.mul_(group["momentum"]).addcdiv_(group["lr"], grad, avg)
                        p.data.add_(-buf)
                    else:
                        # PyTorch scales the param update by LR
                        buf.mul_(group["momentum"]).addcdiv_(grad, avg)
                        p.data.add_(-group["lr"], buf)
                else:
                    p.data.addcdiv_(-group["lr"], grad, avg)

        return loss

__init__(params, lr=0.01, alpha=0.9, eps=1e-10, weight_decay=0, momentum=0.0, centered=False, decoupled_decay=False, lr_in_momentum=True)

RMSprop optimizer that follows the tf's RMSprop characteristics

Parameters:

Name Type Description Default
(iterable) params

iterable of parameters to optimize or dicts defining parameter groups.

required
Source code in src/super_gradients/training/utils/optimizers/rmsprop_tf.py
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
def __init__(
    self,
    params: Union[Iterable[torch.Tensor], Iterable[dict]],
    lr: float = 1e-2,
    alpha: float = 0.9,
    eps: float = 1e-10,
    weight_decay: float = 0,
    momentum: float = 0.0,
    centered: bool = False,
    decoupled_decay: bool = False,
    lr_in_momentum: bool = True,
):
    """RMSprop optimizer that follows the tf's RMSprop characteristics
    :param params (iterable): iterable of parameters to optimize or dicts defining parameter groups.
    :param lr (float, optional): learning rate
    :param momentum (float, optional): momentum factor
    :param alpha (float, optional): smoothing (decay) constant
    :param eps (float, optional): term added to the denominator to improve numerical stability
    :param centered (bool, optional) : if ``True``, compute the centered RMSProp, the gradient is normalized by an
     estimation of its variance
    :param weight_decay (float, optional): weight decay (L2 penalty)
    :param decoupled_decay (bool, optional): decoupled weight decay as per https://arxiv.org/abs/1711.05101
    :param lr_in_momentum (bool, optional): learning rate scaling is included in the momentum buffer update as per
     defaults in Tensorflow
    """
    if not 0.0 <= lr:
        raise ValueError("Invalid learning rate: {}".format(lr))
    if not 0.0 <= eps:
        raise ValueError("Invalid epsilon value: {}".format(eps))
    if not 0.0 <= momentum:
        raise ValueError("Invalid momentum value: {}".format(momentum))
    if not 0.0 <= weight_decay:
        raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
    if not 0.0 <= alpha:
        raise ValueError("Invalid alpha value: {}".format(alpha))

    defaults = dict(
        lr=lr,
        momentum=momentum,
        alpha=alpha,
        eps=eps,
        centered=centered,
        weight_decay=weight_decay,
        decoupled_decay=decoupled_decay,
        lr_in_momentum=lr_in_momentum,
    )
    super(RMSpropTF, self).__init__(params, defaults)

step(closure=None)

Performs a single optimization step. Arguments: closure (callable, optional): A closure that reevaluates the model and returns the loss.

Source code in src/super_gradients/training/utils/optimizers/rmsprop_tf.py
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
def step(self, closure: Optional[callable] = None) -> torch.Tensor:  # noqa: C901
    """Performs a single optimization step.
    Arguments:
        closure (callable, optional): A closure that reevaluates the model
            and returns the loss.
    """
    loss = None
    if closure is not None:
        loss = closure()

    for group in self.param_groups:
        for p in group["params"]:
            if p.grad is None:
                continue
            grad = p.grad.data
            if grad.is_sparse:
                raise RuntimeError("RMSprop does not support sparse gradients")
            state = self.state[p]

            # State initialization
            if len(state) == 0:
                state["step"] = 0
                state["square_avg"] = torch.ones_like(p.data)  # PyTorch inits to zero
                if group["momentum"] > 0:
                    state["momentum_buffer"] = torch.zeros_like(p.data)
                if group["centered"]:
                    state["grad_avg"] = torch.zeros_like(p.data)

            square_avg = state["square_avg"]
            one_minus_alpha = 1.0 - group["alpha"]

            state["step"] += 1

            if group["weight_decay"] != 0:
                if "decoupled_decay" in group and group["decoupled_decay"]:
                    p.data.add_(-group["weight_decay"], p.data)
                else:
                    grad = grad.add(group["weight_decay"], p.data)

            # Tensorflow order of ops for updating squared avg
            square_avg.add_(one_minus_alpha, grad.pow(2) - square_avg)
            # square_avg.mul_(alpha).addcmul_(1 - alpha, grad, grad)  # PyTorch original

            if group["centered"]:
                grad_avg = state["grad_avg"]
                grad_avg.add_(one_minus_alpha, grad - grad_avg)
                # grad_avg.mul_(alpha).add_(1 - alpha, grad)  # PyTorch original
                avg = square_avg.addcmul(-1, grad_avg, grad_avg).add(group["eps"]).sqrt_()  # eps moved in sqrt
            else:
                avg = square_avg.add(group["eps"]).sqrt_()  # eps moved in sqrt

            if group["momentum"] > 0:
                buf = state["momentum_buffer"]
                # Tensorflow accumulates the LR scaling in the momentum buffer
                if "lr_in_momentum" in group and group["lr_in_momentum"]:
                    buf.mul_(group["momentum"]).addcdiv_(group["lr"], grad, avg)
                    p.data.add_(-buf)
                else:
                    # PyTorch scales the param update by LR
                    buf.mul_(group["momentum"]).addcdiv_(grad, avg)
                    p.data.add_(-group["lr"], buf)
            else:
                p.data.addcdiv_(-group["lr"], grad, avg)

    return loss

DEKRPoseEstimationDecodeCallback

Bases: AbstractPoseEstimationPostPredictionCallback

Class that implements decoding logic of DEKR's model predictions into poses.

Source code in src/super_gradients/training/utils/pose_estimation/dekr_decode_callbacks.py
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
class DEKRPoseEstimationDecodeCallback(AbstractPoseEstimationPostPredictionCallback):
    """
    Class that implements decoding logic of DEKR's model predictions into poses.
    """

    def __init__(
        self,
        output_stride: int,
        max_num_people: int,
        keypoint_threshold: float,
        nms_threshold: float,
        nms_num_threshold: int,
        apply_sigmoid: bool,
        min_confidence: float = 0.0,
    ):
        """

        :param output_stride: Output stride of the model
        :param int max_num_people: Maximum number of decoded poses
        :param float keypoint_threshold: (float) A minimum score of a pose center keypoint for pose to be considered as a potential candidate
        :param float nms_threshold: The maximum distance between two joints for them to be considered as belonging to the same pose.
                              Given in terms of a percentage of a square root of the area of the pose bounding box.
        :param int nms_num_threshold: Number of joints that must pass the NMS check for the pose to be considered as a valid one.
        :param bool apply_sigmoid: If True, apply the sigmoid activation on heatmap. This is needed when heatmap is not
                              bound to [0..1] range and trained with logits (E.g focal loss)
        :param float min_confidence: Minimum confidence threshold for pose
        """
        super().__init__()
        self.keypoint_threshold = keypoint_threshold
        self.max_num_people = max_num_people
        self.output_stride = output_stride
        self.nms_threshold = nms_threshold
        self.nms_num_threshold = nms_num_threshold
        self.apply_sigmoid = apply_sigmoid
        self.min_confidence = min_confidence

    @torch.no_grad()
    def __call__(self, predictions: Tuple[Tensor, Tensor]) -> List[PoseEstimationPredictions]:
        """

        :param predictions: Tuple (heatmap, offset):
            heatmap - [BatchSize, NumJoints+1,H,W]
            offset - [BatchSize, NumJoints*2,H,W]

        :return: Tuple
        """
        decoded_predictions: List[PoseEstimationPredictions] = []

        heatmap, offset = predictions
        batch_size = len(heatmap)
        for i in range(batch_size):
            poses, scores = self.decode_one_sized_batch(predictions=(heatmap[i : i + 1], offset[i : i + 1]))
            decoded_predictions.append(
                PoseEstimationPredictions(
                    poses=poses[: self.max_num_people],
                    scores=scores[: self.max_num_people],
                    bboxes_xyxy=None,
                )
            )
        return decoded_predictions

    def decode_one_sized_batch(self, predictions: Tuple[Tensor, Tensor]) -> Tuple[np.ndarray, np.ndarray]:
        heatmap, offset = predictions
        posemap = _offset_to_pose(offset)  # [1, 2 * num_joints, H, W]

        if heatmap.size(0) != 1:
            raise RuntimeError("Batch size of 1 is required")

        if self.apply_sigmoid:
            heatmap = heatmap.sigmoid()

        heatmap_sum, poses_sum = aggregate_results(
            heatmap,
            posemap,
            pose_center_score_threshold=self.keypoint_threshold,
            max_num_people=self.max_num_people,
            output_stride=self.output_stride,
        )

        poses, scores = pose_nms(
            heatmap_sum,
            poses_sum,
            max_num_people=self.max_num_people,
            nms_threshold=self.nms_threshold,
            nms_num_threshold=self.nms_num_threshold,
            pose_score_threshold=self.min_confidence,
        )

        if len(poses) != len(scores):
            raise RuntimeError("Decoding error detected. Returned mismatching number of poses/scores")

        return poses, scores

__call__(predictions)

Parameters:

Name Type Description Default
predictions Tuple[Tensor, Tensor]

Tuple (heatmap, offset): heatmap - [BatchSize, NumJoints+1,H,W] offset - [BatchSize, NumJoints*2,H,W]

required

Returns:

Type Description
List[PoseEstimationPredictions]

Tuple

Source code in src/super_gradients/training/utils/pose_estimation/dekr_decode_callbacks.py
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
@torch.no_grad()
def __call__(self, predictions: Tuple[Tensor, Tensor]) -> List[PoseEstimationPredictions]:
    """

    :param predictions: Tuple (heatmap, offset):
        heatmap - [BatchSize, NumJoints+1,H,W]
        offset - [BatchSize, NumJoints*2,H,W]

    :return: Tuple
    """
    decoded_predictions: List[PoseEstimationPredictions] = []

    heatmap, offset = predictions
    batch_size = len(heatmap)
    for i in range(batch_size):
        poses, scores = self.decode_one_sized_batch(predictions=(heatmap[i : i + 1], offset[i : i + 1]))
        decoded_predictions.append(
            PoseEstimationPredictions(
                poses=poses[: self.max_num_people],
                scores=scores[: self.max_num_people],
                bboxes_xyxy=None,
            )
        )
    return decoded_predictions

__init__(output_stride, max_num_people, keypoint_threshold, nms_threshold, nms_num_threshold, apply_sigmoid, min_confidence=0.0)

Parameters:

Name Type Description Default
output_stride int

Output stride of the model

required
max_num_people int

Maximum number of decoded poses

required
keypoint_threshold float

(float) A minimum score of a pose center keypoint for pose to be considered as a potential candidate

required
nms_threshold float

The maximum distance between two joints for them to be considered as belonging to the same pose. Given in terms of a percentage of a square root of the area of the pose bounding box.

required
nms_num_threshold int

Number of joints that must pass the NMS check for the pose to be considered as a valid one.

required
apply_sigmoid bool

If True, apply the sigmoid activation on heatmap. This is needed when heatmap is not bound to [0..1] range and trained with logits (E.g focal loss)

required
min_confidence float

Minimum confidence threshold for pose

0.0
Source code in src/super_gradients/training/utils/pose_estimation/dekr_decode_callbacks.py
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
def __init__(
    self,
    output_stride: int,
    max_num_people: int,
    keypoint_threshold: float,
    nms_threshold: float,
    nms_num_threshold: int,
    apply_sigmoid: bool,
    min_confidence: float = 0.0,
):
    """

    :param output_stride: Output stride of the model
    :param int max_num_people: Maximum number of decoded poses
    :param float keypoint_threshold: (float) A minimum score of a pose center keypoint for pose to be considered as a potential candidate
    :param float nms_threshold: The maximum distance between two joints for them to be considered as belonging to the same pose.
                          Given in terms of a percentage of a square root of the area of the pose bounding box.
    :param int nms_num_threshold: Number of joints that must pass the NMS check for the pose to be considered as a valid one.
    :param bool apply_sigmoid: If True, apply the sigmoid activation on heatmap. This is needed when heatmap is not
                          bound to [0..1] range and trained with logits (E.g focal loss)
    :param float min_confidence: Minimum confidence threshold for pose
    """
    super().__init__()
    self.keypoint_threshold = keypoint_threshold
    self.max_num_people = max_num_people
    self.output_stride = output_stride
    self.nms_threshold = nms_threshold
    self.nms_num_threshold = nms_num_threshold
    self.apply_sigmoid = apply_sigmoid
    self.min_confidence = min_confidence

aggregate_results(heatmap, posemap, output_stride, pose_center_score_threshold, max_num_people)

Get initial pose proposals and aggregate the results of all scale. Not this implementation works only for batch size of 1.

Parameters:

Name Type Description Default
heatmap Tensor

Heatmap at this scale (B, 1+num_joints, w, h)

required
posemap Tensor

Posemap at this scale (B, 2*num_joints, w, h)

required
output_stride int

Ratio of input size / predictions size

required
pose_center_score_threshold float

(float) A minimum score of a pose center keypoint for pose to be considered as a potential candidate

required
max_num_people int

(int)

required

Returns:

Type Description
Tuple[Tensor, List[Tensor]]
  • heatmap_sum: Sum of the heatmaps (1, 1+num_joints, w, h) - poses (List): Gather of the pose proposals [B, (num_people, num_joints, 3)]
Source code in src/super_gradients/training/utils/pose_estimation/dekr_decode_callbacks.py
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
def aggregate_results(
    heatmap: Tensor, posemap: Tensor, output_stride: int, pose_center_score_threshold: float, max_num_people: int
) -> Tuple[Tensor, List[Tensor]]:
    """
    Get initial pose proposals and aggregate the results of all scale.
    Not this implementation works only for batch size of 1.

    :param heatmap: Heatmap at this scale (B, 1+num_joints, w, h)
    :param posemap: Posemap at this scale (B, 2*num_joints, w, h)
    :param output_stride: Ratio of input size / predictions size
    :param pose_center_score_threshold: (float) A minimum score of a pose center keypoint for pose to be considered as a potential candidate
    :param max_num_people: (int)

    :return:
        - heatmap_sum: Sum of the heatmaps (1, 1+num_joints, w, h)
        - poses (List): Gather of the pose proposals [B, (num_people, num_joints, 3)]
    """

    poses = []

    h, w = heatmap[0].size(-1), heatmap[0].size(-2)

    heatmap_sum = _up_interpolate(heatmap, size=(int(output_stride * w), int(output_stride * h)))
    center_heatmap = heatmap[0, -1:]
    pose_ind, ctr_score = _get_maximum_from_heatmap(center_heatmap, pose_center_score_threshold=pose_center_score_threshold, max_num_people=max_num_people)
    posemap = posemap[0].permute(1, 2, 0).view(h * w, -1, 2)
    pose = output_stride * posemap[pose_ind]
    ctr_score = ctr_score[:, None].expand(-1, pose.shape[-2])[:, :, None]
    poses.append(torch.cat([pose, ctr_score], dim=2))

    return heatmap_sum, poses

get_locations(output_h, output_w, device)

Generate location map (each pixel contains its own XY coordinate)

Parameters:

Name Type Description Default
output_h int

Feature map height (rows)

required
output_w int

Feature map width (cols)

required
device

Target device to put tensor on

required

Returns:

Type Description

[H * W, 2]

Source code in src/super_gradients/training/utils/pose_estimation/dekr_decode_callbacks.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
def get_locations(output_h: int, output_w: int, device):
    """
    Generate location map (each pixel contains its own XY coordinate)

    :param output_h: Feature map height (rows)
    :param output_w: Feature map width (cols)
    :param device: Target device to put tensor on
    :return: [H * W, 2]
    """
    shifts_x = torch.arange(0, output_w, step=1, dtype=torch.float32, device=device)
    shifts_y = torch.arange(0, output_h, step=1, dtype=torch.float32, device=device)
    shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x, indexing="ij")
    shift_x = shift_x.reshape(-1)
    shift_y = shift_y.reshape(-1)
    locations = torch.stack((shift_x, shift_y), dim=1)

    return locations

get_reg_poses(offset, num_joints)

Decode offset predictions into absolute locations.

Parameters:

Name Type Description Default
offset Tensor

Tensor of [num_joints*2,H,W] shape with offset predictions for each joint

required
num_joints int

Number of joints

required

Returns:

Type Description

[H * W, num_joints, 2]

Source code in src/super_gradients/training/utils/pose_estimation/dekr_decode_callbacks.py
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
def get_reg_poses(offset: Tensor, num_joints: int):
    """
    Decode offset predictions into absolute locations.

    :param offset: Tensor of [num_joints*2,H,W] shape with offset predictions for each joint
    :param num_joints: Number of joints
    :return: [H * W, num_joints, 2]
    """
    _, h, w = offset.shape
    offset = offset.permute(1, 2, 0).reshape(h * w, num_joints, 2)
    locations = get_locations(h, w, offset.device)
    locations = locations[:, None, :].expand(-1, num_joints, -1)
    poses = locations - offset

    return poses

pose_nms(heatmap_avg, poses, max_num_people, nms_threshold, nms_num_threshold, pose_score_threshold)

NMS for the regressed poses results.

Parameters:

Name Type Description Default
heatmap_avg Tensor

Avg of the heatmaps at all scales (1, 1+num_joints, w, h)

required
poses List

Gather of the pose proposals [(num_people, num_joints, 3)]

required
max_num_people int

Maximum number of decoded poses

required
nms_threshold float

The maximum distance between two joints for them to be considered as belonging to the same pose. Given in terms of a percentage of a square root of the area of the pose bounding box.

required
nms_num_threshold int

Number of joints that must pass the NMS check for the pose to be considered as a valid one.

required
pose_score_threshold float

Minimum confidence threshold for pose. Pose with confidence lower than this threshold will be discarded.

required
Source code in src/super_gradients/training/utils/pose_estimation/dekr_decode_callbacks.py
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
def pose_nms(
    heatmap_avg, poses, max_num_people: int, nms_threshold: float, nms_num_threshold: int, pose_score_threshold: float
) -> Tuple[np.ndarray, np.ndarray]:
    """
    NMS for the regressed poses results.

    :param Tensor heatmap_avg: Avg of the heatmaps at all scales (1, 1+num_joints, w, h)
    :param List poses: Gather of the pose proposals [(num_people, num_joints, 3)]
    :param int max_num_people: Maximum number of decoded poses
    :param float nms_threshold: The maximum distance between two joints for them to be considered as belonging to the same pose.
                          Given in terms of a percentage of a square root of the area of the pose bounding box.
    :param int nms_num_threshold: Number of joints that must pass the NMS check for the pose to be considered as a valid one.
    :param float pose_score_threshold: Minimum confidence threshold for pose. Pose with confidence lower than this threshold will be discarded.

    :return Tuple of (poses, scores)
    """
    assert len(poses) == 1

    pose_score = torch.cat([pose[:, :, 2:] for pose in poses], dim=0)
    pose_coord = torch.cat([pose[:, :, :2] for pose in poses], dim=0)

    num_people, num_joints, _ = pose_coord.shape

    if num_people == 0:
        return np.zeros((0, num_joints, 3), dtype=np.float32), np.zeros((0,), dtype=np.float32)

    heatval = _get_heat_value(pose_coord, heatmap_avg[0])
    heat_score = (torch.sum(heatval, dim=1) / num_joints)[:, 0]

    pose_score = pose_score * heatval
    poses = torch.cat([pose_coord.cpu(), pose_score.cpu()], dim=2)

    keep_pose_inds = _nms_core(pose_coord, heat_score, nms_threshold=nms_threshold, nms_num_threshold=nms_num_threshold)
    poses = poses[keep_pose_inds]
    heat_score = heat_score[keep_pose_inds]

    if len(keep_pose_inds) > max_num_people:
        heat_score, topk_inds = torch.topk(heat_score, max_num_people)
        poses = poses[topk_inds]

    poses = poses.numpy()
    if len(poses):
        scores = poses[:, :, 2].mean(axis=1)

        mask = scores >= pose_score_threshold
        poses = poses[mask]
        scores = scores[mask]
    else:
        return np.zeros((0, num_joints, 3), dtype=np.float32), np.zeros((0,), dtype=np.float32)
    return poses, scores

DEKRVisualizationCallback

Bases: PhaseCallback

A callback that adds a visualization of a batch of segmentation predictions to context.sg_logger

Parameters:

Name Type Description Default
phase Union[Phase, str]

When to trigger the callback.

required
prefix str

Prefix to add to the log.

required
mean List[float]

Mean to subtract from image.

required
std List[float]

Standard deviation to subtract from image.

required
apply_sigmoid bool

Whether to apply sigmoid to the output.

False
batch_idx int

Batch index to perform visualization for.

0
keypoints_threshold float

Keypoint threshold to use for visualization.

0.01
Source code in src/super_gradients/training/utils/pose_estimation/dekr_visualization_callbacks.py
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
@register_callback(Callbacks.DEKR_VISUALIZATION)
class DEKRVisualizationCallback(PhaseCallback):
    """
    A callback that adds a visualization of a batch of segmentation predictions to context.sg_logger

    :param phase:                   When to trigger the callback.
    :param prefix:                  Prefix to add to the log.
    :param mean:                    Mean to subtract from image.
    :param std:                     Standard deviation to subtract from image.
    :param apply_sigmoid:           Whether to apply sigmoid to the output.
    :param batch_idx:               Batch index to perform visualization for.
    :param keypoints_threshold:     Keypoint threshold to use for visualization.
    """

    def __init__(
        self,
        phase: Union[Phase, str],
        prefix: str,
        mean: List[float],
        std: List[float],
        apply_sigmoid: bool = False,
        batch_idx: int = 0,
        keypoints_threshold: float = 0.01,
    ):
        super(DEKRVisualizationCallback, self).__init__(phase)
        self.batch_idx = batch_idx
        self.prefix = prefix
        self.mean = np.array(list(map(float, mean))).reshape((1, 1, -1))
        self.std = np.array(list(map(float, std))).reshape((1, 1, -1))
        self.apply_sigmoid = apply_sigmoid
        self.keypoints_threshold = keypoints_threshold

    def denormalize_image(self, image_normalized: Tensor) -> np.ndarray:
        """
        Reverse image normalization image_normalized (image / 255 - mean) / std
        :param image_normalized: normalized [3,H,W]
        :return:
        """

        image_normalized = torch.moveaxis(image_normalized, 0, -1).detach().cpu().numpy()
        image = (image_normalized * self.std + self.mean) * 255
        image = np.clip(image, 0, 255).astype(np.uint8)[..., ::-1]
        return image

    @classmethod
    def visualize_heatmap(self, heatmap: Tensor, apply_sigmoid: bool, dsize, min_value=None, max_value=None, colormap=cv2.COLORMAP_JET):
        if apply_sigmoid:
            heatmap = heatmap.sigmoid()

        if min_value is None:
            min_value = heatmap.min().item()

        if max_value is None:
            max_value = heatmap.max().item()

        heatmap = heatmap.detach().cpu().numpy()
        real_min = heatmap.min()
        real_max = heatmap.max()

        heatmap = np.max(heatmap, axis=0)
        heatmap = (heatmap - min_value) / (1e-8 + max_value - min_value)
        heatmap = np.clip(heatmap, 0, 1)
        heatmap_8u = (heatmap * 255).astype(np.uint8)
        heatmap_bgr = cv2.applyColorMap(heatmap_8u, colormap)
        heatmap_rgb = cv2.cvtColor(heatmap_bgr, cv2.COLOR_BGR2RGB)
        if dsize is not None:
            heatmap_rgb = cv2.resize(heatmap_rgb, dsize=dsize)

        cv2.putText(
            heatmap_rgb,
            f"min:{real_min:.3f}",
            (5, 15),
            fontFace=cv2.FONT_HERSHEY_PLAIN,
            color=(255, 255, 255),
            fontScale=0.8,
            thickness=1,
            lineType=cv2.LINE_AA,
        )
        cv2.putText(
            heatmap_rgb,
            f"max:{real_max:.3f}",
            (5, heatmap_rgb.shape[0] - 10),
            cv2.FONT_HERSHEY_PLAIN,
            color=(255, 255, 255),
            fontScale=0.8,
            thickness=1,
            lineType=cv2.LINE_AA,
        )

        return heatmap, heatmap_rgb

    @multi_process_safe
    def __call__(self, context: PhaseContext):
        if context.batch_idx == self.batch_idx:
            batch_imgs = self.visualize_batch(context.inputs, context.preds, context.target)
            batch_imgs = np.stack(batch_imgs)
            tag = self.prefix + str(self.batch_idx) + "_images"
            context.sg_logger.add_images(tag=tag, images=batch_imgs, global_step=context.epoch, data_format="NHWC")

    @torch.no_grad()
    def visualize_batch(self, inputs, predictions, targets):
        num_samples = len(inputs)
        batch_imgs = []

        gt_heatmap, mask, _, _ = targets

        # Check whether model also produce supervised output predictions
        if isinstance(predictions, tuple) and len(predictions) == 2 and torch.is_tensor(predictions[0]) and torch.is_tensor(predictions[1]):
            heatmap, _ = predictions
        else:
            (heatmap, _), (_, _) = predictions

        for i in range(num_samples):
            batch_imgs.append(self.visualize_sample(inputs[i], predicted_heatmap=heatmap[i], target_heatmap=gt_heatmap[i], target_mask=mask[i]))

        return batch_imgs

    def visualize_sample(self, input, predicted_heatmap, target_heatmap, target_mask):
        image_rgb = self.denormalize_image(input)
        dsize = image_rgb.shape[1], image_rgb.shape[0]
        half_size = dsize[0] // 2, dsize[1] // 2

        target_heatmap_f32, target_heatmap_rgb = self.visualize_heatmap(target_heatmap, apply_sigmoid=False, dsize=half_size)
        target_heatmap_f32 = cv2.resize(target_heatmap_f32, dsize=dsize)
        target_heatmap_f32 = np.expand_dims(target_heatmap_f32, -1)

        peaks_heatmap = _hierarchical_pool(predicted_heatmap)[0]
        peaks_heatmap = predicted_heatmap.eq(peaks_heatmap) & (predicted_heatmap > self.keypoints_threshold)

        peaks_heatmap = peaks_heatmap.sum(dim=0, keepdim=False) > 0

        # Apply masking with GT mask to suppress predictions on ignored areas of the image (where target_mask==0)
        flat_target_mask = target_mask.sum(dim=0, keepdim=False) > 0
        peaks_heatmap &= flat_target_mask
        peaks_heatmap = peaks_heatmap.detach().cpu().numpy().astype(np.uint8) * 255

        peaks_heatmap = cv2.applyColorMap(peaks_heatmap, cv2.COLORMAP_JET)
        peaks_heatmap = cv2.cvtColor(peaks_heatmap, cv2.COLOR_BGR2RGB)
        peaks_heatmap = cv2.resize(peaks_heatmap, dsize=half_size)

        _, predicted_heatmap_rgb = self.visualize_heatmap(
            predicted_heatmap, min_value=target_heatmap.min().item(), max_value=target_heatmap.max().item(), apply_sigmoid=self.apply_sigmoid, dsize=half_size
        )

        image_heatmap_overlay = image_rgb * (1 - target_heatmap_f32) + target_heatmap_f32 * cv2.resize(target_heatmap_rgb, dsize=dsize)
        image_heatmap_overlay = image_heatmap_overlay.astype(np.uint8)

        _, target_mask_rgb = self.visualize_heatmap(target_mask, min_value=0, max_value=1, apply_sigmoid=False, dsize=half_size, colormap=cv2.COLORMAP_BONE)

        return np.hstack(
            [
                image_heatmap_overlay,
                np.vstack([target_heatmap_rgb, predicted_heatmap_rgb]),
                np.vstack([target_mask_rgb, peaks_heatmap]),
            ]
        )

denormalize_image(image_normalized)

Reverse image normalization image_normalized (image / 255 - mean) / std

Parameters:

Name Type Description Default
image_normalized Tensor

normalized [3,H,W]

required

Returns:

Type Description
np.ndarray
Source code in src/super_gradients/training/utils/pose_estimation/dekr_visualization_callbacks.py
49
50
51
52
53
54
55
56
57
58
59
def denormalize_image(self, image_normalized: Tensor) -> np.ndarray:
    """
    Reverse image normalization image_normalized (image / 255 - mean) / std
    :param image_normalized: normalized [3,H,W]
    :return:
    """

    image_normalized = torch.moveaxis(image_normalized, 0, -1).detach().cpu().numpy()
    image = (image_normalized * self.std + self.mean) * 255
    image = np.clip(image, 0, 255).astype(np.uint8)[..., ::-1]
    return image

RescoringPoseEstimationDecodeCallback

A special adapter callback to be used with PoseEstimationMetrics to use the outputs from rescoring model inside metric class.

Source code in src/super_gradients/training/utils/pose_estimation/rescoring_callback.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
class RescoringPoseEstimationDecodeCallback:
    """
    A special adapter callback to be used with PoseEstimationMetrics to use the outputs from rescoring model inside metric class.
    """

    def __init__(self, apply_sigmoid: bool):
        """

        :param apply_sigmoid: If True, apply the sigmoid activation on heatmap. This is needed when heatmap is not
                              bound to [0..1] range and trained with logits (E.g focal loss)
        """
        super().__init__()
        self.apply_sigmoid = apply_sigmoid

    def __call__(self, predictions: Tuple[Tensor, Tensor]) -> List[PoseEstimationPredictions]:
        """ """
        poses, scores = predictions
        if self.apply_sigmoid:
            scores = scores.sigmoid()

        return [PoseEstimationPredictions(poses=poses.cpu().numpy(), scores=scores.squeeze(1).cpu().numpy(), bboxes_xyxy=None)]

__init__(apply_sigmoid)

Parameters:

Name Type Description Default
apply_sigmoid bool

If True, apply the sigmoid activation on heatmap. This is needed when heatmap is not bound to [0..1] range and trained with logits (E.g focal loss)

required
Source code in src/super_gradients/training/utils/pose_estimation/rescoring_callback.py
13
14
15
16
17
18
19
20
def __init__(self, apply_sigmoid: bool):
    """

    :param apply_sigmoid: If True, apply the sigmoid activation on heatmap. This is needed when heatmap is not
                          bound to [0..1] range and trained with logits (E.g focal loss)
    """
    super().__init__()
    self.apply_sigmoid = apply_sigmoid

ImagePoseEstimationPrediction dataclass

Bases: ImagePrediction

Object wrapping an image and a detection model's prediction.

:attr image: Input image :attr predictions: Predictions of the model :attr class_names: List of the class names to predict

Source code in src/super_gradients/training/utils/predict/prediction_pose_estimation_results.py
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
@dataclass
class ImagePoseEstimationPrediction(ImagePrediction):
    """Object wrapping an image and a detection model's prediction.

    :attr image:        Input image
    :attr predictions:  Predictions of the model
    :attr class_names:  List of the class names to predict
    """

    image: np.ndarray
    prediction: PoseEstimationPrediction

    def draw(
        self,
        edge_colors=None,
        joint_thickness: Optional[int] = None,
        keypoint_colors: Optional[List[Tuple]] = None,
        keypoint_radius: Optional[int] = None,
        box_thickness: Optional[int] = None,
        show_confidence: bool = False,
    ) -> np.ndarray:
        """Draw the predicted bboxes on the image.

        :param edge_colors:    Optional list of tuples representing the colors for each joint.
                                If None, default colors are used.
                                If not None the length must be equal to the number of joint links in the skeleton.
        :param joint_thickness: (Optional) Thickness of the joint links  (in pixels).
        :param keypoint_colors: Optional list of tuples representing the colors for each keypoint.
                                If None, default colors are used.
                                If not None the length must be equal to the number of joints in the skeleton.
        :param keypoint_radius: (Optional) Radius of the keypoints (in pixels).
        :param show_confidence: Whether to show confidence scores on the image.
        :param box_thickness:   (Optional) Thickness of bounding boxes. If None, will adapt to the box size.
        :return:                Image with predicted bboxes. Note that this does not modify the original image.
        """
        image = PoseVisualization.draw_poses(
            image=self.image,
            poses=self.prediction.poses,
            scores=self.prediction.scores,
            is_crowd=None,
            boxes=self.prediction.bboxes_xyxy,
            edge_links=self.prediction.edge_links,
            edge_colors=edge_colors or self.prediction.edge_colors,
            joint_thickness=joint_thickness,
            keypoint_colors=keypoint_colors or self.prediction.keypoint_colors,
            keypoint_radius=keypoint_radius,
            box_thickness=box_thickness,
        )

        return image

    def show(
        self,
        edge_colors=None,
        joint_thickness: Optional[int] = None,
        keypoint_colors: Optional[List[Tuple]] = None,
        keypoint_radius: Optional[int] = None,
        box_thickness: Optional[int] = None,
        show_confidence: bool = False,
    ) -> None:
        """Display the image with predicted bboxes.

        :param edge_colors:    Optional list of tuples representing the colors for each joint.
                                If None, default colors are used.
                                If not None the length must be equal to the number of joint links in the skeleton.
        :param joint_thickness: (Optional) Thickness of the joint links  (in pixels).
        :param keypoint_colors: Optional list of tuples representing the colors for each keypoint.
                                If None, default colors are used.
                                If not None the length must be equal to the number of joints in the skeleton.
        :param keypoint_radius: (Optional) Radius of the keypoints (in pixels).
        :param show_confidence: Whether to show confidence scores on the image.
        :param box_thickness:   (Optional) Thickness of bounding boxes. If None, will adapt to the box size.
        """
        image = self.draw(
            edge_colors=edge_colors,
            joint_thickness=joint_thickness,
            keypoint_colors=keypoint_colors,
            keypoint_radius=keypoint_radius,
            box_thickness=box_thickness,
            show_confidence=show_confidence,
        )
        show_image(image)

    def save(
        self,
        output_path: str,
        edge_colors=None,
        joint_thickness: Optional[int] = None,
        keypoint_colors: Optional[List[Tuple]] = None,
        keypoint_radius: Optional[int] = None,
        box_thickness: Optional[int] = None,
        show_confidence: bool = False,
    ) -> None:
        """Save the predicted bboxes on the images.

        :param output_path:     Path to the output video file.
        :param edge_colors:    Optional list of tuples representing the colors for each joint.
                                If None, default colors are used.
                                If not None the length must be equal to the number of joint links in the skeleton.
        :param joint_thickness: (Optional) Thickness of the joint links  (in pixels).
        :param keypoint_colors: Optional list of tuples representing the colors for each keypoint.
                                If None, default colors are used.
                                If not None the length must be equal to the number of joints in the skeleton.
        :param keypoint_radius: (Optional) Radius of the keypoints (in pixels).
        :param show_confidence: Whether to show confidence scores on the image.
        :param box_thickness:   (Optional) Thickness of bounding boxes. If None, will adapt to the box size.
        """
        image = self.draw(box_thickness=box_thickness, show_confidence=show_confidence)
        save_image(image=image, path=output_path)

draw(edge_colors=None, joint_thickness=None, keypoint_colors=None, keypoint_radius=None, box_thickness=None, show_confidence=False)

Draw the predicted bboxes on the image.

Parameters:

Name Type Description Default
edge_colors

Optional list of tuples representing the colors for each joint. If None, default colors are used. If not None the length must be equal to the number of joint links in the skeleton.

None
joint_thickness Optional[int]

(Optional) Thickness of the joint links (in pixels).

None
keypoint_colors Optional[List[Tuple]]

Optional list of tuples representing the colors for each keypoint. If None, default colors are used. If not None the length must be equal to the number of joints in the skeleton.

None
keypoint_radius Optional[int]

(Optional) Radius of the keypoints (in pixels).

None
show_confidence bool

Whether to show confidence scores on the image.

False
box_thickness Optional[int]

(Optional) Thickness of bounding boxes. If None, will adapt to the box size.

None

Returns:

Type Description
np.ndarray

Image with predicted bboxes. Note that this does not modify the original image.

Source code in src/super_gradients/training/utils/predict/prediction_pose_estimation_results.py
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
def draw(
    self,
    edge_colors=None,
    joint_thickness: Optional[int] = None,
    keypoint_colors: Optional[List[Tuple]] = None,
    keypoint_radius: Optional[int] = None,
    box_thickness: Optional[int] = None,
    show_confidence: bool = False,
) -> np.ndarray:
    """Draw the predicted bboxes on the image.

    :param edge_colors:    Optional list of tuples representing the colors for each joint.
                            If None, default colors are used.
                            If not None the length must be equal to the number of joint links in the skeleton.
    :param joint_thickness: (Optional) Thickness of the joint links  (in pixels).
    :param keypoint_colors: Optional list of tuples representing the colors for each keypoint.
                            If None, default colors are used.
                            If not None the length must be equal to the number of joints in the skeleton.
    :param keypoint_radius: (Optional) Radius of the keypoints (in pixels).
    :param show_confidence: Whether to show confidence scores on the image.
    :param box_thickness:   (Optional) Thickness of bounding boxes. If None, will adapt to the box size.
    :return:                Image with predicted bboxes. Note that this does not modify the original image.
    """
    image = PoseVisualization.draw_poses(
        image=self.image,
        poses=self.prediction.poses,
        scores=self.prediction.scores,
        is_crowd=None,
        boxes=self.prediction.bboxes_xyxy,
        edge_links=self.prediction.edge_links,
        edge_colors=edge_colors or self.prediction.edge_colors,
        joint_thickness=joint_thickness,
        keypoint_colors=keypoint_colors or self.prediction.keypoint_colors,
        keypoint_radius=keypoint_radius,
        box_thickness=box_thickness,
    )

    return image

save(output_path, edge_colors=None, joint_thickness=None, keypoint_colors=None, keypoint_radius=None, box_thickness=None, show_confidence=False)

Save the predicted bboxes on the images.

Parameters:

Name Type Description Default
output_path str

Path to the output video file.

required
edge_colors

Optional list of tuples representing the colors for each joint. If None, default colors are used. If not None the length must be equal to the number of joint links in the skeleton.

None
joint_thickness Optional[int]

(Optional) Thickness of the joint links (in pixels).

None
keypoint_colors Optional[List[Tuple]]

Optional list of tuples representing the colors for each keypoint. If None, default colors are used. If not None the length must be equal to the number of joints in the skeleton.

None
keypoint_radius Optional[int]

(Optional) Radius of the keypoints (in pixels).

None
show_confidence bool

Whether to show confidence scores on the image.

False
box_thickness Optional[int]

(Optional) Thickness of bounding boxes. If None, will adapt to the box size.

None
Source code in src/super_gradients/training/utils/predict/prediction_pose_estimation_results.py
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
def save(
    self,
    output_path: str,
    edge_colors=None,
    joint_thickness: Optional[int] = None,
    keypoint_colors: Optional[List[Tuple]] = None,
    keypoint_radius: Optional[int] = None,
    box_thickness: Optional[int] = None,
    show_confidence: bool = False,
) -> None:
    """Save the predicted bboxes on the images.

    :param output_path:     Path to the output video file.
    :param edge_colors:    Optional list of tuples representing the colors for each joint.
                            If None, default colors are used.
                            If not None the length must be equal to the number of joint links in the skeleton.
    :param joint_thickness: (Optional) Thickness of the joint links  (in pixels).
    :param keypoint_colors: Optional list of tuples representing the colors for each keypoint.
                            If None, default colors are used.
                            If not None the length must be equal to the number of joints in the skeleton.
    :param keypoint_radius: (Optional) Radius of the keypoints (in pixels).
    :param show_confidence: Whether to show confidence scores on the image.
    :param box_thickness:   (Optional) Thickness of bounding boxes. If None, will adapt to the box size.
    """
    image = self.draw(box_thickness=box_thickness, show_confidence=show_confidence)
    save_image(image=image, path=output_path)

show(edge_colors=None, joint_thickness=None, keypoint_colors=None, keypoint_radius=None, box_thickness=None, show_confidence=False)

Display the image with predicted bboxes.

Parameters:

Name Type Description Default
edge_colors

Optional list of tuples representing the colors for each joint. If None, default colors are used. If not None the length must be equal to the number of joint links in the skeleton.

None
joint_thickness Optional[int]

(Optional) Thickness of the joint links (in pixels).

None
keypoint_colors Optional[List[Tuple]]

Optional list of tuples representing the colors for each keypoint. If None, default colors are used. If not None the length must be equal to the number of joints in the skeleton.

None
keypoint_radius Optional[int]

(Optional) Radius of the keypoints (in pixels).

None
show_confidence bool

Whether to show confidence scores on the image.

False
box_thickness Optional[int]

(Optional) Thickness of bounding boxes. If None, will adapt to the box size.

None
Source code in src/super_gradients/training/utils/predict/prediction_pose_estimation_results.py
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
def show(
    self,
    edge_colors=None,
    joint_thickness: Optional[int] = None,
    keypoint_colors: Optional[List[Tuple]] = None,
    keypoint_radius: Optional[int] = None,
    box_thickness: Optional[int] = None,
    show_confidence: bool = False,
) -> None:
    """Display the image with predicted bboxes.

    :param edge_colors:    Optional list of tuples representing the colors for each joint.
                            If None, default colors are used.
                            If not None the length must be equal to the number of joint links in the skeleton.
    :param joint_thickness: (Optional) Thickness of the joint links  (in pixels).
    :param keypoint_colors: Optional list of tuples representing the colors for each keypoint.
                            If None, default colors are used.
                            If not None the length must be equal to the number of joints in the skeleton.
    :param keypoint_radius: (Optional) Radius of the keypoints (in pixels).
    :param show_confidence: Whether to show confidence scores on the image.
    :param box_thickness:   (Optional) Thickness of bounding boxes. If None, will adapt to the box size.
    """
    image = self.draw(
        edge_colors=edge_colors,
        joint_thickness=joint_thickness,
        keypoint_colors=keypoint_colors,
        keypoint_radius=keypoint_radius,
        box_thickness=box_thickness,
        show_confidence=show_confidence,
    )
    show_image(image)

ImagesPoseEstimationPrediction dataclass

Bases: ImagesPredictions

Object wrapping the list of image detection predictions.

:attr _images_prediction_lst: List of the predictions results

Source code in src/super_gradients/training/utils/predict/prediction_pose_estimation_results.py
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
@dataclass
class ImagesPoseEstimationPrediction(ImagesPredictions):
    """Object wrapping the list of image detection predictions.

    :attr _images_prediction_lst:  List of the predictions results
    """

    _images_prediction_lst: List[ImagePoseEstimationPrediction]

    def show(
        self,
        edge_colors=None,
        joint_thickness: Optional[int] = None,
        keypoint_colors: Optional[List[Tuple]] = None,
        keypoint_radius: Optional[int] = None,
        box_thickness: Optional[int] = None,
        show_confidence: bool = False,
    ) -> None:
        """Display the predicted bboxes on the images.

        :param edge_colors:    Optional list of tuples representing the colors for each joint.
                                If None, default colors are used.
                                If not None the length must be equal to the number of joint links in the skeleton.
        :param joint_thickness: (Optional) Thickness of the joint links  (in pixels).
        :param keypoint_colors: Optional list of tuples representing the colors for each keypoint.
                                If None, default colors are used.
                                If not None the length must be equal to the number of joints in the skeleton.
        :param keypoint_radius: (Optional) Radius of the keypoints (in pixels).
        :param show_confidence: Whether to show confidence scores on the image.
        :param box_thickness:   (Optional) Thickness of bounding boxes. If None, will adapt to the box size.
        """
        for prediction in self._images_prediction_lst:
            prediction.show(
                edge_colors=edge_colors,
                joint_thickness=joint_thickness,
                keypoint_colors=keypoint_colors,
                keypoint_radius=keypoint_radius,
                box_thickness=box_thickness,
                show_confidence=show_confidence,
            )

    def save(
        self,
        output_folder: str,
        edge_colors=None,
        joint_thickness: Optional[int] = None,
        keypoint_colors: Optional[List[Tuple]] = None,
        keypoint_radius: Optional[int] = None,
        box_thickness: Optional[int] = None,
        show_confidence: bool = False,
    ) -> None:
        """Save the predicted bboxes on the images.

        :param output_folder:   Folder path, where the images will be saved.
        :param edge_colors:    Optional list of tuples representing the colors for each joint.
                                If None, default colors are used.
                                If not None the length must be equal to the number of joint links in the skeleton.
        :param joint_thickness: (Optional) Thickness of the joint links  (in pixels).
        :param keypoint_colors: Optional list of tuples representing the colors for each keypoint.
                                If None, default colors are used.
                                If not None the length must be equal to the number of joints in the skeleton.
        :param keypoint_radius: (Optional) Radius of the keypoints (in pixels).
        :param show_confidence: Whether to show confidence scores on the image.
        :param box_thickness:   (Optional) Thickness of bounding boxes. If None, will adapt to the box size.
        """
        if output_folder:
            os.makedirs(output_folder, exist_ok=True)

        for i, prediction in enumerate(self._images_prediction_lst):
            image_output_path = os.path.join(output_folder, f"pred_{i}.jpg")
            prediction.save(
                output_path=image_output_path,
                edge_colors=edge_colors,
                joint_thickness=joint_thickness,
                keypoint_colors=keypoint_colors,
                keypoint_radius=keypoint_radius,
                box_thickness=box_thickness,
                show_confidence=show_confidence,
            )

save(output_folder, edge_colors=None, joint_thickness=None, keypoint_colors=None, keypoint_radius=None, box_thickness=None, show_confidence=False)

Save the predicted bboxes on the images.

Parameters:

Name Type Description Default
output_folder str

Folder path, where the images will be saved.

required
edge_colors

Optional list of tuples representing the colors for each joint. If None, default colors are used. If not None the length must be equal to the number of joint links in the skeleton.

None
joint_thickness Optional[int]

(Optional) Thickness of the joint links (in pixels).

None
keypoint_colors Optional[List[Tuple]]

Optional list of tuples representing the colors for each keypoint. If None, default colors are used. If not None the length must be equal to the number of joints in the skeleton.

None
keypoint_radius Optional[int]

(Optional) Radius of the keypoints (in pixels).

None
show_confidence bool

Whether to show confidence scores on the image.

False
box_thickness Optional[int]

(Optional) Thickness of bounding boxes. If None, will adapt to the box size.

None
Source code in src/super_gradients/training/utils/predict/prediction_pose_estimation_results.py
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
def save(
    self,
    output_folder: str,
    edge_colors=None,
    joint_thickness: Optional[int] = None,
    keypoint_colors: Optional[List[Tuple]] = None,
    keypoint_radius: Optional[int] = None,
    box_thickness: Optional[int] = None,
    show_confidence: bool = False,
) -> None:
    """Save the predicted bboxes on the images.

    :param output_folder:   Folder path, where the images will be saved.
    :param edge_colors:    Optional list of tuples representing the colors for each joint.
                            If None, default colors are used.
                            If not None the length must be equal to the number of joint links in the skeleton.
    :param joint_thickness: (Optional) Thickness of the joint links  (in pixels).
    :param keypoint_colors: Optional list of tuples representing the colors for each keypoint.
                            If None, default colors are used.
                            If not None the length must be equal to the number of joints in the skeleton.
    :param keypoint_radius: (Optional) Radius of the keypoints (in pixels).
    :param show_confidence: Whether to show confidence scores on the image.
    :param box_thickness:   (Optional) Thickness of bounding boxes. If None, will adapt to the box size.
    """
    if output_folder:
        os.makedirs(output_folder, exist_ok=True)

    for i, prediction in enumerate(self._images_prediction_lst):
        image_output_path = os.path.join(output_folder, f"pred_{i}.jpg")
        prediction.save(
            output_path=image_output_path,
            edge_colors=edge_colors,
            joint_thickness=joint_thickness,
            keypoint_colors=keypoint_colors,
            keypoint_radius=keypoint_radius,
            box_thickness=box_thickness,
            show_confidence=show_confidence,
        )

show(edge_colors=None, joint_thickness=None, keypoint_colors=None, keypoint_radius=None, box_thickness=None, show_confidence=False)

Display the predicted bboxes on the images.

Parameters:

Name Type Description Default
edge_colors

Optional list of tuples representing the colors for each joint. If None, default colors are used. If not None the length must be equal to the number of joint links in the skeleton.

None
joint_thickness Optional[int]

(Optional) Thickness of the joint links (in pixels).

None
keypoint_colors Optional[List[Tuple]]

Optional list of tuples representing the colors for each keypoint. If None, default colors are used. If not None the length must be equal to the number of joints in the skeleton.

None
keypoint_radius Optional[int]

(Optional) Radius of the keypoints (in pixels).

None
show_confidence bool

Whether to show confidence scores on the image.

False
box_thickness Optional[int]

(Optional) Thickness of bounding boxes. If None, will adapt to the box size.

None
Source code in src/super_gradients/training/utils/predict/prediction_pose_estimation_results.py
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
def show(
    self,
    edge_colors=None,
    joint_thickness: Optional[int] = None,
    keypoint_colors: Optional[List[Tuple]] = None,
    keypoint_radius: Optional[int] = None,
    box_thickness: Optional[int] = None,
    show_confidence: bool = False,
) -> None:
    """Display the predicted bboxes on the images.

    :param edge_colors:    Optional list of tuples representing the colors for each joint.
                            If None, default colors are used.
                            If not None the length must be equal to the number of joint links in the skeleton.
    :param joint_thickness: (Optional) Thickness of the joint links  (in pixels).
    :param keypoint_colors: Optional list of tuples representing the colors for each keypoint.
                            If None, default colors are used.
                            If not None the length must be equal to the number of joints in the skeleton.
    :param keypoint_radius: (Optional) Radius of the keypoints (in pixels).
    :param show_confidence: Whether to show confidence scores on the image.
    :param box_thickness:   (Optional) Thickness of bounding boxes. If None, will adapt to the box size.
    """
    for prediction in self._images_prediction_lst:
        prediction.show(
            edge_colors=edge_colors,
            joint_thickness=joint_thickness,
            keypoint_colors=keypoint_colors,
            keypoint_radius=keypoint_radius,
            box_thickness=box_thickness,
            show_confidence=show_confidence,
        )

VideoPoseEstimationPrediction dataclass

Bases: VideoPredictions

Object wrapping the list of image detection predictions as a Video.

:attr _images_prediction_lst: List of the predictions results :att fps: Frames per second of the video

Source code in src/super_gradients/training/utils/predict/prediction_pose_estimation_results.py
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
@dataclass
class VideoPoseEstimationPrediction(VideoPredictions):
    """Object wrapping the list of image detection predictions as a Video.

    :attr _images_prediction_lst:   List of the predictions results
    :att fps:                       Frames per second of the video
    """

    _images_prediction_gen: Iterator[ImagePoseEstimationPrediction]
    fps: int
    n_frames: int

    def draw(
        self,
        edge_colors=None,
        joint_thickness: Optional[int] = None,
        keypoint_colors: Optional[List[Tuple]] = None,
        keypoint_radius: Optional[int] = None,
        box_thickness: Optional[int] = None,
        show_confidence: bool = False,
    ) -> Iterator[np.ndarray]:
        """Draw the predicted bboxes on the images.

        :param output_folder:   Folder path, where the images will be saved.
        :param edge_colors:    Optional list of tuples representing the colors for each joint.
                                If None, default colors are used.
                                If not None the length must be equal to the number of joint links in the skeleton.
        :param joint_thickness: (Optional) Thickness of the joint links  (in pixels).
        :param keypoint_colors: Optional list of tuples representing the colors for each keypoint.
                                If None, default colors are used.
                                If not None the length must be equal to the number of joints in the skeleton.
        :param keypoint_radius: (Optional) Radius of the keypoints (in pixels).
        :param show_confidence: Whether to show confidence scores on the image.
        :param box_thickness:   (Optional) Thickness of bounding boxes. If None, will adapt to the box size.

        :return:                Iterator of images with predicted bboxes. Note that this does not modify the original image.
        """

        for result in tqdm(self._images_prediction_gen, total=self.n_frames, desc="Processing Video"):
            yield result.draw(
                edge_colors=edge_colors,
                joint_thickness=joint_thickness,
                keypoint_colors=keypoint_colors,
                keypoint_radius=keypoint_radius,
                box_thickness=box_thickness,
                show_confidence=show_confidence,
            )

    def show(
        self,
        edge_colors=None,
        joint_thickness: Optional[int] = None,
        keypoint_colors: Optional[List[Tuple]] = None,
        keypoint_radius: Optional[int] = None,
        box_thickness: Optional[int] = None,
        show_confidence: bool = False,
    ) -> None:
        """Display the predicted bboxes on the images.

        :param edge_colors:    Optional list of tuples representing the colors for each joint.
                                If None, default colors are used.
                                If not None the length must be equal to the number of joint links in the skeleton.
        :param joint_thickness: (Optional) Thickness of the joint links  (in pixels).
        :param keypoint_colors: Optional list of tuples representing the colors for each keypoint.
                                If None, default colors are used.
                                If not None the length must be equal to the number of joints in the skeleton.
        :param keypoint_radius: (Optional) Radius of the keypoints (in pixels).
        :param show_confidence: Whether to show confidence scores on the image.
        :param box_thickness:   (Optional) Thickness of bounding boxes. If None, will adapt to the box size.
        """
        frames = self.draw(
            edge_colors=edge_colors,
            joint_thickness=joint_thickness,
            keypoint_colors=keypoint_colors,
            keypoint_radius=keypoint_radius,
            box_thickness=box_thickness,
            show_confidence=show_confidence,
        )
        show_video_from_frames(window_name="Pose Estimation", frames=frames, fps=self.fps)

    def save(
        self,
        output_path: str,
        edge_colors=None,
        joint_thickness: Optional[int] = None,
        keypoint_colors: Optional[List[Tuple]] = None,
        keypoint_radius: Optional[int] = None,
        box_thickness: Optional[int] = None,
        show_confidence: bool = False,
    ) -> None:
        """Save the predicted bboxes on the images.

        :param output_path:     Path to the output video file.
        :param edge_colors:    Optional list of tuples representing the colors for each joint.
                                If None, default colors are used.
                                If not None the length must be equal to the number of joint links in the skeleton.
        :param joint_thickness: (Optional) Thickness of the joint links  (in pixels).
        :param keypoint_colors: Optional list of tuples representing the colors for each keypoint.
                                If None, default colors are used.
                                If not None the length must be equal to the number of joints in the skeleton.
        :param keypoint_radius: (Optional) Radius of the keypoints (in pixels).
        :param show_confidence: Whether to show confidence scores on the image.
        :param box_thickness:   (Optional) Thickness of bounding boxes. If None, will adapt to the box size.
        """
        frames = self.draw(
            edge_colors=edge_colors,
            joint_thickness=joint_thickness,
            keypoint_colors=keypoint_colors,
            keypoint_radius=keypoint_radius,
            box_thickness=box_thickness,
            show_confidence=show_confidence,
        )
        save_video(output_path=output_path, frames=frames, fps=self.fps)

draw(edge_colors=None, joint_thickness=None, keypoint_colors=None, keypoint_radius=None, box_thickness=None, show_confidence=False)

Draw the predicted bboxes on the images.

Parameters:

Name Type Description Default
output_folder

Folder path, where the images will be saved.

required
edge_colors

Optional list of tuples representing the colors for each joint. If None, default colors are used. If not None the length must be equal to the number of joint links in the skeleton.

None
joint_thickness Optional[int]

(Optional) Thickness of the joint links (in pixels).

None
keypoint_colors Optional[List[Tuple]]

Optional list of tuples representing the colors for each keypoint. If None, default colors are used. If not None the length must be equal to the number of joints in the skeleton.

None
keypoint_radius Optional[int]

(Optional) Radius of the keypoints (in pixels).

None
show_confidence bool

Whether to show confidence scores on the image.

False
box_thickness Optional[int]

(Optional) Thickness of bounding boxes. If None, will adapt to the box size.

None

Returns:

Type Description
Iterator[np.ndarray]

Iterator of images with predicted bboxes. Note that this does not modify the original image.

Source code in src/super_gradients/training/utils/predict/prediction_pose_estimation_results.py
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
def draw(
    self,
    edge_colors=None,
    joint_thickness: Optional[int] = None,
    keypoint_colors: Optional[List[Tuple]] = None,
    keypoint_radius: Optional[int] = None,
    box_thickness: Optional[int] = None,
    show_confidence: bool = False,
) -> Iterator[np.ndarray]:
    """Draw the predicted bboxes on the images.

    :param output_folder:   Folder path, where the images will be saved.
    :param edge_colors:    Optional list of tuples representing the colors for each joint.
                            If None, default colors are used.
                            If not None the length must be equal to the number of joint links in the skeleton.
    :param joint_thickness: (Optional) Thickness of the joint links  (in pixels).
    :param keypoint_colors: Optional list of tuples representing the colors for each keypoint.
                            If None, default colors are used.
                            If not None the length must be equal to the number of joints in the skeleton.
    :param keypoint_radius: (Optional) Radius of the keypoints (in pixels).
    :param show_confidence: Whether to show confidence scores on the image.
    :param box_thickness:   (Optional) Thickness of bounding boxes. If None, will adapt to the box size.

    :return:                Iterator of images with predicted bboxes. Note that this does not modify the original image.
    """

    for result in tqdm(self._images_prediction_gen, total=self.n_frames, desc="Processing Video"):
        yield result.draw(
            edge_colors=edge_colors,
            joint_thickness=joint_thickness,
            keypoint_colors=keypoint_colors,
            keypoint_radius=keypoint_radius,
            box_thickness=box_thickness,
            show_confidence=show_confidence,
        )

save(output_path, edge_colors=None, joint_thickness=None, keypoint_colors=None, keypoint_radius=None, box_thickness=None, show_confidence=False)

Save the predicted bboxes on the images.

Parameters:

Name Type Description Default
output_path str

Path to the output video file.

required
edge_colors

Optional list of tuples representing the colors for each joint. If None, default colors are used. If not None the length must be equal to the number of joint links in the skeleton.

None
joint_thickness Optional[int]

(Optional) Thickness of the joint links (in pixels).

None
keypoint_colors Optional[List[Tuple]]

Optional list of tuples representing the colors for each keypoint. If None, default colors are used. If not None the length must be equal to the number of joints in the skeleton.

None
keypoint_radius Optional[int]

(Optional) Radius of the keypoints (in pixels).

None
show_confidence bool

Whether to show confidence scores on the image.

False
box_thickness Optional[int]

(Optional) Thickness of bounding boxes. If None, will adapt to the box size.

None
Source code in src/super_gradients/training/utils/predict/prediction_pose_estimation_results.py
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
def save(
    self,
    output_path: str,
    edge_colors=None,
    joint_thickness: Optional[int] = None,
    keypoint_colors: Optional[List[Tuple]] = None,
    keypoint_radius: Optional[int] = None,
    box_thickness: Optional[int] = None,
    show_confidence: bool = False,
) -> None:
    """Save the predicted bboxes on the images.

    :param output_path:     Path to the output video file.
    :param edge_colors:    Optional list of tuples representing the colors for each joint.
                            If None, default colors are used.
                            If not None the length must be equal to the number of joint links in the skeleton.
    :param joint_thickness: (Optional) Thickness of the joint links  (in pixels).
    :param keypoint_colors: Optional list of tuples representing the colors for each keypoint.
                            If None, default colors are used.
                            If not None the length must be equal to the number of joints in the skeleton.
    :param keypoint_radius: (Optional) Radius of the keypoints (in pixels).
    :param show_confidence: Whether to show confidence scores on the image.
    :param box_thickness:   (Optional) Thickness of bounding boxes. If None, will adapt to the box size.
    """
    frames = self.draw(
        edge_colors=edge_colors,
        joint_thickness=joint_thickness,
        keypoint_colors=keypoint_colors,
        keypoint_radius=keypoint_radius,
        box_thickness=box_thickness,
        show_confidence=show_confidence,
    )
    save_video(output_path=output_path, frames=frames, fps=self.fps)

show(edge_colors=None, joint_thickness=None, keypoint_colors=None, keypoint_radius=None, box_thickness=None, show_confidence=False)

Display the predicted bboxes on the images.

Parameters:

Name Type Description Default
edge_colors

Optional list of tuples representing the colors for each joint. If None, default colors are used. If not None the length must be equal to the number of joint links in the skeleton.

None
joint_thickness Optional[int]

(Optional) Thickness of the joint links (in pixels).

None
keypoint_colors Optional[List[Tuple]]

Optional list of tuples representing the colors for each keypoint. If None, default colors are used. If not None the length must be equal to the number of joints in the skeleton.

None
keypoint_radius Optional[int]

(Optional) Radius of the keypoints (in pixels).

None
show_confidence bool

Whether to show confidence scores on the image.

False
box_thickness Optional[int]

(Optional) Thickness of bounding boxes. If None, will adapt to the box size.

None
Source code in src/super_gradients/training/utils/predict/prediction_pose_estimation_results.py
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
def show(
    self,
    edge_colors=None,
    joint_thickness: Optional[int] = None,
    keypoint_colors: Optional[List[Tuple]] = None,
    keypoint_radius: Optional[int] = None,
    box_thickness: Optional[int] = None,
    show_confidence: bool = False,
) -> None:
    """Display the predicted bboxes on the images.

    :param edge_colors:    Optional list of tuples representing the colors for each joint.
                            If None, default colors are used.
                            If not None the length must be equal to the number of joint links in the skeleton.
    :param joint_thickness: (Optional) Thickness of the joint links  (in pixels).
    :param keypoint_colors: Optional list of tuples representing the colors for each keypoint.
                            If None, default colors are used.
                            If not None the length must be equal to the number of joints in the skeleton.
    :param keypoint_radius: (Optional) Radius of the keypoints (in pixels).
    :param show_confidence: Whether to show confidence scores on the image.
    :param box_thickness:   (Optional) Thickness of bounding boxes. If None, will adapt to the box size.
    """
    frames = self.draw(
        edge_colors=edge_colors,
        joint_thickness=joint_thickness,
        keypoint_colors=keypoint_colors,
        keypoint_radius=keypoint_radius,
        box_thickness=box_thickness,
        show_confidence=show_confidence,
    )
    show_video_from_frames(window_name="Pose Estimation", frames=frames, fps=self.fps)

ImageClassificationPrediction dataclass

Bases: ImagePrediction

Object wrapping an image and a classification model's prediction.

:attr image: Input image :attr predictions: Predictions of the model :attr class_names: List of the class names to predict

Source code in src/super_gradients/training/utils/predict/prediction_results.py
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
@dataclass
class ImageClassificationPrediction(ImagePrediction):
    """Object wrapping an image and a classification model's prediction.

    :attr image:        Input image
    :attr predictions:  Predictions of the model
    :attr class_names:  List of the class names to predict
    """

    image: np.ndarray
    prediction: ClassificationPrediction
    class_names: List[str]

    def draw(self, show_confidence: bool = True) -> np.ndarray:
        """Draw the predicted label on the image.

        :param show_confidence: Whether to show confidence scores on the image.
        :return:                Image with predicted label.
        """

        image = self.image.copy()
        return draw_label(image=image, label=self.class_names[self.prediction.label], confidence=self.prediction.confidence)

    def show(self, show_confidence: bool = True) -> None:
        """Display the image with predicted label.

        :param show_confidence: Whether to show confidence scores on the image.
        """
        # to do draw the prediction on the image
        image = self.draw(show_confidence=show_confidence)
        show_image(image)

    def save(
        self,
        output_path: str,
        show_confidence: bool = True,
    ) -> None:
        """Save the predicted label on the images.

        :param output_path:     Path to the output video file.
        :param show_confidence: Whether to show confidence scores on the image.
        """
        image = self.draw(show_confidence=show_confidence)
        save_image(image=image, path=output_path)

draw(show_confidence=True)

Draw the predicted label on the image.

Parameters:

Name Type Description Default
show_confidence bool

Whether to show confidence scores on the image.

True

Returns:

Type Description
np.ndarray

Image with predicted label.

Source code in src/super_gradients/training/utils/predict/prediction_results.py
65
66
67
68
69
70
71
72
73
def draw(self, show_confidence: bool = True) -> np.ndarray:
    """Draw the predicted label on the image.

    :param show_confidence: Whether to show confidence scores on the image.
    :return:                Image with predicted label.
    """

    image = self.image.copy()
    return draw_label(image=image, label=self.class_names[self.prediction.label], confidence=self.prediction.confidence)

save(output_path, show_confidence=True)

Save the predicted label on the images.

Parameters:

Name Type Description Default
output_path str

Path to the output video file.

required
show_confidence bool

Whether to show confidence scores on the image.

True
Source code in src/super_gradients/training/utils/predict/prediction_results.py
84
85
86
87
88
89
90
91
92
93
94
95
def save(
    self,
    output_path: str,
    show_confidence: bool = True,
) -> None:
    """Save the predicted label on the images.

    :param output_path:     Path to the output video file.
    :param show_confidence: Whether to show confidence scores on the image.
    """
    image = self.draw(show_confidence=show_confidence)
    save_image(image=image, path=output_path)

show(show_confidence=True)

Display the image with predicted label.

Parameters:

Name Type Description Default
show_confidence bool

Whether to show confidence scores on the image.

True
Source code in src/super_gradients/training/utils/predict/prediction_results.py
75
76
77
78
79
80
81
82
def show(self, show_confidence: bool = True) -> None:
    """Display the image with predicted label.

    :param show_confidence: Whether to show confidence scores on the image.
    """
    # to do draw the prediction on the image
    image = self.draw(show_confidence=show_confidence)
    show_image(image)

ImageDetectionPrediction dataclass

Bases: ImagePrediction

Object wrapping an image and a detection model's prediction.

:attr image: Input image :attr predictions: Predictions of the model :attr class_names: List of the class names to predict

Source code in src/super_gradients/training/utils/predict/prediction_results.py
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
@dataclass
class ImageDetectionPrediction(ImagePrediction):
    """Object wrapping an image and a detection model's prediction.

    :attr image:        Input image
    :attr predictions:  Predictions of the model
    :attr class_names:  List of the class names to predict
    """

    image: np.ndarray
    prediction: DetectionPrediction
    class_names: List[str]

    def draw(
        self,
        box_thickness: Optional[int] = None,
        show_confidence: bool = True,
        color_mapping: Optional[List[Tuple[int, int, int]]] = None,
        target_bboxes: Optional[np.ndarray] = None,
        target_bboxes_format: Optional[str] = None,
        target_class_ids: Optional[np.ndarray] = None,
        class_names: Optional[List[str]] = None,
    ) -> np.ndarray:
        """Draw the predicted bboxes on the image.

        :param box_thickness:           (Optional) Thickness of bounding boxes. If None, will adapt to the box size.
        :param show_confidence:         Whether to show confidence scores on the image.
        :param color_mapping:           List of tuples representing the colors for each class.
                                        Default is None, which generates a default color mapping based on the number of class names.
        :param target_bboxes:           Optional[Union[np.ndarray, List[np.ndarray]]], ground truth bounding boxes.
                                        Can either be an np.ndarray of shape (image_i_object_count, 4) when predicting a single image,
                                        or a list of length len(target_bboxes), containing such arrays.
                                        When not None, will plot the predictions and the ground truth bounding boxes side by side (i.e 2 images stitched as one)
        :param target_class_ids:        Optional[Union[np.ndarray, List[np.ndarray]]], ground truth target class indices. Can either be an np.ndarray of shape
                                        (image_i_object_count) when predicting a single image, or a list of length len(target_bboxes), containing such arrays.
        :param target_bboxes_format:    Optional[str], bounding box format of target_bboxes, one of
                                        ['xyxy','xywh', 'yxyx' 'cxcywh' 'normalized_xyxy' 'normalized_xywh', 'normalized_yxyx', 'normalized_cxcywh'].
                                        Will raise an error if not None and target_bboxes is None.
        :param class_names:             List of class names to show. By default, is None which shows all classes using during training.

        :return:                Image with predicted bboxes. Note that this does not modify the original image.
        """
        image = self.image.copy()

        target_bboxes = target_bboxes if target_bboxes is not None else np.zeros((0, 4))
        target_class_ids = target_class_ids if target_class_ids is not None else np.zeros((0, 1))

        class_names_to_show = class_names if class_names else self.class_names
        class_ids_to_show = [i for i, class_name in enumerate(self.class_names) if class_name in class_names_to_show]
        invalid_class_names_to_show = set(class_names_to_show) - set(self.class_names)
        if len(invalid_class_names_to_show) > 0:
            raise ValueError(
                "`class_names` includes class names that the model was not trained on.\n"
                f"    - Invalid class names:   {list(invalid_class_names_to_show)}\n"
                f"    - Available class names: {list(self.class_names)}"
            )

        bbox_format_factory = BBoxFormatFactory()
        if len(target_bboxes):
            target_bboxes_xyxy = convert_bboxes(
                bboxes=target_bboxes,
                image_shape=self.prediction.image_shape,
                source_format=bbox_format_factory.get(target_bboxes_format),
                target_format=bbox_format_factory.get("xyxy"),
                inplace=False,
            )
        else:
            target_bboxes_xyxy = target_bboxes

        plot_targets = any([len(tbbx) > 0 for tbbx in target_bboxes_xyxy])
        color_mapping = color_mapping or generate_color_mapping(len(self.class_names))

        for pred_i in np.argsort(self.prediction.confidence):

            class_id = int(self.prediction.labels[pred_i])
            if class_id in class_ids_to_show:
                score = "" if not show_confidence else str(round(self.prediction.confidence[pred_i], 2))
                image = draw_bbox(
                    image=image,
                    title=f"{self.class_names[class_id]} {score}",
                    color=color_mapping[class_id],
                    box_thickness=box_thickness,
                    x1=int(self.prediction.bboxes_xyxy[pred_i, 0]),
                    y1=int(self.prediction.bboxes_xyxy[pred_i, 1]),
                    x2=int(self.prediction.bboxes_xyxy[pred_i, 2]),
                    y2=int(self.prediction.bboxes_xyxy[pred_i, 3]),
                )

        if plot_targets:
            target_image = self.image.copy()
            for target_idx in range(len(target_bboxes_xyxy)):
                class_id = int(target_class_ids[target_idx])
                if class_id in class_ids_to_show:
                    target_image = draw_bbox(
                        image=target_image,
                        title=f"{self.class_names[class_id]}",
                        color=color_mapping[class_id],
                        box_thickness=box_thickness,
                        x1=int(target_bboxes_xyxy[target_idx, 0]),
                        y1=int(target_bboxes_xyxy[target_idx, 1]),
                        x2=int(target_bboxes_xyxy[target_idx, 2]),
                        y2=int(target_bboxes_xyxy[target_idx, 3]),
                    )

            height, width, ch = target_image.shape
            new_width, new_height = int(width + width / 20), int(height + height / 8)

            # Crate a new canvas with new width and height.
            canvas_image = np.ones((new_height, new_width, ch), dtype=np.uint8) * 255
            canvas_target = np.ones((new_height, new_width, ch), dtype=np.uint8) * 255

            # New replace the center of canvas with original image
            padding_top, padding_left = 60, 10

            canvas_image[padding_top : padding_top + height, padding_left : padding_left + width] = image
            canvas_target[padding_top : padding_top + height, padding_left : padding_left + width] = target_image

            img1 = cv2.putText(canvas_image, "Predictions", (int(0.25 * width), 30), cv2.FONT_HERSHEY_COMPLEX, 1, (0, 0, 0))
            img2 = cv2.putText(canvas_target, "Ground Truth", (int(0.25 * width), 30), cv2.FONT_HERSHEY_COMPLEX, 1, (0, 0, 0))

            image = cv2.hconcat((img1, img2))
        return image

    def show(
        self,
        box_thickness: Optional[int] = None,
        show_confidence: bool = True,
        color_mapping: Optional[List[Tuple[int, int, int]]] = None,
        target_bboxes: Optional[np.ndarray] = None,
        target_bboxes_format: Optional[str] = None,
        target_class_ids: Optional[np.ndarray] = None,
        class_names: Optional[List[str]] = None,
    ) -> None:

        """Display the image with predicted bboxes.

        :param box_thickness:           (Optional) Thickness of bounding boxes. If None, will adapt to the box size.
        :param show_confidence:         Whether to show confidence scores on the image.
        :param color_mapping:           List of tuples representing the colors for each class.
                                        Default is None, which generates a default color mapping based on the number of class names.
        :param target_bboxes:           Optional[Union[np.ndarray, List[np.ndarray]]], ground truth bounding boxes.
                                        Can either be an np.ndarray of shape (image_i_object_count, 4) when predicting a single image,
                                        or a list of length len(target_bboxes), containing such arrays.
                                        When not None, will plot the predictions and the ground truth bounding boxes side by side (i.e 2 images stitched as one)
        :param target_class_ids:        Optional[Union[np.ndarray, List[np.ndarray]]], ground truth target class indices. Can either be an np.ndarray of shape
                                        (image_i_object_count) when predicting a single image, or a list of length len(target_bboxes), containing such arrays.
        :param target_bboxes_format:    Optional[str], bounding box format of target_bboxes, one of
                                        ['xyxy','xywh', 'yxyx' 'cxcywh' 'normalized_xyxy' 'normalized_xywh', 'normalized_yxyx', 'normalized_cxcywh'].
                                        Will raise an error if not None and target_bboxes is None.
        :param class_names:             List of class names to show. By default, is None which shows all classes using during training.
        """
        image = self.draw(
            box_thickness=box_thickness,
            show_confidence=show_confidence,
            color_mapping=color_mapping,
            target_bboxes=target_bboxes,
            target_bboxes_format=target_bboxes_format,
            target_class_ids=target_class_ids,
            class_names=class_names,
        )
        show_image(image)

    def save(
        self,
        output_path: str,
        box_thickness: Optional[int] = None,
        show_confidence: bool = True,
        color_mapping: Optional[List[Tuple[int, int, int]]] = None,
        target_bboxes: Optional[np.ndarray] = None,
        target_bboxes_format: Optional[str] = None,
        target_class_ids: Optional[np.ndarray] = None,
        class_names: Optional[List[str]] = None,
    ) -> None:
        """Save the predicted bboxes on the images.

        :param output_path:             Path to the output video file.
        :param box_thickness:           (Optional) Thickness of bounding boxes. If None, will adapt to the box size.
        :param show_confidence:         Whether to show confidence scores on the image.
        :param color_mapping:           List of tuples representing the colors for each class.
                                        Default is None, which generates a default color mapping based on the number of class names.
        :param target_bboxes:           Optional[Union[np.ndarray, List[np.ndarray]]], ground truth bounding boxes.
                                        Can either be an np.ndarray of shape (image_i_object_count, 4) when predicting a single image,
                                        or a list of length len(target_bboxes), containing such arrays.
                                        When not None, will plot the predictions and the ground truth bounding boxes side by side (i.e 2 images stitched as one)
        :param target_class_ids:        Optional[Union[np.ndarray, List[np.ndarray]]], ground truth target class indices. Can either be an np.ndarray of shape
                                        (image_i_object_count) when predicting a single image, or a list of length len(target_bboxes), containing such arrays.
        :param target_bboxes_format:    Optional[str], bounding box format of target_bboxes, one of
                                        ['xyxy','xywh', 'yxyx' 'cxcywh' 'normalized_xyxy' 'normalized_xywh', 'normalized_yxyx', 'normalized_cxcywh'].
                                        Will raise an error if not None and target_bboxes is None.
        :param class_names:             List of class names to show. By default, is None which shows all classes using during training.
        """
        image = self.draw(
            box_thickness=box_thickness,
            show_confidence=show_confidence,
            color_mapping=color_mapping,
            target_bboxes=target_bboxes,
            target_bboxes_format=target_bboxes_format,
            target_class_ids=target_class_ids,
            class_names=class_names,
        )
        save_image(image=image, path=output_path)

draw(box_thickness=None, show_confidence=True, color_mapping=None, target_bboxes=None, target_bboxes_format=None, target_class_ids=None, class_names=None)

Draw the predicted bboxes on the image.

Parameters:

Name Type Description Default
box_thickness Optional[int]

(Optional) Thickness of bounding boxes. If None, will adapt to the box size.

None
show_confidence bool

Whether to show confidence scores on the image.

True
color_mapping Optional[List[Tuple[int, int, int]]]

List of tuples representing the colors for each class. Default is None, which generates a default color mapping based on the number of class names.

None
target_bboxes Optional[np.ndarray]

Optional[Union[np.ndarray, List[np.ndarray]]], ground truth bounding boxes. Can either be an np.ndarray of shape (image_i_object_count, 4) when predicting a single image, or a list of length len(target_bboxes), containing such arrays. When not None, will plot the predictions and the ground truth bounding boxes side by side (i.e 2 images stitched as one)

None
target_class_ids Optional[np.ndarray]

Optional[Union[np.ndarray, List[np.ndarray]]], ground truth target class indices. Can either be an np.ndarray of shape (image_i_object_count) when predicting a single image, or a list of length len(target_bboxes), containing such arrays.

None
target_bboxes_format Optional[str]

Optional[str], bounding box format of target_bboxes, one of ['xyxy','xywh', 'yxyx' 'cxcywh' 'normalized_xyxy' 'normalized_xywh', 'normalized_yxyx', 'normalized_cxcywh']. Will raise an error if not None and target_bboxes is None.

None
class_names Optional[List[str]]

List of class names to show. By default, is None which shows all classes using during training.

None

Returns:

Type Description
np.ndarray

Image with predicted bboxes. Note that this does not modify the original image.

Source code in src/super_gradients/training/utils/predict/prediction_results.py
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
def draw(
    self,
    box_thickness: Optional[int] = None,
    show_confidence: bool = True,
    color_mapping: Optional[List[Tuple[int, int, int]]] = None,
    target_bboxes: Optional[np.ndarray] = None,
    target_bboxes_format: Optional[str] = None,
    target_class_ids: Optional[np.ndarray] = None,
    class_names: Optional[List[str]] = None,
) -> np.ndarray:
    """Draw the predicted bboxes on the image.

    :param box_thickness:           (Optional) Thickness of bounding boxes. If None, will adapt to the box size.
    :param show_confidence:         Whether to show confidence scores on the image.
    :param color_mapping:           List of tuples representing the colors for each class.
                                    Default is None, which generates a default color mapping based on the number of class names.
    :param target_bboxes:           Optional[Union[np.ndarray, List[np.ndarray]]], ground truth bounding boxes.
                                    Can either be an np.ndarray of shape (image_i_object_count, 4) when predicting a single image,
                                    or a list of length len(target_bboxes), containing such arrays.
                                    When not None, will plot the predictions and the ground truth bounding boxes side by side (i.e 2 images stitched as one)
    :param target_class_ids:        Optional[Union[np.ndarray, List[np.ndarray]]], ground truth target class indices. Can either be an np.ndarray of shape
                                    (image_i_object_count) when predicting a single image, or a list of length len(target_bboxes), containing such arrays.
    :param target_bboxes_format:    Optional[str], bounding box format of target_bboxes, one of
                                    ['xyxy','xywh', 'yxyx' 'cxcywh' 'normalized_xyxy' 'normalized_xywh', 'normalized_yxyx', 'normalized_cxcywh'].
                                    Will raise an error if not None and target_bboxes is None.
    :param class_names:             List of class names to show. By default, is None which shows all classes using during training.

    :return:                Image with predicted bboxes. Note that this does not modify the original image.
    """
    image = self.image.copy()

    target_bboxes = target_bboxes if target_bboxes is not None else np.zeros((0, 4))
    target_class_ids = target_class_ids if target_class_ids is not None else np.zeros((0, 1))

    class_names_to_show = class_names if class_names else self.class_names
    class_ids_to_show = [i for i, class_name in enumerate(self.class_names) if class_name in class_names_to_show]
    invalid_class_names_to_show = set(class_names_to_show) - set(self.class_names)
    if len(invalid_class_names_to_show) > 0:
        raise ValueError(
            "`class_names` includes class names that the model was not trained on.\n"
            f"    - Invalid class names:   {list(invalid_class_names_to_show)}\n"
            f"    - Available class names: {list(self.class_names)}"
        )

    bbox_format_factory = BBoxFormatFactory()
    if len(target_bboxes):
        target_bboxes_xyxy = convert_bboxes(
            bboxes=target_bboxes,
            image_shape=self.prediction.image_shape,
            source_format=bbox_format_factory.get(target_bboxes_format),
            target_format=bbox_format_factory.get("xyxy"),
            inplace=False,
        )
    else:
        target_bboxes_xyxy = target_bboxes

    plot_targets = any([len(tbbx) > 0 for tbbx in target_bboxes_xyxy])
    color_mapping = color_mapping or generate_color_mapping(len(self.class_names))

    for pred_i in np.argsort(self.prediction.confidence):

        class_id = int(self.prediction.labels[pred_i])
        if class_id in class_ids_to_show:
            score = "" if not show_confidence else str(round(self.prediction.confidence[pred_i], 2))
            image = draw_bbox(
                image=image,
                title=f"{self.class_names[class_id]} {score}",
                color=color_mapping[class_id],
                box_thickness=box_thickness,
                x1=int(self.prediction.bboxes_xyxy[pred_i, 0]),
                y1=int(self.prediction.bboxes_xyxy[pred_i, 1]),
                x2=int(self.prediction.bboxes_xyxy[pred_i, 2]),
                y2=int(self.prediction.bboxes_xyxy[pred_i, 3]),
            )

    if plot_targets:
        target_image = self.image.copy()
        for target_idx in range(len(target_bboxes_xyxy)):
            class_id = int(target_class_ids[target_idx])
            if class_id in class_ids_to_show:
                target_image = draw_bbox(
                    image=target_image,
                    title=f"{self.class_names[class_id]}",
                    color=color_mapping[class_id],
                    box_thickness=box_thickness,
                    x1=int(target_bboxes_xyxy[target_idx, 0]),
                    y1=int(target_bboxes_xyxy[target_idx, 1]),
                    x2=int(target_bboxes_xyxy[target_idx, 2]),
                    y2=int(target_bboxes_xyxy[target_idx, 3]),
                )

        height, width, ch = target_image.shape
        new_width, new_height = int(width + width / 20), int(height + height / 8)

        # Crate a new canvas with new width and height.
        canvas_image = np.ones((new_height, new_width, ch), dtype=np.uint8) * 255
        canvas_target = np.ones((new_height, new_width, ch), dtype=np.uint8) * 255

        # New replace the center of canvas with original image
        padding_top, padding_left = 60, 10

        canvas_image[padding_top : padding_top + height, padding_left : padding_left + width] = image
        canvas_target[padding_top : padding_top + height, padding_left : padding_left + width] = target_image

        img1 = cv2.putText(canvas_image, "Predictions", (int(0.25 * width), 30), cv2.FONT_HERSHEY_COMPLEX, 1, (0, 0, 0))
        img2 = cv2.putText(canvas_target, "Ground Truth", (int(0.25 * width), 30), cv2.FONT_HERSHEY_COMPLEX, 1, (0, 0, 0))

        image = cv2.hconcat((img1, img2))
    return image

save(output_path, box_thickness=None, show_confidence=True, color_mapping=None, target_bboxes=None, target_bboxes_format=None, target_class_ids=None, class_names=None)

Save the predicted bboxes on the images.

Parameters:

Name Type Description Default
output_path str

Path to the output video file.

required
box_thickness Optional[int]

(Optional) Thickness of bounding boxes. If None, will adapt to the box size.

None
show_confidence bool

Whether to show confidence scores on the image.

True
color_mapping Optional[List[Tuple[int, int, int]]]

List of tuples representing the colors for each class. Default is None, which generates a default color mapping based on the number of class names.

None
target_bboxes Optional[np.ndarray]

Optional[Union[np.ndarray, List[np.ndarray]]], ground truth bounding boxes. Can either be an np.ndarray of shape (image_i_object_count, 4) when predicting a single image, or a list of length len(target_bboxes), containing such arrays. When not None, will plot the predictions and the ground truth bounding boxes side by side (i.e 2 images stitched as one)

None
target_class_ids Optional[np.ndarray]

Optional[Union[np.ndarray, List[np.ndarray]]], ground truth target class indices. Can either be an np.ndarray of shape (image_i_object_count) when predicting a single image, or a list of length len(target_bboxes), containing such arrays.

None
target_bboxes_format Optional[str]

Optional[str], bounding box format of target_bboxes, one of ['xyxy','xywh', 'yxyx' 'cxcywh' 'normalized_xyxy' 'normalized_xywh', 'normalized_yxyx', 'normalized_cxcywh']. Will raise an error if not None and target_bboxes is None.

None
class_names Optional[List[str]]

List of class names to show. By default, is None which shows all classes using during training.

None
Source code in src/super_gradients/training/utils/predict/prediction_results.py
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
def save(
    self,
    output_path: str,
    box_thickness: Optional[int] = None,
    show_confidence: bool = True,
    color_mapping: Optional[List[Tuple[int, int, int]]] = None,
    target_bboxes: Optional[np.ndarray] = None,
    target_bboxes_format: Optional[str] = None,
    target_class_ids: Optional[np.ndarray] = None,
    class_names: Optional[List[str]] = None,
) -> None:
    """Save the predicted bboxes on the images.

    :param output_path:             Path to the output video file.
    :param box_thickness:           (Optional) Thickness of bounding boxes. If None, will adapt to the box size.
    :param show_confidence:         Whether to show confidence scores on the image.
    :param color_mapping:           List of tuples representing the colors for each class.
                                    Default is None, which generates a default color mapping based on the number of class names.
    :param target_bboxes:           Optional[Union[np.ndarray, List[np.ndarray]]], ground truth bounding boxes.
                                    Can either be an np.ndarray of shape (image_i_object_count, 4) when predicting a single image,
                                    or a list of length len(target_bboxes), containing such arrays.
                                    When not None, will plot the predictions and the ground truth bounding boxes side by side (i.e 2 images stitched as one)
    :param target_class_ids:        Optional[Union[np.ndarray, List[np.ndarray]]], ground truth target class indices. Can either be an np.ndarray of shape
                                    (image_i_object_count) when predicting a single image, or a list of length len(target_bboxes), containing such arrays.
    :param target_bboxes_format:    Optional[str], bounding box format of target_bboxes, one of
                                    ['xyxy','xywh', 'yxyx' 'cxcywh' 'normalized_xyxy' 'normalized_xywh', 'normalized_yxyx', 'normalized_cxcywh'].
                                    Will raise an error if not None and target_bboxes is None.
    :param class_names:             List of class names to show. By default, is None which shows all classes using during training.
    """
    image = self.draw(
        box_thickness=box_thickness,
        show_confidence=show_confidence,
        color_mapping=color_mapping,
        target_bboxes=target_bboxes,
        target_bboxes_format=target_bboxes_format,
        target_class_ids=target_class_ids,
        class_names=class_names,
    )
    save_image(image=image, path=output_path)

show(box_thickness=None, show_confidence=True, color_mapping=None, target_bboxes=None, target_bboxes_format=None, target_class_ids=None, class_names=None)

Display the image with predicted bboxes.

Parameters:

Name Type Description Default
box_thickness Optional[int]

(Optional) Thickness of bounding boxes. If None, will adapt to the box size.

None
show_confidence bool

Whether to show confidence scores on the image.

True
color_mapping Optional[List[Tuple[int, int, int]]]

List of tuples representing the colors for each class. Default is None, which generates a default color mapping based on the number of class names.

None
target_bboxes Optional[np.ndarray]

Optional[Union[np.ndarray, List[np.ndarray]]], ground truth bounding boxes. Can either be an np.ndarray of shape (image_i_object_count, 4) when predicting a single image, or a list of length len(target_bboxes), containing such arrays. When not None, will plot the predictions and the ground truth bounding boxes side by side (i.e 2 images stitched as one)

None
target_class_ids Optional[np.ndarray]

Optional[Union[np.ndarray, List[np.ndarray]]], ground truth target class indices. Can either be an np.ndarray of shape (image_i_object_count) when predicting a single image, or a list of length len(target_bboxes), containing such arrays.

None
target_bboxes_format Optional[str]

Optional[str], bounding box format of target_bboxes, one of ['xyxy','xywh', 'yxyx' 'cxcywh' 'normalized_xyxy' 'normalized_xywh', 'normalized_yxyx', 'normalized_cxcywh']. Will raise an error if not None and target_bboxes is None.

None
class_names Optional[List[str]]

List of class names to show. By default, is None which shows all classes using during training.

None
Source code in src/super_gradients/training/utils/predict/prediction_results.py
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
def show(
    self,
    box_thickness: Optional[int] = None,
    show_confidence: bool = True,
    color_mapping: Optional[List[Tuple[int, int, int]]] = None,
    target_bboxes: Optional[np.ndarray] = None,
    target_bboxes_format: Optional[str] = None,
    target_class_ids: Optional[np.ndarray] = None,
    class_names: Optional[List[str]] = None,
) -> None:

    """Display the image with predicted bboxes.

    :param box_thickness:           (Optional) Thickness of bounding boxes. If None, will adapt to the box size.
    :param show_confidence:         Whether to show confidence scores on the image.
    :param color_mapping:           List of tuples representing the colors for each class.
                                    Default is None, which generates a default color mapping based on the number of class names.
    :param target_bboxes:           Optional[Union[np.ndarray, List[np.ndarray]]], ground truth bounding boxes.
                                    Can either be an np.ndarray of shape (image_i_object_count, 4) when predicting a single image,
                                    or a list of length len(target_bboxes), containing such arrays.
                                    When not None, will plot the predictions and the ground truth bounding boxes side by side (i.e 2 images stitched as one)
    :param target_class_ids:        Optional[Union[np.ndarray, List[np.ndarray]]], ground truth target class indices. Can either be an np.ndarray of shape
                                    (image_i_object_count) when predicting a single image, or a list of length len(target_bboxes), containing such arrays.
    :param target_bboxes_format:    Optional[str], bounding box format of target_bboxes, one of
                                    ['xyxy','xywh', 'yxyx' 'cxcywh' 'normalized_xyxy' 'normalized_xywh', 'normalized_yxyx', 'normalized_cxcywh'].
                                    Will raise an error if not None and target_bboxes is None.
    :param class_names:             List of class names to show. By default, is None which shows all classes using during training.
    """
    image = self.draw(
        box_thickness=box_thickness,
        show_confidence=show_confidence,
        color_mapping=color_mapping,
        target_bboxes=target_bboxes,
        target_bboxes_format=target_bboxes_format,
        target_class_ids=target_class_ids,
        class_names=class_names,
    )
    show_image(image)

ImagePrediction dataclass

Bases: ABC

Object wrapping an image and a model's prediction.

:attr image: Input image :attr predictions: Predictions of the model :attr class_names: List of the class names to predict

Source code in src/super_gradients/training/utils/predict/prediction_results.py
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
@dataclass
class ImagePrediction(ABC):
    """Object wrapping an image and a model's prediction.

    :attr image:        Input image
    :attr predictions:  Predictions of the model
    :attr class_names:  List of the class names to predict
    """

    image: np.ndarray
    prediction: Prediction
    class_names: List[str]

    @abstractmethod
    def draw(self, *args, **kwargs) -> np.ndarray:
        """Draw the predictions on the image."""
        pass

    @abstractmethod
    def show(self, *args, **kwargs) -> None:
        """Display the predictions on the image."""
        pass

    @abstractmethod
    def save(self, *args, **kwargs) -> None:
        """Save the predictions on the image."""
        pass

draw(*args, **kwargs) abstractmethod

Draw the predictions on the image.

Source code in src/super_gradients/training/utils/predict/prediction_results.py
36
37
38
39
@abstractmethod
def draw(self, *args, **kwargs) -> np.ndarray:
    """Draw the predictions on the image."""
    pass

save(*args, **kwargs) abstractmethod

Save the predictions on the image.

Source code in src/super_gradients/training/utils/predict/prediction_results.py
46
47
48
49
@abstractmethod
def save(self, *args, **kwargs) -> None:
    """Save the predictions on the image."""
    pass

show(*args, **kwargs) abstractmethod

Display the predictions on the image.

Source code in src/super_gradients/training/utils/predict/prediction_results.py
41
42
43
44
@abstractmethod
def show(self, *args, **kwargs) -> None:
    """Display the predictions on the image."""
    pass

ImageSegmentationPrediction dataclass

Bases: ImagePrediction

Object wrapping an image and a segmentation model's prediction.

:attr image: Input image :attr predictions: Predictions of the model :attr class_names: List of the class names to predict

Source code in src/super_gradients/training/utils/predict/prediction_results.py
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
@dataclass
class ImageSegmentationPrediction(ImagePrediction):
    """Object wrapping an image and a segmentation model's prediction.

    :attr image:        Input image
    :attr predictions:  Predictions of the model
    :attr class_names:  List of the class names to predict
    """

    image: np.ndarray
    prediction: SegmentationPrediction
    class_names: List[str]

    def draw(self, alpha: float = 0.6, color_mapping: Optional[List[Tuple[int, int, int]]] = None, class_names: Optional[List[str]] = None) -> np.ndarray:
        """Draw the predicted segmentation on the image.

        :param alpha:           Float number between [0,1] denoting the transparency of the masks (0 means full transparency, 1 means opacity).
        :param color_mapping:   List of tuples representing the colors for each class.
                                Default is None, which generates a default color mapping based on the number of class names.
        :param class_names:     List of class names to predict (segmentation classes)
        :return:                Image with predicted segmentation. Note that this does not modify the original image.
        """
        image = self.image.copy()
        class_names = class_names or self.class_names
        if len(class_names) == 1:
            class_names = ["background"] + class_names
        color_mapping = color_mapping or generate_color_mapping(len(class_names))

        return overlay_segmentation(
            image=image, pred_mask=self.prediction.segmentation_map, num_classes=len(class_names), alpha=alpha, colors=color_mapping, class_names=class_names
        )

    def show(self, alpha: float = 0.6, color_mapping: Optional[List[Tuple[int, int, int]]] = None) -> None:
        """Display the image with segmentation prediction overlay.

        :param alpha:           Float number between [0,1] denoting the transparency of the masks (0 means full transparency, 1 means opacity).
        :param color_mapping:   List of tuples representing the colors for each class.
                                Default is None, which generates a default color mapping based on the number of class names.
        """
        image = self.draw(alpha=alpha, color_mapping=color_mapping, class_names=self.class_names)
        show_image(image)

    def save(self, output_path: str, alpha=0.6, color_mapping: Optional[List[Tuple[int, int, int]]] = None) -> None:
        """Save the predicted segmentation on the images.

        :param alpha:           Float number between [0,1] denoting the transparency of the masks (0 means full transparency, 1 means opacity).
        :param output_path:     Path to the output file.
        :param color_mapping:   List of tuples representing the colors for each class.
                                Default is None, which generates a default color mapping based on the number of class names.
        """
        image = self.draw(alpha=alpha, color_mapping=color_mapping, class_names=self.class_names)
        save_image(image=image, path=output_path)

draw(alpha=0.6, color_mapping=None, class_names=None)

Draw the predicted segmentation on the image.

Parameters:

Name Type Description Default
alpha float

Float number between [0,1] denoting the transparency of the masks (0 means full transparency, 1 means opacity).

0.6
color_mapping Optional[List[Tuple[int, int, int]]]

List of tuples representing the colors for each class. Default is None, which generates a default color mapping based on the number of class names.

None
class_names Optional[List[str]]

List of class names to predict (segmentation classes)

None

Returns:

Type Description
np.ndarray

Image with predicted segmentation. Note that this does not modify the original image.

Source code in src/super_gradients/training/utils/predict/prediction_results.py
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
def draw(self, alpha: float = 0.6, color_mapping: Optional[List[Tuple[int, int, int]]] = None, class_names: Optional[List[str]] = None) -> np.ndarray:
    """Draw the predicted segmentation on the image.

    :param alpha:           Float number between [0,1] denoting the transparency of the masks (0 means full transparency, 1 means opacity).
    :param color_mapping:   List of tuples representing the colors for each class.
                            Default is None, which generates a default color mapping based on the number of class names.
    :param class_names:     List of class names to predict (segmentation classes)
    :return:                Image with predicted segmentation. Note that this does not modify the original image.
    """
    image = self.image.copy()
    class_names = class_names or self.class_names
    if len(class_names) == 1:
        class_names = ["background"] + class_names
    color_mapping = color_mapping or generate_color_mapping(len(class_names))

    return overlay_segmentation(
        image=image, pred_mask=self.prediction.segmentation_map, num_classes=len(class_names), alpha=alpha, colors=color_mapping, class_names=class_names
    )

save(output_path, alpha=0.6, color_mapping=None)

Save the predicted segmentation on the images.

Parameters:

Name Type Description Default
alpha

Float number between [0,1] denoting the transparency of the masks (0 means full transparency, 1 means opacity).

0.6
output_path str

Path to the output file.

required
color_mapping Optional[List[Tuple[int, int, int]]]

List of tuples representing the colors for each class. Default is None, which generates a default color mapping based on the number of class names.

None
Source code in src/super_gradients/training/utils/predict/prediction_results.py
343
344
345
346
347
348
349
350
351
352
def save(self, output_path: str, alpha=0.6, color_mapping: Optional[List[Tuple[int, int, int]]] = None) -> None:
    """Save the predicted segmentation on the images.

    :param alpha:           Float number between [0,1] denoting the transparency of the masks (0 means full transparency, 1 means opacity).
    :param output_path:     Path to the output file.
    :param color_mapping:   List of tuples representing the colors for each class.
                            Default is None, which generates a default color mapping based on the number of class names.
    """
    image = self.draw(alpha=alpha, color_mapping=color_mapping, class_names=self.class_names)
    save_image(image=image, path=output_path)

show(alpha=0.6, color_mapping=None)

Display the image with segmentation prediction overlay.

Parameters:

Name Type Description Default
alpha float

Float number between [0,1] denoting the transparency of the masks (0 means full transparency, 1 means opacity).

0.6
color_mapping Optional[List[Tuple[int, int, int]]]

List of tuples representing the colors for each class. Default is None, which generates a default color mapping based on the number of class names.

None
Source code in src/super_gradients/training/utils/predict/prediction_results.py
333
334
335
336
337
338
339
340
341
def show(self, alpha: float = 0.6, color_mapping: Optional[List[Tuple[int, int, int]]] = None) -> None:
    """Display the image with segmentation prediction overlay.

    :param alpha:           Float number between [0,1] denoting the transparency of the masks (0 means full transparency, 1 means opacity).
    :param color_mapping:   List of tuples representing the colors for each class.
                            Default is None, which generates a default color mapping based on the number of class names.
    """
    image = self.draw(alpha=alpha, color_mapping=color_mapping, class_names=self.class_names)
    show_image(image)

ImagesClassificationPrediction dataclass

Bases: ImagesPredictions

Object wrapping the list of image classification predictions.

:attr _images_prediction_lst: List of the predictions results

Source code in src/super_gradients/training/utils/predict/prediction_results.py
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
@dataclass
class ImagesClassificationPrediction(ImagesPredictions):
    """Object wrapping the list of image classification predictions.

    :attr _images_prediction_lst:  List of the predictions results
    """

    _images_prediction_lst: List[ImageClassificationPrediction]

    def show(self, show_confidence: bool = True) -> None:
        """Display the predicted labels on the images.
        :param show_confidence: Whether to show confidence scores on the image.
        """
        for prediction in self._images_prediction_lst:
            prediction.show(show_confidence=show_confidence)

    def save(self, output_folder: str, show_confidence: bool = True) -> None:
        """Save the predicted label on the images.

        :param output_folder:     Folder path, where the images will be saved.
        :param show_confidence: Whether to show confidence scores on the image.
        """
        if output_folder:
            os.makedirs(output_folder, exist_ok=True)

        for i, prediction in enumerate(self._images_prediction_lst):
            image_output_path = os.path.join(output_folder, f"pred_{i}.jpg")
            prediction.save(output_path=image_output_path, show_confidence=show_confidence)

save(output_folder, show_confidence=True)

Save the predicted label on the images.

Parameters:

Name Type Description Default
output_folder str

Folder path, where the images will be saved.

required
show_confidence bool

Whether to show confidence scores on the image.

True
Source code in src/super_gradients/training/utils/predict/prediction_results.py
423
424
425
426
427
428
429
430
431
432
433
434
def save(self, output_folder: str, show_confidence: bool = True) -> None:
    """Save the predicted label on the images.

    :param output_folder:     Folder path, where the images will be saved.
    :param show_confidence: Whether to show confidence scores on the image.
    """
    if output_folder:
        os.makedirs(output_folder, exist_ok=True)

    for i, prediction in enumerate(self._images_prediction_lst):
        image_output_path = os.path.join(output_folder, f"pred_{i}.jpg")
        prediction.save(output_path=image_output_path, show_confidence=show_confidence)

show(show_confidence=True)

Display the predicted labels on the images.

Parameters:

Name Type Description Default
show_confidence bool

Whether to show confidence scores on the image.

True
Source code in src/super_gradients/training/utils/predict/prediction_results.py
416
417
418
419
420
421
def show(self, show_confidence: bool = True) -> None:
    """Display the predicted labels on the images.
    :param show_confidence: Whether to show confidence scores on the image.
    """
    for prediction in self._images_prediction_lst:
        prediction.show(show_confidence=show_confidence)

ImagesDetectionPrediction dataclass

Bases: ImagesPredictions

Object wrapping the list of image detection predictions.

:attr _images_prediction_lst: List of the predictions results

Source code in src/super_gradients/training/utils/predict/prediction_results.py
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
@dataclass
class ImagesDetectionPrediction(ImagesPredictions):
    """Object wrapping the list of image detection predictions.

    :attr _images_prediction_lst:  List of the predictions results
    """

    _images_prediction_lst: List[ImageDetectionPrediction]

    def show(
        self,
        box_thickness: Optional[int] = None,
        show_confidence: bool = True,
        color_mapping: Optional[List[Tuple[int, int, int]]] = None,
        target_bboxes: Optional[Union[np.ndarray, List[np.ndarray]]] = None,
        target_bboxes_format: Optional[str] = None,
        target_class_ids: Optional[Union[np.ndarray, List[np.ndarray]]] = None,
        class_names: Optional[List[str]] = None,
    ) -> None:
        """Display the predicted bboxes on the images.

        :param box_thickness:           (Optional) Thickness of bounding boxes. If None, will adapt to the box size.
        :param show_confidence:         Whether to show confidence scores on the image.
        :param color_mapping:           List of tuples representing the colors for each class.
                                        Default is None, which generates a default color mapping based on the number of class names.
        :param target_bboxes:           Optional[Union[np.ndarray, List[np.ndarray]]], ground truth bounding boxes.
                                        Can either be an np.ndarray of shape (image_i_object_count, 4) when predicting a single image,
                                        or a list of length len(target_bboxes), containing such arrays.
                                        When not None, will plot the predictions and the ground truth bounding boxes side by side (i.e 2 images stitched as one)
        :param target_class_ids:        Optional[Union[np.ndarray, List[np.ndarray]]], ground truth target class indices. Can either be an np.ndarray of shape
                                        (image_i_object_count) when predicting a single image, or a list of length len(target_bboxes), containing such arrays.
        :param target_bboxes_format:    Optional[str], bounding box format of target_bboxes, one of
                                        ['xyxy','xywh', 'yxyx' 'cxcywh' 'normalized_xyxy' 'normalized_xywh', 'normalized_yxyx', 'normalized_cxcywh'].
                                        Will raise an error if not None and target_bboxes is None.
        :param class_names:             List of class names to show. By default, is None which shows all classes using during training.
        """
        target_bboxes, target_class_ids = self._check_target_args(target_bboxes, target_bboxes_format, target_class_ids)

        for prediction, target_bbox, target_class_id in zip(self._images_prediction_lst, target_bboxes, target_class_ids):
            prediction.show(
                box_thickness=box_thickness,
                show_confidence=show_confidence,
                color_mapping=color_mapping,
                target_bboxes=target_bbox,
                target_bboxes_format=target_bboxes_format,
                target_class_ids=target_class_id,
                class_names=class_names,
            )

    def _check_target_args(
        self,
        target_bboxes: Optional[Union[np.ndarray, List[np.ndarray]]] = None,
        target_bboxes_format: Optional[str] = None,
        target_class_ids: Optional[Union[np.ndarray, List[np.ndarray]]] = None,
    ):
        if not (
            (target_bboxes is None and target_bboxes_format is None and target_class_ids is None)
            or (target_bboxes is not None and target_bboxes_format is not None and target_class_ids is not None)
        ):
            raise ValueError("target_bboxes, target_bboxes_format, and target_class_ids should either all be None or all not None.")

        if isinstance(target_bboxes, np.ndarray):
            target_bboxes = [target_bboxes]
        if isinstance(target_class_ids, np.ndarray):
            target_class_ids = [target_class_ids]

        if target_bboxes is not None and target_class_ids is not None and len(target_bboxes) != len(target_class_ids):
            raise ValueError(f"target_bboxes and target_class_ids lengths should be equal, got: {len(target_bboxes)} and {len(target_class_ids)}.")
        if target_bboxes is not None and target_class_ids is not None and len(target_bboxes) != len(self._images_prediction_lst):
            raise ValueError(
                f"target_bboxes and target_class_ids lengths should be equal, to the "
                f"amount of images passed to predict(), got: {len(target_bboxes)} and {len(self._images_prediction_lst)}."
            )
        if target_bboxes is None:
            target_bboxes = [None for _ in range(len(self._images_prediction_lst))]
            target_class_ids = [None for _ in range(len(self._images_prediction_lst))]

        return target_bboxes, target_class_ids

    def save(
        self,
        output_folder: str,
        box_thickness: Optional[int] = None,
        show_confidence: bool = True,
        color_mapping: Optional[List[Tuple[int, int, int]]] = None,
        target_bboxes: Optional[Union[np.ndarray, List[np.ndarray]]] = None,
        target_bboxes_format: Optional[str] = None,
        target_class_ids: Optional[Union[np.ndarray, List[np.ndarray]]] = None,
        class_names: Optional[List[str]] = None,
    ) -> None:
        """Save the predicted bboxes on the images.

        :param output_folder:           Folder path, where the images will be saved.
        :param box_thickness:           (Optional) Thickness of bounding boxes. If None, will adapt to the box size.
        :param show_confidence:         Whether to show confidence scores on the image.
        :param color_mapping:           List of tuples representing the colors for each class.
                                        Default is None, which generates a default color mapping based on the number of class names.
        :param target_bboxes:           Optional[Union[np.ndarray, List[np.ndarray]]], ground truth bounding boxes.
                                        Can either be an np.ndarray of shape (image_i_object_count, 4) when predicting a single image,
                                        or a list of length len(target_bboxes), containing such arrays.
                                        When not None, will plot the predictions and the ground truth bounding boxes side by side (i.e 2 images stitched as one)
        :param target_class_ids:        Optional[Union[np.ndarray, List[np.ndarray]]], ground truth target class indices. Can either be an np.ndarray of shape
                                        (image_i_object_count) when predicting a single image, or a list of length len(target_bboxes), containing such arrays.
        :param target_bboxes_format:    Optional[str], bounding box format of target_bboxes, one of
                                        ['xyxy','xywh', 'yxyx' 'cxcywh' 'normalized_xyxy' 'normalized_xywh', 'normalized_yxyx', 'normalized_cxcywh'].
                                        Will raise an error if not None and target_bboxes is None.
        :param class_names:             List of class names to show. By default, is None which shows all classes using during training.
        """
        if output_folder:
            os.makedirs(output_folder, exist_ok=True)

        target_bboxes, target_class_ids = self._check_target_args(target_bboxes, target_bboxes_format, target_class_ids)

        for i, (prediction, target_bbox, target_class_id) in enumerate(zip(self._images_prediction_lst, target_bboxes, target_class_ids)):
            image_output_path = os.path.join(output_folder, f"pred_{i}.jpg")
            prediction.save(
                output_path=image_output_path,
                box_thickness=box_thickness,
                show_confidence=show_confidence,
                color_mapping=color_mapping,
                class_names=class_names,
            )

save(output_folder, box_thickness=None, show_confidence=True, color_mapping=None, target_bboxes=None, target_bboxes_format=None, target_class_ids=None, class_names=None)

Save the predicted bboxes on the images.

Parameters:

Name Type Description Default
output_folder str

Folder path, where the images will be saved.

required
box_thickness Optional[int]

(Optional) Thickness of bounding boxes. If None, will adapt to the box size.

None
show_confidence bool

Whether to show confidence scores on the image.

True
color_mapping Optional[List[Tuple[int, int, int]]]

List of tuples representing the colors for each class. Default is None, which generates a default color mapping based on the number of class names.

None
target_bboxes Optional[Union[np.ndarray, List[np.ndarray]]]

Optional[Union[np.ndarray, List[np.ndarray]]], ground truth bounding boxes. Can either be an np.ndarray of shape (image_i_object_count, 4) when predicting a single image, or a list of length len(target_bboxes), containing such arrays. When not None, will plot the predictions and the ground truth bounding boxes side by side (i.e 2 images stitched as one)

None
target_class_ids Optional[Union[np.ndarray, List[np.ndarray]]]

Optional[Union[np.ndarray, List[np.ndarray]]], ground truth target class indices. Can either be an np.ndarray of shape (image_i_object_count) when predicting a single image, or a list of length len(target_bboxes), containing such arrays.

None
target_bboxes_format Optional[str]

Optional[str], bounding box format of target_bboxes, one of ['xyxy','xywh', 'yxyx' 'cxcywh' 'normalized_xyxy' 'normalized_xywh', 'normalized_yxyx', 'normalized_cxcywh']. Will raise an error if not None and target_bboxes is None.

None
class_names Optional[List[str]]

List of class names to show. By default, is None which shows all classes using during training.

None
Source code in src/super_gradients/training/utils/predict/prediction_results.py
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
def save(
    self,
    output_folder: str,
    box_thickness: Optional[int] = None,
    show_confidence: bool = True,
    color_mapping: Optional[List[Tuple[int, int, int]]] = None,
    target_bboxes: Optional[Union[np.ndarray, List[np.ndarray]]] = None,
    target_bboxes_format: Optional[str] = None,
    target_class_ids: Optional[Union[np.ndarray, List[np.ndarray]]] = None,
    class_names: Optional[List[str]] = None,
) -> None:
    """Save the predicted bboxes on the images.

    :param output_folder:           Folder path, where the images will be saved.
    :param box_thickness:           (Optional) Thickness of bounding boxes. If None, will adapt to the box size.
    :param show_confidence:         Whether to show confidence scores on the image.
    :param color_mapping:           List of tuples representing the colors for each class.
                                    Default is None, which generates a default color mapping based on the number of class names.
    :param target_bboxes:           Optional[Union[np.ndarray, List[np.ndarray]]], ground truth bounding boxes.
                                    Can either be an np.ndarray of shape (image_i_object_count, 4) when predicting a single image,
                                    or a list of length len(target_bboxes), containing such arrays.
                                    When not None, will plot the predictions and the ground truth bounding boxes side by side (i.e 2 images stitched as one)
    :param target_class_ids:        Optional[Union[np.ndarray, List[np.ndarray]]], ground truth target class indices. Can either be an np.ndarray of shape
                                    (image_i_object_count) when predicting a single image, or a list of length len(target_bboxes), containing such arrays.
    :param target_bboxes_format:    Optional[str], bounding box format of target_bboxes, one of
                                    ['xyxy','xywh', 'yxyx' 'cxcywh' 'normalized_xyxy' 'normalized_xywh', 'normalized_yxyx', 'normalized_cxcywh'].
                                    Will raise an error if not None and target_bboxes is None.
    :param class_names:             List of class names to show. By default, is None which shows all classes using during training.
    """
    if output_folder:
        os.makedirs(output_folder, exist_ok=True)

    target_bboxes, target_class_ids = self._check_target_args(target_bboxes, target_bboxes_format, target_class_ids)

    for i, (prediction, target_bbox, target_class_id) in enumerate(zip(self._images_prediction_lst, target_bboxes, target_class_ids)):
        image_output_path = os.path.join(output_folder, f"pred_{i}.jpg")
        prediction.save(
            output_path=image_output_path,
            box_thickness=box_thickness,
            show_confidence=show_confidence,
            color_mapping=color_mapping,
            class_names=class_names,
        )

show(box_thickness=None, show_confidence=True, color_mapping=None, target_bboxes=None, target_bboxes_format=None, target_class_ids=None, class_names=None)

Display the predicted bboxes on the images.

Parameters:

Name Type Description Default
box_thickness Optional[int]

(Optional) Thickness of bounding boxes. If None, will adapt to the box size.

None
show_confidence bool

Whether to show confidence scores on the image.

True
color_mapping Optional[List[Tuple[int, int, int]]]

List of tuples representing the colors for each class. Default is None, which generates a default color mapping based on the number of class names.

None
target_bboxes Optional[Union[np.ndarray, List[np.ndarray]]]

Optional[Union[np.ndarray, List[np.ndarray]]], ground truth bounding boxes. Can either be an np.ndarray of shape (image_i_object_count, 4) when predicting a single image, or a list of length len(target_bboxes), containing such arrays. When not None, will plot the predictions and the ground truth bounding boxes side by side (i.e 2 images stitched as one)

None
target_class_ids Optional[Union[np.ndarray, List[np.ndarray]]]

Optional[Union[np.ndarray, List[np.ndarray]]], ground truth target class indices. Can either be an np.ndarray of shape (image_i_object_count) when predicting a single image, or a list of length len(target_bboxes), containing such arrays.

None
target_bboxes_format Optional[str]

Optional[str], bounding box format of target_bboxes, one of ['xyxy','xywh', 'yxyx' 'cxcywh' 'normalized_xyxy' 'normalized_xywh', 'normalized_yxyx', 'normalized_cxcywh']. Will raise an error if not None and target_bboxes is None.

None
class_names Optional[List[str]]

List of class names to show. By default, is None which shows all classes using during training.

None
Source code in src/super_gradients/training/utils/predict/prediction_results.py
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
def show(
    self,
    box_thickness: Optional[int] = None,
    show_confidence: bool = True,
    color_mapping: Optional[List[Tuple[int, int, int]]] = None,
    target_bboxes: Optional[Union[np.ndarray, List[np.ndarray]]] = None,
    target_bboxes_format: Optional[str] = None,
    target_class_ids: Optional[Union[np.ndarray, List[np.ndarray]]] = None,
    class_names: Optional[List[str]] = None,
) -> None:
    """Display the predicted bboxes on the images.

    :param box_thickness:           (Optional) Thickness of bounding boxes. If None, will adapt to the box size.
    :param show_confidence:         Whether to show confidence scores on the image.
    :param color_mapping:           List of tuples representing the colors for each class.
                                    Default is None, which generates a default color mapping based on the number of class names.
    :param target_bboxes:           Optional[Union[np.ndarray, List[np.ndarray]]], ground truth bounding boxes.
                                    Can either be an np.ndarray of shape (image_i_object_count, 4) when predicting a single image,
                                    or a list of length len(target_bboxes), containing such arrays.
                                    When not None, will plot the predictions and the ground truth bounding boxes side by side (i.e 2 images stitched as one)
    :param target_class_ids:        Optional[Union[np.ndarray, List[np.ndarray]]], ground truth target class indices. Can either be an np.ndarray of shape
                                    (image_i_object_count) when predicting a single image, or a list of length len(target_bboxes), containing such arrays.
    :param target_bboxes_format:    Optional[str], bounding box format of target_bboxes, one of
                                    ['xyxy','xywh', 'yxyx' 'cxcywh' 'normalized_xyxy' 'normalized_xywh', 'normalized_yxyx', 'normalized_cxcywh'].
                                    Will raise an error if not None and target_bboxes is None.
    :param class_names:             List of class names to show. By default, is None which shows all classes using during training.
    """
    target_bboxes, target_class_ids = self._check_target_args(target_bboxes, target_bboxes_format, target_class_ids)

    for prediction, target_bbox, target_class_id in zip(self._images_prediction_lst, target_bboxes, target_class_ids):
        prediction.show(
            box_thickness=box_thickness,
            show_confidence=show_confidence,
            color_mapping=color_mapping,
            target_bboxes=target_bbox,
            target_bboxes_format=target_bboxes_format,
            target_class_ids=target_class_id,
            class_names=class_names,
        )

ImagesPredictions dataclass

Bases: ABC

Object wrapping the list of image predictions.

:attr _images_prediction_lst: List of results of the run

Source code in src/super_gradients/training/utils/predict/prediction_results.py
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
@dataclass
class ImagesPredictions(ABC):
    """Object wrapping the list of image predictions.

    :attr _images_prediction_lst: List of results of the run
    """

    _images_prediction_lst: List[ImagePrediction]

    def __len__(self) -> int:
        return len(self._images_prediction_lst)

    def __getitem__(self, index: int) -> ImagePrediction:
        return self._images_prediction_lst[index]

    def __iter__(self) -> Iterator[ImagePrediction]:
        return iter(self._images_prediction_lst)

    @abstractmethod
    def show(self, *args, **kwargs) -> None:
        """Display the predictions on the images."""
        pass

    @abstractmethod
    def save(self, *args, **kwargs) -> None:
        """Save the predictions on the images."""
        pass

save(*args, **kwargs) abstractmethod

Save the predictions on the images.

Source code in src/super_gradients/training/utils/predict/prediction_results.py
378
379
380
381
@abstractmethod
def save(self, *args, **kwargs) -> None:
    """Save the predictions on the images."""
    pass

show(*args, **kwargs) abstractmethod

Display the predictions on the images.

Source code in src/super_gradients/training/utils/predict/prediction_results.py
373
374
375
376
@abstractmethod
def show(self, *args, **kwargs) -> None:
    """Display the predictions on the images."""
    pass

ImagesSegmentationPrediction dataclass

Bases: ImagesPredictions

Object wrapping the list of image segmentation predictions.

:attr _images_prediction_lst: List of the predictions results

Source code in src/super_gradients/training/utils/predict/prediction_results.py
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
@dataclass
class ImagesSegmentationPrediction(ImagesPredictions):
    """Object wrapping the list of image segmentation predictions.

    :attr _images_prediction_lst:  List of the predictions results
    """

    _images_prediction_lst: List[ImageSegmentationPrediction]

    def show(self, color_mapping: Optional[List[Tuple[int, int, int]]] = None) -> None:
        """Display the predicted segmentation on the images.

        :param color_mapping:   List of tuples representing the colors for each class.
                                Default is None, which generates a default color mapping based on the number of class names.
        """
        for prediction in self._images_prediction_lst:
            prediction.show(color_mapping=color_mapping)

    def save(self, output_folder: str, color_mapping: Optional[List[Tuple[int, int, int]]] = None) -> None:
        """Save the predicted bboxes on the images.

        :param output_folder:     Folder path, where the images will be saved.
        :param color_mapping:   List of tuples representing the colors for each class.
                                Default is None, which generates a default color mapping based on the number of class names.
        """
        if output_folder:
            os.makedirs(output_folder, exist_ok=True)

        for i, prediction in enumerate(self._images_prediction_lst):
            image_output_path = os.path.join(output_folder, f"pred_{i}.jpg")
            prediction.save(output_path=image_output_path, color_mapping=color_mapping)

save(output_folder, color_mapping=None)

Save the predicted bboxes on the images.

Parameters:

Name Type Description Default
output_folder str

Folder path, where the images will be saved.

required
color_mapping Optional[List[Tuple[int, int, int]]]

List of tuples representing the colors for each class. Default is None, which generates a default color mapping based on the number of class names.

None
Source code in src/super_gradients/training/utils/predict/prediction_results.py
655
656
657
658
659
660
661
662
663
664
665
666
667
def save(self, output_folder: str, color_mapping: Optional[List[Tuple[int, int, int]]] = None) -> None:
    """Save the predicted bboxes on the images.

    :param output_folder:     Folder path, where the images will be saved.
    :param color_mapping:   List of tuples representing the colors for each class.
                            Default is None, which generates a default color mapping based on the number of class names.
    """
    if output_folder:
        os.makedirs(output_folder, exist_ok=True)

    for i, prediction in enumerate(self._images_prediction_lst):
        image_output_path = os.path.join(output_folder, f"pred_{i}.jpg")
        prediction.save(output_path=image_output_path, color_mapping=color_mapping)

show(color_mapping=None)

Display the predicted segmentation on the images.

Parameters:

Name Type Description Default
color_mapping Optional[List[Tuple[int, int, int]]]

List of tuples representing the colors for each class. Default is None, which generates a default color mapping based on the number of class names.

None
Source code in src/super_gradients/training/utils/predict/prediction_results.py
646
647
648
649
650
651
652
653
def show(self, color_mapping: Optional[List[Tuple[int, int, int]]] = None) -> None:
    """Display the predicted segmentation on the images.

    :param color_mapping:   List of tuples representing the colors for each class.
                            Default is None, which generates a default color mapping based on the number of class names.
    """
    for prediction in self._images_prediction_lst:
        prediction.show(color_mapping=color_mapping)

VideoDetectionPrediction dataclass

Bases: VideoPredictions

Object wrapping the list of image detection predictions as a Video.

:attr _images_prediction_gen: Iterable object of the predictions results :att fps: Frames per second of the video

Source code in src/super_gradients/training/utils/predict/prediction_results.py
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
@dataclass
class VideoDetectionPrediction(VideoPredictions):
    """Object wrapping the list of image detection predictions as a Video.

    :attr _images_prediction_gen:   Iterable object of the predictions results
    :att fps:                       Frames per second of the video
    """

    _images_prediction_gen: Iterator[ImagePrediction]
    fps: int
    n_frames: int

    def draw(
        self,
        box_thickness: Optional[int] = None,
        show_confidence: bool = True,
        color_mapping: Optional[List[Tuple[int, int, int]]] = None,
        class_names: Optional[List[str]] = None,
    ) -> Iterator[np.ndarray]:
        """Draw the predicted bboxes on the images.

        :param box_thickness:   (Optional) Thickness of bounding boxes. If None, will adapt to the box size.
        :param show_confidence: Whether to show confidence scores on the image.
        :param color_mapping:   List of tuples representing the colors for each class.
                                Default is None, which generates a default color mapping based on the number of class names.
        :param class_names:     List of class names to show. By default, is None which shows all classes using during training.
        :return:                Iterable object of images with predicted bboxes. Note that this does not modify the original image.
        """

        for result in tqdm(self._images_prediction_gen, total=self.n_frames, desc="Processing Video"):
            yield result.draw(
                box_thickness=box_thickness,
                show_confidence=show_confidence,
                color_mapping=color_mapping,
                class_names=class_names,
            )

    def show(
        self,
        box_thickness: Optional[int] = None,
        show_confidence: bool = True,
        color_mapping: Optional[List[Tuple[int, int, int]]] = None,
        class_names: Optional[List[str]] = None,
    ) -> None:
        """Display the predicted bboxes on the images.

        :param box_thickness:   (Optional) Thickness of bounding boxes. If None, will adapt to the box size.
        :param show_confidence: Whether to show confidence scores on the image.
        :param color_mapping:   List of tuples representing the colors for each class.
                                Default is None, which generates a default color mapping based on the number of class names.
        :param class_names:     List of class names to show. By default, is None which shows all classes using during training.
        """
        frames = self.draw(box_thickness=box_thickness, show_confidence=show_confidence, color_mapping=color_mapping, class_names=class_names)
        show_video_from_frames(window_name="Detection", frames=frames, fps=self.fps)

    def save(
        self,
        output_path: str,
        box_thickness: Optional[int] = None,
        show_confidence: bool = True,
        color_mapping: Optional[List[Tuple[int, int, int]]] = None,
        class_names: Optional[List[str]] = None,
    ) -> None:
        """Save the predicted bboxes on the images.

        :param output_path:     Path to the output video file.
        :param box_thickness:   (Optional) Thickness of bounding boxes. If None, will adapt to the box size.
        :param show_confidence: Whether to show confidence scores on the image.
        :param color_mapping:   List of tuples representing the colors for each class.
                                Default is None, which generates a default color mapping based on the number of class names.
        :param class_names:     List of class names to show. By default, is None which shows all classes using during training.
        """
        frames = self.draw(box_thickness=box_thickness, show_confidence=show_confidence, color_mapping=color_mapping, class_names=class_names)
        save_video(output_path=output_path, frames=frames, fps=self.fps)

draw(box_thickness=None, show_confidence=True, color_mapping=None, class_names=None)

Draw the predicted bboxes on the images.

Parameters:

Name Type Description Default
box_thickness Optional[int]

(Optional) Thickness of bounding boxes. If None, will adapt to the box size.

None
show_confidence bool

Whether to show confidence scores on the image.

True
color_mapping Optional[List[Tuple[int, int, int]]]

List of tuples representing the colors for each class. Default is None, which generates a default color mapping based on the number of class names.

None
class_names Optional[List[str]]

List of class names to show. By default, is None which shows all classes using during training.

None

Returns:

Type Description
Iterator[np.ndarray]

Iterable object of images with predicted bboxes. Note that this does not modify the original image.

Source code in src/super_gradients/training/utils/predict/prediction_results.py
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
def draw(
    self,
    box_thickness: Optional[int] = None,
    show_confidence: bool = True,
    color_mapping: Optional[List[Tuple[int, int, int]]] = None,
    class_names: Optional[List[str]] = None,
) -> Iterator[np.ndarray]:
    """Draw the predicted bboxes on the images.

    :param box_thickness:   (Optional) Thickness of bounding boxes. If None, will adapt to the box size.
    :param show_confidence: Whether to show confidence scores on the image.
    :param color_mapping:   List of tuples representing the colors for each class.
                            Default is None, which generates a default color mapping based on the number of class names.
    :param class_names:     List of class names to show. By default, is None which shows all classes using during training.
    :return:                Iterable object of images with predicted bboxes. Note that this does not modify the original image.
    """

    for result in tqdm(self._images_prediction_gen, total=self.n_frames, desc="Processing Video"):
        yield result.draw(
            box_thickness=box_thickness,
            show_confidence=show_confidence,
            color_mapping=color_mapping,
            class_names=class_names,
        )

save(output_path, box_thickness=None, show_confidence=True, color_mapping=None, class_names=None)

Save the predicted bboxes on the images.

Parameters:

Name Type Description Default
output_path str

Path to the output video file.

required
box_thickness Optional[int]

(Optional) Thickness of bounding boxes. If None, will adapt to the box size.

None
show_confidence bool

Whether to show confidence scores on the image.

True
color_mapping Optional[List[Tuple[int, int, int]]]

List of tuples representing the colors for each class. Default is None, which generates a default color mapping based on the number of class names.

None
class_names Optional[List[str]]

List of class names to show. By default, is None which shows all classes using during training.

None
Source code in src/super_gradients/training/utils/predict/prediction_results.py
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
def save(
    self,
    output_path: str,
    box_thickness: Optional[int] = None,
    show_confidence: bool = True,
    color_mapping: Optional[List[Tuple[int, int, int]]] = None,
    class_names: Optional[List[str]] = None,
) -> None:
    """Save the predicted bboxes on the images.

    :param output_path:     Path to the output video file.
    :param box_thickness:   (Optional) Thickness of bounding boxes. If None, will adapt to the box size.
    :param show_confidence: Whether to show confidence scores on the image.
    :param color_mapping:   List of tuples representing the colors for each class.
                            Default is None, which generates a default color mapping based on the number of class names.
    :param class_names:     List of class names to show. By default, is None which shows all classes using during training.
    """
    frames = self.draw(box_thickness=box_thickness, show_confidence=show_confidence, color_mapping=color_mapping, class_names=class_names)
    save_video(output_path=output_path, frames=frames, fps=self.fps)

show(box_thickness=None, show_confidence=True, color_mapping=None, class_names=None)

Display the predicted bboxes on the images.

Parameters:

Name Type Description Default
box_thickness Optional[int]

(Optional) Thickness of bounding boxes. If None, will adapt to the box size.

None
show_confidence bool

Whether to show confidence scores on the image.

True
color_mapping Optional[List[Tuple[int, int, int]]]

List of tuples representing the colors for each class. Default is None, which generates a default color mapping based on the number of class names.

None
class_names Optional[List[str]]

List of class names to show. By default, is None which shows all classes using during training.

None
Source code in src/super_gradients/training/utils/predict/prediction_results.py
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
def show(
    self,
    box_thickness: Optional[int] = None,
    show_confidence: bool = True,
    color_mapping: Optional[List[Tuple[int, int, int]]] = None,
    class_names: Optional[List[str]] = None,
) -> None:
    """Display the predicted bboxes on the images.

    :param box_thickness:   (Optional) Thickness of bounding boxes. If None, will adapt to the box size.
    :param show_confidence: Whether to show confidence scores on the image.
    :param color_mapping:   List of tuples representing the colors for each class.
                            Default is None, which generates a default color mapping based on the number of class names.
    :param class_names:     List of class names to show. By default, is None which shows all classes using during training.
    """
    frames = self.draw(box_thickness=box_thickness, show_confidence=show_confidence, color_mapping=color_mapping, class_names=class_names)
    show_video_from_frames(window_name="Detection", frames=frames, fps=self.fps)

VideoPredictions dataclass

Bases: ABC

Object wrapping the list of image predictions as a Video.

:attr _images_prediction_gen: List of results of the run :att fps: Frames per second of the video

Source code in src/super_gradients/training/utils/predict/prediction_results.py
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
@dataclass
class VideoPredictions(ABC):
    """Object wrapping the list of image predictions as a Video.

    :attr _images_prediction_gen:   List of results of the run
    :att fps:                       Frames per second of the video
    """

    _images_prediction_gen: Iterator[ImagePrediction]
    fps: float
    n_frames: int

    @abstractmethod
    def show(self, *args, **kwargs) -> None:
        """Display the predictions on the video."""
        pass

    @abstractmethod
    def save(self, *args, **kwargs) -> None:
        """Save the predictions on the video."""
        pass

save(*args, **kwargs) abstractmethod

Save the predictions on the video.

Source code in src/super_gradients/training/utils/predict/prediction_results.py
401
402
403
404
@abstractmethod
def save(self, *args, **kwargs) -> None:
    """Save the predictions on the video."""
    pass

show(*args, **kwargs) abstractmethod

Display the predictions on the video.

Source code in src/super_gradients/training/utils/predict/prediction_results.py
396
397
398
399
@abstractmethod
def show(self, *args, **kwargs) -> None:
    """Display the predictions on the video."""
    pass

VideoSegmentationPrediction dataclass

Bases: VideoPredictions

Object wrapping the list of image segmentation predictions as a Video.

:attr _images_prediction_lst: List of the predictions results :att fps: Frames per second of the video

Source code in src/super_gradients/training/utils/predict/prediction_results.py
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
@dataclass
class VideoSegmentationPrediction(VideoPredictions):
    """Object wrapping the list of image segmentation predictions as a Video.

    :attr _images_prediction_lst:   List of the predictions results
    :att fps:                       Frames per second of the video
    """

    _images_prediction_lst: List[ImageSegmentationPrediction]
    fps: int

    def draw(self, alpha: float = 0.6, color_mapping: Optional[List[Tuple[int, int, int]]] = None, class_names: Optional[List[str]] = None) -> List[np.ndarray]:
        """Draw the predicted segmentation on the images.

        :param alpha:           Float number between [0,1] denoting the transparency of the masks (0 means full transparency, 1 means opacity).
        :param color_mapping:   List of tuples representing the colors for each class.
                                Default is None, which generates a default color mapping based on the number of class names.
        :param class_names:     List of class names to predict (segmentation classes).
        :return:                List of images with predicted segmentation. Note that this does not modify the original image.
        """
        frames_with_segmentation = [result.draw(alpha=alpha, color_mapping=color_mapping, class_names=class_names) for result in self._images_prediction_lst]
        return frames_with_segmentation

    def show(self, alpha: float = 0.6, color_mapping: Optional[List[Tuple[int, int, int]]] = None, class_names: Optional[List[str]] = None) -> None:
        """Display the predicted segmentation on the images.

        :param alpha:           Float number between [0,1] denoting the transparency of the masks (0 means full transparency, 1 means opacity).
        :param color_mapping:   List of tuples representing the colors for each class.
                                Default is None, which generates a default color mapping based on the number of class names.
        :param class_names:     List of class names to predict (segmentation classes).
        """
        frames = self.draw(alpha=alpha, color_mapping=color_mapping, class_names=class_names)
        show_video_from_frames(window_name="Segmentation", frames=frames, fps=self.fps)

    def save(
        self, output_path: str, alpha: float = 0.6, color_mapping: Optional[List[Tuple[int, int, int]]] = None, class_names: Optional[List[str]] = None
    ) -> None:
        """Save the predicted bboxes on the images.

        :param output_path:     Path to the output video file.
        :param alpha:           Float number between [0,1] denoting the transparency of the masks (0 means full transparency, 1 means opacity).
        :param color_mapping:   List of tuples representing the colors for each class.
                                Default is None, which generates a default color mapping based on the number of class names.
        :param class_names:     List of class names to predict (segmentation classes).
        """
        frames = self.draw(alpha=alpha, color_mapping=color_mapping, class_names=class_names)
        save_video(output_path=output_path, frames=frames, fps=self.fps)

draw(alpha=0.6, color_mapping=None, class_names=None)

Draw the predicted segmentation on the images.

Parameters:

Name Type Description Default
alpha float

Float number between [0,1] denoting the transparency of the masks (0 means full transparency, 1 means opacity).

0.6
color_mapping Optional[List[Tuple[int, int, int]]]

List of tuples representing the colors for each class. Default is None, which generates a default color mapping based on the number of class names.

None
class_names Optional[List[str]]

List of class names to predict (segmentation classes).

None

Returns:

Type Description
List[np.ndarray]

List of images with predicted segmentation. Note that this does not modify the original image.

Source code in src/super_gradients/training/utils/predict/prediction_results.py
681
682
683
684
685
686
687
688
689
690
691
def draw(self, alpha: float = 0.6, color_mapping: Optional[List[Tuple[int, int, int]]] = None, class_names: Optional[List[str]] = None) -> List[np.ndarray]:
    """Draw the predicted segmentation on the images.

    :param alpha:           Float number between [0,1] denoting the transparency of the masks (0 means full transparency, 1 means opacity).
    :param color_mapping:   List of tuples representing the colors for each class.
                            Default is None, which generates a default color mapping based on the number of class names.
    :param class_names:     List of class names to predict (segmentation classes).
    :return:                List of images with predicted segmentation. Note that this does not modify the original image.
    """
    frames_with_segmentation = [result.draw(alpha=alpha, color_mapping=color_mapping, class_names=class_names) for result in self._images_prediction_lst]
    return frames_with_segmentation

save(output_path, alpha=0.6, color_mapping=None, class_names=None)

Save the predicted bboxes on the images.

Parameters:

Name Type Description Default
output_path str

Path to the output video file.

required
alpha float

Float number between [0,1] denoting the transparency of the masks (0 means full transparency, 1 means opacity).

0.6
color_mapping Optional[List[Tuple[int, int, int]]]

List of tuples representing the colors for each class. Default is None, which generates a default color mapping based on the number of class names.

None
class_names Optional[List[str]]

List of class names to predict (segmentation classes).

None
Source code in src/super_gradients/training/utils/predict/prediction_results.py
704
705
706
707
708
709
710
711
712
713
714
715
716
def save(
    self, output_path: str, alpha: float = 0.6, color_mapping: Optional[List[Tuple[int, int, int]]] = None, class_names: Optional[List[str]] = None
) -> None:
    """Save the predicted bboxes on the images.

    :param output_path:     Path to the output video file.
    :param alpha:           Float number between [0,1] denoting the transparency of the masks (0 means full transparency, 1 means opacity).
    :param color_mapping:   List of tuples representing the colors for each class.
                            Default is None, which generates a default color mapping based on the number of class names.
    :param class_names:     List of class names to predict (segmentation classes).
    """
    frames = self.draw(alpha=alpha, color_mapping=color_mapping, class_names=class_names)
    save_video(output_path=output_path, frames=frames, fps=self.fps)

show(alpha=0.6, color_mapping=None, class_names=None)

Display the predicted segmentation on the images.

Parameters:

Name Type Description Default
alpha float

Float number between [0,1] denoting the transparency of the masks (0 means full transparency, 1 means opacity).

0.6
color_mapping Optional[List[Tuple[int, int, int]]]

List of tuples representing the colors for each class. Default is None, which generates a default color mapping based on the number of class names.

None
class_names Optional[List[str]]

List of class names to predict (segmentation classes).

None
Source code in src/super_gradients/training/utils/predict/prediction_results.py
693
694
695
696
697
698
699
700
701
702
def show(self, alpha: float = 0.6, color_mapping: Optional[List[Tuple[int, int, int]]] = None, class_names: Optional[List[str]] = None) -> None:
    """Display the predicted segmentation on the images.

    :param alpha:           Float number between [0,1] denoting the transparency of the masks (0 means full transparency, 1 means opacity).
    :param color_mapping:   List of tuples representing the colors for each class.
                            Default is None, which generates a default color mapping based on the number of class names.
    :param class_names:     List of class names to predict (segmentation classes).
    """
    frames = self.draw(alpha=alpha, color_mapping=color_mapping, class_names=class_names)
    show_video_from_frames(window_name="Segmentation", frames=frames, fps=self.fps)

ClassificationPrediction dataclass

Bases: Prediction

Represents a Classification prediction

Source code in src/super_gradients/training/utils/predict/predictions.py
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
@dataclass
class ClassificationPrediction(Prediction):
    """Represents a Classification prediction"""

    confidence: float
    label: int
    image_shape: Tuple[int, int]

    def __init__(self, confidence: float, label: int, image_shape: Optional[Tuple[int, int]]):
        """

        :param confidence:  Confidence scores for each bounding box
        :param label:      Labels for each bounding box.
        :param image_shape: Shape of the image the prediction is made on, (H, W).
        """
        self._validate_input(confidence, label)

        self.confidence = confidence
        self.label = label
        self.image_shape = image_shape

    def _validate_input(self, confidence: float, label: int) -> None:
        if not isinstance(confidence, float):
            raise ValueError(f"Argument confidence must be a float, not {type(confidence)}")
        if not isinstance(label, int):
            raise ValueError(f"Argument labels must be an integer, not {type(label)}")

    def __len__(self):
        return len(self.labels)

__init__(confidence, label, image_shape)

Parameters:

Name Type Description Default
confidence float

Confidence scores for each bounding box

required
label int

Labels for each bounding box.

required
image_shape Optional[Tuple[int, int]]

Shape of the image the prediction is made on, (H, W).

required
Source code in src/super_gradients/training/utils/predict/predictions.py
136
137
138
139
140
141
142
143
144
145
146
147
def __init__(self, confidence: float, label: int, image_shape: Optional[Tuple[int, int]]):
    """

    :param confidence:  Confidence scores for each bounding box
    :param label:      Labels for each bounding box.
    :param image_shape: Shape of the image the prediction is made on, (H, W).
    """
    self._validate_input(confidence, label)

    self.confidence = confidence
    self.label = label
    self.image_shape = image_shape

DetectionPrediction dataclass

Bases: Prediction

Represents a detection prediction, with bboxes represented in xyxy format.

Source code in src/super_gradients/training/utils/predict/predictions.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
@dataclass
class DetectionPrediction(Prediction):
    """Represents a detection prediction, with bboxes represented in xyxy format."""

    bboxes_xyxy: np.ndarray
    confidence: np.ndarray
    labels: np.ndarray

    def __init__(self, bboxes: np.ndarray, bbox_format: str, confidence: np.ndarray, labels: np.ndarray, image_shape: Tuple[int, int]):
        """
        :param bboxes:      BBoxes in the format specified by bbox_format
        :param bbox_format: BBoxes format that can be a string ("xyxy", "cxywh", ...)
        :param confidence:  Confidence scores for each bounding box
        :param labels:      Labels for each bounding box.
        :param image_shape: Shape of the image the prediction is made on, (H, W). This is used to convert bboxes to xyxy format

        :param target_bboxes: np.ndarray, ground truth bounding boxes as np.ndarray of shape (image_i_object_count, 4)
         When not None, will plot the predictions and the ground truth bounding boxes side by side (i.e 2 images stitched as one)

        :param target_labels: np.ndarray, ground truth target class indices as an np.ndarray of shape (image_i_object_count).

        :param target_bbox_format: str, bounding box format of target_bboxes, one of ['xyxy','xywh',
        'yxyx' 'cxcywh' 'normalized_xyxy' 'normalized_xywh', 'normalized_yxyx', 'normalized_cxcywh']. Will raise an
        error if not None and target_bboxes is None.
        """
        self._validate_input(bboxes, confidence, labels)

        factory = BBoxFormatFactory()
        bboxes_xyxy = convert_bboxes(
            bboxes=bboxes,
            image_shape=image_shape,
            source_format=factory.get(bbox_format),
            target_format=factory.get("xyxy"),
            inplace=False,
        )

        self.bboxes_xyxy = bboxes_xyxy
        self.confidence = confidence
        self.labels = labels
        self.image_shape = image_shape

    def _validate_input(self, bboxes: np.ndarray, confidence: np.ndarray, labels: np.ndarray) -> None:
        n_bboxes, n_confidences, n_labels = bboxes.shape[0], confidence.shape[0], labels.shape[0]
        if n_bboxes != n_confidences != n_labels:
            raise ValueError(
                f"The number of bounding boxes ({n_bboxes}) does not match the number of confidence scores ({n_confidences}) and labels ({n_labels})."
            )

    def __len__(self):
        return len(self.bboxes_xyxy)

__init__(bboxes, bbox_format, confidence, labels, image_shape)

Parameters:

Name Type Description Default
bboxes np.ndarray

BBoxes in the format specified by bbox_format

required
bbox_format str

BBoxes format that can be a string ("xyxy", "cxywh", ...)

required
confidence np.ndarray

Confidence scores for each bounding box

required
labels np.ndarray

Labels for each bounding box.

required
image_shape Tuple[int, int]

Shape of the image the prediction is made on, (H, W). This is used to convert bboxes to xyxy format

required
target_bboxes

np.ndarray, ground truth bounding boxes as np.ndarray of shape (image_i_object_count, 4) When not None, will plot the predictions and the ground truth bounding boxes side by side (i.e 2 images stitched as one)

required
target_labels

np.ndarray, ground truth target class indices as an np.ndarray of shape (image_i_object_count).

required
target_bbox_format

str, bounding box format of target_bboxes, one of ['xyxy','xywh', 'yxyx' 'cxcywh' 'normalized_xyxy' 'normalized_xywh', 'normalized_yxyx', 'normalized_cxcywh']. Will raise an error if not None and target_bboxes is None.

required
Source code in src/super_gradients/training/utils/predict/predictions.py
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
def __init__(self, bboxes: np.ndarray, bbox_format: str, confidence: np.ndarray, labels: np.ndarray, image_shape: Tuple[int, int]):
    """
    :param bboxes:      BBoxes in the format specified by bbox_format
    :param bbox_format: BBoxes format that can be a string ("xyxy", "cxywh", ...)
    :param confidence:  Confidence scores for each bounding box
    :param labels:      Labels for each bounding box.
    :param image_shape: Shape of the image the prediction is made on, (H, W). This is used to convert bboxes to xyxy format

    :param target_bboxes: np.ndarray, ground truth bounding boxes as np.ndarray of shape (image_i_object_count, 4)
     When not None, will plot the predictions and the ground truth bounding boxes side by side (i.e 2 images stitched as one)

    :param target_labels: np.ndarray, ground truth target class indices as an np.ndarray of shape (image_i_object_count).

    :param target_bbox_format: str, bounding box format of target_bboxes, one of ['xyxy','xywh',
    'yxyx' 'cxcywh' 'normalized_xyxy' 'normalized_xywh', 'normalized_yxyx', 'normalized_cxcywh']. Will raise an
    error if not None and target_bboxes is None.
    """
    self._validate_input(bboxes, confidence, labels)

    factory = BBoxFormatFactory()
    bboxes_xyxy = convert_bboxes(
        bboxes=bboxes,
        image_shape=image_shape,
        source_format=factory.get(bbox_format),
        target_format=factory.get("xyxy"),
        inplace=False,
    )

    self.bboxes_xyxy = bboxes_xyxy
    self.confidence = confidence
    self.labels = labels
    self.image_shape = image_shape

PoseEstimationPrediction dataclass

Bases: Prediction

Represents a pose estimation prediction.

Parameters:

Name Type Description Default
poses np.ndarray

Numpy array of [Num Poses, Num Joints, 2] shape

required
scores np.ndarray

Numpy array of [Num Poses] shape

required
boxes

Numpy array of [Num Poses, 4] shape which represents the bounding boxes of each pose in xyxy format

required
Source code in src/super_gradients/training/utils/predict/predictions.py
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
@dataclass
class PoseEstimationPrediction(Prediction):
    """Represents a pose estimation prediction.

    :param poses:  Numpy array of [Num Poses, Num Joints, 2] shape
    :param scores: Numpy array of [Num Poses] shape
    :param boxes:  Numpy array of [Num Poses, 4] shape which represents the bounding boxes of each pose in xyxy format
    """

    poses: np.ndarray
    scores: np.ndarray
    bboxes_xyxy: Optional[np.ndarray]
    edge_links: np.ndarray
    edge_colors: np.ndarray
    keypoint_colors: np.ndarray
    image_shape: Tuple[int, int]

    def __init__(
        self,
        poses: np.ndarray,
        scores: np.ndarray,
        bboxes_xyxy: Optional[np.ndarray],
        edge_links: np.ndarray,
        edge_colors: np.ndarray,
        keypoint_colors: np.ndarray,
        image_shape: Tuple[int, int],
    ):
        """
        :param poses:       Predicted poses as a numpy array of shape [Num Poses, Num Joints, 2]
        :param scores:      Confidence scores for each pose [Num Poses]
        :param bboxes_xyxy:      Bounding boxes of each pose in xyxy format [Num Poses, 4]
        :param image_shape: Shape of the image the prediction is made on, (H, W).
        """
        self._validate_input(poses, scores, bboxes_xyxy, edge_links, edge_colors, keypoint_colors)
        self.poses = poses
        self.scores = scores
        self.bboxes_xyxy = bboxes_xyxy
        self.edge_links = edge_links
        self.edge_colors = edge_colors
        self.image_shape = image_shape
        self.keypoint_colors = keypoint_colors

    def _validate_input(self, poses: np.ndarray, scores: np.ndarray, bboxes: Optional[np.ndarray], edge_links, edge_colors, keypoint_colors) -> None:
        if not isinstance(poses, np.ndarray):
            raise ValueError(f"Argument poses must be a numpy array, not {type(poses)}")
        if not isinstance(scores, np.ndarray):
            raise ValueError(f"Argument scores must be a numpy array, not {type(scores)}")
        if bboxes is not None and not isinstance(bboxes, np.ndarray):
            raise ValueError(f"Argument bboxes must be a numpy array, not {type(bboxes)}")
        if not isinstance(keypoint_colors, np.ndarray):
            raise ValueError(f"Argument keypoint_colors must be a numpy array, not {type(keypoint_colors)}")
        if len(poses) != len(scores) != len(keypoint_colors):
            raise ValueError(f"The number of poses ({len(poses)}) does not match the number of scores ({len(scores)}).")
        if len(edge_links) != len(edge_colors):
            raise ValueError(f"The number of joint links ({len(edge_links)}) does not match the number of joint colors ({len(edge_colors)}).")

    def __len__(self):
        return len(self.poses)

__init__(poses, scores, bboxes_xyxy, edge_links, edge_colors, keypoint_colors, image_shape)

Parameters:

Name Type Description Default
poses np.ndarray

Predicted poses as a numpy array of shape [Num Poses, Num Joints, 2]

required
scores np.ndarray

Confidence scores for each pose [Num Poses]

required
bboxes_xyxy Optional[np.ndarray]

Bounding boxes of each pose in xyxy format [Num Poses, 4]

required
image_shape Tuple[int, int]

Shape of the image the prediction is made on, (H, W).

required
Source code in src/super_gradients/training/utils/predict/predictions.py
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
def __init__(
    self,
    poses: np.ndarray,
    scores: np.ndarray,
    bboxes_xyxy: Optional[np.ndarray],
    edge_links: np.ndarray,
    edge_colors: np.ndarray,
    keypoint_colors: np.ndarray,
    image_shape: Tuple[int, int],
):
    """
    :param poses:       Predicted poses as a numpy array of shape [Num Poses, Num Joints, 2]
    :param scores:      Confidence scores for each pose [Num Poses]
    :param bboxes_xyxy:      Bounding boxes of each pose in xyxy format [Num Poses, 4]
    :param image_shape: Shape of the image the prediction is made on, (H, W).
    """
    self._validate_input(poses, scores, bboxes_xyxy, edge_links, edge_colors, keypoint_colors)
    self.poses = poses
    self.scores = scores
    self.bboxes_xyxy = bboxes_xyxy
    self.edge_links = edge_links
    self.edge_colors = edge_colors
    self.image_shape = image_shape
    self.keypoint_colors = keypoint_colors

SegmentationPrediction dataclass

Bases: Prediction

Represents a segmentation prediction.

Source code in src/super_gradients/training/utils/predict/predictions.py
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
@dataclass
class SegmentationPrediction(Prediction):
    """Represents a segmentation prediction."""

    segmentation_map: np.ndarray
    segmentation_map_shape: Tuple[int, int]
    image_shape: Tuple[int, int]

    def __init__(self, segmentation_map: np.ndarray, segmentation_map_shape: Tuple[int, int], image_shape: Tuple[int, int]):
        """
        :param segmentation_map: Segmentation map (predication) in the shape specified segmentation_map_shape
        :param segmentation_map_shape: Shape of the prediction (H, W).
        :param image_shape: Shape of the image the prediction is made on, (H, W).
        """
        self._validate_input(segmentation_map_shape, image_shape)

        self.segmentation_map = segmentation_map
        self.segmentation_map_shape = segmentation_map_shape
        self.image_shape = image_shape

    def _validate_input(self, segmentation_map_shape: Tuple[int, int], image_shape: Tuple[int, int]) -> None:
        if segmentation_map_shape[0] != image_shape[0] or segmentation_map_shape[1] != image_shape[1]:
            raise ValueError("The shape of the segmentation map does not match the shape of the input image.")

__init__(segmentation_map, segmentation_map_shape, image_shape)

Parameters:

Name Type Description Default
segmentation_map np.ndarray

Segmentation map (predication) in the shape specified segmentation_map_shape

required
segmentation_map_shape Tuple[int, int]

Shape of the prediction (H, W).

required
image_shape Tuple[int, int]

Shape of the image the prediction is made on, (H, W).

required
Source code in src/super_gradients/training/utils/predict/predictions.py
167
168
169
170
171
172
173
174
175
176
177
def __init__(self, segmentation_map: np.ndarray, segmentation_map_shape: Tuple[int, int], image_shape: Tuple[int, int]):
    """
    :param segmentation_map: Segmentation map (predication) in the shape specified segmentation_map_shape
    :param segmentation_map_shape: Shape of the prediction (H, W).
    :param image_shape: Shape of the image the prediction is made on, (H, W).
    """
    self._validate_input(segmentation_map_shape, image_shape)

    self.segmentation_map = segmentation_map
    self.segmentation_map_shape = segmentation_map_shape
    self.image_shape = image_shape

Quantization utilities

Methods are based on: https://github.com/NVIDIA/TensorRT/blob/51a4297753d3e12d0eed864be52400f429a6a94c/tools/pytorch-quantization/examples/torchvision/classification_flow.py#L385

(Licensed under the Apache License, Version 2.0)

QuantizationCalibrator

Source code in src/super_gradients/training/utils/quantization/calibrator.py
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
class QuantizationCalibrator:
    def __init__(self, torch_hist: bool = True, verbose: bool = True) -> None:
        super().__init__()
        self.verbose = verbose
        self.torch_hist = torch_hist

    def calibrate_model(
        self,
        model: torch.nn.Module,
        calib_data_loader: torch.utils.data.DataLoader,
        method: str = "percentile",
        num_calib_batches: int = 2,
        percentile: float = 99.99,
    ):
        """
        Calibrates torch model with quantized modules.

        :param model:               torch.nn.Module, model to perfrom the calibration on.
        :param calib_data_loader:   torch.utils.data.DataLoader, data loader of the calibration dataset.
                                    Assumes that the first element of the tuple is the input image.
        :param method:              str, One of [percentile, mse, entropy, max].
                                    Statistics method for amax computation of the quantized modules
                                    (Default=percentile).
        :param num_calib_batches:   int, number of batches to collect the statistics from.
        :param percentile:          float, percentile value to use when SgModel,quant_modules_calib_method='percentile'.
                                    Discarded when other methods are used (Default=99.99).

        """

        logging_level = logging.getLogger("absl").getEffectiveLevel()
        if not self.verbose:  # suppress pytorch-quantization spam
            logging.getLogger("absl").setLevel("ERROR")

        acceptable_methods = ["percentile", "mse", "entropy", "max"]
        if method in acceptable_methods:
            with torch.no_grad():
                device = next(model.parameters()).device

                self._collect_stats(model, calib_data_loader, num_batches=num_calib_batches)
                # FOR PERCENTILE WE MUST PASS PERCENTILE VALUE THROUGH KWARGS,
                # SO IT WOULD BE PASSED TO module.load_calib_amax(**kwargs), AND IN OTHER METHODS WE MUST NOT PASS IT.
                if method == "precentile":
                    self._compute_amax(model, method="percentile", percentile=percentile)
                else:
                    self._compute_amax(model, method=method)

                model.to(device)
        else:
            raise ValueError(f"Unsupported quantization calibration method, " f"expected one of: {'.'.join(acceptable_methods)}, however, received: {method}")

        logging.getLogger("absl").setLevel(logging_level)

    def _collect_stats(self, model, data_loader, num_batches):
        """Feed data to the network and collect statistics"""
        local_rank = get_local_rank()
        world_size = get_world_size()

        device = infer_model_device(model)

        # Enable calibrators
        self._enable_calibrators(model)

        # Feed data to the network for collecting stats
        for i, batch in tqdm(enumerate(data_loader), total=num_batches, disable=local_rank > 0):
            if isinstance(batch, (list, tuple)):
                image = batch[0]
            elif torch.is_tensor(batch):
                image = batch
            else:
                raise ValueError("Unsupported batch type")

            if world_size > 1:
                all_batches = [torch.zeros_like(image, device=device) for _ in range(world_size)]
                all_gather(all_batches, image.to(device=device))
            else:
                all_batches = [image]

            for local_image in all_batches:
                model(local_image.to(device=device))
            if i >= num_batches:
                break

        # Disable calibrators
        self._disable_calibrators(model)

    def _disable_calibrators(self, model):
        for name, module in model.named_modules():
            if isinstance(module, quant_nn.TensorQuantizer):
                if module._calibrator is not None:
                    module.disable_calib()
                    module.enable_quant()
                else:
                    module.enable()

    def reset_calibrators(self, model):
        for name, module in model.named_modules():
            if isinstance(module, quant_nn.TensorQuantizer):
                if module._calibrator is not None:
                    module._calibrator.reset()  # release memory

    def _enable_calibrators(self, model):
        for name, module in model.named_modules():
            if isinstance(module, quant_nn.TensorQuantizer):
                if module._calibrator is not None:
                    if isinstance(module._calibrator, calib.HistogramCalibrator):
                        module._calibrator._torch_hist = self.torch_hist  # TensorQuantizer does not expose it as API
                    module.disable_quant()
                    module.enable_calib()
                else:
                    module.disable()

    def _compute_amax(self, model, **kwargs):
        for name, module in model.named_modules():
            if isinstance(module, quant_nn.TensorQuantizer):
                if module._calibrator is not None:
                    if isinstance(module._calibrator, calib.MaxCalibrator):
                        module.load_calib_amax()
                    else:
                        module.load_calib_amax(**kwargs)

                if hasattr(module, "clip"):
                    module.init_learn_amax()

                if self.verbose:
                    print(f"{name:40}: {module}")

calibrate_model(model, calib_data_loader, method='percentile', num_calib_batches=2, percentile=99.99)

Calibrates torch model with quantized modules.

Parameters:

Name Type Description Default
model torch.nn.Module

torch.nn.Module, model to perfrom the calibration on.

required
calib_data_loader torch.utils.data.DataLoader

torch.utils.data.DataLoader, data loader of the calibration dataset. Assumes that the first element of the tuple is the input image.

required
method str

str, One of [percentile, mse, entropy, max]. Statistics method for amax computation of the quantized modules (Default=percentile).

'percentile'
num_calib_batches int

int, number of batches to collect the statistics from.

2
percentile float

float, percentile value to use when SgModel,quant_modules_calib_method='percentile'. Discarded when other methods are used (Default=99.99).

99.99
Source code in src/super_gradients/training/utils/quantization/calibrator.py
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
def calibrate_model(
    self,
    model: torch.nn.Module,
    calib_data_loader: torch.utils.data.DataLoader,
    method: str = "percentile",
    num_calib_batches: int = 2,
    percentile: float = 99.99,
):
    """
    Calibrates torch model with quantized modules.

    :param model:               torch.nn.Module, model to perfrom the calibration on.
    :param calib_data_loader:   torch.utils.data.DataLoader, data loader of the calibration dataset.
                                Assumes that the first element of the tuple is the input image.
    :param method:              str, One of [percentile, mse, entropy, max].
                                Statistics method for amax computation of the quantized modules
                                (Default=percentile).
    :param num_calib_batches:   int, number of batches to collect the statistics from.
    :param percentile:          float, percentile value to use when SgModel,quant_modules_calib_method='percentile'.
                                Discarded when other methods are used (Default=99.99).

    """

    logging_level = logging.getLogger("absl").getEffectiveLevel()
    if not self.verbose:  # suppress pytorch-quantization spam
        logging.getLogger("absl").setLevel("ERROR")

    acceptable_methods = ["percentile", "mse", "entropy", "max"]
    if method in acceptable_methods:
        with torch.no_grad():
            device = next(model.parameters()).device

            self._collect_stats(model, calib_data_loader, num_batches=num_calib_batches)
            # FOR PERCENTILE WE MUST PASS PERCENTILE VALUE THROUGH KWARGS,
            # SO IT WOULD BE PASSED TO module.load_calib_amax(**kwargs), AND IN OTHER METHODS WE MUST NOT PASS IT.
            if method == "precentile":
                self._compute_amax(model, method="percentile", percentile=percentile)
            else:
                self._compute_amax(model, method=method)

            model.to(device)
    else:
        raise ValueError(f"Unsupported quantization calibration method, " f"expected one of: {'.'.join(acceptable_methods)}, however, received: {method}")

    logging.getLogger("absl").setLevel(logging_level)

QuantizedMapping

Bases: nn.Module

This class wraps a float module instance, and defines a mapping from this instance to the corresponding quantized class, with relevant quant descriptors.

Example: self.my_block = QuantizedMapping(float_module=MyBlock(4, n_classes), quantized_target_class=MyQuantizedBlock)

Source code in src/super_gradients/training/utils/quantization/core.py
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
class QuantizedMapping(nn.Module):
    """
    This class wraps a float module instance, and defines a mapping from this instance to the corresponding quantized
    class, with relevant quant descriptors.

    Example:
        self.my_block = QuantizedMapping(float_module=MyBlock(4, n_classes), quantized_target_class=MyQuantizedBlock)
    """

    def __init__(
        self,
        *,
        float_module: nn.Module,
        quantized_target_class: Union[Type[QuantMixin], Type[QuantInputMixin], Type[SGQuantMixin]],
        action=QuantizedMetadata.ReplacementAction.REPLACE,
        input_quant_descriptor: QuantDescriptor = None,
        weights_quant_descriptor: QuantDescriptor = None,
    ) -> None:
        super().__init__()
        self.float_module = float_module
        self.quantized_target_class = quantized_target_class
        self.action = action
        self.input_quant_descriptor = input_quant_descriptor
        self.weights_quant_descriptor = weights_quant_descriptor
        self.forward = float_module.forward

QuantizedMetadata dataclass

This dataclass is responsible for holding the information regarding float->quantized module relation. It can be both layer-grained and module-grained, e.g., module.backbone.conv1 -> QuantConv2d, nn.Linear -> QuantLinear, etc...

Parameters:

Name Type Description Default
float_source Union[str, Type]

Name of a specific layer (e.g., module.backbone.conv1), or a specific type (e.g., Conv2d) that will be later quantized

required
quantized_target_class Optional[Union[Type[QuantMixin], Type[QuantInputMixin], Type[SGQuantMixin]]]

Quantized type that the source will be converted to

required
action ReplacementAction

how to resolve the conversion, we either: - SKIP: skip it, - UNWRAP: unwrap the instance and work with the wrapped one (i.e., we wrap with a mapper), - REPLACE: replace source with an instance of the quantized type - REPLACE_AND_RECURE: replace source with an instance of the quantized type, then try to recursively quantize the child modules of that type - RECURE_AND_REPLACE: recursively quantize the child modules, then replace source with an instance of the quantized type

required
input_quant_descriptor QuantDescriptor

Quantization descriptor for inputs (None will take the default one)

None
weights_quant_descriptor QuantDescriptor

Quantization descriptor for weights (None will take the default one)

None
Source code in src/super_gradients/training/utils/quantization/core.py
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
@dataclass(init=True)
class QuantizedMetadata:
    """
    This dataclass is responsible for holding the information regarding float->quantized module relation.
    It can be both layer-grained and module-grained, e.g.,
    `module.backbone.conv1 -> QuantConv2d`, `nn.Linear -> QuantLinear`, etc...

    :param float_source:          Name of a specific layer (e.g., `module.backbone.conv1`),
                                        or a specific type (e.g., `Conv2d`) that will be later quantized
    :param quantized_target_class: Quantized type that the source will be converted to
    :param action:                     how to resolve the conversion, we either:
                                - SKIP: skip it,
                                - UNWRAP: unwrap the instance and work with the wrapped one
                                  (i.e., we wrap with a mapper),
                                - REPLACE: replace source with an instance of the
                                  quantized type
                                - REPLACE_AND_RECURE: replace source with an instance of the
                                  quantized type, then try to recursively quantize the child modules of that type
                                - RECURE_AND_REPLACE: recursively quantize the child modules, then
                                  replace source with an instance of the quantized type
    :param input_quant_descriptor:     Quantization descriptor for inputs (None will take the default one)
    :param weights_quant_descriptor:   Quantization descriptor for weights (None will take the default one)
    """

    class ReplacementAction(Enum):
        REPLACE = "replace"
        REPLACE_AND_RECURE = "replace_and_recure"
        RECURE_AND_REPLACE = "recure_and_replace"
        UNWRAP = "unwrap"
        SKIP = "skip"

    float_source: Union[str, Type]
    quantized_target_class: Optional[Union[Type[QuantMixin], Type[QuantInputMixin], Type[SGQuantMixin]]]
    action: ReplacementAction
    input_quant_descriptor: QuantDescriptor = None  # default is used if None
    weights_quant_descriptor: QuantDescriptor = None  # default is used if None

    def __post_init__(self):
        if self.action in (
            QuantizedMetadata.ReplacementAction.REPLACE,
            QuantizedMetadata.ReplacementAction.REPLACE_AND_RECURE,
            QuantizedMetadata.ReplacementAction.RECURE_AND_REPLACE,
        ):
            assert issubclass(self.quantized_target_class, (SGQuantMixin, QuantMixin, QuantInputMixin))

SGQuantMixin

Bases: nn.Module

A base class for user custom Quantized classes. Every Quantized class must inherit this mixin, which adds from_float class-method. NOTES: * the Quantized class may also inherit from the native QuantMixin or QuantInputMixin * quant descriptors (for inputs and weights) will be passed as kwargs. The module may ignore them if they are not necessary * the default implementation of from_float is inspecting the init args, and searching for corresponding properties from the float instance that is passed as argument, e.g., for __init__(self, a) the mechanism will look for float_instance.a and pass that value to the __init__ method

Source code in src/super_gradients/training/utils/quantization/core.py
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
class SGQuantMixin(nn.Module):
    """
    A base class for user custom Quantized classes.
    Every Quantized class must inherit this mixin, which adds `from_float` class-method.
    NOTES:
        * the Quantized class may also inherit from the native `QuantMixin` or `QuantInputMixin`
        * quant descriptors (for inputs and weights) will be passed as `kwargs`. The module may ignore them if they are
          not necessary
        * the default implementation of `from_float` is inspecting the __init__ args, and searching for corresponding
          properties from the float instance that is passed as argument, e.g., for `__init__(self, a)`
          the mechanism will look for `float_instance.a` and pass that value to the `__init__` method
    """

    @classmethod
    def from_float(cls, float_instance, **kwargs):
        required_init_params = list(inspect.signature(cls.__init__).parameters)[1:]  # [0] is self

        # if cls.__init__ has explicit `quant_desc_input` or `quant_desc_weight` - we don't search the state of the
        # float module, because it would not contain this state. these values are injected by the framework
        ignore_init_args = {"quant_desc_input", "quant_desc_weight"}.intersection(set(required_init_params))

        # if cls.__init__ doesn't have neither **kwargs, nor `quant_desc_input` and `quant_desc_weight`,
        # we should also remove these keys from the passed kwargs and make sure there's nothing more!
        if "kwargs" not in required_init_params:
            for arg in ("quant_desc_input", "quant_desc_weight"):
                if arg in ignore_init_args:
                    continue
                kwargs.pop(arg, None)  # we ignore if not existing

        return _from_float(cls, float_instance, ignore_init_args, **kwargs)

SkipQuantization

Bases: nn.Module

This class wraps a float module instance, and defines that this instance will not be converted to quantized version

Example: self.my_block = SkipQuantization(MyBlock(4, n_classes))

Source code in src/super_gradients/training/utils/quantization/core.py
81
82
83
84
85
86
87
88
89
90
91
92
class SkipQuantization(nn.Module):
    """
    This class wraps a float module instance, and defines that this instance will not be converted to quantized version

    Example:
        self.my_block = SkipQuantization(MyBlock(4, n_classes))
    """

    def __init__(self, module: nn.Module) -> None:
        super().__init__()
        self.float_module = module
        self.forward = module.forward

export_quantized_module_to_onnx(model, onnx_filename, input_shape, train=False, to_cpu=True, deepcopy_model=False, **kwargs)

Method for exporting onnx after QAT.

Parameters:

Name Type Description Default
to_cpu bool

transfer model to CPU before converting to ONNX, dirty workaround when model's tensors are on different devices

True
train bool

export model in training mode

False
model torch.nn.Module

torch.nn.Module, model to export

required
onnx_filename str

str, target path for the onnx file,

required
input_shape tuple

tuple, input shape (usually BCHW)

required
deepcopy_model

Whether to export deepcopy(model). Necessary in case further training is performed and prep_model_for_conversion makes the network un-trainable (i.e RepVGG blocks).

False
Source code in src/super_gradients/training/utils/quantization/export.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
@deprecated(
    deprecated_since="3.7.0",
    removed_from="4.0.0",
    target=export_to_onnx,
)
def export_quantized_module_to_onnx(
    model: torch.nn.Module, onnx_filename: str, input_shape: tuple, train: bool = False, to_cpu: bool = True, deepcopy_model=False, **kwargs
):
    """
    Method for exporting onnx after QAT.

    :param to_cpu: transfer model to CPU before converting to ONNX, dirty workaround when model's tensors are on different devices
    :param train: export model in training mode
    :param model: torch.nn.Module, model to export
    :param onnx_filename: str, target path for the onnx file,
    :param input_shape: tuple, input shape (usually BCHW)
    :param deepcopy_model: Whether to export deepcopy(model). Necessary in case further training is performed and
     prep_model_for_conversion makes the network un-trainable (i.e RepVGG blocks).
    """
    if deepcopy_model:
        model = deepcopy(model)

    use_fb_fake_quant_state = quant_nn.TensorQuantizer.use_fb_fake_quant
    quant_nn.TensorQuantizer.use_fb_fake_quant = True

    # Export ONNX for multiple batch sizes
    logger.info("Creating ONNX file: " + onnx_filename)

    if train:
        training_mode = TrainingMode.TRAINING
        model.train()
    else:
        training_mode = TrainingMode.EVAL
        model.eval()
        if hasattr(model, "prep_model_for_conversion"):
            model.prep_model_for_conversion(**kwargs)

    # workaround when model.prep_model_for_conversion does reparametrization
    # and tensors get scattered to different devices
    if to_cpu:
        export_model = model.cpu()
    else:
        export_model = model

    dummy_input = torch.randn(input_shape, device=next(model.parameters()).device)
    torch.onnx.export(export_model, dummy_input, onnx_filename, verbose=False, opset_version=13, do_constant_folding=True, training=training_mode)

    # Restore functions of quant_nn back as expected
    quant_nn.TensorQuantizer.use_fb_fake_quant = use_fb_fake_quant_state

QuantBackboneInternalSkipConnection

Bases: QuantSkipConnection

This is a placeholder module used by the quantization engine only. The module is to be used as a quantized substitute to a skip connection between blocks inside the backbone.

Source code in src/super_gradients/training/utils/quantization/modules/quantized_skip_connections.py
40
41
42
43
44
45
@register_quantized_module(float_source=BackboneInternalSkipConnection)
class QuantBackboneInternalSkipConnection(QuantSkipConnection):
    """
    This is a placeholder module used by the quantization engine only.
    The module is to be used as a quantized substitute to a skip connection between blocks inside the backbone.
    """

QuantCrossModelSkipConnection

Bases: QuantSkipConnection

This is a placeholder module used by the quantization engine only. The module is to be used as a quantized substitute to a skip connection between backbone and the head.

Source code in src/super_gradients/training/utils/quantization/modules/quantized_skip_connections.py
56
57
58
59
60
61
@register_quantized_module(float_source=CrossModelSkipConnection)
class QuantCrossModelSkipConnection(QuantSkipConnection):
    """
    This is a placeholder module used by the quantization engine only.
    The module is to be used as a quantized substitute to a skip connection between backbone and the head.
    """

QuantHeadInternalSkipConnection

Bases: QuantSkipConnection

This is a placeholder module used by the quantization engine only. The module is to be used as a quantized substitute to a skip connection between blocks inside the head.

Source code in src/super_gradients/training/utils/quantization/modules/quantized_skip_connections.py
48
49
50
51
52
53
@register_quantized_module(float_source=HeadInternalSkipConnection)
class QuantHeadInternalSkipConnection(QuantSkipConnection):
    """
    This is a placeholder module used by the quantization engine only.
    The module is to be used as a quantized substitute to a skip connection between blocks inside the head.
    """

QuantResidual

Bases: SGQuantMixin

This is a placeholder module used by the quantization engine only. The module is to be used as a quantized substitute to a residual skip connection within a single block.

Source code in src/super_gradients/training/utils/quantization/modules/quantized_skip_connections.py
16
17
18
19
20
21
22
23
24
25
@register_quantized_module(float_source=Residual)
class QuantResidual(SGQuantMixin):
    """
    This is a placeholder module used by the quantization engine only.
    The module is to be used as a quantized substitute to a residual skip connection within a single block.
    """

    @classmethod
    def from_float(cls, float_instance: Residual, **kwargs):
        return quant_nn.TensorQuantizer(kwargs.get("quant_desc_input"))

QuantSkipConnection

Bases: SGQuantMixin

This is a placeholder module used by the quantization engine only. The module is to be used as a quantized substitute to a skip connection between blocks.

Source code in src/super_gradients/training/utils/quantization/modules/quantized_skip_connections.py
28
29
30
31
32
33
34
35
36
37
@register_quantized_module(float_source=SkipConnection)
class QuantSkipConnection(SGQuantMixin):
    """
    This is a placeholder module used by the quantization engine only.
    The module is to be used as a quantized substitute to a skip connection between blocks.
    """

    @classmethod
    def from_float(cls, float_instance: SkipConnection, **kwargs):
        return quant_nn.TensorQuantizer(kwargs.get("quant_desc_input"))

QuantAttentionRefinementModule

Bases: SGQuantMixin, AttentionRefinementModule

AttentionRefinementModule to apply on the last two backbone stages.

Source code in src/super_gradients/training/utils/quantization/modules/quantized_stdc_blocks.py
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
@register_quantized_module(float_source=AttentionRefinementModule, action=QuantizedMetadata.ReplacementAction.REPLACE_AND_RECURE)
class QuantAttentionRefinementModule(SGQuantMixin, AttentionRefinementModule):
    """
    AttentionRefinementModule to apply on the last two backbone stages.
    """

    def __init__(self, in_channels: int, out_channels: int):
        super(QuantAttentionRefinementModule, self).__init__(in_channels=in_channels, out_channels=out_channels)
        self.q_x = Residual()
        self.q_y = Residual()

    def forward(self, x):
        x = self.conv_first(x)
        y = self.attention_block(x)
        return torch.mul(self.q_x(x), self.q_y(y))

QuantBottleneck

Bases: SGQuantMixin

we just insert quantized tensor to the shortcut (=residual) layer, so that it would be quantized NOTE: we must quantize the float instance, so the mode should be QuantizedMetadata.ReplacementAction.RECURE_AND_REPLACE

Source code in src/super_gradients/training/utils/quantization/modules/resnet_bottleneck.py
 8
 9
10
11
12
13
14
15
16
17
18
19
@register_quantized_module(float_source=Bottleneck, action=QuantizedMetadata.ReplacementAction.RECURE_AND_REPLACE)
class QuantBottleneck(SGQuantMixin):
    """
    we just insert quantized tensor to the shortcut (=residual) layer, so that it would be quantized
    NOTE: we must quantize the float instance, so the mode should be
          QuantizedMetadata.ReplacementAction.RECURE_AND_REPLACE
    """

    @classmethod
    def from_float(cls, float_instance: Bottleneck, **kwargs):
        float_instance.shortcut.add_module("residual_quantizer", quant_nn.TensorQuantizer(kwargs.get("quant_desc_input")))
        return float_instance

ptq(model, selective_quantizer, calibration_loader, calibration_method='percentile', calibration_batches=16, calibration_percentile=99.99, calibration_verbose=False)

Perform Post Training Quantization (PTQ) on the model.

Parameters:

Name Type Description Default
model

Input model to quantize. This function always returns a new model, the input model is not modified.

required
selective_quantizer Optional[SelectiveQuantizer]

An instance of SelectiveQuantizer class that defines what modules to quantize.

required
calibration_loader Optional[DataLoader]

An instance of DataLoader that provides calibration data (optional).

required
calibration_method str

(str) Calibration method for quantized models. See QuantizationCalibrator for details.

'percentile'
calibration_batches int

(int) Number of batches to use for calibration. Default is 16.

16
calibration_percentile float

(float) Percentile for percentile calibration method. Default is 99.99.

99.99
calibration_verbose bool False

Returns:

Type Description

A quantized model

Source code in src/super_gradients/training/utils/quantization/ptq.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
def ptq(
    model,
    selective_quantizer: Optional[SelectiveQuantizer],
    calibration_loader: Optional[DataLoader],
    calibration_method: str = "percentile",
    calibration_batches: int = 16,
    calibration_percentile: float = 99.99,
    calibration_verbose: bool = False,
):
    """
    Perform Post Training Quantization (PTQ) on the model.

    :param model: Input model to quantize. This function always returns a new model, the input model is not modified.
    :param selective_quantizer:  An instance of SelectiveQuantizer class that defines what modules to quantize.
    :param calibration_loader: An instance of DataLoader that provides calibration data (optional).
    :param calibration_method: (str) Calibration method for quantized models. See QuantizationCalibrator for details.
    :param calibration_batches: (int) Number of batches to use for calibration. Default is 16.
    :param calibration_percentile: (float) Percentile for percentile calibration method. Default is 99.99.
    :param calibration_verbose:
    :return: A quantized model
    """
    contains_quantized_modules = check_model_contains_quantized_modules(model)
    if contains_quantized_modules:
        logger.debug("Model contains quantized modules. Skipping quantization & calibration steps since it is already quantized.")
        return model

    model = copy.deepcopy(model)

    if selective_quantizer is None:
        selective_quantizer = SelectiveQuantizer(
            default_quant_modules_calibrator_weights="max",
            default_quant_modules_calibrator_inputs="histogram",
            default_per_channel_quant_weights=True,
            default_learn_amax=False,
            verbose=True,
        )
    selective_quantizer.quantize_module(model)

    if calibration_loader:
        logger.debug("Calibrating model")
        calibrator = QuantizationCalibrator(verbose=calibration_verbose)
        calibrator.calibrate_model(
            model,
            method=calibration_method,
            calib_data_loader=calibration_loader,
            num_calib_batches=calibration_batches,
            percentile=calibration_percentile,
        )
        logger.debug("Calibrating model complete")
        calibrator.reset_calibrators(model)

    return model

SelectiveQuantizer

Parameters:

Name Type Description Default
custom_mappings dict

custom mappings that extend the default mappings with extra behaviour

None
default_per_channel_quant_weights bool

whether quant module weights should be per channel (default=True)

True
default_quant_modules_calibrator_weights str

default calibrator method for weights (default='max')

'max'
default_quant_modules_calibrator_inputs str

default calibrator method for inputs (default='histogram')

'histogram'
default_learn_amax bool

EXPERIMENTAL! whether quant modules should have learnable amax (default=False)

False
Source code in src/super_gradients/training/utils/quantization/selective_quantization_utils.py
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
class SelectiveQuantizer:
    """
    :param custom_mappings:                             custom mappings that extend the default mappings with extra behaviour
    :param default_per_channel_quant_weights:           whether quant module weights should be per channel (default=True)
    :param default_quant_modules_calibrator_weights:    default calibrator method for weights (default='max')
    :param default_quant_modules_calibrator_inputs:     default calibrator method for inputs (default='histogram')
    :param default_learn_amax:                          EXPERIMENTAL! whether quant modules should have learnable amax (default=False)
    """

    mapping_instructions: Dict[Union[str, Type], QuantizedMetadata] = {
        **{
            float_type: QuantizedMetadata(
                float_source=float_type,
                quantized_target_class=quantized_target_class,
                action=QuantizedMetadata.ReplacementAction.REPLACE,
            )
            for (float_type, quantized_target_class) in [
                (nn.Conv1d, quant_nn.QuantConv1d),
                (nn.Conv2d, quant_nn.QuantConv2d),
                (nn.Conv3d, quant_nn.QuantConv3d),
                (nn.ConvTranspose1d, quant_nn.QuantConvTranspose1d),
                (nn.ConvTranspose2d, quant_nn.QuantConvTranspose2d),
                (nn.ConvTranspose3d, quant_nn.QuantConvTranspose3d),
                (nn.Linear, quant_nn.Linear),
                (nn.LSTM, quant_nn.LSTM),
                (nn.LSTMCell, quant_nn.LSTMCell),
                (nn.AvgPool1d, quant_nn.QuantAvgPool1d),
                (nn.AvgPool2d, quant_nn.QuantAvgPool2d),
                (nn.AvgPool3d, quant_nn.QuantAvgPool3d),
                (nn.AdaptiveAvgPool1d, quant_nn.QuantAdaptiveAvgPool1d),
                (nn.AdaptiveAvgPool2d, quant_nn.QuantAdaptiveAvgPool2d),
                (nn.AdaptiveAvgPool3d, quant_nn.QuantAdaptiveAvgPool3d),
            ]
        },
        SkipQuantization: QuantizedMetadata(float_source=SkipQuantization, quantized_target_class=None, action=QuantizedMetadata.ReplacementAction.UNWRAP),
    }  # DEFAULT MAPPING INSTRUCTIONS

    def __init__(
        self,
        *,
        custom_mappings: dict = None,
        default_quant_modules_calibrator_weights: str = "max",
        default_quant_modules_calibrator_inputs: str = "histogram",
        default_per_channel_quant_weights: bool = True,
        default_learn_amax: bool = False,
        verbose: bool = True,
    ) -> None:
        super().__init__()
        self.default_quant_modules_calibrator_weights = default_quant_modules_calibrator_weights
        self.default_quant_modules_calibrator_inputs = default_quant_modules_calibrator_inputs
        self.default_per_channel_quant_weights = default_per_channel_quant_weights
        self.default_learn_amax = default_learn_amax
        self.verbose = verbose
        self.mapping_instructions = self.mapping_instructions.copy()
        if custom_mappings is not None:
            self.mapping_instructions.update(custom_mappings)  # OVERRIDE DEFAULT WITH CUSTOM. CUSTOM IS PRIORITIZED

    def _get_default_quant_descriptor(self, for_weights=False):
        methods = {"percentile": "histogram", "mse": "histogram", "entropy": "histogram", "histogram": "histogram", "max": "max"}

        if for_weights:
            axis = 0 if self.default_per_channel_quant_weights else None

            learn_amax = self.default_learn_amax
            if self.default_learn_amax and self.default_per_channel_quant_weights:
                logger.error("Learnable amax is suported only for per-tensor quantization. Disabling it for weights quantization!")
                learn_amax = False

            return QuantDescriptor(calib_method=methods[self.default_quant_modules_calibrator_weights], axis=axis, learn_amax=learn_amax)
        else:
            # activations stay per-tensor by default
            return QuantDescriptor(calib_method=methods[self.default_quant_modules_calibrator_inputs], learn_amax=self.default_learn_amax)

    def register_skip_quantization(self, *, layer_names: Optional[Set[str]] = None):
        if layer_names is not None:
            self.mapping_instructions.update(
                {
                    name: QuantizedMetadata(float_source=name, quantized_target_class=None, action=QuantizedMetadata.ReplacementAction.SKIP)
                    for name in layer_names
                }
            )

    def register_quantization_mapping(
        self, *, layer_names: Set[str], quantized_target_class: Type[SGQuantMixin], input_quant_descriptor=None, weights_quant_descriptor=None
    ):
        self.mapping_instructions.update(
            {
                name: QuantizedMetadata(
                    float_source=name,
                    quantized_target_class=quantized_target_class,
                    action=QuantizedMetadata.ReplacementAction.REPLACE,
                    input_quant_descriptor=input_quant_descriptor,
                    weights_quant_descriptor=weights_quant_descriptor,
                )
                for name in layer_names
            }
        )

    def _preprocess_skips_and_custom_mappings(self, module: nn.Module, nesting: Tuple[str, ...] = ()):
        """
        This pass is done to extract layer name and mapping instructions, so that we regard to per-layer processing.
        Relevant layer-specific mapping instructions are either `SkipQuantization` or `QuantizedMapping`, which are then
        being added to the mappings
        """
        mapping_instructions = dict()
        for name, child_module in module.named_children():
            nested_name = ".".join(nesting + (name,))
            if isinstance(child_module, SkipQuantization):
                mapping_instructions[nested_name] = QuantizedMetadata(
                    float_source=nested_name, quantized_target_class=None, action=QuantizedMetadata.ReplacementAction.UNWRAP
                )

            if isinstance(child_module, QuantizedMapping):
                mapping_instructions[nested_name] = QuantizedMetadata(
                    float_source=nested_name,
                    quantized_target_class=child_module.quantized_target_class,
                    input_quant_descriptor=child_module.input_quant_descriptor,
                    weights_quant_descriptor=child_module.weights_quant_descriptor,
                    action=child_module.action,
                )

            if isinstance(child_module, nn.Module):  # recursive call
                mapping_instructions.update(self._preprocess_skips_and_custom_mappings(child_module, nesting + (name,)))

        return mapping_instructions

    def _instantiate_quantized_from_float(self, float_module, metadata, preserve_state_dict):
        base_classes = (QuantMixin, QuantInputMixin, SGQuantMixin)
        if not issubclass(metadata.quantized_target_class, base_classes):
            raise AssertionError(
                f"Quantization suite for {type(float_module).__name__} is invalid. "
                f"{metadata.quantized_target_class.__name__} must inherit one of "
                f"{', '.join(map(lambda _: _.__name__, base_classes))}"
            )

        # USE PROVIDED QUANT DESCRIPTORS, OR DEFAULT IF NONE PROVIDED
        quant_descriptors = dict()
        if issubclass(metadata.quantized_target_class, (SGQuantMixin, QuantMixin, QuantInputMixin)):
            quant_descriptors = {"quant_desc_input": metadata.input_quant_descriptor or self._get_default_quant_descriptor(for_weights=False)}
        if issubclass(metadata.quantized_target_class, (SGQuantMixin, QuantMixin)):
            quant_descriptors.update({"quant_desc_weight": metadata.weights_quant_descriptor or self._get_default_quant_descriptor(for_weights=True)})

        if not hasattr(metadata.quantized_target_class, "from_float"):
            assert isinstance(metadata.quantized_target_class, SGQuantMixin), (
                f"{metadata.quantized_target_class.__name__} must inherit from " f"{SGQuantMixin.__name__}, so that it would include `from_float` class method"
            )

        q_instance = metadata.quantized_target_class.from_float(float_module, **quant_descriptors)

        # MOVE TENSORS TO ORIGINAL DEVICE
        if len(list(float_module.parameters(recurse=False))) > 0:
            q_instance = q_instance.to(next(float_module.parameters(recurse=False)).device)
        elif len(list(float_module.buffers(recurse=False))):
            q_instance = q_instance.to(next(float_module.buffers(recurse=False)).device)

        # COPY STATE DICT IF NEEDED
        if preserve_state_dict:
            # quant state dict may have additional parameters for Clip and strict loading will fail
            # if we find at least one Clip module in q_instance, disable strict loading and hope for the best
            strict_load = True
            for k in q_instance.state_dict().keys():
                if "clip.clip_value_max" in k or "clip.clip_value_min" in k:
                    strict_load = False
                    logger.debug(
                        "Instantiating quant module in non-strict mode leaving Clip parameters non-initilaized. Use QuantizationCalibrator to initialize them."
                    )
                    break

            q_instance.load_state_dict(float_module.state_dict(), strict=strict_load)

        return q_instance

    def _maybe_quantize_one_layer(
        self,
        module: nn.Module,
        child_name: str,
        nesting: Tuple[str, ...],
        child_module: nn.Module,
        mapping_instructions: Dict[Union[str, Type], QuantizedMetadata],
        preserve_state_dict: bool,
    ) -> bool:
        """
        Does the heavy lifting of (maybe) quantizing a layer: creates a quantized instance based on a float instance,
        and replaces it in the "parent" module

        :param module:                  the module we'd like to quantize a specific layer in
        :param child_name:              the attribute name of the layer in the module
        :param nesting:                 the current nesting we're in. Needed to find the appropriate key in the mappings
        :param child_module:            the instance of the float module we'd like to quantize
        :param mapping_instructions:    mapping instructions: how to quantize
        :param preserve_state_dict:     whether to copy the state dict from the float instance to the quantized instance

        :return: a boolean indicates if we found a match and should not continue recursively
        """
        # if we don't have any instruction for the specific layer or the specific type - we continue
        # NOTE! IT IS IMPORTANT TO FIRST PROCESS THE NAME AND ONLY THEN THE TYPE
        for candidate_key in (".".join(nesting + (child_name,)), type(child_module)):
            if candidate_key not in mapping_instructions:
                continue
            metadata: QuantizedMetadata = mapping_instructions[candidate_key]
            if metadata.action == QuantizedMetadata.ReplacementAction.SKIP:
                return True
            elif metadata.action == QuantizedMetadata.ReplacementAction.UNWRAP:
                assert isinstance(child_module, SkipQuantization)
                setattr(module, child_name, child_module.float_module)
                return True
            elif metadata.action in (
                QuantizedMetadata.ReplacementAction.REPLACE,
                QuantizedMetadata.ReplacementAction.REPLACE_AND_RECURE,
                QuantizedMetadata.ReplacementAction.RECURE_AND_REPLACE,
            ):
                if isinstance(child_module, QuantizedMapping):  # UNWRAP MAPPING
                    child_module = child_module.float_module
                q_instance: nn.Module = self._instantiate_quantized_from_float(
                    float_module=child_module, metadata=metadata, preserve_state_dict=preserve_state_dict
                )

                # ACTUAL REPLACEMENT
                def replace():
                    setattr(module, child_name, q_instance)

                def recurse_quantize():
                    self._quantize_module_aux(
                        module=getattr(module, child_name),
                        mapping_instructions=mapping_instructions,
                        nesting=nesting + (child_name,),
                        preserve_state_dict=preserve_state_dict,
                    )

                if metadata.action == QuantizedMetadata.ReplacementAction.REPLACE:
                    replace()

                elif metadata.action == QuantizedMetadata.ReplacementAction.REPLACE_AND_RECURE:
                    replace()
                    recurse_quantize()
                elif metadata.action == QuantizedMetadata.ReplacementAction.RECURE_AND_REPLACE:
                    recurse_quantize()
                    replace()
                return True
            else:
                raise NotImplementedError
        return False

    def quantize_module(self, module: nn.Module, *, preserve_state_dict=True):
        per_layer_mappings = self._preprocess_skips_and_custom_mappings(module)
        mapping_instructions = {
            **per_layer_mappings,
            **self.mapping_instructions,
        }  # we first regard the per layer mappings, and then override with the custom mappings in case there is overlap
        logging_level = logging.getLogger("absl").getEffectiveLevel()
        if not self.verbose:  # suppress pytorch-quantization spam
            logging.getLogger("absl").setLevel("ERROR")

        device = next(module.parameters()).device
        self._quantize_module_aux(mapping_instructions=mapping_instructions, module=module, nesting=(), preserve_state_dict=preserve_state_dict)
        module.to(device)

        logging.getLogger("absl").setLevel(logging_level)

    def _quantize_module_aux(self, mapping_instructions, module, nesting, preserve_state_dict):
        for name, child_module in module.named_children():
            found = self._maybe_quantize_one_layer(module, name, nesting, child_module, mapping_instructions, preserve_state_dict)

            # RECURSIVE CALL, to support module_list, sequential, custom (nested) modules
            if not found and isinstance(child_module, nn.Module):
                self._quantize_module_aux(mapping_instructions, child_module, nesting + (name,), preserve_state_dict)

register_quantized_module(float_source, action=QuantizedMetadata.ReplacementAction.REPLACE, input_quant_descriptor=None, weights_quant_descriptor=None)

Decorator used to register a Quantized module as a quantized version for Float module

Parameters:

Name Type Description Default
action QuantizedMetadata.ReplacementAction

action to perform on the float_source

QuantizedMetadata.ReplacementAction.REPLACE
float_source Union[str, Type[nn.Module]]

the float module type that is being registered

required
input_quant_descriptor Optional[QuantDescriptor]

the input quantization descriptor

None
weights_quant_descriptor Optional[QuantDescriptor]

the weight quantization descriptor

None
Source code in src/super_gradients/training/utils/quantization/selective_quantization_utils.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
def register_quantized_module(
    float_source: Union[str, Type[nn.Module]],
    action: QuantizedMetadata.ReplacementAction = QuantizedMetadata.ReplacementAction.REPLACE,
    input_quant_descriptor: Optional[QuantDescriptor] = None,
    weights_quant_descriptor: Optional[QuantDescriptor] = None,
) -> Callable:
    """
    Decorator used to register a Quantized module as a quantized version for Float module
    :param action:                      action to perform on the float_source
    :param float_source:                the float module type that is being registered
    :param input_quant_descriptor:      the input quantization descriptor
    :param weights_quant_descriptor:    the weight quantization descriptor
    """

    def decorator(quant_module: Type[SGQuantMixin]) -> Type[SGQuantMixin]:
        if float_source in SelectiveQuantizer.mapping_instructions:
            metadata = SelectiveQuantizer.mapping_instructions[float_source]
            raise ValueError(f"`{float_source}` is already registered with following metadata {metadata}")

        SelectiveQuantizer.mapping_instructions.update(
            {
                float_source: QuantizedMetadata(
                    float_source=float_source,
                    quantized_target_class=quant_module,
                    input_quant_descriptor=input_quant_descriptor,
                    weights_quant_descriptor=weights_quant_descriptor,
                    action=action,
                )
            }
        )
        return quant_module  # this is required since the decorator assigns the result to the `quant_module`

    return decorator

use_fb_fake_quant

Context manager object to ensure that fake quantization state is preserved

with use_fb_fake_quant(True): do_stuff()

Source code in src/super_gradients/training/utils/quantization/use_fb_fake_quant.py
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
class use_fb_fake_quant:
    """
    Context manager object to ensure that fake quantization
    state is preserved

    >>> with use_fb_fake_quant(True):
    >>>    do_stuff()
    """

    def __init__(self, enabled: bool):
        self.use_fb_fake_quant_state = None
        self.enabled = enabled

    def __enter__(self):
        self.use_fb_fake_quant_state = quant_nn.TensorQuantizer.use_fb_fake_quant
        quant_nn.TensorQuantizer.use_fb_fake_quant = self.enabled
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        quant_nn.TensorQuantizer.use_fb_fake_quant = self.use_fb_fake_quant_state

DropPath

Bases: nn.Module

Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).

Intended usage of this block is the following:

class ResNetBlock(nn.Module): def init(self, ..., drop_path_rate:float): self.drop_path = DropPath(drop_path_rate)

def forward(self, x): return x + self.drop_path(self.conv_bn_act(x))

Code taken from TIMM (https://github.com/rwightman/pytorch-image-models) Apache License 2.0

Source code in src/super_gradients/training/utils/regularization_utils.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
class DropPath(nn.Module):
    """
    Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).

    Intended usage of this block is the following:

    >>> class ResNetBlock(nn.Module):
    >>>   def __init__(self, ..., drop_path_rate:float):
    >>>     self.drop_path = DropPath(drop_path_rate)
    >>>
    >>>   def forward(self, x):
    >>>     return x + self.drop_path(self.conv_bn_act(x))

    Code taken from TIMM (https://github.com/rwightman/pytorch-image-models)
    Apache License 2.0
    """

    def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
        """

        :param drop_prob: Probability of zeroing out individual vector (channel dimension) of each feature map
        :param scale_by_keep: Whether to scale the output by the keep probability. Enable by default and helps to
                              keep output mean & std in the same range as w/o drop path.
        """
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob
        self.scale_by_keep = scale_by_keep

    def forward(self, x):
        if self.drop_prob == 0.0 or not self.training:
            return x

        return drop_path(x, self.drop_prob, self.scale_by_keep)

    def extra_repr(self):
        return f"drop_prob={round(self.drop_prob,3):0.3f}"

__init__(drop_prob=0.0, scale_by_keep=True)

Parameters:

Name Type Description Default
drop_prob float

Probability of zeroing out individual vector (channel dimension) of each feature map

0.0
scale_by_keep bool

Whether to scale the output by the keep probability. Enable by default and helps to keep output mean & std in the same range as w/o drop path.

True
Source code in src/super_gradients/training/utils/regularization_utils.py
34
35
36
37
38
39
40
41
42
43
def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
    """

    :param drop_prob: Probability of zeroing out individual vector (channel dimension) of each feature map
    :param scale_by_keep: Whether to scale the output by the keep probability. Enable by default and helps to
                          keep output mean & std in the same range as w/o drop path.
    """
    super(DropPath, self).__init__()
    self.drop_prob = drop_prob
    self.scale_by_keep = scale_by_keep

drop_path(x, drop_prob=0.0, scale_by_keep=True)

Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).

Source code in src/super_gradients/training/utils/regularization_utils.py
 4
 5
 6
 7
 8
 9
10
11
12
13
14
def drop_path(x, drop_prob: float = 0.0, scale_by_keep: bool = True):
    """
    Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
    """

    keep_prob = 1 - drop_prob
    shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets
    random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
    if keep_prob > 0.0 and scale_by_keep:
        random_tensor.div_(keep_prob)
    return x * random_tensor

BinarySegmentationVisualization

Source code in src/super_gradients/training/utils/segmentation_utils.py
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
class BinarySegmentationVisualization:
    @staticmethod
    def _visualize_image(image_np: np.ndarray, pred_mask: torch.Tensor, target_mask: torch.Tensor, image_scale: float, checkpoint_dir: str, image_name: str):
        pred_mask = pred_mask.copy()
        image_np = torch.from_numpy(np.moveaxis(image_np, -1, 0).astype(np.uint8))

        pred_mask = pred_mask[np.newaxis, :, :] > 0.5
        target_mask = target_mask[np.newaxis, :, :].astype(bool)
        tp_mask = np.logical_and(pred_mask, target_mask)
        fp_mask = np.logical_and(pred_mask, np.logical_not(target_mask))
        fn_mask = np.logical_and(np.logical_not(pred_mask), target_mask)
        overlay = torch.from_numpy(np.concatenate([tp_mask, fp_mask, fn_mask]))

        # SWITCH BETWEEN BLUE AND RED IF WE SAVE THE IMAGE ON THE DISC AS OTHERWISE WE CHANGE CHANNEL ORDERING
        colors = ["green", "red", "blue"]
        res_image = draw_segmentation_masks(image_np, overlay, colors=colors).detach().numpy()
        res_image = np.concatenate([res_image[ch, :, :, np.newaxis] for ch in range(3)], 2)
        res_image = cv2.resize(res_image.astype(np.uint8), (0, 0), fx=image_scale, fy=image_scale, interpolation=cv2.INTER_NEAREST)

        if checkpoint_dir is None:
            return res_image
        else:
            cv2.imwrite(os.path.join(checkpoint_dir, str(image_name) + ".jpg"), res_image)

    @staticmethod
    def visualize_batch(
        image_tensor: torch.Tensor,
        pred_mask: torch.Tensor,
        target_mask: torch.Tensor,
        batch_name: Union[int, str],
        checkpoint_dir: str = None,
        undo_preprocessing_func: Callable[[torch.Tensor], np.ndarray] = reverse_imagenet_preprocessing,
        image_scale: float = 1.0,
    ):
        """
        A helper function to visualize detections predicted by a network:
        saves images into a given path with a name that is {batch_name}_{imade_idx_in_the_batch}.jpg, one batch per call.
        Colors are generated on the fly: uniformly sampled from color wheel to support all given classes.

        :param image_tensor:            rgb images, (B, H, W, 3)
        :param pred_mask:              prediction mask in shape [B, 1, H, W] with C number of classes
        :param target_mask:            (Num_targets, 6), values on dim 1 are: image id in a batch, class, x y w h
                                        (coordinates scaled to [0, 1])
        :param batch_name:              id of the current batch to use for image naming

        :param checkpoint_dir:          a path where images with boxes will be saved. if None, the result images will
                                        be returns as a list of numpy image arrays

        :param undo_preprocessing_func: a function to convert preprocessed images tensor into a batch of cv2-like images
        :param image_scale:             scale factor for output image
        """
        image_np = undo_preprocessing_func(image_tensor.detach())
        pred_mask = torch.sigmoid(pred_mask[:, 0, :, :])  # comment out

        out_images = []
        for i in range(image_np.shape[0]):
            preds = pred_mask[i].detach().cpu().numpy()
            targets = target_mask[i].detach().cpu().numpy()

            image_name = "_".join([str(batch_name), str(i)])
            res_image = BinarySegmentationVisualization._visualize_image(image_np[i], preds, targets, image_scale, checkpoint_dir, image_name)
            if res_image is not None:
                out_images.append(res_image)

        return out_images

visualize_batch(image_tensor, pred_mask, target_mask, batch_name, checkpoint_dir=None, undo_preprocessing_func=reverse_imagenet_preprocessing, image_scale=1.0) staticmethod

A helper function to visualize detections predicted by a network: saves images into a given path with a name that is {batch_name}_{imade_idx_in_the_batch}.jpg, one batch per call. Colors are generated on the fly: uniformly sampled from color wheel to support all given classes.

Parameters:

Name Type Description Default
image_tensor torch.Tensor

rgb images, (B, H, W, 3)

required
pred_mask torch.Tensor

prediction mask in shape [B, 1, H, W] with C number of classes

required
target_mask torch.Tensor

(Num_targets, 6), values on dim 1 are: image id in a batch, class, x y w h (coordinates scaled to [0, 1])

required
batch_name Union[int, str]

id of the current batch to use for image naming

required
checkpoint_dir str

a path where images with boxes will be saved. if None, the result images will be returns as a list of numpy image arrays

None
undo_preprocessing_func Callable[[torch.Tensor], np.ndarray]

a function to convert preprocessed images tensor into a batch of cv2-like images

reverse_imagenet_preprocessing
image_scale float

scale factor for output image

1.0
Source code in src/super_gradients/training/utils/segmentation_utils.py
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
@staticmethod
def visualize_batch(
    image_tensor: torch.Tensor,
    pred_mask: torch.Tensor,
    target_mask: torch.Tensor,
    batch_name: Union[int, str],
    checkpoint_dir: str = None,
    undo_preprocessing_func: Callable[[torch.Tensor], np.ndarray] = reverse_imagenet_preprocessing,
    image_scale: float = 1.0,
):
    """
    A helper function to visualize detections predicted by a network:
    saves images into a given path with a name that is {batch_name}_{imade_idx_in_the_batch}.jpg, one batch per call.
    Colors are generated on the fly: uniformly sampled from color wheel to support all given classes.

    :param image_tensor:            rgb images, (B, H, W, 3)
    :param pred_mask:              prediction mask in shape [B, 1, H, W] with C number of classes
    :param target_mask:            (Num_targets, 6), values on dim 1 are: image id in a batch, class, x y w h
                                    (coordinates scaled to [0, 1])
    :param batch_name:              id of the current batch to use for image naming

    :param checkpoint_dir:          a path where images with boxes will be saved. if None, the result images will
                                    be returns as a list of numpy image arrays

    :param undo_preprocessing_func: a function to convert preprocessed images tensor into a batch of cv2-like images
    :param image_scale:             scale factor for output image
    """
    image_np = undo_preprocessing_func(image_tensor.detach())
    pred_mask = torch.sigmoid(pred_mask[:, 0, :, :])  # comment out

    out_images = []
    for i in range(image_np.shape[0]):
        preds = pred_mask[i].detach().cpu().numpy()
        targets = target_mask[i].detach().cpu().numpy()

        image_name = "_".join([str(batch_name), str(i)])
        res_image = BinarySegmentationVisualization._visualize_image(image_np[i], preds, targets, image_scale, checkpoint_dir, image_name)
        if res_image is not None:
            out_images.append(res_image)

    return out_images

forward_with_sliding_window_wrapper(forward, img, sliding_window_stride, sliding_window_crop_size, num_classes)

Inference by sliding-window with overlap. It involves systematically moving a window with a fixed crop-size over the input image. As the window moves across the image, features or patterns within the window are extracted by running a forward pass of the crop image through the net.

If h_crop > h_img or w_crop > w_img, the small patch will be used to decode without padding.

Parameters:

Name Type Description Default
forward Callable[[torch.Tensor], torch.Tensor]

a model's forward function.

required
img torch.Tensor

a batch of images to benchmark the model on using sliding window.

required
sliding_window_stride tuple

(height, width) the stride size between crops for forward with sliding window

required
sliding_window_crop_size tuple

(height, width) the crop size to take from the image for forward with sliding window

required
num_classes int

the number of classes. return: predictions tensor

required
Source code in src/super_gradients/training/utils/segmentation_utils.py
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
def forward_with_sliding_window_wrapper(
    forward: Callable[[torch.Tensor], torch.Tensor], img: torch.Tensor, sliding_window_stride: tuple, sliding_window_crop_size: tuple, num_classes: int
) -> torch.Tensor:
    """
    Inference by sliding-window with overlap. It involves systematically moving a window with a fixed crop-size over
    the input image. As the window moves across the image, features or patterns within the window are extracted by
    running a forward pass of the crop image through the net.

    If h_crop > h_img or w_crop > w_img, the small patch will be used to decode without padding.

    :param forward: a model's forward function.
    :param img: a batch of images to benchmark the model on using sliding window.
    :param sliding_window_stride: (height, width) the stride size between crops for forward with sliding window
    :param sliding_window_crop_size: (height, width) the crop size to take from the image for forward with sliding window
    :param num_classes: the number of classes.

    return: predictions tensor
    """

    h_stride, w_stride = sliding_window_stride
    h_crop, w_crop = sliding_window_crop_size

    if h_stride > h_crop or w_stride > w_crop:
        raise ValueError("sliding_window_stride cannot be larger than sliding_window_crop_size.")

    batch_size, _, h_img, w_img = img.size()

    h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1
    w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1

    preds = torch.zeros((batch_size, num_classes, h_img, w_img), device=img.device)
    count_mat = torch.zeros((batch_size, 1, h_img, w_img), device=img.device)

    for h_idx in range(h_grids):
        for w_idx in range(w_grids):
            y1 = h_idx * h_stride
            x1 = w_idx * w_stride
            y2 = min(y1 + h_crop, h_img)
            x2 = min(x1 + w_crop, w_img)
            y1 = max(y2 - h_crop, 0)
            x1 = max(x2 - w_crop, 0)
            crop_img = img[:, :, y1:y2, x1:x2]

            crop_logits = forward(crop_img)

            if isinstance(crop_logits, torch.Tensor):
                crop_logits = (crop_logits,)

            crop_logits = crop_logits[0]

            preds[:, :, y1:y2, x1:x2] += crop_logits

            count_mat[:, :, y1:y2, x1:x2] += 1
    preds = preds / count_mat
    return preds

one_hot_to_binary_edge(x, kernel_size, flatten_channels=True)

Utils function to create edge feature maps.

Parameters:

Name Type Description Default
x torch.Tensor

input tensor, must be one_hot tensor with shape [B, C, H, W]

required
kernel_size int

kernel size of dilation erosion convolutions. The result edge widths depends on this argument as follows: edge_width = kernel - 1

required
flatten_channels bool

Whether to apply logical_or across channels dimension, if at least one pixel class is considered as edge pixel flatten value is 1. If set as False the output tensor shape is [B, C, H, W], else [B, 1, H, W]. Default is True.

True

Returns:

Type Description
torch.Tensor

one_hot edge torch.Tensor.

Source code in src/super_gradients/training/utils/segmentation_utils.py
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
def one_hot_to_binary_edge(x: torch.Tensor, kernel_size: int, flatten_channels: bool = True) -> torch.Tensor:
    """
    Utils function to create edge feature maps.
    :param x: input tensor, must be one_hot tensor with shape [B, C, H, W]
    :param kernel_size: kernel size of dilation erosion convolutions. The result edge widths depends on this argument as
        follows: `edge_width = kernel - 1`
    :param flatten_channels: Whether to apply logical_or across channels dimension, if at least one pixel class is
        considered as edge pixel flatten value is 1. If set as `False` the output tensor shape is [B, C, H, W], else
        [B, 1, H, W]. Default is `True`.
    :return: one_hot edge torch.Tensor.
    """
    if kernel_size < 0 or kernel_size % 2 == 0:
        raise ValueError(f"kernel size must be an odd positive values, such as [1, 3, 5, ..], found: {kernel_size}")
    _kernel = torch.ones(x.size(1), 1, kernel_size, kernel_size, dtype=torch.float32, device=x.device)
    padding = (kernel_size - 1) // 2
    # Use replicate padding to prevent class shifting and edge formation at the image boundaries.
    padded_x = F.pad(x.float(), mode="replicate", pad=[padding] * 4)
    # The binary edges feature map is created by subtracting dilated features from erosed features.
    # First the positive one value masks are expanded (dilation) by applying a sliding window filter of one values.
    # The resulted output is then clamped to binary format to [0, 1], this way the one-hot boundaries are expanded by
    # (kernel_size - 1) / 2.
    dilation = torch.clamp(F.conv2d(padded_x, _kernel, groups=x.size(1)), 0, 1)
    # Similar to dilation, erosion (can be seen as inverse of dilation) is applied to contract the one-hot features by
    # applying a dilation operation on the inverse of the one-hot features.
    erosion = 1 - torch.clamp(F.conv2d(1 - padded_x, _kernel, groups=x.size(1)), 0, 1)
    # Finally the edge features are the result of subtracting dilation by erosion.
    # i.e for a simple 1D one-hot input:    [0, 0, 0, 1, 1, 1, 0, 0, 0], using sliding kernel with size 3: [1, 1, 1]
    # Dilated features:                     [0, 0, 1, 1, 1, 1, 1, 0, 0]
    # Erosed inverse features:              [0, 0, 0, 0, 1, 0, 0, 0, 0]
    # Edge features: dilation - erosion:    [0, 0, 1, 1, 0, 1, 1, 0, 0]
    edge = dilation - erosion
    if flatten_channels:
        # use max operator across channels. Equivalent to logical or for input with binary values [0, 1].
        edge = edge.max(dim=1, keepdim=True)[0]
    return edge

reverse_imagenet_preprocessing(im_tensor)

Parameters:

Name Type Description Default
im_tensor torch.Tensor

images in a batch after preprocessing for inference, RGB, (B, C, H, W)

required

Returns:

Type Description
np.ndarray

images in a batch in cv2 format, BGR, (B, H, W, C)

Source code in src/super_gradients/training/utils/segmentation_utils.py
59
60
61
62
63
64
65
66
67
68
69
def reverse_imagenet_preprocessing(im_tensor: torch.Tensor) -> np.ndarray:
    """
    :param im_tensor: images in a batch after preprocessing for inference, RGB, (B, C, H, W)
    :return:          images in a batch in cv2 format, BGR, (B, H, W, C)
    """
    im_np = im_tensor.cpu().numpy()
    im_np = im_np[:, ::-1, :, :].transpose(0, 2, 3, 1)
    im_np *= np.array([[[0.229, 0.224, 0.225][::-1]]])
    im_np += np.array([[[0.485, 0.456, 0.406][::-1]]])
    im_np *= 255.0
    return np.ascontiguousarray(im_np, dtype=np.uint8)

target_to_binary_edge(target, num_classes, kernel_size, ignore_index=None, flatten_channels=True)

Utils function to create edge feature maps from target.

Parameters:

Name Type Description Default
target torch.Tensor

Class labels long tensor, with shape [N, H, W]

required
num_classes int

num of classes in datasets excluding ignore label, this is the output channels of the one hot result.

required
kernel_size int

kernel size of dilation erosion convolutions. The result edge widths depends on this argument as follows: edge_width = kernel - 1

required
flatten_channels bool

Whether to apply logical or across channels dimension, if at least one pixel class is considered as edge pixel flatten value is 1. If set as False the output tensor shape is [B, C, H, W], else [B, 1, H, W]. Default is True.

True
ignore_index int

the index of the class in the dataset to ignore

None

Returns:

Type Description
torch.Tensor

one_hot edge torch.Tensor.

Source code in src/super_gradients/training/utils/segmentation_utils.py
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
def target_to_binary_edge(target: torch.Tensor, num_classes: int, kernel_size: int, ignore_index: int = None, flatten_channels: bool = True) -> torch.Tensor:
    """
    Utils function to create edge feature maps from target.
    :param target: Class labels long tensor, with shape [N, H, W]
    :param num_classes: num of classes in datasets excluding ignore label, this is the output channels of the one hot
        result.
    :param kernel_size: kernel size of dilation erosion convolutions. The result edge widths depends on this argument as
        follows: `edge_width = kernel - 1`
    :param flatten_channels: Whether to apply logical or across channels dimension, if at least one pixel class is
        considered as edge pixel flatten value is 1. If set as `False` the output tensor shape is [B, C, H, W], else
        [B, 1, H, W]. Default is `True`.
    :param ignore_index: the index of the class in the dataset to ignore

    :return: one_hot edge torch.Tensor.
    """
    one_hot = to_one_hot(target, num_classes=num_classes, ignore_index=ignore_index)
    return one_hot_to_binary_edge(one_hot, kernel_size=kernel_size, flatten_channels=flatten_channels)

to_one_hot(target, num_classes, ignore_index=None)

Target label to one_hot tensor. labels and ignore_index must be consecutive numbers.

Parameters:

Name Type Description Default
target torch.Tensor

Class labels long tensor, with shape [N, H, W]

required
num_classes int

num of classes in datasets excluding ignore label, this is the output channels of the one hot result.

required
ignore_index int

the index of the class in the dataset to ignore

None

Returns:

Type Description

one hot tensor with shape [N, num_classes, H, W]

Source code in src/super_gradients/training/utils/segmentation_utils.py
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
def to_one_hot(target: torch.Tensor, num_classes: int, ignore_index: int = None):
    """
    Target label to one_hot tensor. labels and ignore_index must be consecutive numbers.
    :param target: Class labels long tensor, with shape [N, H, W]
    :param num_classes: num of classes in datasets excluding ignore label, this is the output channels of the one hot
        result.
    :param ignore_index: the index of the class in the dataset to ignore
    :return: one hot tensor with shape [N, num_classes, H, W]
    """
    num_classes = num_classes if ignore_index is None else num_classes + 1

    one_hot = F.one_hot(target, num_classes).permute((0, 3, 1, 2))

    if ignore_index is not None:
        # remove ignore_index channel
        one_hot = torch.cat([one_hot[:, :ignore_index], one_hot[:, ignore_index + 1 :]], dim=1)

    return one_hot

ImprovementType

Bases: Enum

Type of improvement compared to previous value, i.e. if the value is better, worse or the same.

Difference with "increase": If a loss goes from 1 to 0.5, the value is smaller (decreased), but the result is better (improvement). For accuracy from 1 to 0.5, the value is smaller, but this time the result decreased, because greater is better.

Source code in src/super_gradients/training/utils/sg_trainer_utils.py
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
class ImprovementType(Enum):
    """Type of improvement compared to previous value, i.e. if the value is better, worse or the same.

    Difference with "increase":
        If a loss goes from 1 to 0.5, the value is smaller (decreased), but the result is better (improvement).
        For accuracy from 1 to 0.5, the value is smaller, but this time the result decreased, because greater is better.
    """

    IS_BETTER = "better"
    IS_WORSE = "worse"
    IS_SAME = "same"
    NONE = "none"

    def to_color(self) -> Union[str, None]:
        """Get the color representing the current improvement type"""
        if self == ImprovementType.IS_SAME:
            return "white"
        elif self == ImprovementType.IS_BETTER:
            return "green"
        elif self == ImprovementType.IS_WORSE:
            return "red"
        else:
            return None

to_color()

Get the color representing the current improvement type

Source code in src/super_gradients/training/utils/sg_trainer_utils.py
64
65
66
67
68
69
70
71
72
73
def to_color(self) -> Union[str, None]:
    """Get the color representing the current improvement type"""
    if self == ImprovementType.IS_SAME:
        return "white"
    elif self == ImprovementType.IS_BETTER:
        return "green"
    elif self == ImprovementType.IS_WORSE:
        return "red"
    else:
        return None

IncreaseType

Bases: Enum

Type of increase compared to previous value, i.e. if the value is greater, smaller or the same.

Difference with "improvement": If a loss goes from 1 to 0.5, the value is smaller (decreased), but the result is better (improvement). For accuracy from 1 to 0.5, the value is smaller, but this time the result decreased, because greater is better.

Source code in src/super_gradients/training/utils/sg_trainer_utils.py
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
class IncreaseType(Enum):
    """Type of increase compared to previous value, i.e. if the value is greater, smaller or the same.

    Difference with "improvement":
        If a loss goes from 1 to 0.5, the value is smaller (decreased), but the result is better (improvement).
        For accuracy from 1 to 0.5, the value is smaller, but this time the result decreased, because greater is better.
    """

    NONE = "none"
    IS_GREATER = "greater"
    IS_SMALLER = "smaller"
    IS_EQUAL = "equal"

    def to_symbol(self) -> str:
        """Get the symbol representing the current increase type"""
        if self == IncreaseType.NONE:
            return ""
        elif self == IncreaseType.IS_GREATER:
            return "↗"
        elif self == IncreaseType.IS_SMALLER:
            return "↘"
        else:
            return "="

to_symbol()

Get the symbol representing the current increase type

Source code in src/super_gradients/training/utils/sg_trainer_utils.py
39
40
41
42
43
44
45
46
47
48
def to_symbol(self) -> str:
    """Get the symbol representing the current increase type"""
    if self == IncreaseType.NONE:
        return ""
    elif self == IncreaseType.IS_GREATER:
        return "↗"
    elif self == IncreaseType.IS_SMALLER:
        return "↘"
    else:
        return "="

MonitoredValue dataclass

Store a value and some indicators relative to its past iterations.

The value can be a metric/loss, and the iteration can be epochs/batch.

Parameters:

Name Type Description Default
name str

Name of the metric

required
greater_is_better Optional[bool]

True, a greater value is considered better. ex: (greater_is_better=True) For Accuracy 1 is greater and therefore better than 0.4 ex: (greater_is_better=False) For Loss 1 is greater and therefore worse than 0.4 None when unknown

None
current Optional[float]

Current value of the metric

None
previous Optional[float]

Value of the metric in previous iteration

None
best Optional[float]

Value of the metric in best iteration (best according to greater_is_better)

None
change_from_previous Optional[float]

Change compared to previous iteration value

None
change_from_best Optional[float]

Change compared to best iteration value

None
Source code in src/super_gradients/training/utils/sg_trainer_utils.py
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
@dataclass
class MonitoredValue:
    """Store a value and some indicators relative to its past iterations.

    The value can be a metric/loss, and the iteration can be epochs/batch.

    :param name:                    Name of the metric
    :param greater_is_better:       True, a greater value is considered better.
                                      ex: (greater_is_better=True) For Accuracy 1 is greater and therefore better than 0.4
                                      ex: (greater_is_better=False) For Loss 1 is greater and therefore worse than 0.4
                                    None when unknown
    :param current:                 Current value of the metric
    :param previous:                Value of the metric in previous iteration
    :param best:                    Value of the metric in best iteration (best according to greater_is_better)
    :param change_from_previous:    Change compared to previous iteration value
    :param change_from_best:        Change compared to best iteration value
    """

    name: str
    greater_is_better: Optional[bool] = None
    current: Optional[float] = None
    previous: Optional[float] = None
    best: Optional[float] = None
    change_from_previous: Optional[float] = None
    change_from_best: Optional[float] = None

    @property
    def has_increased_from_previous(self) -> IncreaseType:
        """Type of increase compared to previous value, i.e. if the value is greater, smaller or the same."""
        return self._get_increase_type(self.change_from_previous)

    @property
    def has_improved_from_previous(self) -> ImprovementType:
        """Type of improvement compared to previous value, i.e. if the value is better, worse or the same."""
        return self._get_improvement_type(delta=self.change_from_previous)

    @property
    def has_increased_from_best(self) -> IncreaseType:
        """Type of increase compared to best value, i.e. if the value is greater, smaller or the same."""
        return self._get_increase_type(self.change_from_best)

    @property
    def has_improved_from_best(self) -> ImprovementType:
        """Type of improvement compared to best value, i.e. if the value is better, worse or the same."""
        return self._get_improvement_type(delta=self.change_from_best)

    def _get_increase_type(self, delta: float) -> IncreaseType:
        """Type of increase, i.e. if the value is greater, smaller or the same."""
        if self.change_from_best is None:
            return IncreaseType.NONE
        if delta > 0:
            return IncreaseType.IS_GREATER
        elif delta < 0:
            return IncreaseType.IS_SMALLER
        else:
            return IncreaseType.IS_EQUAL

    def _get_improvement_type(self, delta: float) -> ImprovementType:
        """Type of improvement, i.e. if value is better, worse or the same."""
        if self.greater_is_better is None or self.change_from_best is None:
            return ImprovementType.NONE
        has_increased, has_decreased = delta > 0, delta < 0
        if has_increased and self.greater_is_better or has_decreased and not self.greater_is_better:
            return ImprovementType.IS_BETTER
        elif has_increased and not self.greater_is_better or has_decreased and self.greater_is_better:
            return ImprovementType.IS_WORSE
        else:
            return ImprovementType.IS_SAME

has_improved_from_best: ImprovementType property

Type of improvement compared to best value, i.e. if the value is better, worse or the same.

has_improved_from_previous: ImprovementType property

Type of improvement compared to previous value, i.e. if the value is better, worse or the same.

has_increased_from_best: IncreaseType property

Type of increase compared to best value, i.e. if the value is greater, smaller or the same.

has_increased_from_previous: IncreaseType property

Type of increase compared to previous value, i.e. if the value is greater, smaller or the same.

add_log_to_file(filename, results_titles_list, results_values_list, epoch, max_epochs)

Add a message to the log file

Source code in src/super_gradients/training/utils/sg_trainer_utils.py
340
341
342
343
344
345
346
347
348
def add_log_to_file(filename, results_titles_list, results_values_list, epoch, max_epochs):
    """Add a message to the log file"""
    # -Note: opening and closing the file every time is in-efficient. It is done for experimental purposes
    with open(filename, "a") as f:
        f.write("\nEpoch (%d/%d)  - " % (epoch, max_epochs))
        for result_title, result_value in zip(results_titles_list, results_values_list):
            if isinstance(result_value, torch.Tensor):
                result_value = result_value.item()
            f.write(result_title + ": " + str(result_value) + "\t")

display_epoch_summary(epoch, n_digits, monitored_values_dict)

Display a summary of loss/metric of interest, for a given epoch.

Parameters:

Name Type Description Default
epoch int

the number of epoch.

required
n_digits int

number of digits to display on screen for float values

required
monitored_values_dict Dict[str, Dict[str, MonitoredValue]]

Dict of Dict. The first one represents the splut, and the second one a loss/metric.

required
Source code in src/super_gradients/training/utils/sg_trainer_utils.py
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
def display_epoch_summary(epoch: int, n_digits: int, monitored_values_dict: Dict[str, Dict[str, MonitoredValue]]) -> None:
    """Display a summary of loss/metric of interest, for a given epoch.

    :param epoch: the number of epoch.
    :param n_digits: number of digits to display on screen for float values
    :param monitored_values_dict: Dict of Dict. The first one represents the splut, and the second one a loss/metric.
    """

    def _format_to_str(val: Optional[float]) -> str:
        return str(round(val, n_digits)) if val is not None else "None"

    def _generate_tree(value_name: str, monitored_value: MonitoredValue) -> Tree:
        """Generate a tree that represents the stats of a given loss/metric."""

        current = _format_to_str(monitored_value.current)
        root_id = str(hash(f"{value_name} = {current}")) + str(random.random())

        tree = Tree()
        tree.create_node(tag=f"{value_name.capitalize()} = {current}", identifier=root_id)

        if monitored_value.previous is not None:
            previous = _format_to_str(monitored_value.previous)
            best = _format_to_str(monitored_value.best)
            change_from_previous = _format_to_str(monitored_value.change_from_previous)
            change_from_best = _format_to_str(monitored_value.change_from_best)

            diff_with_prev_colored = colored(
                text=f"{monitored_value.has_increased_from_previous.to_symbol()} {change_from_previous}",
                color=monitored_value.has_improved_from_previous.to_color(),
            )
            diff_with_best_colored = colored(
                text=f"{monitored_value.has_increased_from_best.to_symbol()} {change_from_best}", color=monitored_value.has_improved_from_best.to_color()
            )

            tree.create_node(tag=f"Epoch N-1      = {previous:6} ({diff_with_prev_colored:8})", identifier=f"0_previous_{root_id}", parent=root_id)
            tree.create_node(tag=f"Best until now = {best:6} ({diff_with_best_colored:8})", identifier=f"1_best_{root_id}", parent=root_id)
        return tree

    summary_tree = Tree()
    summary_tree.create_node(f"SUMMARY OF EPOCH {epoch}", "Summary")

    for split, monitored_values in monitored_values_dict.items():
        if len(monitored_values):
            split_tree = Tree()
            split_tree.create_node(split, split)
            for name, value in monitored_values.items():
                split_tree.paste(split, new_tree=_generate_tree(name, monitored_value=value))
            summary_tree.paste("Summary", split_tree)

    print("===========================================================")
    summary_tree.show(key=False)
    print("===========================================================")

get_callable_param_names(obj)

Get the param names of a given callable (function, class, ...)

Parameters:

Name Type Description Default
obj callable

Object to inspect

required

Returns:

Type Description
Tuple[str]

Param names of that object

Source code in src/super_gradients/training/utils/sg_trainer_utils.py
439
440
441
442
443
444
def get_callable_param_names(obj: callable) -> Tuple[str]:
    """Get the param names of a given callable (function, class, ...)
    :param obj: Object to inspect
    :return: Param names of that object
    """
    return tuple(inspect.signature(obj).parameters)

get_lr_info(model, param_groups)

Generate a string with information about the model and learning rates for each parameter group.

Parameters:

Name Type Description Default
model nn.Module

(nn.Module): The PyTorch model.

required
param_groups List[Dict[str, Union[str, float, List[tuple]]]]

(List[Dict[str, Union[str, float, List[tuple]]]]): List of dictionaries containing information about each parameter group, including the group name, learning rate, and named parameters. Returns: str: A formatted string with information about the model and learning rates.

required
Source code in src/super_gradients/training/utils/sg_trainer_utils.py
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
def get_lr_info(model: nn.Module, param_groups: List[Dict[str, Union[str, float, List[tuple]]]]) -> str:
    """
    Generate a string with information about the model and learning rates for each parameter group.

    :param model: (nn.Module): The PyTorch model.
    :param param_groups: (List[Dict[str, Union[str, float, List[tuple]]]]): List of dictionaries containing information about
            each parameter group, including the group name, learning rate, and named parameters.

    Returns:
        str: A formatted string with information about the model and learning rates.
    """
    total_params = sum(p.numel() for p in model.parameters())
    optimized_params = sum(p.numel() for group in param_groups for p in group["params"])

    info_str = f"    - Model: {type(model).__name__}  ({_format_value(total_params)} parameters"
    info_str += f", {_format_value(optimized_params)} optimized)\n"
    info_str += "    - Learning Rates and Weight Decays:\n"

    lr_wd_groups = {}
    max_group_name_len = max(len(group["name"]) for group in param_groups)
    max_lr_len = 0
    max_wd_len = 0

    for group in param_groups:
        group_name = group["name"]
        group_lr = group["lr"]
        group_wd = group["weight_decay"]
        group_params = sum(p.numel() for p in group["params"])

        if group_name not in lr_wd_groups:
            lr_wd_groups[group_name] = {"params": 0, "lr_params": {}, "wd_params": {}}
        lr_wd_groups[group_name]["params"] += group_params

        if group_lr not in lr_wd_groups[group_name]["lr_params"].keys():
            lr_wd_groups[group_name]["lr_params"][group_lr] = group_params
        else:
            lr_wd_groups[group_name]["lr_params"][group_lr] += group_params
        max_lr_len = max(max_lr_len, len(f"LR: {group_lr} ({_format_value(group_params)} parameters)"))

        if group_wd not in lr_wd_groups[group_name]["wd_params"].keys():
            lr_wd_groups[group_name]["wd_params"][group_wd] = group_params
        else:
            lr_wd_groups[group_name]["wd_params"][group_wd] += group_params
        max_wd_len = max(max_wd_len, len(f"WD: {group_wd}, ({_format_value(group_params)} parameters)"))

    for group_name, info in lr_wd_groups.items():
        all_group_params = info["params"]
        lr_str = ", ".join([f"LR: {lr_val} ({_format_value(lr_params)} parameters)" for lr_val, lr_params in info["lr_params"].items()])
        wd_str = ", ".join([f"WD: {wd_val}, ({_format_value(wd_params)} parameters)" for wd_val, wd_params in info["wd_params"].items()])

        # Calculate padding for alignment
        padding_len = max_group_name_len - len(group_name)
        padded_group_name = f"{group_name}:{' ' * padding_len}"

        lr_padding = " " * (max_lr_len - len(lr_str))
        wd_padding = " " * (max_wd_len - len(wd_str))

        info_str += f"      - {padded_group_name} ({_format_value(all_group_params)} parameters). {lr_str}{lr_padding} {wd_str}{wd_padding}\n"

    return info_str

init_summary_writer(tb_dir, checkpoint_loaded, user_prompt=False)

Remove previous tensorboard files from directory and launch a tensor board process

Source code in src/super_gradients/training/utils/sg_trainer_utils.py
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
def init_summary_writer(tb_dir, checkpoint_loaded, user_prompt=False):
    """Remove previous tensorboard files from directory and launch a tensor board process"""
    # If the training is from scratch, Walk through destination folder and delete existing tensorboard logs
    user = ""
    if not checkpoint_loaded:
        for filename in os.listdir(tb_dir):
            if "events" in filename:
                if not user_prompt:
                    logger.debug('"{}" will not be deleted'.format(filename))
                    continue

                while True:
                    # Verify with user before deleting old tensorboard files
                    user = (
                        input('\nOLDER TENSORBOARD FILES EXISTS IN EXPERIMENT FOLDER:\n"{}"\n' "DO YOU WANT TO DELETE THEM? [y/n]".format(filename))
                        if (user != "n" or user != "y")
                        else user
                    )
                    if user == "y":
                        os.remove("{}/{}".format(tb_dir, filename))
                        print("DELETED: {}!".format(filename))
                        break
                    elif user == "n":
                        print('"{}" will not be deleted'.format(filename))
                        break
                    print("Unknown answer...")

    # Launch a tensorboard process
    return SummaryWriter(tb_dir)

launch_tensorboard_process(checkpoints_dir_path, sleep_postpone=True, port=None)

launch_tensorboard_process - Default behavior is to scan all free ports from 6006-6016 and try using them unless port is defined by the user :param checkpoints_dir_path: :param sleep_postpone: :param port: :return: tuple of tb process, port

Source code in src/super_gradients/training/utils/sg_trainer_utils.py
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
def launch_tensorboard_process(checkpoints_dir_path: str, sleep_postpone: bool = True, port: int = None) -> Tuple[Process, int]:
    """
    launch_tensorboard_process - Default behavior is to scan all free ports from 6006-6016 and try using them
                                 unless port is defined by the user
        :param checkpoints_dir_path:
        :param sleep_postpone:
        :param port:
        :return: tuple of tb process, port
    """
    logdir_path = str(Path(checkpoints_dir_path).parent.absolute())
    tb_cmd = "tensorboard --logdir=" + logdir_path + " --bind_all"
    if port is not None:
        tb_ports = [port]
    else:
        tb_ports = range(6006, 6016)

    for tb_port in tb_ports:
        if not try_port(tb_port):
            continue
        else:
            print("Starting Tensor-Board process on port: " + str(tb_port))
            tensor_board_process = Process(target=os.system, args=([tb_cmd + " --port=" + str(tb_port)]))
            tensor_board_process.daemon = True
            tensor_board_process.start()

            # LET THE TENSORBOARD PROCESS START
            if sleep_postpone:
                time.sleep(3)
            return tensor_board_process, tb_port

    # RETURNING IRRELEVANT VALUES
    print("Failed to initialize Tensor-Board process on port: " + ", ".join(map(str, tb_ports)))
    return None, -1

log_main_training_params(multi_gpu, num_gpus, batch_size, batch_accumulate, train_dataset_length, train_dataloader_len, model, param_groups, max_train_batches=None)

Log training parameters

Source code in src/super_gradients/training/utils/sg_trainer_utils.py
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
def log_main_training_params(
    multi_gpu: MultiGPUMode,
    num_gpus: int,
    batch_size: int,
    batch_accumulate: int,
    train_dataset_length: int,
    train_dataloader_len: int,
    model: nn.Module,
    param_groups: List[Dict[str, Union[str, float, List[tuple]]]],
    max_train_batches: Optional[int] = None,
):
    """Log training parameters"""

    iterations_per_epoch = int(train_dataloader_len) if max_train_batches is None else max_train_batches
    gradients_updates_per_epoch = int(iterations_per_epoch / batch_accumulate)
    what_used_str = "len(train_loader)" if max_train_batches is None else "max_train_batches"

    msg = (
        "TRAINING PARAMETERS:\n"
        f"    - Mode:                         {multi_gpu.name if multi_gpu else 'Single GPU'}\n"
        f"    - Number of GPUs:               {num_gpus if 'cuda' in device_config.device else 0:<10} ({torch.cuda.device_count()} available on the machine)\n"
        f"    - Full dataset size:            {train_dataset_length:<10} (len(train_set))\n"
        f"    - Batch size per GPU:           {batch_size:<10} (batch_size)\n"
        f"    - Batch Accumulate:             {batch_accumulate:<10} (batch_accumulate)\n"
        f"    - Total batch size:             {num_gpus * batch_size:<10} (num_gpus * batch_size)\n"
        f"    - Effective Batch size:         {num_gpus * batch_size * batch_accumulate:<10} (num_gpus * batch_size * batch_accumulate)\n"
        f"    - Iterations per epoch:         {iterations_per_epoch:<10} ({what_used_str})\n"
        f"    - Gradient updates per epoch:   {gradients_updates_per_epoch:<10} ({what_used_str} / batch_accumulate)\n"
    )
    msg += get_lr_info(model, param_groups)

    logger.info(msg)

    if max_train_batches:
        logger.warning(f"max_train_batch is set to {max_train_batches}. This limits the number of iterations per epoch and gradient updates per epoch.")

log_uncaught_exceptions(logger)

Makes logger log uncaught exceptions

Parameters:

Name Type Description Default
logger

logging.Logger

required

Returns:

Type Description

None

Source code in src/super_gradients/training/utils/sg_trainer_utils.py
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
def log_uncaught_exceptions(logger):
    """
    Makes logger log uncaught exceptions
    :param logger: logging.Logger

    :return: None
    """

    def log_exceptook(excepthook: Callable) -> Callable:
        """Wrapping function that logs exceptions that are not KeyboardInterrupt"""

        def handle_exception(exc_type, exc_value, exc_traceback):
            if not issubclass(exc_type, KeyboardInterrupt):
                logger.error("Uncaught exception", exc_info=(exc_type, exc_value, exc_traceback))
            excepthook(exc_type, exc_value, exc_traceback)
            return

        return handle_exception

    sys.excepthook = log_exceptook(sys.excepthook)

parse_args(cfg, arg_names)

parse args from a config. unlike get_param(), in this case only parameters that appear in the config will override default params from the function's signature

Source code in src/super_gradients/training/utils/sg_trainer_utils.py
424
425
426
427
428
429
430
431
432
433
434
435
436
def parse_args(cfg, arg_names: Union[Sequence[str], callable]) -> dict:
    """
    parse args from a config.
    unlike get_param(), in this case only parameters that appear in the config will override default params from the function's signature
    """
    if not isinstance(arg_names, Sequence):
        arg_names = get_callable_param_names(arg_names)

    kwargs_dict = {}
    for arg_name in arg_names:
        if hasattr(cfg, arg_name) and getattr(cfg, arg_name) is not None:
            kwargs_dict[arg_name] = getattr(cfg, arg_name)
    return kwargs_dict

try_port(port)

try_port - Helper method for tensorboard port binding

Parameters:

Name Type Description Default
port required

Returns:

Type Description
Source code in src/super_gradients/training/utils/sg_trainer_utils.py
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
def try_port(port):
    """
    try_port - Helper method for tensorboard port binding
    :param port:
    :return:
    """
    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    is_port_available = False
    try:
        sock.bind(("localhost", port))
        is_port_available = True

    except Exception as ex:
        print("Port " + str(port) + " is in use" + str(ex))

    sock.close()
    return is_port_available

unpack_batch_items(batch_items)

Adds support for unpacking batch items in train/validation loop.

Parameters:

Name Type Description Default
batch_items Union[tuple, torch.Tensor]

(Union[tuple, torch.Tensor]) returned by the data loader, which is expected to be in one of the following formats: 1. torch.Tensor or tuple, s.t inputs = batch_items[0], targets = batch_items[1] and len(batch_items) = 2 2. tuple: (inputs, targets, additional_batch_items) where inputs are fed to the network, targets are their corresponding labels and additional_batch_items is a dictionary (format {additional_batch_item_i_name: additional_batch_item_i ...}) which can be accessed through the phase context under the attribute additional_batch_item_i_name, using a phase callback.

required

Returns:

Type Description

inputs, target, additional_batch_items

Source code in src/super_gradients/training/utils/sg_trainer_utils.py
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
def unpack_batch_items(batch_items: Union[tuple, torch.Tensor]):
    """
    Adds support for unpacking batch items in train/validation loop.

    :param batch_items: (Union[tuple, torch.Tensor]) returned by the data loader, which is expected to be in one of
         the following formats:
            1. torch.Tensor or tuple, s.t inputs = batch_items[0], targets = batch_items[1] and len(batch_items) = 2
            2. tuple: (inputs, targets, additional_batch_items)

         where inputs are fed to the network, targets are their corresponding labels and additional_batch_items is a
         dictionary (format {additional_batch_item_i_name: additional_batch_item_i ...}) which can be accessed through
         the phase context under the attribute additional_batch_item_i_name, using a phase callback.


    :return: inputs, target, additional_batch_items
    """
    additional_batch_items = {}
    if len(batch_items) == 2:
        inputs, target = batch_items

    elif len(batch_items) == 3:
        inputs, target, additional_batch_items = batch_items

    else:
        raise UnsupportedBatchItemsFormat(batch_items)

    return inputs, target, additional_batch_items

update_monitored_value(previous_monitored_value, new_value)

Update the given ValueToMonitor object (could be a loss or a metric) with the new value

Parameters:

Name Type Description Default
previous_monitored_value MonitoredValue

The stats about the value that is monitored throughout epochs.

required
new_value float

The value of the current epoch that will be used to update previous_monitored_value

required

Returns:

Type Description
MonitoredValue
Source code in src/super_gradients/training/utils/sg_trainer_utils.py
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
def update_monitored_value(previous_monitored_value: MonitoredValue, new_value: float) -> MonitoredValue:
    """Update the given ValueToMonitor object (could be a loss or a metric) with the new value

    :param previous_monitored_value: The stats about the value that is monitored throughout epochs.
    :param new_value: The value of the current epoch that will be used to update previous_monitored_value
    :return:
    """
    previous_value, previous_best_value = previous_monitored_value.current, previous_monitored_value.best
    name, greater_is_better = previous_monitored_value.name, previous_monitored_value.greater_is_better

    if previous_best_value is None:
        previous_best_value = previous_value
    elif greater_is_better:
        previous_best_value = max(previous_value, previous_best_value)
    else:
        previous_best_value = min(previous_value, previous_best_value)

    if previous_value is None:
        change_from_previous = None
        change_from_best = None
    else:
        change_from_previous = new_value - previous_value
        change_from_best = new_value - previous_best_value

    return MonitoredValue(
        name=name,
        current=new_value,
        previous=previous_value,
        best=previous_best_value,
        change_from_previous=change_from_previous,
        change_from_best=change_from_best,
        greater_is_better=greater_is_better,
    )

update_monitored_values_dict(monitored_values_dict, new_values_dict)

Update the given ValueToMonitor object (could be a loss or a metric) with the new value

Parameters:

Name Type Description Default
monitored_values_dict Dict[str, MonitoredValue]

Dict mapping value names to their stats throughout epochs.

required
new_values_dict Dict[str, float]

Dict mapping value names to their new (i.e. current epoch) value.

required

Returns:

Type Description
Dict[str, MonitoredValue]

Updated monitored_values_dict

Source code in src/super_gradients/training/utils/sg_trainer_utils.py
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
def update_monitored_values_dict(monitored_values_dict: Dict[str, MonitoredValue], new_values_dict: Dict[str, float]) -> Dict[str, MonitoredValue]:
    """Update the given ValueToMonitor object (could be a loss or a metric) with the new value

    :param monitored_values_dict: Dict mapping value names to their stats throughout epochs.
    :param new_values_dict: Dict mapping value names to their new (i.e. current epoch) value.
    :return: Updated monitored_values_dict
    """
    relevant_keys = set(new_values_dict.keys()).intersection(monitored_values_dict.keys())
    for monitored_value_name in relevant_keys:
        previous_value = monitored_values_dict[monitored_value_name]
        monitored_values_dict[monitored_value_name] = update_monitored_value(
            new_value=new_values_dict[monitored_value_name],
            previous_monitored_value=previous_value,
        )
    return monitored_values_dict

write_hpms(writer, hpmstructs=[], special_conf={})

Stores the training and dataset hyper params in the tensorboard file

Source code in src/super_gradients/training/utils/sg_trainer_utils.py
360
361
362
363
364
365
366
367
368
369
def write_hpms(writer, hpmstructs=[], special_conf={}):
    """Stores the training and dataset hyper params in the tensorboard file"""
    hpm_string = ""
    for hpm in hpmstructs:
        for key, val in hpm.__dict__.items():
            hpm_string += "{}: {}  \n  ".format(key, val)
    for key, val in special_conf.items():
        hpm_string += "{}: {}  \n  ".format(key, val)
    writer.add_text("Hyper_parameters", hpm_string)
    writer.flush()

write_training_results(writer, results_titles_list, results_values_list, epoch)

Stores the training and validation loss and accuracy for current epoch in a tensorboard file

Source code in src/super_gradients/training/utils/sg_trainer_utils.py
351
352
353
354
355
356
357
def write_training_results(writer, results_titles_list, results_values_list, epoch):
    """Stores the training and validation loss and accuracy for current epoch in a tensorboard file"""
    for res_key, res_val in zip(results_titles_list, results_values_list):
        # USE ONLY LOWER-CASE LETTERS AND REPLACE SPACES WITH '_' TO AVOID MANY TITLES FOR THE SAME KEY
        corrected_res_key = res_key.lower().replace(" ", "_")
        writer.add_scalar(corrected_res_key, res_val, epoch)
    writer.flush()

DefaultBoxes

Bases: object

Default Boxes, (aka: anchor boxes or priors boxes) used by SSD model

Source code in src/super_gradients/training/utils/ssd_utils.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
class DefaultBoxes(object):
    """
    Default Boxes, (aka: anchor boxes or priors boxes) used by SSD model
    """

    def __init__(self, fig_size: int, feat_size: List[int], scales: List[int], aspect_ratios: List[List[int]], scale_xy=0.1, scale_wh=0.2):
        """
        For each feature map i (each predicting level, grids) the anchors (a.k.a. default boxes) will be:
        [
            [s, s], [sqrt(s * s_next), sqrt(s * s_next)],
            [s * sqrt(alpha1), s / sqrt(alpha1)], [s / sqrt(alpha1), s * sqrt(alpha1)],
            ...
            [s * sqrt(alphaN), s / sqrt(alphaN)], [s / sqrt(alphaN), s * sqrt(alphaN)]
        ] / fig_size
        where:
            * s = scale[i] - this level's scale
            * s_next = scale[i + 1] - next level's scale
            * alpha1, ... alphaN - this level's alphas, e.g. [2, 3]
            * fig_size - input image resolution

        Because of division by image resolution, the anchors will be in image coordinates normalized to [0, 1]

        :param fig_size:        input image resolution
        :param feat_size:       resolution of all feature maps with predictions (grids)
        :param scales:          anchor sizes in pixels for each feature level;
                                one value per level will be used to generate anchors based on the formula above
        :param aspect_ratios:   lists of alpha values for each feature map
        :param scale_xy:        predicted boxes will be with a factor scale_xy
                                so will be multiplied by scale_xy during post-prediction processing;
                                e.g. scale 0.1 means that prediction will be 10 times bigger
                                (improves predictions quality)
        :param scale_wh:        same logic as in scale_xy, but for width and height.
        """
        self.feat_size = feat_size
        self.fig_size = fig_size

        self.scale_xy_ = scale_xy
        self.scale_wh_ = scale_wh
        # According to https://github.com/weiliu89/caffe
        # Calculation method slightly different from paper
        self.scales = scales
        self.aspect_ratios = aspect_ratios

        self.default_boxes = []
        self.num_anchors = []
        # size of feature and number of feature
        for idx, sfeat in enumerate(self.feat_size):

            sk1 = scales[idx]
            sk2 = scales[idx + 1]
            sk3 = sqrt(sk1 * sk2)
            all_sizes = [(sk1, sk1), (sk3, sk3)]

            for alpha in aspect_ratios[idx]:
                w, h = sk1 * sqrt(alpha), sk1 / sqrt(alpha)
                all_sizes.append((w, h))
                all_sizes.append((h, w))

            all_sizes = np.array(all_sizes) / fig_size
            self.num_anchors.append(len(all_sizes))
            for w, h in all_sizes:
                for i, j in itertools.product(range(sfeat), repeat=2):
                    cx, cy = (j + 0.5) / sfeat, (i + 0.5) / sfeat
                    self.default_boxes.append((cx, cy, w, h))

        self.dboxes = torch.tensor(self.default_boxes, dtype=torch.float)
        self.dboxes.clamp_(min=0, max=1)

        # For IoU calculation
        self.dboxes_xyxy = self.dboxes.clone()
        self.dboxes_xyxy[:, 0] = self.dboxes[:, 0] - 0.5 * self.dboxes[:, 2]
        self.dboxes_xyxy[:, 1] = self.dboxes[:, 1] - 0.5 * self.dboxes[:, 3]
        self.dboxes_xyxy[:, 2] = self.dboxes[:, 0] + 0.5 * self.dboxes[:, 2]
        self.dboxes_xyxy[:, 3] = self.dboxes[:, 1] + 0.5 * self.dboxes[:, 3]

    @property
    def scale_xy(self):
        return self.scale_xy_

    @property
    def scale_wh(self):
        return self.scale_wh_

    def __call__(self, order="xyxy"):
        if order == "xyxy":
            return self.dboxes_xyxy
        if order == "xywh":
            return self.dboxes

__init__(fig_size, feat_size, scales, aspect_ratios, scale_xy=0.1, scale_wh=0.2)

For each feature map i (each predicting level, grids) the anchors (a.k.a. default boxes) will be: [ [s, s], [sqrt(s * s_next), sqrt(s * s_next)], [s * sqrt(alpha1), s / sqrt(alpha1)], [s / sqrt(alpha1), s * sqrt(alpha1)], ... [s * sqrt(alphaN), s / sqrt(alphaN)], [s / sqrt(alphaN), s * sqrt(alphaN)] ] / fig_size where: * s = scale[i] - this level's scale * s_next = scale[i + 1] - next level's scale * alpha1, ... alphaN - this level's alphas, e.g. [2, 3] * fig_size - input image resolution

Because of division by image resolution, the anchors will be in image coordinates normalized to [0, 1]

Parameters:

Name Type Description Default
fig_size int

input image resolution

required
feat_size List[int]

resolution of all feature maps with predictions (grids)

required
scales List[int]

anchor sizes in pixels for each feature level; one value per level will be used to generate anchors based on the formula above

required
aspect_ratios List[List[int]]

lists of alpha values for each feature map

required
scale_xy

predicted boxes will be with a factor scale_xy so will be multiplied by scale_xy during post-prediction processing; e.g. scale 0.1 means that prediction will be 10 times bigger (improves predictions quality)

0.1
scale_wh

same logic as in scale_xy, but for width and height.

0.2
Source code in src/super_gradients/training/utils/ssd_utils.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
def __init__(self, fig_size: int, feat_size: List[int], scales: List[int], aspect_ratios: List[List[int]], scale_xy=0.1, scale_wh=0.2):
    """
    For each feature map i (each predicting level, grids) the anchors (a.k.a. default boxes) will be:
    [
        [s, s], [sqrt(s * s_next), sqrt(s * s_next)],
        [s * sqrt(alpha1), s / sqrt(alpha1)], [s / sqrt(alpha1), s * sqrt(alpha1)],
        ...
        [s * sqrt(alphaN), s / sqrt(alphaN)], [s / sqrt(alphaN), s * sqrt(alphaN)]
    ] / fig_size
    where:
        * s = scale[i] - this level's scale
        * s_next = scale[i + 1] - next level's scale
        * alpha1, ... alphaN - this level's alphas, e.g. [2, 3]
        * fig_size - input image resolution

    Because of division by image resolution, the anchors will be in image coordinates normalized to [0, 1]

    :param fig_size:        input image resolution
    :param feat_size:       resolution of all feature maps with predictions (grids)
    :param scales:          anchor sizes in pixels for each feature level;
                            one value per level will be used to generate anchors based on the formula above
    :param aspect_ratios:   lists of alpha values for each feature map
    :param scale_xy:        predicted boxes will be with a factor scale_xy
                            so will be multiplied by scale_xy during post-prediction processing;
                            e.g. scale 0.1 means that prediction will be 10 times bigger
                            (improves predictions quality)
    :param scale_wh:        same logic as in scale_xy, but for width and height.
    """
    self.feat_size = feat_size
    self.fig_size = fig_size

    self.scale_xy_ = scale_xy
    self.scale_wh_ = scale_wh
    # According to https://github.com/weiliu89/caffe
    # Calculation method slightly different from paper
    self.scales = scales
    self.aspect_ratios = aspect_ratios

    self.default_boxes = []
    self.num_anchors = []
    # size of feature and number of feature
    for idx, sfeat in enumerate(self.feat_size):

        sk1 = scales[idx]
        sk2 = scales[idx + 1]
        sk3 = sqrt(sk1 * sk2)
        all_sizes = [(sk1, sk1), (sk3, sk3)]

        for alpha in aspect_ratios[idx]:
            w, h = sk1 * sqrt(alpha), sk1 / sqrt(alpha)
            all_sizes.append((w, h))
            all_sizes.append((h, w))

        all_sizes = np.array(all_sizes) / fig_size
        self.num_anchors.append(len(all_sizes))
        for w, h in all_sizes:
            for i, j in itertools.product(range(sfeat), repeat=2):
                cx, cy = (j + 0.5) / sfeat, (i + 0.5) / sfeat
                self.default_boxes.append((cx, cy, w, h))

    self.dboxes = torch.tensor(self.default_boxes, dtype=torch.float)
    self.dboxes.clamp_(min=0, max=1)

    # For IoU calculation
    self.dboxes_xyxy = self.dboxes.clone()
    self.dboxes_xyxy[:, 0] = self.dboxes[:, 0] - 0.5 * self.dboxes[:, 2]
    self.dboxes_xyxy[:, 1] = self.dboxes[:, 1] - 0.5 * self.dboxes[:, 3]
    self.dboxes_xyxy[:, 2] = self.dboxes[:, 0] + 0.5 * self.dboxes[:, 2]
    self.dboxes_xyxy[:, 3] = self.dboxes[:, 1] + 0.5 * self.dboxes[:, 3]

SSDPostPredictCallback

Bases: DetectionPostPredictionCallback

post prediction callback module to convert and filter predictions coming from the SSD net to a format used by all other detection models

Source code in src/super_gradients/training/utils/ssd_utils.py
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
class SSDPostPredictCallback(DetectionPostPredictionCallback):
    """
    post prediction callback module to convert and filter predictions coming from the SSD net to a format
    used by all other detection models
    """

    def __init__(
        self,
        conf: float = 0.001,
        iou: float = 0.6,
        classes: list = None,
        max_predictions: int = 300,
        nms_type: NMS_Type = NMS_Type.ITERATIVE,
        multi_label_per_box=True,
    ):
        """
        Predictions of SSD contain unnormalized probabilities for a background class,
        together with confidences for all the dataset classes. Background will be utilized and discarded,
        so this callback will return 0-based classes without background
        :param conf: confidence threshold
        :param iou: IoU threshold
        :param classes: (optional list) filter by class
        :param nms_type: the type of nms to use (iterative or matrix)
        :param multi_label_per_box: controls whether to decode multiple labels per box.
                                    True - each anchor can produce multiple labels of different classes
                                           that pass confidence threshold check (default).
                                    False - each anchor can produce only one label of the class with the highest score.
        """
        super(SSDPostPredictCallback, self).__init__()
        self.conf = conf
        self.iou = iou
        self.nms_type = nms_type
        self.classes = classes
        self.max_predictions = max_predictions

        self.multi_label_per_box = multi_label_per_box

    def forward(self, predictions, device=None):
        nms_input = predictions[0]
        if self.nms_type == NMS_Type.ITERATIVE:
            nms_res = non_max_suppression(
                nms_input, conf_thres=self.conf, iou_thres=self.iou, multi_label_per_box=self.multi_label_per_box, with_confidence=True
            )
        else:
            nms_res = matrix_non_max_suppression(nms_input, conf_thres=self.conf, max_num_of_detections=self.max_predictions)

        return self._filter_max_predictions(nms_res)

    def _filter_max_predictions(self, res: List) -> List:
        res[:] = [im[: self.max_predictions] if (im is not None and im.shape[0] > self.max_predictions) else im for im in res]
        return res

__init__(conf=0.001, iou=0.6, classes=None, max_predictions=300, nms_type=NMS_Type.ITERATIVE, multi_label_per_box=True)

Predictions of SSD contain unnormalized probabilities for a background class, together with confidences for all the dataset classes. Background will be utilized and discarded, so this callback will return 0-based classes without background

Parameters:

Name Type Description Default
conf float

confidence threshold

0.001
iou float

IoU threshold

0.6
classes list

(optional list) filter by class

None
nms_type NMS_Type

the type of nms to use (iterative or matrix)

NMS_Type.ITERATIVE
multi_label_per_box

controls whether to decode multiple labels per box. True - each anchor can produce multiple labels of different classes that pass confidence threshold check (default). False - each anchor can produce only one label of the class with the highest score.

True
Source code in src/super_gradients/training/utils/ssd_utils.py
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
def __init__(
    self,
    conf: float = 0.001,
    iou: float = 0.6,
    classes: list = None,
    max_predictions: int = 300,
    nms_type: NMS_Type = NMS_Type.ITERATIVE,
    multi_label_per_box=True,
):
    """
    Predictions of SSD contain unnormalized probabilities for a background class,
    together with confidences for all the dataset classes. Background will be utilized and discarded,
    so this callback will return 0-based classes without background
    :param conf: confidence threshold
    :param iou: IoU threshold
    :param classes: (optional list) filter by class
    :param nms_type: the type of nms to use (iterative or matrix)
    :param multi_label_per_box: controls whether to decode multiple labels per box.
                                True - each anchor can produce multiple labels of different classes
                                       that pass confidence threshold check (default).
                                False - each anchor can produce only one label of the class with the highest score.
    """
    super(SSDPostPredictCallback, self).__init__()
    self.conf = conf
    self.iou = iou
    self.nms_type = nms_type
    self.classes = classes
    self.max_predictions = max_predictions

    self.multi_label_per_box = multi_label_per_box

AverageMeter

A class to calculate the average of a metric, for each batch during training/testing

Source code in src/super_gradients/training/utils/utils.py
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
class AverageMeter:
    """A class to calculate the average of a metric, for each batch
    during training/testing"""

    def __init__(self):
        self._sum = None
        self._count = 0

    def update(self, value: Union[float, tuple, list, torch.Tensor], batch_size: int):

        if not isinstance(value, torch.Tensor):
            value = torch.tensor(value)

        if self._sum is None:
            self._sum = value * batch_size
        else:
            self._sum += value * batch_size

        self._count += batch_size

    @property
    def average(self):
        if self._sum is None:
            return 0
        return ((self._sum / self._count).__float__()) if self._sum.dim() < 1 else tuple((self._sum / self._count).cpu().numpy())

HpmStruct

Source code in src/super_gradients/training/utils/utils.py
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
class HpmStruct:
    def __init__(self, **entries):
        self.__dict__.update(entries)
        self.schema = None

    def set_schema(self, schema: dict):
        self.schema = schema

    def override(self, **entries):
        recursive_override(self.__dict__, entries)

    def to_dict(self, include_schema=True) -> dict:
        """Convert this HpmStruct instance into a dict.
        :param include_schema: If True, also return the field "schema"
        :return: Dict representation of this HpmStruct instance.
        """
        out_dict = self.__dict__.copy()
        if not include_schema:
            out_dict.pop("schema")
        return out_dict

    def validate(self):
        """
        Validate the current dict values according to the provided schema
        :raises
            `AttributeError` if schema was not set
            `jsonschema.exceptions.ValidationError` if the instance is invalid
            `jsonschema.exceptions.SchemaError` if the schema itselfis invalid
        """
        if self.schema is None:
            raise AttributeError("schema was not set")
        else:
            validate(self.__dict__, self.schema)

to_dict(include_schema=True)

Convert this HpmStruct instance into a dict.

Parameters:

Name Type Description Default
include_schema

If True, also return the field "schema"

True

Returns:

Type Description
dict

Dict representation of this HpmStruct instance.

Source code in src/super_gradients/training/utils/utils.py
60
61
62
63
64
65
66
67
68
def to_dict(self, include_schema=True) -> dict:
    """Convert this HpmStruct instance into a dict.
    :param include_schema: If True, also return the field "schema"
    :return: Dict representation of this HpmStruct instance.
    """
    out_dict = self.__dict__.copy()
    if not include_schema:
        out_dict.pop("schema")
    return out_dict

validate()

Validate the current dict values according to the provided schema

Source code in src/super_gradients/training/utils/utils.py
70
71
72
73
74
75
76
77
78
79
80
81
def validate(self):
    """
    Validate the current dict values according to the provided schema
    :raises
        `AttributeError` if schema was not set
        `jsonschema.exceptions.ValidationError` if the instance is invalid
        `jsonschema.exceptions.SchemaError` if the schema itselfis invalid
    """
    if self.schema is None:
        raise AttributeError("schema was not set")
    else:
        validate(self.__dict__, self.schema)

Timer

A class to measure time handling both GPU & CPU processes Returns time in milliseconds

Source code in src/super_gradients/training/utils/utils.py
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
class Timer:
    """A class to measure time handling both GPU & CPU processes
    Returns time in milliseconds"""

    def __init__(self, device: str):
        """
        :param device: str
            'cpu'\'cuda'
        """
        self.on_gpu = device == "cuda"
        # On GPU time is measured using cuda.events
        if self.on_gpu:
            self.starter = torch.cuda.Event(enable_timing=True)
            self.ender = torch.cuda.Event(enable_timing=True)
        # On CPU time is measured using time
        else:
            self.starter, self.ender = 0, 0

    def start(self):
        if self.on_gpu:
            self.starter.record()
        else:
            self.starter = time.time()

    def stop(self):
        if self.on_gpu:
            self.ender.record()
            torch.cuda.synchronize()
            timer = self.starter.elapsed_time(self.ender)
        else:
            # Time measures in seconds -> convert to milliseconds
            timer = (time.time() - self.starter) * 1000

        # Return time in milliseconds
        return timer

__init__(device)

Parameters:

Name Type Description Default
device str

str 'cpu''cuda'

required
Source code in src/super_gradients/training/utils/utils.py
159
160
161
162
163
164
165
166
167
168
169
170
171
def __init__(self, device: str):
    """
    :param device: str
        'cpu'\'cuda'
    """
    self.on_gpu = device == "cuda"
    # On GPU time is measured using cuda.events
    if self.on_gpu:
        self.starter = torch.cuda.Event(enable_timing=True)
        self.ender = torch.cuda.Event(enable_timing=True)
    # On CPU time is measured using time
    else:
        self.starter, self.ender = 0, 0

arch_params_deprecated(func)

Since initialization of arch_params is deprecated and will be removed, this decorator will be used to wrap the init function of some models. It will unwrap the parameters of the function and will log a warning.

Source code in src/super_gradients/training/utils/utils.py
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
def arch_params_deprecated(func):
    """
    Since initialization of arch_params is deprecated and will be removed, this decorator will be used to wrap the _init_
    function of some models. It will unwrap the parameters of the function and will log a warning.
    """

    @wraps(func)
    def wrapper(*args, **kwargs):

        func_args = inspect.getfullargspec(func).args
        _args = []
        if "arch_params" in kwargs:
            _arch_params = kwargs.get("arch_params", kwargs).to_dict()
        elif len(args) > 1 and isinstance(args[1], HpmStruct):
            _arch_params = args[1].to_dict()  # when called from inheritance
            _args.append(args[0])
        elif len(args) > 0 and isinstance(args[0], HpmStruct):
            _arch_params = args[0].to_dict()
        else:
            return func(*args, **kwargs)

        _kwargs = dict()
        for param_name in func_args:
            if param_name in _arch_params:
                _kwargs[param_name] = _arch_params[param_name]
            if param_name in kwargs:
                _kwargs[param_name] = kwargs[param_name]

        logger.warning(
            f"The {func.__qualname__} received `arch_params` argument which is deprecated and will be removed in next versions. "
            f"Please change the signature of the __init__ method to take explicit list arguments instead: "
            f"{func.__qualname__}({', '.join(_kwargs.keys())})"
        )
        return func(*_args, **_kwargs)

    return wrapper

check_img_size_divisibility(img_size, stride=32)

Parameters:

Name Type Description Default
img_size int

Int, the size of the image (H or W).

required
stride int

Int, the number to check if img_size is divisible by.

32

Returns:

Type Description
Tuple[bool, Optional[Tuple[int, int]]]

(True, None) if img_size is divisble by stride, (False, Suggestions) if it's not. Note: Suggestions are the two closest numbers to img_size that are divisible by stride. For example if img_size=321, stride=32, it will return (False,(352, 320)).

Source code in src/super_gradients/training/utils/utils.py
565
566
567
568
569
570
571
572
573
574
575
576
577
def check_img_size_divisibility(img_size: int, stride: int = 32) -> Tuple[bool, Optional[Tuple[int, int]]]:
    """
    :param img_size: Int, the size of the image (H or W).
    :param stride: Int, the number to check if img_size is divisible by.
    :return: (True, None) if img_size is divisble by stride, (False, Suggestions) if it's not.
        Note: Suggestions are the two closest numbers to img_size that *are* divisible by stride.
        For example if img_size=321, stride=32, it will return (False,(352, 320)).
    """
    new_size = make_divisible(img_size, int(stride))
    if new_size != img_size:
        return False, (new_size, make_divisible(img_size, int(stride), ceil=False))
    else:
        return True, None

check_model_contains_quantized_modules(model)

Check if the model contains any quantized modules.

Parameters:

Name Type Description Default
model nn.Module

Model to check.

required

Returns:

Type Description
bool

True if the model contains any quantized modules, False otherwise.

Source code in src/super_gradients/training/utils/utils.py
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
def check_model_contains_quantized_modules(model: nn.Module) -> bool:
    """
    Check if the model contains any quantized modules.
    :param model: Model to check.
    :return: True if the model contains any quantized modules, False otherwise.
    """
    try:
        from pytorch_quantization.nn.modules._utils import QuantMixin
    except ImportError:
        # If pytorch_quantization is not installed then by definition the model cannot contain any quantized modules
        return False

    from super_gradients.training.utils.quantization.core import SGQuantMixin

    model = unwrap_model(model)
    for m in model.modules():
        if isinstance(m, (QuantMixin, SGQuantMixin)):
            return True

    return False

check_models_have_same_weights(model_1, model_2, skip_bn_stats=False)

Checks whether two networks have the same weights

Parameters:

Name Type Description Default
model_1 torch.nn.Module

Net to be checked

required
model_2 torch.nn.Module

Net to be checked

required
skip_bn_stats bool

bool, whether to skip batch normazliation related stats

False

Returns:

Type Description

True iff the two networks have the same weights

Source code in src/super_gradients/training/utils/utils.py
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
def check_models_have_same_weights(model_1: torch.nn.Module, model_2: torch.nn.Module, skip_bn_stats: bool = False):
    """
    Checks whether two networks have the same weights

    :param model_1: Net to be checked
    :param model_2: Net to be checked
    :param skip_bn_stats: bool, whether to skip batch normazliation related stats

    :return: True iff the two networks have the same weights
    """
    bn_stats_layer_names = ["running_var", "running_mean", "num_batches_tracked"]
    model_1, model_2 = model_1.to("cpu"), model_2.to("cpu")
    models_differ = 0
    for key_item_1, key_item_2 in zip(model_1.state_dict().items(), model_2.state_dict().items()):
        if torch.equal(key_item_1[1], key_item_2[1]):
            pass
        elif (
            skip_bn_stats
            and any([bn_lname in key_item_1[0]] for bn_lname in bn_stats_layer_names)
            and any([bn_lname in key_item_1[0]] for bn_lname in bn_stats_layer_names)
        ):
            pass
        else:
            models_differ += 1
            if key_item_1[0] == key_item_2[0]:
                print(f"Layer names match but layers have different weights for layers: {key_item_1[0]}")
    if models_differ == 0:
        return True
    else:
        return False

convert_to_tensor(array, dtype=None, device=None)

Converts numpy arrays and lists to Torch tensors before calculation losses

Parameters:

Name Type Description Default
array

torch.tensor / Numpy array / List

required
Source code in src/super_gradients/training/utils/utils.py
36
37
38
39
40
41
42
43
44
45
46
def convert_to_tensor(array, dtype=None, device=None):
    """Converts numpy arrays and lists to Torch tensors before calculation losses
    :param array: torch.tensor / Numpy array / List
    """
    if not torch.is_tensor(array):
        if isinstance(array, np.ndarray):
            return torch.from_numpy(array).to(device=device, dtype=dtype)
        else:
            return torch.tensor(array, device=device, dtype=dtype)
    else:
        return array.to(device=device, dtype=dtype)

download_and_untar_from_url(urls, dir='.')

Download a file from url and untar.

Parameters:

Name Type Description Default
urls List[str]

Url to download the file from.

required
dir Union[str, Path]

Destination directory.

'.'
Source code in src/super_gradients/training/utils/utils.py
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
def download_and_untar_from_url(urls: List[str], dir: Union[str, Path] = "."):
    """
    Download a file from url and untar.

    :param urls:    Url to download the file from.
    :param dir:     Destination directory.
    """
    dir = Path(dir)
    dir.mkdir(parents=True, exist_ok=True)

    for url in urls:
        url_path = Path(url)
        filepath = dir / url_path.name

        if url_path.is_file():
            url_path.rename(filepath)
        elif not filepath.exists():
            logger.info(f"Downloading {url} to {filepath}...")
            torch.hub.download_url_to_file(url, str(filepath), progress=True)

        modes = {".tar.gz": "r:gz", ".tar": "r:"}
        assert filepath.suffix in modes.keys(), f"{filepath} has {filepath.suffix} suffix which is not supported"

        logger.info(f"Extracting to {dir}...")
        safe_untar(filepath, dir)
        filepath.unlink()

download_and_unzip_from_url(url, dir='.', unzip=True, delete=True)

Downloads a zip file from url to dir, and unzips it.

Parameters:

Name Type Description Default
url

Url to download the file from.

required
dir

Destination directory.

'.'
unzip

Whether to unzip the downloaded file.

True
delete

Whether to delete the zip file. used to downlaod VOC. Source: https://github.com/ultralytics/yolov5/blob/master/data/VOC.yaml

True
Source code in src/super_gradients/training/utils/utils.py
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
def download_and_unzip_from_url(url, dir=".", unzip=True, delete=True):
    """
    Downloads a zip file from url to dir, and unzips it.

    :param url: Url to download the file from.
    :param dir: Destination directory.
    :param unzip: Whether to unzip the downloaded file.
    :param delete: Whether to delete the zip file.

    used to downlaod VOC.

    Source:
    https://github.com/ultralytics/yolov5/blob/master/data/VOC.yaml
    """

    def download_one(url, dir):
        # Download 1 file
        f = dir / Path(url).name  # filename
        if Path(url).is_file():  # exists in current path
            Path(url).rename(f)  # move to dir
        elif not f.exists():
            print(f"Downloading {url} to {f}...")
            torch.hub.download_url_to_file(url, f, progress=True)  # torch download
        if unzip and f.suffix in (".zip", ".gz"):
            print(f"Unzipping {f}...")
            if f.suffix == ".zip":
                ZipFile(f).extractall(path=dir)  # unzip
            elif f.suffix == ".gz":
                os.system(f"tar xfz {f} --directory {f.parent}")  # unzip
            if delete:
                f.unlink()  # remove zip

    dir = Path(dir)
    dir.mkdir(parents=True, exist_ok=True)  # make directory
    for u in [url] if isinstance(url, (str, Path)) else url:
        download_one(u, dir)

empty_list()

Instantiate an empty list. This is a workaround to generate a list with a function call in hydra, instead of the "[]".

Source code in src/super_gradients/training/utils/utils.py
31
32
33
def empty_list():
    """Instantiate an empty list. This is a workaround to generate a list with a function call in hydra, instead of the "[]"."""
    return list()

ensure_is_tuple_of_two(inputs)

Checks input and converts it to a tuple of length two. If input is None returns None.

Parameters:

Name Type Description Default
inputs Union[Any, Iterable[Any], None]

Input argument, either a number or a tuple of two numbers.

required

Returns:

Type Description
Union[Tuple[Any, Any], None]

Tuple of two numbers if input is not None, otherwise - None.

Source code in src/super_gradients/training/utils/utils.py
644
645
646
647
648
649
650
651
652
653
654
655
656
657
def ensure_is_tuple_of_two(inputs: Union[Any, Iterable[Any], None]) -> Union[Tuple[Any, Any], None]:
    """
    Checks input and converts it to a tuple of length two. If input is None returns None.
    :param inputs: Input argument, either a number or a tuple of two numbers.
    :return: Tuple of two numbers if input is not None, otherwise - None.
    """
    if inputs is None:
        return None

    if isinstance(inputs, typing.Iterable) and not isinstance(inputs, str):
        a, b = inputs
        return a, b

    return inputs, inputs

exif_size(image)

Get the size of image.

Parameters:

Name Type Description Default
image Image

The image to get size from

required

Returns:

Type Description
Tuple[int, int]

(height, width)

Source code in src/super_gradients/training/utils/utils.py
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
def exif_size(image: Image) -> Tuple[int, int]:
    """Get the size of image.
    :param image:   The image to get size from
    :return:        (height, width)
    """

    orientation_key = get_orientation_key()

    image_size = image.size
    try:
        exif_data = image._getexif()
        if exif_data is not None:
            rotation = dict(exif_data.items())[orientation_key]
            # ROTATION 270
            if rotation == 6:
                image_size = (image_size[1], image_size[0])
            # ROTATION 90
            elif rotation == 8:
                image_size = (image_size[1], image_size[0])
    except Exception as ex:
        print("Caught Exception trying to rotate: " + str(image) + str(ex))
    width, height = image_size
    return height, width

fuzzy_idx_in_list(name, lst)

Returns the index of name in lst, with non sensitivity to symbols, uppercase and lowercase.

Parameters:

Name Type Description Default
name str

str, the name to be searched in lst.

required
lst List[str]

List[str], the list as described above.

required

Returns:

Type Description
int

int, index of name in lst in the matter discussed above.

Source code in src/super_gradients/training/utils/utils.py
293
294
295
296
297
298
299
300
301
302
303
304
305
def fuzzy_idx_in_list(name: str, lst: List[str]) -> int:
    """
    Returns the index of name in lst, with non sensitivity to symbols, uppercase and lowercase.
    :param name: str, the name to be searched in lst.
    :param lst: List[str], the list as described above.
    :return: int, index of name in lst in the matter discussed above.
    """
    fuzzy_name = fuzzy_str(name)
    fuzzy_list = [fuzzy_str(x) for x in lst]
    if fuzzy_name in fuzzy_list:
        return fuzzy_list.index(fuzzy_name)
    else:
        raise IndexError(f"Value `{name}` not found in the list `{lst}`. Please check the spelling.")

fuzzy_keys(params)

Returns params.key() removing leading and trailing white space, lower-casing and dropping symbols.

Parameters:

Name Type Description Default
params Mapping

Mapping, the mapping containing the keys to be returned.

required

Returns:

Type Description
List[str]

List[str], list of keys as discussed above.

Source code in src/super_gradients/training/utils/utils.py
246
247
248
249
250
251
252
def fuzzy_keys(params: Mapping) -> List[str]:
    """
    Returns params.key() removing leading and trailing white space, lower-casing and dropping symbols.
    :param params: Mapping, the mapping containing the keys to be returned.
    :return: List[str], list of keys as discussed above.
    """
    return [fuzzy_str(str(s)) for s in params.keys()]

fuzzy_str(s)

Returns s removing leading and trailing white space, lower-casing and drops non word chars (except for '/')

Parameters:

Name Type Description Default
s str

str, string to apply the manipulation discussed above.

required

Returns:

Type Description

str, s after the manipulation discussed above.

Source code in src/super_gradients/training/utils/utils.py
255
256
257
258
259
260
261
def fuzzy_str(s: str):
    """
    Returns s removing leading and trailing white space, lower-casing and drops non word chars (except for '/')
    :param s: str, string to apply the manipulation discussed above.
    :return: str, s after the manipulation discussed above.
    """
    return re.sub(r"[^\w|\/]", "", s).replace("_", "").lower()

generate_batch(iterable, batch_size)

Batch data into tuples of length n. The last batch may be shorter.

Source code in src/super_gradients/training/utils/utils.py
633
634
635
636
637
638
639
640
641
def generate_batch(iterable: Iterable, batch_size: int) -> Iterable:
    """Batch data into tuples of length n. The last batch may be shorter."""
    it = iter(iterable)
    while True:
        batch = tuple(islice(it, batch_size))
        if batch:
            yield batch
        else:
            return

get_filename_suffix_by_framework(framework)

Return the file extension of framework.

Parameters:

Name Type Description Default
framework str

(str)

required

Returns:

Type Description

(str) the suffix for the specific framework

Source code in src/super_gradients/training/utils/utils.py
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
def get_filename_suffix_by_framework(framework: str):
    """
    Return the file extension of framework.

    :param framework: (str)
    :return: (str) the suffix for the specific framework
    """
    frameworks_dict = {
        "TENSORFLOW1": ".pb",
        "TENSORFLOW2": ".zip",
        "PYTORCH": ".pth",
        "ONNX": ".onnx",
        "TENSORRT": ".pkl",
        "OPENVINO": ".pkl",
        "TORCHSCRIPT": ".pth",
        "TVM": "",
        "KERAS": ".h5",
        "TFLITE": ".tflite",
    }

    if framework.upper() not in frameworks_dict.keys():
        raise ValueError(f"Unsupported framework: {framework}")

    return frameworks_dict[framework.upper()]

get_fuzzy_attr(params, name)

Returns attribute (same functionality as getattr), but non sensitive to symbols, uppercase and lowercase.

Parameters:

Name Type Description Default
params Any

Any, any object which wed looking for the attribute name in.

required
name str

str, the attribute of param to be returned.

required

Returns:

Type Description

Any, the attribute value or None when not fuzzy matching of the attribute is found

Source code in src/super_gradients/training/utils/utils.py
283
284
285
286
287
288
289
290
def get_fuzzy_attr(params: Any, name: str):
    """
    Returns attribute (same functionality as getattr), but non sensitive to symbols, uppercase and lowercase.
    :param params: Any, any object which wed looking for the attribute name in.
    :param name: str, the attribute of param to be returned.
    :return: Any, the attribute value or None when not fuzzy matching of the attribute is found
    """
    return getattr(params, _get_fuzzy_attr_map(params)[fuzzy_str(name)])

get_fuzzy_mapping_param(name, params)

Returns parameter value, with key=name with no sensitivity to lowercase, uppercase and symbols.

Parameters:

Name Type Description Default
name str

str, the key in params which is fuzzy-matched and retruned.

required
params Mapping

Mapping, the mapping containing param.

required

Returns:

Type Description
Source code in src/super_gradients/training/utils/utils.py
272
273
274
275
276
277
278
279
280
def get_fuzzy_mapping_param(name: str, params: Mapping):
    """
    Returns parameter value, with key=name with no sensitivity to lowercase, uppercase and symbols.
    :param name: str, the key in params which is fuzzy-matched and retruned.
    :param params: Mapping, the mapping containing param.
    :return:
    """
    fuzzy_params = {fuzzy_str(str(key)): params[key] for key in params.keys()}
    return fuzzy_params[fuzzy_str(name)]

get_image_size_from_path(img_path)

Get the image size of an image at a specific path

Source code in src/super_gradients/training/utils/utils.py
614
615
616
617
def get_image_size_from_path(img_path: str) -> Tuple[int, int]:
    """Get the image size of an image at a specific path"""
    with open(img_path, "rb") as f:
        return exif_size(Image.open(f))

get_orientation_key() cached

Get the orientation key according to PIL, which is useful to get the image size for instance

Returns:

Type Description
int

Orientation key according to PIL

Source code in src/super_gradients/training/utils/utils.py
580
581
582
583
584
585
586
@lru_cache(None)
def get_orientation_key() -> int:
    """Get the orientation key according to PIL, which is useful to get the image size for instance
    :return: Orientation key according to PIL"""
    for key, value in ExifTags.TAGS.items():
        if value == "Orientation":
            return key

get_param(params, name, default_val=None)

Retrieves a param from a parameter object/dict . If the parameter does not exist, will return default_val. In case the default_val is of type dictionary, and a value is found in the params - the function will return the default value dictionary with internal values overridden by the found value IMPORTANT: Not sensitive to lowercase, uppercase and symbols.

i.e. default_opt_params = {'lr':0.1, 'momentum':0.99, 'alpha':0.001} training_params = {'optimizer_params': {'lr':0.0001}, 'batch': 32 .... } get_param(training_params, name='OptimizerParams', default_val=default_opt_params) will return {'lr':0.0001, 'momentum':0.99, 'alpha':0.001}

Parameters:

Name Type Description Default
params

an object (typically HpmStruct) or a dict holding the params

required
name

name of the searched parameter

required
default_val

assumed to be the same type as the value searched in the params

None

Returns:

Type Description

the found value, or default if not found

Source code in src/super_gradients/training/utils/utils.py
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
def get_param(params, name, default_val=None):
    """
    Retrieves a param from a parameter object/dict . If the parameter does not exist, will return default_val.
    In case the default_val is of type dictionary, and a value is found in the params - the function
    will return the default value dictionary with internal values overridden by the found value
    IMPORTANT: Not sensitive to lowercase, uppercase and symbols.

    i.e.
    default_opt_params = {'lr':0.1, 'momentum':0.99, 'alpha':0.001}
    training_params = {'optimizer_params': {'lr':0.0001}, 'batch': 32 .... }
    get_param(training_params, name='OptimizerParams', default_val=default_opt_params)
    will return {'lr':0.0001, 'momentum':0.99, 'alpha':0.001}


    :param params:      an object (typically HpmStruct) or a dict holding the params
    :param name:        name of the searched parameter
    :param default_val: assumed to be the same type as the value searched in the params
    :return:            the found value, or default if not found
    """
    if isinstance(params, Mapping):
        if name in params:
            param_val = params[name]

        elif fuzzy_str(name) in fuzzy_keys(params):
            param_val = get_fuzzy_mapping_param(name, params)

        else:
            param_val = default_val
    elif hasattr(params, name):
        param_val = getattr(params, name)
    elif _has_fuzzy_attr(params, name):
        param_val = get_fuzzy_attr(params, name)
    else:
        param_val = default_val

    if isinstance(default_val, Mapping):
        return {**default_val, **param_val}
    else:
        return param_val

infer_model_device(model)

Get the device where the model's parameters are stored. This function returns device of the first parameter of the model, assuming there is no cross-device parameter movement inside the model.

Parameters:

Name Type Description Default
model nn.Module

Model to get the device from.

required

Returns:

Type Description
Optional[torch.device]

Device where the model's parameters are stored. The function may return None if the model has no parameters or buffers.

Source code in src/super_gradients/training/utils/utils.py
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
def infer_model_device(model: nn.Module) -> Optional[torch.device]:
    """
    Get the device where the model's parameters are stored.
    This function returns device of the first parameter of the model, assuming there is no
    cross-device parameter movement inside the model.
    :param model: Model to get the device from.
    :return: Device where the model's parameters are stored.
             The function may return None if the model has no parameters or buffers.
    """
    try:
        first_parameter = next(iter(model.parameters()))
        return first_parameter.device
    except StopIteration:
        try:
            first_buffer = next(iter(model.buffers()))
            return first_buffer.device
        except StopIteration:
            return None

infer_model_dtype(model)

Get the device where the model's parameters are stored. This function returns device of the first parameter of the model, assuming there is no cross-device parameter movement inside the model.

Parameters:

Name Type Description Default
model nn.Module

Model to get the device from.

required

Returns:

Type Description
Optional[torch.device]

Device where the model's parameters are stored. The function may return None if the model has no parameters or buffers.

Source code in src/super_gradients/training/utils/utils.py
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
def infer_model_dtype(model: nn.Module) -> Optional[torch.device]:
    """
    Get the device where the model's parameters are stored.
    This function returns device of the first parameter of the model, assuming there is no
    cross-device parameter movement inside the model.
    :param model: Model to get the device from.
    :return: Device where the model's parameters are stored.
             The function may return None if the model has no parameters or buffers.
    """
    try:
        first_parameter = next(iter(model.parameters()))
        return first_parameter.dtype
    except StopIteration:
        try:
            first_buffer = next(iter(model.buffers()))
            return first_buffer.dtype
        except StopIteration:
            return None

load_func(dotpath)

load function in module. function is right-most segment.

Used for passing functions (without calling them) in yaml files.

Parameters:

Name Type Description Default
dotpath str

path to module.

required

Returns:

Type Description

a python function

Source code in src/super_gradients/training/utils/utils.py
391
392
393
394
395
396
397
398
399
400
401
402
def load_func(dotpath: str):
    """
    load function in module.  function is right-most segment.

    Used for passing functions (without calling them) in yaml files.

    :param dotpath: path to module.
    :return: a python function
    """
    module_, func = dotpath.rsplit(".", maxsplit=1)
    m = import_module(module_)
    return getattr(m, func)

make_divisible(x, divisor, ceil=True)

Returns x evenly divisible by divisor. If ceil=True it will return the closest larger number to the original x, and ceil=False the closest smaller number.

Source code in src/super_gradients/training/utils/utils.py
554
555
556
557
558
559
560
561
562
def make_divisible(x: int, divisor: int, ceil: bool = True) -> int:
    """
    Returns x evenly divisible by divisor.
    If ceil=True it will return the closest larger number to the original x, and ceil=False the closest smaller number.
    """
    if ceil:
        return math.ceil(x / divisor) * divisor
    else:
        return math.floor(x / divisor) * divisor

move_state_dict_to_device(model_sd, device)

Moving model state dict tensors to target device (cuda or cpu)

Parameters:

Name Type Description Default
model_sd

model state dict

required
device

either cuda or cpu

required
Source code in src/super_gradients/training/utils/utils.py
365
366
367
368
369
370
371
372
373
def move_state_dict_to_device(model_sd, device):
    """
    Moving model state dict tensors to target device (cuda or cpu)
    :param model_sd: model state dict
    :param device: either cuda or cpu
    """
    for k, v in model_sd.items():
        model_sd[k] = v.to(device)
    return model_sd

override_default_params_without_nones(params, default_params)

Helper method for overriding default dictionary's entries excluding entries with None values.

Parameters:

Name Type Description Default
params Dict

dict, output dictionary which will take the defaults.

required
default_params Mapping

dict, dictionary for the defaults.

required

Returns:

Type Description
Dict

dict, params after manipulation,

Source code in src/super_gradients/training/utils/utils.py
620
621
622
623
624
625
626
627
628
629
630
def override_default_params_without_nones(params: Dict, default_params: Mapping) -> Dict:
    """
    Helper method for overriding default dictionary's entries excluding entries with None values.
    :param params: dict, output dictionary which will take the defaults.
    :param default_params: dict, dictionary for the defaults.
    :return: dict, params after manipulation,
    """
    for key, val in default_params.items():
        if key not in params.keys() or params[key] is None:
            params[key] = val
    return params

random_seed(is_ddp, device, seed)

Sets random seed of numpy, torch and random.

When using ddp a seed will be set for each process according to its local rank derived from the device number.

Parameters:

Name Type Description Default
is_ddp

bool, will set different random seed for each process when using ddp.

required
device

'cuda','cpu', 'cuda:'

required
seed

int, random seed to be set

required
Source code in src/super_gradients/training/utils/utils.py
376
377
378
379
380
381
382
383
384
385
386
387
388
def random_seed(is_ddp, device, seed):
    """
    Sets random seed of numpy, torch and random.

    When using ddp a seed will be set for each process according to its local rank derived from the device number.
    :param is_ddp: bool, will set different random seed for each process when using ddp.
    :param device: 'cuda','cpu', 'cuda:<device_number>'
    :param seed: int, random seed to be set
    """
    rank = 0 if not is_ddp else int(device.split(":")[1])
    torch.manual_seed(seed + rank)
    np.random.seed(seed + rank)
    random.seed(seed + rank)

resolve_torch_device(device)

Resolve the specified torch device. It accepts either a string or a torch.device object.

This function takes the provided device identifier and returns a corresponding torch.device object, which represents the device where a torch.Tensor will be allocated.

Parameters:

Name Type Description Default
device Union[str, torch.device]

A string or torch.device object representing the device (e.g., 'cpu', 'cuda', 'cuda:0').

required

Returns:

Type Description
torch.device

A torch.device object representing the resolved device. Example: >>> torch.cuda.set_device(5) >>> str(resolve_torch_device("cuda")) 'cuda:5'

Source code in src/super_gradients/training/utils/utils.py
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
def resolve_torch_device(device: Union[str, torch.device]) -> torch.device:
    """
    Resolve the specified torch device. It accepts either a string or a torch.device object.

    This function takes the provided device identifier and returns a corresponding torch.device object,
    which represents the device where a torch.Tensor will be allocated.

    :param device: A string or torch.device object representing the device (e.g., 'cpu', 'cuda', 'cuda:0').
    :return: A torch.device object representing the resolved device.

    Example:
        >>> torch.cuda.set_device(5)
        >>> str(resolve_torch_device("cuda"))
        'cuda:5'
    """
    return torch.zeros([], device=device).device

safe_untar(tar_file, extract_path)

Protect against Tar Slip vulnerability. Calling extractall to extract all files from a tar file without sanitization may result files outside destination directory to be overwritten, resulting in an arbitrary file write. CVE-2007-4559 https://nvd.nist.gov/vuln/detail/CVE-2007-4559

Source code in src/super_gradients/training/utils/utils.py
512
513
514
515
516
517
518
519
520
521
522
523
def safe_untar(tar_file, extract_path):
    """
    Protect against Tar Slip vulnerability.
    Calling extractall to extract all files from a tar file without sanitization
    may result files outside destination directory to be overwritten, resulting in an arbitrary file write.
    CVE-2007-4559 https://nvd.nist.gov/vuln/detail/CVE-2007-4559
    """
    with tarfile.TarFile(tar_file, "r") as tf:
        for member in tf:
            file_path = os.path.realpath(os.path.join(extract_path, member.name))
            if file_path.startswith(os.path.realpath(extract_path)):
                tf.extract(member, extract_path)

tensor_container_to_device(obj, device, non_blocking=True, detach=False)

Recursively send compounded objects to device (sending all tensors to device and maintaining structure)

Parameters:

Name Type Description Default
device str

device to send the tensors to

required
non_blocking

used for DistributedDataParallel

True
detach bool

detach the tensors from the graph

False
Source code in src/super_gradients/training/utils/utils.py
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
def tensor_container_to_device(obj: Union[torch.Tensor, tuple, list, dict], device: str, non_blocking=True, detach: bool = False):
    """
    Recursively send compounded objects to device (sending all tensors to device and maintaining structure)
    :param obj           the object to send to device (list / tuple / tensor / dict)
    :param device:       device to send the tensors to
    :param non_blocking: used for DistributedDataParallel
    :param detach:       detach the tensors from the graph
    :returns             an object with the same structure (tensors, lists, tuples) with the device pointers (like
                         the return value of Tensor.to(device)
    """
    if isinstance(obj, torch.Tensor):
        if detach:
            obj = obj.detach()
        return obj.to(device, non_blocking=non_blocking)
    elif isinstance(obj, tuple):
        return tuple(tensor_container_to_device(x, device, non_blocking=non_blocking, detach=detach) for x in obj)
    elif isinstance(obj, list):
        return [tensor_container_to_device(x, device, non_blocking=non_blocking, detach=detach) for x in obj]
    elif isinstance(obj, (dict, typing.Mapping)):
        return {k: tensor_container_to_device(v, device, non_blocking=non_blocking, detach=detach) for k, v in obj.items()}
    else:
        return obj

unwrap_model(model)

Get the real model from a model wrapper (DataParallel, DistributedDataParallel)

Parameters:

Name Type Description Default
model Union[nn.Module, nn.DataParallel, DistributedDataParallel] required

Returns:

Type Description
nn.Module
Source code in src/super_gradients/training/utils/utils.py
103
104
105
106
107
108
109
110
111
112
113
114
def unwrap_model(model: Union[nn.Module, nn.DataParallel, DistributedDataParallel]) -> nn.Module:
    """
    Get the real model from a model wrapper (DataParallel, DistributedDataParallel)

    :param model:
    :return:
    """
    if is_model_wrapped(model):
        return model.module
    elif isinstance(model, nn.Module):
        return model
    raise ValueError(f"Unknown model type: {type(model)}")

draw_label(image, label, confidence)

Draw a label and confidence on an image.

Parameters:

Name Type Description Default
image np.ndarray

The image on which to draw the label and confidence, in RGB format, and Channel Last (H, W, C)

required
label str

The label to draw.

required
confidence float

The confidence of the label.

required
Source code in src/super_gradients/training/utils/visualization/classification.py
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
def draw_label(image: np.ndarray, label: str, confidence: float) -> np.ndarray:
    """Draw a label and confidence on an image.
    :param image:       The image on which to draw the label and confidence, in RGB format, and Channel Last (H, W, C)
    :param label:       The label to draw.
    :param confidence:  The confidence of the label.
    """

    # Format confidence as a percentage
    confidence_str = f"{confidence * 100:.3f}%"

    # Use a slightly smaller font scale and a moderate thickness
    fontScale = 0.8
    thickness = 1

    # Define additional spacing between the two lines
    line_spacing = 5

    # Determine the size of the label and confidence text
    label_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, fontScale, thickness)[0]
    confidence_size = cv2.getTextSize(confidence_str, cv2.FONT_HERSHEY_SIMPLEX, fontScale, thickness)[0]

    # Determine the size of the bounding rectangle
    text_width = max(label_size[0], confidence_size[0])
    text_height = label_size[1] + confidence_size[1] + thickness * 3 + line_spacing

    # Calculate the position to draw the label, centered horizontally and at the top
    start_x = (image.shape[1] - text_width) // 2
    start_y = 5

    # Draw a filled rectangle with transparency as the background for the label
    overlay = image.copy()
    bg_color = (255, 255, 255)  # White
    bg_start = (start_x, start_y)
    bg_end = (start_x + text_width, start_y + text_height)
    cv2.rectangle(overlay, bg_start, bg_end, bg_color, thickness=-1)

    alpha = 0.6
    cv2.addWeighted(overlay, alpha, image, 1 - alpha, 0, image)

    # Center the label and confidence text within the bounding rectangle, with additional spacing
    text_color = (0, 0, 0)  # Black
    cv2.putText(
        image,
        label,
        (start_x + (text_width - label_size[0]) // 2, start_y + label_size[1]),
        cv2.FONT_HERSHEY_SIMPLEX,
        fontScale,
        text_color,
        thickness,
        lineType=cv2.LINE_AA,
    )
    cv2.putText(
        image,
        confidence_str,
        (start_x + (text_width - confidence_size[0]) // 2, start_y + label_size[1] + confidence_size[1] + thickness + line_spacing),
        cv2.FONT_HERSHEY_SIMPLEX,
        fontScale,
        text_color,
        thickness,
        lineType=cv2.LINE_AA,
    )

    return image

DepthVisualization

Source code in src/super_gradients/training/utils/visualization/depth_estimation.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
class DepthVisualization:
    @staticmethod
    def process_depth_map_for_visualization(
        depth_map: np.ndarray,
        color_scheme: Optional[int] = None,
        drop_extreme_percentage: float = 0,
        inverse: bool = False,
        ignored_val=None,
    ) -> np.ndarray:
        """
        Process a depth map for visualization.

        This method enhances the visual representation of a depth map by:
        1. Clipping extreme values based on the specified percentage.
        2. Normalizing the depth map to the 0-255 range.
        3. Optionally inverting the depth map (inversion is done as 1 / depth).
        4. Applying a color map using OpenCV's applyColorMap.

        :param depth_map:               Input depth map as a NumPy array.
        :param color_scheme:            OpenCV color scheme for the depth map visualization. If not specified:
                                        - If `inverse=True`, the default is COLORMAP_VIRIDIS.
                                        - If `inverse=False`, the default is COLORMAP_MAGMA.
        :param drop_extreme_percentage: Percentage of extreme values to drop.
        :param inverse:                 Apply inversion (1 / depth) if True.

        :return:                        Processed colormap of the depth map for visualization.
        """
        if ignored_val is not None:
            ignored_mask = depth_map != ignored_val

        if inverse:
            depth_map = 1 / depth_map

        # Drop extreme values
        if drop_extreme_percentage > 0:
            if ignored_val is not None:
                min_val = np.percentile(depth_map[ignored_mask], drop_extreme_percentage[ignored_mask])
                max_val = np.percentile(depth_map[ignored_mask], 100 - drop_extreme_percentage[ignored_mask])
            else:
                min_val = np.percentile(depth_map, drop_extreme_percentage)
                max_val = np.percentile(depth_map, 100 - drop_extreme_percentage)

            depth_map = np.clip(depth_map, min_val, max_val)
        else:
            if ignored_val is not None:
                min_val = depth_map[ignored_mask].min()
                max_val = depth_map[ignored_mask].max()
            else:
                min_val = depth_map.min()
                max_val = depth_map.max()

        # Normalize to 0-255
        depth_map = ((depth_map - min_val) / (max_val - min_val) * 255).astype(np.uint8)

        # Determine the default color scheme
        default_color_scheme = cv2.COLORMAP_VIRIDIS if inverse else cv2.COLORMAP_MAGMA

        # Apply colormap
        colormap = cv2.applyColorMap(depth_map, color_scheme if color_scheme is not None else default_color_scheme)

        if ignored_val is not None:
            colormap[~ignored_mask] = (127, 127, 127)

        # Convert BGR to RGB
        colormap_rgb = cv2.cvtColor(colormap, cv2.COLOR_BGR2RGB)

        return colormap_rgb

process_depth_map_for_visualization(depth_map, color_scheme=None, drop_extreme_percentage=0, inverse=False, ignored_val=None) staticmethod

Process a depth map for visualization.

This method enhances the visual representation of a depth map by: 1. Clipping extreme values based on the specified percentage. 2. Normalizing the depth map to the 0-255 range. 3. Optionally inverting the depth map (inversion is done as 1 / depth). 4. Applying a color map using OpenCV's applyColorMap.

Parameters:

Name Type Description Default
depth_map np.ndarray

Input depth map as a NumPy array.

required
color_scheme Optional[int]

OpenCV color scheme for the depth map visualization. If not specified: - If inverse=True, the default is COLORMAP_VIRIDIS. - If inverse=False, the default is COLORMAP_MAGMA.

None
drop_extreme_percentage float

Percentage of extreme values to drop.

0
inverse bool

Apply inversion (1 / depth) if True.

False

Returns:

Type Description
np.ndarray

Processed colormap of the depth map for visualization.

Source code in src/super_gradients/training/utils/visualization/depth_estimation.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
@staticmethod
def process_depth_map_for_visualization(
    depth_map: np.ndarray,
    color_scheme: Optional[int] = None,
    drop_extreme_percentage: float = 0,
    inverse: bool = False,
    ignored_val=None,
) -> np.ndarray:
    """
    Process a depth map for visualization.

    This method enhances the visual representation of a depth map by:
    1. Clipping extreme values based on the specified percentage.
    2. Normalizing the depth map to the 0-255 range.
    3. Optionally inverting the depth map (inversion is done as 1 / depth).
    4. Applying a color map using OpenCV's applyColorMap.

    :param depth_map:               Input depth map as a NumPy array.
    :param color_scheme:            OpenCV color scheme for the depth map visualization. If not specified:
                                    - If `inverse=True`, the default is COLORMAP_VIRIDIS.
                                    - If `inverse=False`, the default is COLORMAP_MAGMA.
    :param drop_extreme_percentage: Percentage of extreme values to drop.
    :param inverse:                 Apply inversion (1 / depth) if True.

    :return:                        Processed colormap of the depth map for visualization.
    """
    if ignored_val is not None:
        ignored_mask = depth_map != ignored_val

    if inverse:
        depth_map = 1 / depth_map

    # Drop extreme values
    if drop_extreme_percentage > 0:
        if ignored_val is not None:
            min_val = np.percentile(depth_map[ignored_mask], drop_extreme_percentage[ignored_mask])
            max_val = np.percentile(depth_map[ignored_mask], 100 - drop_extreme_percentage[ignored_mask])
        else:
            min_val = np.percentile(depth_map, drop_extreme_percentage)
            max_val = np.percentile(depth_map, 100 - drop_extreme_percentage)

        depth_map = np.clip(depth_map, min_val, max_val)
    else:
        if ignored_val is not None:
            min_val = depth_map[ignored_mask].min()
            max_val = depth_map[ignored_mask].max()
        else:
            min_val = depth_map.min()
            max_val = depth_map.max()

    # Normalize to 0-255
    depth_map = ((depth_map - min_val) / (max_val - min_val) * 255).astype(np.uint8)

    # Determine the default color scheme
    default_color_scheme = cv2.COLORMAP_VIRIDIS if inverse else cv2.COLORMAP_MAGMA

    # Apply colormap
    colormap = cv2.applyColorMap(depth_map, color_scheme if color_scheme is not None else default_color_scheme)

    if ignored_val is not None:
        colormap[~ignored_mask] = (127, 127, 127)

    # Convert BGR to RGB
    colormap_rgb = cv2.cvtColor(colormap, cv2.COLOR_BGR2RGB)

    return colormap_rgb

draw_bbox(image, title, color, box_thickness, x1, y1, x2, y2)

Draw a bounding box on an image.

Parameters:

Name Type Description Default
image np.ndarray

Image on which to draw the bounding box.

required
color Tuple[int, int, int]

RGB values of the color of the bounding box.

required
title Optional[str]

Title to display inside the bounding box.

required
box_thickness Optional[int]

Thickness of the bounding box border.

required
x1 int

x-coordinate of the top-left corner of the bounding box.

required
y1 int

y-coordinate of the top-left corner of the bounding box.

required
x2 int

x-coordinate of the bottom-right corner of the bounding box.

required
y2 int

y-coordinate of the bottom-right corner of the bounding box.

required
Source code in src/super_gradients/training/utils/visualization/detection.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
def draw_bbox(
    image: np.ndarray,
    title: Optional[str],
    color: Tuple[int, int, int],
    box_thickness: Optional[int],
    x1: int,
    y1: int,
    x2: int,
    y2: int,
) -> np.ndarray:
    """Draw a bounding box on an image.

    :param image:           Image on which to draw the bounding box.
    :param color:           RGB values of the color of the bounding box.
    :param title:           Title to display inside the bounding box.
    :param box_thickness:   Thickness of the bounding box border.
    :param x1:              x-coordinate of the top-left corner of the bounding box.
    :param y1:              y-coordinate of the top-left corner of the bounding box.
    :param x2:              x-coordinate of the bottom-right corner of the bounding box.
    :param y2:              y-coordinate of the bottom-right corner of the bounding box.
    """

    if box_thickness is None:
        box_thickness = get_recommended_box_thickness(x1=x1, y1=y1, x2=x2, y2=y2)

    # Draw bbox
    overlay = image.copy()
    overlay = cv2.rectangle(overlay, (x1, y1), (x2, y2), color, box_thickness)

    if title is not None or title != "":
        # Adapt font size to image shape.
        # This is required because small images require small font size, but this makes the title look bad,
        # so when possible we increase the font size to a more appropriate value.

        font_size = get_recommended_text_size(x1=x1, y1=y1, x2=x2, y2=y2)
        overlay = draw_text_box(image=overlay, text=title, x=x1, y=y1, font=2, font_size=font_size, background_color=color, thickness=1)

    return cv2.addWeighted(overlay, 0.75, image, 0.25, 0)

Get a nice box thickness for a given bounding box.

Source code in src/super_gradients/training/utils/visualization/detection.py
48
49
50
51
52
53
54
55
56
57
58
59
def get_recommended_box_thickness(x1: int, y1: int, x2: int, y2: int) -> int:
    """Get a nice box thickness for a given bounding box."""
    bbox_width = x2 - x1
    bbox_height = y2 - y1
    diag_length = np.sqrt(bbox_width**2 + bbox_height**2)

    if diag_length <= 100:
        return 1
    elif diag_length <= 200:
        return 2
    else:
        return 3

Get a nice text size for a given bounding box.

Source code in src/super_gradients/training/utils/visualization/detection.py
62
63
64
65
66
67
68
69
70
71
72
73
74
75
def get_recommended_text_size(x1: int, y1: int, x2: int, y2: int) -> float:
    """Get a nice text size for a given bounding box."""
    bbox_width = x2 - x1
    bbox_height = y2 - y1
    diag_length = np.sqrt(bbox_width**2 + bbox_height**2)

    # This follows the heuristic (defined after some visual experiments):
    # - diag_length=100 -> base_font_size=0.4 (min text size)
    # - diag_length=300 -> base_font_size=0.7 (max text size)
    font_size = diag_length * 0.0015 + 0.25
    font_size = max(0.4, font_size)  # Min = 0.4
    font_size = min(0.7, font_size)  # Max = 0.7

    return font_size

LabelInfo dataclass

Hold information about labels.

:attr name: Label name. :attr color: Color of the label. :attr text_size: Size of the label text.

Source code in src/super_gradients/training/utils/visualization/legend.py
14
15
16
17
18
19
20
21
22
23
24
25
@dataclass
class LabelInfo:
    """Hold information about labels.

    :attr name: Label name.
    :attr color: Color of the label.
    :attr text_size: Size of the label text.
    """

    name: str
    color: Tuple[int, int, int]
    text_size: Tuple[int, int]

Row dataclass

Represent a row of labels.

Source code in src/super_gradients/training/utils/visualization/legend.py
28
29
30
31
32
33
@dataclass
class Row:
    """Represent a row of labels."""

    labels: List[LabelInfo]
    total_width: int

add_to_row_or_create_new(rows, label, image_width)

Adds a label to a row or creates a new row if the current one is full.

Parameters:

Name Type Description Default
rows List[Row]

Existing rows of labels.

required
label LabelInfo

Label to add.

required
image_width int

Width of the image.

required

Returns:

Type Description
List[Row]

Updated rows of labels.

Source code in src/super_gradients/training/utils/visualization/legend.py
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
def add_to_row_or_create_new(rows: List[Row], label: LabelInfo, image_width: int) -> List[Row]:
    """Adds a label to a row or creates a new row if the current one is full.

    :param rows: Existing rows of labels.
    :param label: Label to add.
    :param image_width: Width of the image.
    :return: Updated rows of labels.
    """
    if not rows or rows[-1].total_width + label.text_size[0] + 2 * MARGIN_SPACE > image_width:
        # create a new row and initialize total width
        rows.append(Row([label], label.text_size[0] + 2 * MARGIN_SPACE))
    else:
        # append label to existing row and add to total width
        rows[-1].labels.append(label)
        rows[-1].total_width += label.text_size[0] + MARGIN_SPACE
    return rows

draw_label_on_canvas(canvas, label, position, font_size)

Draws a label on the canvas.

Parameters:

Name Type Description Default
canvas np.ndarray

The canvas to draw on.

required
label LabelInfo

The label to draw.

required
position Tuple[int, int]

Position to draw the label.

required
font_size int

Font size of the label.

required

Returns:

Type Description
Tuple[np.ndarray, int]

The updated canvas and horizontal position for next label.

Source code in src/super_gradients/training/utils/visualization/legend.py
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
def draw_label_on_canvas(canvas: np.ndarray, label: LabelInfo, position: Tuple[int, int], font_size: int) -> Tuple[np.ndarray, int]:
    """Draws a label on the canvas.

    :param canvas: The canvas to draw on.
    :param label: The label to draw.
    :param position: Position to draw the label.
    :param font_size: Font size of the label.
    :return: The updated canvas and horizontal position for next label.
    """
    upper_left = (position[0] - MARGIN_SPACE // 2, position[1] - label.text_size[1] - MARGIN_SPACE // 2)
    lower_right = (position[0] + label.text_size[0] + MARGIN_SPACE // 2, position[1] + MARGIN_SPACE // 2)
    canvas = cv2.rectangle(canvas, upper_left, lower_right, label.color, -1)
    canvas = cv2.putText(canvas, label.name, position, FONT, font_size, best_text_color(label.color), LINE_TYPE, lineType=cv2.LINE_AA)
    return canvas, position[0] + label.text_size[0] + MARGIN_SPACE

draw_legend_on_canvas(image, class_color_tuples)

Draws a legend on the canvas.

Parameters:

Name Type Description Default
image np.ndarray

The image to draw the legend on.

required
class_color_tuples Iterable[Tuple[str, Tuple[int, int, int]]]

Iterable of tuples containing class name and its color.

required

Returns:

Type Description
np.ndarray

The canvas with the legend drawnOops, it seems like the response got cut off.

Source code in src/super_gradients/training/utils/visualization/legend.py
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
def draw_legend_on_canvas(image: np.ndarray, class_color_tuples: Iterable[Tuple[str, Tuple[int, int, int]]]) -> np.ndarray:
    """Draws a legend on the canvas.

    :param image: The image to draw the legend on.
    :param class_color_tuples: Iterable of tuples containing class name and its color.
    :return: The canvas with the legend drawnOops, it seems like the response got cut off.
    """
    sorted_labels = get_sorted_labels(class_color_tuples)
    label_rows = get_label_rows(sorted_labels, image.shape[1])

    canvas_height = (sorted_labels[0].text_size[1] + MARGIN_SPACE) * len(label_rows)
    canvas = np.ones((canvas_height, image.shape[1], 3), dtype=np.uint8) * 255

    vertical_position = sorted_labels[0].text_size[1] + MARGIN_SPACE // 2

    for row in label_rows:
        horizontal_position = MARGIN_SPACE
        for label in row.labels:
            canvas, horizontal_position = draw_label_on_canvas(canvas, label, (horizontal_position, vertical_position), INITIAL_FONT_SIZE)
        vertical_position += sorted_labels[0].text_size[1] + MARGIN_SPACE

    return canvas

get_label_info(name, color)

Creates a LabelInfo object for a given name and color.

Parameters:

Name Type Description Default
name str

Label name.

required
color Tuple[int, int, int]

Label color.

required

Returns:

Type Description
LabelInfo

An object of LabelInfo.

Source code in src/super_gradients/training/utils/visualization/legend.py
45
46
47
48
49
50
51
52
def get_label_info(name: str, color: Tuple[int, int, int]) -> LabelInfo:
    """Creates a LabelInfo object for a given name and color.

    :param name: Label name.
    :param color: Label color.
    :return: An object of LabelInfo.
    """
    return LabelInfo(name, color, get_text_size(name))

get_label_rows(labels, image_width)

Arranges labels in rows to fit into the image.

Parameters:

Name Type Description Default
labels List[LabelInfo]

List of labels.

required
image_width int

Width of the image.

required

Returns:

Type Description
List[Row]

List of label rows.

Source code in src/super_gradients/training/utils/visualization/legend.py
83
84
85
86
87
88
89
90
91
92
93
def get_label_rows(labels: List[LabelInfo], image_width: int) -> List[Row]:
    """Arranges labels in rows to fit into the image.

    :param labels: List of labels.
    :param image_width: Width of the image.
    :return: List of label rows.
    """
    rows = []
    for label in labels:
        rows = add_to_row_or_create_new(rows, label, image_width)
    return rows

get_sorted_labels(class_color_tuples)

Sorts and creates LabelInfo for class-color tuples.

Parameters:

Name Type Description Default
class_color_tuples Sequence[Tuple[str, Tuple[int, int, int]]]

Tuples of class names and associated colors.

required

Returns:

Type Description
List[LabelInfo]

A sorted list of LabelInfo objects.

Source code in src/super_gradients/training/utils/visualization/legend.py
73
74
75
76
77
78
79
80
def get_sorted_labels(class_color_tuples: Sequence[Tuple[str, Tuple[int, int, int]]]) -> List[LabelInfo]:
    """Sorts and creates LabelInfo for class-color tuples.

    :param class_color_tuples: Tuples of class names and associated colors.
    :return: A sorted list of LabelInfo objects.
    """
    sorted_classes = sorted(class_color_tuples, key=lambda x: x[0])
    return [get_label_info(name, color) for name, color in sorted_classes]

get_text_size(text)

Calculate the size of a given text using the CV2 getTextSize function.

Parameters:

Name Type Description Default
text str

Input text.

required

Returns:

Type Description
Tuple[int, int]

A tuple of width and height of the text box.

Source code in src/super_gradients/training/utils/visualization/legend.py
36
37
38
39
40
41
42
def get_text_size(text: str) -> Tuple[int, int]:
    """Calculate the size of a given text using the CV2 getTextSize function.

    :param text: Input text.
    :return: A tuple of width and height of the text box.
    """
    return cv2.getTextSize(text, FONT, INITIAL_FONT_SIZE, LINE_TYPE)[0]

PoseVisualization

Source code in src/super_gradients/training/utils/visualization/pose_estimation.py
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
class PoseVisualization:
    @classmethod
    def draw_poses(
        self,
        *,
        image: np.ndarray,
        poses: np.ndarray,
        boxes: Optional[np.ndarray],
        scores: Optional[np.ndarray],
        is_crowd: Optional[np.ndarray],
        edge_links: Union[np.ndarray, List[Tuple[int, int]]],
        edge_colors: Union[None, np.ndarray, List[Tuple[int, int, int]]],
        keypoint_colors: Union[None, np.ndarray, List[Tuple[int, int, int]]],
        show_keypoint_confidence: bool = False,
        joint_thickness: Optional[int] = None,
        box_thickness: Optional[int] = None,
        keypoint_radius: Optional[int] = None,
        keypoint_confidence_threshold: float = 0.5,
    ):
        """
        Draw multiple poses on an image.
        :param image: Image on which to draw the poses. This image will not be modified, instead a new image will be returned.
        :param poses: Predicted poses. Shape [Num Poses, Num Joints, 2] or [Num Poses, Num Joints, 3] if confidence scores are available.
        :param boxes: Optional bounding boxes for each pose. Shape [Num Poses, 4] in XYXY format.
        :param scores: Optional confidence scores for each pose. Shape [Num Poses]
        :param is_crowd: Optional array of booleans indicating whether each pose is crowd or not. Shape [Num Poses]
        :param edge_links: Array of [Num Links, 2] containing the links between joints to draw.
        :param edge_colors: Array of shape [Num Links, 3] or list of tuples containing the (r,g,b) colors for each joint link.
        :param keypoint_colors: Array of shape [Num Joints, 3] or list of tuples containing the (r,g,b) colors for each keypoint.
        :param show_keypoint_confidence: Whether to show the confidence score for each keypoint individually.
        :param keypoint_confidence_threshold: A minimal confidence score for individual keypoint to be drawn.
        :param joint_thickness: (Optional) Thickness of the joint links
        :return: A new image with the poses drawn on it.
        """
        if boxes is not None and len(boxes) != len(poses):
            raise ValueError("boxes and poses must have the same length")
        if scores is not None and len(scores) != len(poses):
            raise ValueError("conf and poses must have the same length")
        if is_crowd is not None and len(is_crowd) != len(poses):
            raise ValueError("is_crowd and poses must have the same length")

        # For visualization purposes, sort poses by confidence starting from the least confident
        if scores is not None:
            order = np.argsort(scores)
            poses = poses[order]
            scores = scores[order]
            if boxes is not None:
                boxes = boxes[order]
            if is_crowd is not None:
                is_crowd = is_crowd[order]

        res_image = image.copy()
        num_poses = len(poses)

        for pose_index in range(num_poses):

            if boxes is not None:
                x1 = int(boxes[pose_index][0])
                y1 = int(boxes[pose_index][1])
                x2 = int(boxes[pose_index][2])
                y2 = int(boxes[pose_index][3])

                current_box_thickness = box_thickness or get_recommended_box_thickness(x1, y1, x2, y2)
                current_joint_thickness = joint_thickness or current_box_thickness
                current_keypoint_radius = keypoint_radius or math.ceil(current_box_thickness * 3 / 2)
            else:
                current_joint_thickness = 2
                current_keypoint_radius = 3
                current_box_thickness = 2

            res_image = draw_skeleton(
                image=res_image,
                keypoints=poses[pose_index],
                score=scores[pose_index] if scores is not None else None,
                edge_links=edge_links,
                edge_colors=edge_colors,
                joint_thickness=current_joint_thickness,
                keypoint_colors=keypoint_colors,
                keypoint_radius=current_keypoint_radius,
                show_confidence=scores is not None and boxes is None,
                show_keypoint_confidence=show_keypoint_confidence,
                box_thickness=current_box_thickness,
                keypoint_confidence_threshold=keypoint_confidence_threshold,
            )

            if boxes is not None:
                x1 = int(boxes[pose_index][0])
                y1 = int(boxes[pose_index][1])
                x2 = int(boxes[pose_index][2])
                y2 = int(boxes[pose_index][3])
                title = ""
                if scores is not None:
                    title += f"{scores[pose_index]:.2f}"
                if is_crowd is not None:
                    title += f"Crowd {is_crowd[pose_index]}"

                res_image = draw_bbox(
                    image=res_image,
                    x1=x1,
                    y1=y1,
                    x2=x2,
                    y2=y2,
                    color=(255, 255, 255),
                    title=title,
                    box_thickness=current_box_thickness,
                )

        return res_image

draw_poses(*, image, poses, boxes, scores, is_crowd, edge_links, edge_colors, keypoint_colors, show_keypoint_confidence=False, joint_thickness=None, box_thickness=None, keypoint_radius=None, keypoint_confidence_threshold=0.5) classmethod

Draw multiple poses on an image.

Parameters:

Name Type Description Default
image np.ndarray

Image on which to draw the poses. This image will not be modified, instead a new image will be returned.

required
poses np.ndarray

Predicted poses. Shape [Num Poses, Num Joints, 2] or [Num Poses, Num Joints, 3] if confidence scores are available.

required
boxes Optional[np.ndarray]

Optional bounding boxes for each pose. Shape [Num Poses, 4] in XYXY format.

required
scores Optional[np.ndarray]

Optional confidence scores for each pose. Shape [Num Poses]

required
is_crowd Optional[np.ndarray]

Optional array of booleans indicating whether each pose is crowd or not. Shape [Num Poses]

required
edge_links Union[np.ndarray, List[Tuple[int, int]]]

Array of [Num Links, 2] containing the links between joints to draw.

required
edge_colors Union[None, np.ndarray, List[Tuple[int, int, int]]]

Array of shape [Num Links, 3] or list of tuples containing the (r,g,b) colors for each joint link.

required
keypoint_colors Union[None, np.ndarray, List[Tuple[int, int, int]]]

Array of shape [Num Joints, 3] or list of tuples containing the (r,g,b) colors for each keypoint.

required
show_keypoint_confidence bool

Whether to show the confidence score for each keypoint individually.

False
keypoint_confidence_threshold float

A minimal confidence score for individual keypoint to be drawn.

0.5
joint_thickness Optional[int]

(Optional) Thickness of the joint links

None

Returns:

Type Description

A new image with the poses drawn on it.

Source code in src/super_gradients/training/utils/visualization/pose_estimation.py
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
@classmethod
def draw_poses(
    self,
    *,
    image: np.ndarray,
    poses: np.ndarray,
    boxes: Optional[np.ndarray],
    scores: Optional[np.ndarray],
    is_crowd: Optional[np.ndarray],
    edge_links: Union[np.ndarray, List[Tuple[int, int]]],
    edge_colors: Union[None, np.ndarray, List[Tuple[int, int, int]]],
    keypoint_colors: Union[None, np.ndarray, List[Tuple[int, int, int]]],
    show_keypoint_confidence: bool = False,
    joint_thickness: Optional[int] = None,
    box_thickness: Optional[int] = None,
    keypoint_radius: Optional[int] = None,
    keypoint_confidence_threshold: float = 0.5,
):
    """
    Draw multiple poses on an image.
    :param image: Image on which to draw the poses. This image will not be modified, instead a new image will be returned.
    :param poses: Predicted poses. Shape [Num Poses, Num Joints, 2] or [Num Poses, Num Joints, 3] if confidence scores are available.
    :param boxes: Optional bounding boxes for each pose. Shape [Num Poses, 4] in XYXY format.
    :param scores: Optional confidence scores for each pose. Shape [Num Poses]
    :param is_crowd: Optional array of booleans indicating whether each pose is crowd or not. Shape [Num Poses]
    :param edge_links: Array of [Num Links, 2] containing the links between joints to draw.
    :param edge_colors: Array of shape [Num Links, 3] or list of tuples containing the (r,g,b) colors for each joint link.
    :param keypoint_colors: Array of shape [Num Joints, 3] or list of tuples containing the (r,g,b) colors for each keypoint.
    :param show_keypoint_confidence: Whether to show the confidence score for each keypoint individually.
    :param keypoint_confidence_threshold: A minimal confidence score for individual keypoint to be drawn.
    :param joint_thickness: (Optional) Thickness of the joint links
    :return: A new image with the poses drawn on it.
    """
    if boxes is not None and len(boxes) != len(poses):
        raise ValueError("boxes and poses must have the same length")
    if scores is not None and len(scores) != len(poses):
        raise ValueError("conf and poses must have the same length")
    if is_crowd is not None and len(is_crowd) != len(poses):
        raise ValueError("is_crowd and poses must have the same length")

    # For visualization purposes, sort poses by confidence starting from the least confident
    if scores is not None:
        order = np.argsort(scores)
        poses = poses[order]
        scores = scores[order]
        if boxes is not None:
            boxes = boxes[order]
        if is_crowd is not None:
            is_crowd = is_crowd[order]

    res_image = image.copy()
    num_poses = len(poses)

    for pose_index in range(num_poses):

        if boxes is not None:
            x1 = int(boxes[pose_index][0])
            y1 = int(boxes[pose_index][1])
            x2 = int(boxes[pose_index][2])
            y2 = int(boxes[pose_index][3])

            current_box_thickness = box_thickness or get_recommended_box_thickness(x1, y1, x2, y2)
            current_joint_thickness = joint_thickness or current_box_thickness
            current_keypoint_radius = keypoint_radius or math.ceil(current_box_thickness * 3 / 2)
        else:
            current_joint_thickness = 2
            current_keypoint_radius = 3
            current_box_thickness = 2

        res_image = draw_skeleton(
            image=res_image,
            keypoints=poses[pose_index],
            score=scores[pose_index] if scores is not None else None,
            edge_links=edge_links,
            edge_colors=edge_colors,
            joint_thickness=current_joint_thickness,
            keypoint_colors=keypoint_colors,
            keypoint_radius=current_keypoint_radius,
            show_confidence=scores is not None and boxes is None,
            show_keypoint_confidence=show_keypoint_confidence,
            box_thickness=current_box_thickness,
            keypoint_confidence_threshold=keypoint_confidence_threshold,
        )

        if boxes is not None:
            x1 = int(boxes[pose_index][0])
            y1 = int(boxes[pose_index][1])
            x2 = int(boxes[pose_index][2])
            y2 = int(boxes[pose_index][3])
            title = ""
            if scores is not None:
                title += f"{scores[pose_index]:.2f}"
            if is_crowd is not None:
                title += f"Crowd {is_crowd[pose_index]}"

            res_image = draw_bbox(
                image=res_image,
                x1=x1,
                y1=y1,
                x2=x2,
                y2=y2,
                color=(255, 255, 255),
                title=title,
                box_thickness=current_box_thickness,
            )

    return res_image

draw_skeleton(image, keypoints, score, edge_links, edge_colors, joint_thickness, keypoint_colors, keypoint_radius, show_confidence, box_thickness, keypoint_confidence_threshold=0.0, show_keypoint_confidence=False)

Draw a skeleton on an image.

Parameters:

Name Type Description Default
image np.ndarray

Input image (will not be modified)

required
keypoints np.ndarray

Array of [Num Joints, 2] or [Num Joints, 3] containing the keypoints to draw. First two values are the x and y coordinates, the third (optional, not used) is the confidence score.

required
score float

Confidence score of the whole pose

required
edge_links Union[None, np.ndarray, List[Tuple[int, int]]]

Array of [Num Links, 2] containing the links between joints to draw. Can be None, in which case no links will be drawn.

required
edge_colors Union[None, np.ndarray, List[Tuple[int, int, int]]]

Array of shape [Num Links, 3] or list of tuples containing the (r,g,b) colors for each joint link.

required
joint_thickness int

(Optional) Thickness of the joint links

required
keypoint_colors Union[None, np.ndarray, List[Tuple[int, int, int]]]

Array of shape [Num Joints, 3] or list of tuples containing the (r,g,b) colors for each keypoint.

required
keypoint_radius int

(Optional) Radius of the keypoints (in pixels)

required
show_confidence bool

Whether to show the bounding box around the pose and confidence score on top of it.

required
box_thickness Optional[int]

(Optional) Thickness of bounding boxes. If None, will adapt to the box size.

required
keypoint_confidence_threshold float

If keypoints contains confidence scores (Shape is [Num Joints, 3]), this function will draw keypoints with confidence score > threshold.

0.0
show_keypoint_confidence bool

Whether to show the confidence score for each keypoint individually.

False

Returns:

Type Description

A new image with the skeleton drawn on it

Source code in src/super_gradients/training/utils/visualization/pose_estimation.py
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
def draw_skeleton(
    image: np.ndarray,
    keypoints: np.ndarray,
    score: float,
    edge_links: Union[None, np.ndarray, List[Tuple[int, int]]],
    edge_colors: Union[None, np.ndarray, List[Tuple[int, int, int]]],
    joint_thickness: int,
    keypoint_colors: Union[None, np.ndarray, List[Tuple[int, int, int]]],
    keypoint_radius: int,
    show_confidence: bool,
    box_thickness: Optional[int],
    keypoint_confidence_threshold: float = 0.0,
    show_keypoint_confidence: bool = False,
):
    """
    Draw a skeleton on an image.

    :param image:           Input image (will not be modified)
    :param keypoints:       Array of [Num Joints, 2] or [Num Joints, 3] containing the keypoints to draw.
                            First two values are the x and y coordinates, the third (optional, not used) is the confidence score.
    :param score:           Confidence score of the whole pose
    :param edge_links:      Array of [Num Links, 2] containing the links between joints to draw. Can be None, in which case no links will be drawn.
    :param edge_colors:     Array of shape [Num Links, 3] or list of tuples containing the (r,g,b) colors for each joint link.
    :param joint_thickness: (Optional) Thickness of the joint links
    :param keypoint_colors: Array of shape [Num Joints, 3] or list of tuples containing the (r,g,b) colors for each keypoint.
    :param keypoint_radius: (Optional) Radius of the keypoints (in pixels)
    :param show_confidence: Whether to show the bounding box around the pose and confidence score on top of it.
    :param box_thickness:   (Optional) Thickness of bounding boxes. If None, will adapt to the box size.
    :param keypoint_confidence_threshold: If keypoints contains confidence scores (Shape is [Num Joints, 3]), this function
    will draw keypoints with confidence score > threshold.
    :param show_keypoint_confidence: Whether to show the confidence score for each keypoint individually.


    :return: A new image with the skeleton drawn on it
    """
    if edge_links is not None and edge_colors is not None and len(edge_links) != len(edge_colors):
        raise ValueError("edge_colors and edge_links must have the same length")

    if edge_colors is None and edge_links is not None:
        edge_colors = [(255, 0, 255)] * len(edge_links)

    if keypoint_colors is None:
        keypoint_colors = [(0, 255, 0)] * len(keypoints)

    if len(keypoints.shape) != 2 or keypoints.shape[1] not in (2, 3):
        raise ValueError(f"Argument keypoints must be a 2D array of shape [Num Joints, 2] or [Num Joints, 3], got input of shape {keypoints.shape}")

    if keypoints.shape[1] == 3:
        keypoint_scores = keypoints[..., 2]
        keypoints = keypoints[..., 0:2].astype(int)
    else:
        # If keypoints contains no scores, set all scores above keypoint_confidence_threshold to draw them all
        keypoint_scores = np.ones(len(keypoints)) + keypoint_confidence_threshold
        keypoints = keypoints[..., 0:2].astype(int)
        show_keypoint_confidence = False

    keypoints_to_show_mask = keypoint_scores > keypoint_confidence_threshold

    pose_center = keypoints[keypoints_to_show_mask].mean(axis=0)
    direction_from_center = keypoints - pose_center
    direction_from_center /= np.linalg.norm(direction_from_center, axis=1, ord=2, keepdims=True) + 1e-9

    overlay = image.copy()

    for keypoint, score, direction, show, color in zip(keypoints, keypoint_scores, direction_from_center, keypoints_to_show_mask, keypoint_colors):
        if not show:
            continue
        x, y = keypoint
        x = int(x)
        y = int(y)
        color = tuple(map(int, color))
        cv2.circle(overlay, center=(x, y), radius=keypoint_radius, color=color, thickness=-1, lineType=cv2.LINE_AA)

        # Draw confidence score for each keypoint individually
        if show_keypoint_confidence:
            center_of_score = keypoint + direction * 16
            text = f"{score:.2f}"
            (w, h), baseline = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)

            cx, cy = center_of_score
            if direction[0] < -0.5:
                x = int(cx - w)
            elif direction[0] > 0.5:
                x = int(cx)
            else:
                x = int(cx - w // 2)

            y = int(cy + h // 2)
            cv2.putText(overlay, text, org=(x, y), fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=0.5, color=(250, 250, 250), thickness=1, lineType=cv2.LINE_AA)

    if edge_links is not None:
        for (kp1, kp2), color in zip(edge_links, edge_colors):
            show = keypoints_to_show_mask[kp1] and keypoints_to_show_mask[kp2]
            if not show:
                continue
            p1 = tuple(map(int, keypoints[kp1]))
            p2 = tuple(map(int, keypoints[kp2]))
            color = tuple(map(int, color))
            cv2.line(overlay, p1, p2, color=color, thickness=joint_thickness, lineType=cv2.LINE_AA)

    confident_keypoints = keypoints[keypoints_to_show_mask]

    if show_confidence and len(confident_keypoints):
        x, y, w, h = cv2.boundingRect(confident_keypoints)
        overlay = draw_bbox(overlay, title=f"{score:.2f}", box_thickness=box_thickness, color=(255, 0, 255), x1=x, y1=y, x2=x + w, y2=y + h)

    return cv2.addWeighted(overlay, 0.75, image, 0.25, 0)

overlay_segmentation(image, pred_mask, num_classes, alpha, colors=None, class_names=None)

Draw a bounding box on an image.

Parameters:

Name Type Description Default
image np.ndarray

Image on which to draw the segmentation.

required
pred_mask torch.Tensor

Image on which to draw the segmentation.

required
num_classes int

Image on which to draw the segmentation.

required
alpha float

Float number between [0,1] denoting the transparency of the masks (0 means full transparency, 1 means opacity).

required
colors Optional[Union[List[Union[str, Tuple[int, int, int]]], str, Tuple[int, int, int]]]

List containing the colors of the masks or single color for all masks. By default, random colors are generated for each mask.

None
class_names Optional[List[str]]

List containing the class names of cityscapes classes used for model training

None
Source code in src/super_gradients/training/utils/visualization/segmentation.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
def overlay_segmentation(
    image: np.ndarray,
    pred_mask: torch.Tensor,
    num_classes: int,
    alpha: float,
    colors: Optional[Union[List[Union[str, Tuple[int, int, int]]], str, Tuple[int, int, int]]] = None,
    class_names: Optional[List[str]] = None,
) -> np.ndarray:
    """Draw a bounding box on an image.

    :param image:           Image on which to draw the segmentation.
    :param pred_mask:       Image on which to draw the segmentation.
    :param num_classes:     Image on which to draw the segmentation.
    :param alpha:           Float number between [0,1] denoting the transparency of the masks (0 means full transparency, 1 means opacity).
    :param colors:          List containing the colors of the masks or single color for all masks. By default, random colors are generated for each mask.
    :param class_names:     List containing the class names of cityscapes classes used for model training
    """
    if class_names is None:
        class_names = [f"class_{i}" for i in range(num_classes)]
    if colors is None:
        colors = generate_color_mapping(num_classes)
    if len(colors) != num_classes:
        raise ValueError(f"Number of colors ({len(colors)}) should be equal to the number of classes ({num_classes})")
    if len(class_names) != num_classes:
        raise ValueError(f"Number of class names ({len(class_names)}) should be equal to the number of classes ({num_classes})")

    overlay = image.copy()
    overlay = torch.from_numpy(overlay.transpose(2, 0, 1))  # torch.from_numpy(overlay.astype(np.uint8)*255)
    segmentation_mask = torch.from_numpy(np.expand_dims(pred_mask, axis=0))
    one_hot_prediction_masks = to_one_hot(target=segmentation_mask, num_classes=num_classes)  # , ignore_index: int = None)
    segmentation_overlay = draw_segmentation_masks(overlay, masks=one_hot_prediction_masks.squeeze(0).bool(), alpha=alpha, colors=colors)

    segmentation_prediction = np.array(segmentation_overlay.detach().permute(1, 2, 0))

    # Initialize an empty list to store the classes that appear in the image
    classes_in_image_with_color: Set[Tuple[str, Tuple]] = set()

    for idx, class_name in enumerate(class_names):
        color = colors[idx]
        if torch.any(one_hot_prediction_masks[0, idx, :, :]):
            classes_in_image_with_color.add((class_name, color))

    canvas = draw_legend_on_canvas(image=segmentation_prediction, class_color_tuples=classes_in_image_with_color)
    segmentation_prediction = np.concatenate((segmentation_prediction, canvas), axis=0)

    return segmentation_prediction

best_text_color(background_color)

Determine the best color for text to be visible on a given background color.

Parameters:

Name Type Description Default
background_color Tuple[int, int, int]

RGB values of the background color.

required

Returns:

Type Description
Tuple[int, int, int]

RGB values of the best text color for the given background color.

Source code in src/super_gradients/training/utils/visualization/utils.py
38
39
40
41
42
43
44
45
46
47
48
49
def best_text_color(background_color: Tuple[int, int, int]) -> Tuple[int, int, int]:
    """Determine the best color for text to be visible on a given background color.

    :param background_color: RGB values of the background color.
    :return: RGB values of the best text color for the given background color.
    """

    # If the brightness is greater than 0.5, use black text; otherwise, use white text.
    if compute_brightness(background_color) > 0.5:
        return (0, 0, 0)  # Black
    else:
        return (255, 255, 255)  # White

compute_brightness(color)

Computes the brightness of a given color in RGB format. From https://alienryderflex.com/hsp.html

Parameters:

Name Type Description Default
color Tuple[int, int, int]

A tuple of three integers representing the RGB values of the color.

required

Returns:

Type Description
float

The brightness of the color.

Source code in src/super_gradients/training/utils/visualization/utils.py
52
53
54
55
56
57
58
def compute_brightness(color: Tuple[int, int, int]) -> float:
    """Computes the brightness of a given color in RGB format. From https://alienryderflex.com/hsp.html

    :param color: A tuple of three integers representing the RGB values of the color.
    :return: The brightness of the color.
    """
    return (0.299 * color[0] + 0.587 * color[1] + 0.114 * color[0]) / 255

draw_text_box(image, text, x, y, font, font_size, background_color, thickness=1)

Draw a text inside a box

Parameters:

Name Type Description Default
image np.ndarray

The image on which to draw the text box.

required
text str

The text to display in the text box.

required
x int

The x-coordinate of the top-left corner of the text box.

required
y int

The y-coordinate of the top-left corner of the text box.

required
font int

The font to use for the text.

required
font_size float

The size of the font to use.

required
background_color Tuple[int, int, int]

The color of the text box and text as a tuple of three integers representing RGB values.

required
thickness int

The thickness of the text.

1

Returns:

Type Description
np.ndarray

Image with the text inside the box.

Source code in src/super_gradients/training/utils/visualization/utils.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
def draw_text_box(
    image: np.ndarray,
    text: str,
    x: int,
    y: int,
    font: int,
    font_size: float,
    background_color: Tuple[int, int, int],
    thickness: int = 1,
) -> np.ndarray:
    """Draw a text inside a box

    :param image:               The image on which to draw the text box.
    :param text:                The text to display in the text box.
    :param x:                   The x-coordinate of the top-left corner of the text box.
    :param y:                   The y-coordinate of the top-left corner of the text box.
    :param font:                The font to use for the text.
    :param font_size:           The size of the font to use.
    :param background_color:    The color of the text box and text as a tuple of three integers representing RGB values.
    :param thickness:           The thickness of the text.
    :return: Image with the text inside the box.
    """
    text_color = best_text_color(background_color)
    (text_width, text_height), baseline = cv2.getTextSize(text, font, font_size, thickness)
    text_left_offset = 7

    image = cv2.rectangle(image, (x, y), (x + text_width + text_left_offset, y - text_height - int(15 * font_size)), background_color, -1)
    image = cv2.putText(image, text, (x + text_left_offset, y - int(10 * font_size)), font, font_size, text_color, thickness, lineType=cv2.LINE_AA)
    return image

generate_color_mapping(num_classes)

Generate a unique BGR color for each class

Parameters:

Name Type Description Default
num_classes int

The number of classes in the dataset.

required

Returns:

Type Description
List[Tuple[int, ...]]

List of RGB colors for each class.

Source code in src/super_gradients/training/utils/visualization/utils.py
61
62
63
64
65
66
67
68
69
def generate_color_mapping(num_classes: int) -> List[Tuple[int, ...]]:
    """Generate a unique BGR color for each class

    :param num_classes: The number of classes in the dataset.
    :return:            List of RGB colors for each class.
    """
    cmap = plt.cm.get_cmap("gist_rainbow", num_classes)
    colors = [cmap(i, bytes=True)[:3][::-1] for i in range(num_classes)]
    return [tuple(int(v) for v in c) for c in colors]

ModelWeightAveraging

Utils class for managing the averaging of the best several snapshots into a single model. A snapshot dictionary file and the average model will be saved / updated at every epoch and evaluated only when training is completed. The snapshot file will only be deleted upon completing the training. The snapshot dict will be managed on cpu.

Source code in src/super_gradients/training/utils/weight_averaging_utils.py
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
class ModelWeightAveraging:
    """
    Utils class for managing the averaging of the best several snapshots into a single model.
    A snapshot dictionary file and the average model will be saved / updated at every epoch and evaluated only when
    training is completed. The snapshot file will only be deleted upon completing the training.
    The snapshot dict will be managed on cpu.
    """

    def __init__(
        self,
        ckpt_dir: str,
        greater_is_better: bool,
        metric_to_watch: str,
        load_checkpoint: bool = False,
        number_of_models_to_average: int = 10,
    ):
        """
        Init the ModelWeightAveraging
        :param ckpt_dir:                    The directory where the checkpoints are saved
        :param metric_to_watch:             Monitoring loss or acc, will be identical to that which determines best_model
        :param load_checkpoint:             Whether to load pre-existing snapshot dict.
        :param number_of_models_to_average: Number of models to average
        """

        self.averaging_snapshots_file = os.path.join(ckpt_dir, "averaging_snapshots.pkl")
        self.number_of_models_to_average = number_of_models_to_average
        self.metric_to_watch = metric_to_watch
        self.greater_is_better = greater_is_better

        # if continuing training, copy previous snapshot dict if exist
        if load_checkpoint and ckpt_dir is not None and os.path.isfile(self.averaging_snapshots_file):
            averaging_snapshots_dict = read_ckpt_state_dict(self.averaging_snapshots_file)
        else:
            averaging_snapshots_dict = {"snapshot" + str(i): None for i in range(self.number_of_models_to_average)}
            # if metric to watch is acc, hold a zero array, if loss hold inf array
            if self.greater_is_better:
                averaging_snapshots_dict["snapshots_metric"] = -1 * np.inf * np.ones(self.number_of_models_to_average)
            else:
                averaging_snapshots_dict["snapshots_metric"] = np.inf * np.ones(self.number_of_models_to_average)

            torch.save(averaging_snapshots_dict, self.averaging_snapshots_file)

    def update_snapshots_dict(self, model: nn.Module, validation_results_dict: Mapping[str, float]):
        """
        Update the snapshot dict and returns the updated average model for saving
        :param model: the latest model
        :param validation_results_dict: performance of the latest model
        """
        averaging_snapshots_dict = self._get_averaging_snapshots_dict()

        # IF CURRENT MODEL IS BETTER, TAKING HIS PLACE IN ACC LIST AND OVERWRITE THE NEW AVERAGE
        require_update, update_ind = self._is_better(averaging_snapshots_dict, validation_results_dict)
        if require_update:
            # moving state dict to cpu
            new_sd = unwrap_model(model).state_dict()
            new_sd = move_state_dict_to_device(new_sd, "cpu")

            averaging_snapshots_dict["snapshot" + str(update_ind)] = new_sd
            averaging_snapshots_dict["snapshots_metric"][update_ind] = float(validation_results_dict[self.metric_to_watch])

        return averaging_snapshots_dict

    def get_average_model(self, model, validation_results_dict=None) -> Mapping[str, torch.Tensor]:
        """
        Returns the averaged model
        :param model: will be used to determine arch
        :param validation_results_dict: if provided, will update the average model before returning
        :param target_device: if provided, return sd on target device

        """
        # If validation tuple is provided, update the average model
        if validation_results_dict is not None:
            averaging_snapshots_dict = self.update_snapshots_dict(model, validation_results_dict)
        else:
            averaging_snapshots_dict = self._get_averaging_snapshots_dict()

        torch.save(averaging_snapshots_dict, self.averaging_snapshots_file)
        average_model_sd = averaging_snapshots_dict["snapshot0"]
        for n_model in range(1, self.number_of_models_to_average):
            if averaging_snapshots_dict["snapshot" + str(n_model)] is not None:
                net_sd = averaging_snapshots_dict["snapshot" + str(n_model)]
                # USING MOVING AVERAGE
                for key in average_model_sd:
                    average_model_sd[key] = torch.true_divide(average_model_sd[key] * n_model + net_sd[key], (n_model + 1))

        return average_model_sd

    def cleanup(self):
        """
        Delete snapshot file when reaching the last epoch
        """
        os.remove(self.averaging_snapshots_file)

    def _is_better(
        self, averaging_snapshots_dict: Mapping[str, Any], validation_results_dict: Mapping[str, Union[float, torch.Tensor]]
    ) -> Tuple[bool, Optional[int]]:
        """
        Determines if the new model is better according to the specified metrics
        :param averaging_snapshots_dict: snapshot dict
        :param validation_results_dict:  latest model performance
        :return: Tuple (bool, index) whether first item is True if the new model is better and False otherwise;
                 Second item is the index in the averaging_snapshots_dict to which the new model should be saved
        """
        snapshot_metric_array = averaging_snapshots_dict["snapshots_metric"]
        val = float(validation_results_dict[self.metric_to_watch])

        if not np.isfinite(val):
            return False, None

        if self.greater_is_better:
            update_ind = np.argmin(snapshot_metric_array)
        else:
            update_ind = np.argmax(snapshot_metric_array)

        if (self.greater_is_better and val > snapshot_metric_array[update_ind]) or (not self.greater_is_better and val < snapshot_metric_array[update_ind]):
            return True, update_ind

        return False, None

    def _get_averaging_snapshots_dict(self):
        return torch.load(self.averaging_snapshots_file, map_location="cpu")

__init__(ckpt_dir, greater_is_better, metric_to_watch, load_checkpoint=False, number_of_models_to_average=10)

Init the ModelWeightAveraging

Parameters:

Name Type Description Default
ckpt_dir str

The directory where the checkpoints are saved

required
metric_to_watch str

Monitoring loss or acc, will be identical to that which determines best_model

required
load_checkpoint bool

Whether to load pre-existing snapshot dict.

False
number_of_models_to_average int

Number of models to average

10
Source code in src/super_gradients/training/utils/weight_averaging_utils.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
def __init__(
    self,
    ckpt_dir: str,
    greater_is_better: bool,
    metric_to_watch: str,
    load_checkpoint: bool = False,
    number_of_models_to_average: int = 10,
):
    """
    Init the ModelWeightAveraging
    :param ckpt_dir:                    The directory where the checkpoints are saved
    :param metric_to_watch:             Monitoring loss or acc, will be identical to that which determines best_model
    :param load_checkpoint:             Whether to load pre-existing snapshot dict.
    :param number_of_models_to_average: Number of models to average
    """

    self.averaging_snapshots_file = os.path.join(ckpt_dir, "averaging_snapshots.pkl")
    self.number_of_models_to_average = number_of_models_to_average
    self.metric_to_watch = metric_to_watch
    self.greater_is_better = greater_is_better

    # if continuing training, copy previous snapshot dict if exist
    if load_checkpoint and ckpt_dir is not None and os.path.isfile(self.averaging_snapshots_file):
        averaging_snapshots_dict = read_ckpt_state_dict(self.averaging_snapshots_file)
    else:
        averaging_snapshots_dict = {"snapshot" + str(i): None for i in range(self.number_of_models_to_average)}
        # if metric to watch is acc, hold a zero array, if loss hold inf array
        if self.greater_is_better:
            averaging_snapshots_dict["snapshots_metric"] = -1 * np.inf * np.ones(self.number_of_models_to_average)
        else:
            averaging_snapshots_dict["snapshots_metric"] = np.inf * np.ones(self.number_of_models_to_average)

        torch.save(averaging_snapshots_dict, self.averaging_snapshots_file)

cleanup()

Delete snapshot file when reaching the last epoch

Source code in src/super_gradients/training/utils/weight_averaging_utils.py
 99
100
101
102
103
def cleanup(self):
    """
    Delete snapshot file when reaching the last epoch
    """
    os.remove(self.averaging_snapshots_file)

get_average_model(model, validation_results_dict=None)

Returns the averaged model

Parameters:

Name Type Description Default
model

will be used to determine arch

required
validation_results_dict

if provided, will update the average model before returning

None
target_device

if provided, return sd on target device

required
Source code in src/super_gradients/training/utils/weight_averaging_utils.py
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
def get_average_model(self, model, validation_results_dict=None) -> Mapping[str, torch.Tensor]:
    """
    Returns the averaged model
    :param model: will be used to determine arch
    :param validation_results_dict: if provided, will update the average model before returning
    :param target_device: if provided, return sd on target device

    """
    # If validation tuple is provided, update the average model
    if validation_results_dict is not None:
        averaging_snapshots_dict = self.update_snapshots_dict(model, validation_results_dict)
    else:
        averaging_snapshots_dict = self._get_averaging_snapshots_dict()

    torch.save(averaging_snapshots_dict, self.averaging_snapshots_file)
    average_model_sd = averaging_snapshots_dict["snapshot0"]
    for n_model in range(1, self.number_of_models_to_average):
        if averaging_snapshots_dict["snapshot" + str(n_model)] is not None:
            net_sd = averaging_snapshots_dict["snapshot" + str(n_model)]
            # USING MOVING AVERAGE
            for key in average_model_sd:
                average_model_sd[key] = torch.true_divide(average_model_sd[key] * n_model + net_sd[key], (n_model + 1))

    return average_model_sd

update_snapshots_dict(model, validation_results_dict)

Update the snapshot dict and returns the updated average model for saving

Parameters:

Name Type Description Default
model nn.Module

the latest model

required
validation_results_dict Mapping[str, float]

performance of the latest model

required
Source code in src/super_gradients/training/utils/weight_averaging_utils.py
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
def update_snapshots_dict(self, model: nn.Module, validation_results_dict: Mapping[str, float]):
    """
    Update the snapshot dict and returns the updated average model for saving
    :param model: the latest model
    :param validation_results_dict: performance of the latest model
    """
    averaging_snapshots_dict = self._get_averaging_snapshots_dict()

    # IF CURRENT MODEL IS BETTER, TAKING HIS PLACE IN ACC LIST AND OVERWRITE THE NEW AVERAGE
    require_update, update_ind = self._is_better(averaging_snapshots_dict, validation_results_dict)
    if require_update:
        # moving state dict to cpu
        new_sd = unwrap_model(model).state_dict()
        new_sd = move_state_dict_to_device(new_sd, "cpu")

        averaging_snapshots_dict["snapshot" + str(update_ind)] = new_sd
        averaging_snapshots_dict["snapshots_metric"][update_ind] = float(validation_results_dict[self.metric_to_watch])

    return averaging_snapshots_dict