Skip to content

Utils

NormalizationAdapter

Bases: torch.nn.Module

Denormalizes input by mean_original, std_original, then normalizes by mean_required, std_required.

Used in KD training where teacher expects data normalized by mean_required, std_required.

mean_original, std_original, mean_required, std_required are all list-like objects of length that's equal to the number of input channels.

Source code in src/super_gradients/modules/utils.py
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
class NormalizationAdapter(torch.nn.Module):
    """
    Denormalizes input by mean_original, std_original, then normalizes by mean_required, std_required.

    Used in KD training where teacher expects data normalized by mean_required, std_required.

    mean_original, std_original, mean_required, std_required are all list-like objects of length that's equal to the
     number of input channels.

    """

    def __init__(self, mean_original, std_original, mean_required, std_required):
        super(NormalizationAdapter, self).__init__()
        mean_original = torch.tensor(mean_original).unsqueeze(-1).unsqueeze(-1)
        std_original = torch.tensor(std_original).unsqueeze(-1).unsqueeze(-1)
        mean_required = torch.tensor(mean_required).unsqueeze(-1).unsqueeze(-1)
        std_required = torch.tensor(std_required).unsqueeze(-1).unsqueeze(-1)

        self.additive = torch.nn.Parameter((mean_original - mean_required) / std_original)
        self.multiplier = torch.nn.Parameter(std_original / std_required)

    def forward(self, x):
        x = (x + self.additive) * self.multiplier
        return x

replace_activations(module, new_activation, activations_to_replace)

Recursively go through module and replaces each activation in activations_to_replace with a copy of new_activation

Parameters:

Name Type Description Default
module nn.Module

a module that will be changed inplace

required
new_activation nn.Module

a sample of a new activation (will be copied)

required
activations_to_replace List[type]

types of activations to replace, each must be a subclass of nn.Module

required
Source code in src/super_gradients/modules/utils.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
def replace_activations(module: nn.Module, new_activation: nn.Module, activations_to_replace: List[type]):
    """
    Recursively go through module and replaces each activation in activations_to_replace with a copy of new_activation
    :param module:                  a module that will be changed inplace
    :param new_activation:          a sample of a new activation (will be copied)
    :param activations_to_replace:  types of activations to replace, each must be a subclass of nn.Module
    """
    # check arguments once before the recursion
    assert isinstance(new_activation, nn.Module), "new_activation should be nn.Module"
    assert all(
        [isinstance(t, type) and issubclass(t, nn.Module) for t in activations_to_replace]
    ), "activations_to_replace should be types that are subclasses of nn.Module"

    # do the replacement
    _replace_activations_recursive(module, new_activation, activations_to_replace)