Skip to content

Multi output modules

MultiOutputModule

Bases: nn.Module, SupportsReplaceInputChannels

This module wraps around a container nn.Module (such as Module, Sequential and ModuleList) and allows to extract multiple output from its inner modules on each forward call() (as a list of output tensors) note: the default output of the wrapped module will not be added to the output list by default. To get the default output in the outputs list, explicitly include its path in the @output_paths parameter

i.e. for module: Sequential( (0): Sequential( (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU6(inplace=True) ) ===================================>> (1): InvertedResidual( (conv): Sequential( (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False) (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU6(inplace=True) ===================================>> (3): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False) (4): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) ) and paths: [0, [1, 'conv', 2]] the output are marked with arrows

Source code in src/super_gradients/modules/multi_output_modules.py
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
class MultiOutputModule(nn.Module, SupportsReplaceInputChannels):
    """
    This module wraps around a container nn.Module (such as Module, Sequential and ModuleList) and allows to extract
    multiple output from its inner modules on each forward call() (as a list of output tensors)
    note: the default output of the wrapped module will not be added to the output list by default. To get
    the default output in the outputs list, explicitly include its path in the @output_paths parameter

    i.e.
    for module:
        Sequential(
          (0): Sequential(
            (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
            (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU6(inplace=True)
          )                                         ===================================>>
          (1): InvertedResidual(
            (conv): Sequential(
              (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
              (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU6(inplace=True)              ===================================>>
              (3): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (4): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
          )
        )
    and paths:
        [0, [1, 'conv', 2]]
    the output are marked with arrows
    """

    def __init__(self, module: nn.Module, output_paths: list, prune: bool = True):
        """
        :param module: The wrapped container module
        :param output_paths: a list of lists or keys containing the canonical paths to the outputs
        i.e. [3, [4, 'conv', 5], 7] will extract outputs of layers 3, 7 and 4->conv->5
        """
        super().__init__()
        self.output_paths = output_paths
        self._modules["0"] = module
        self._outputs_lists = {}

        for path in output_paths:
            child = self._get_recursive(module, path)
            child.register_forward_hook(hook=self.save_output_hook)

        # PRUNE THE MODULE TO SUPPORT ALL PROVIDED OUTPUT_PATHS BUT REMOVE ALL REDUNDANT LAYERS
        if prune:
            self._prune(module, output_paths)

    def save_output_hook(self, _, input, output):
        self._outputs_lists[input[0].device].append(output)

    def forward(self, x) -> list:
        self._outputs_lists[x.device] = []
        self._modules["0"](x)
        outputs = self._outputs_lists[x.device]
        self._outputs_lists = {}
        return outputs

    def _get_recursive(self, module: nn.Module, path) -> nn.Module:
        """recursively look for a module using a path"""
        if not isinstance(path, (list, ListConfig)):
            return module._modules[str(path)]
        elif len(path) == 1:
            return module._modules[str(path[0])]
        else:
            return self._get_recursive(module._modules[str(path[0])], path[1:])

    def _prune(self, module: nn.Module, output_paths: list):
        """
        Recursively prune the module to support all provided output_paths but remove all redundant layers
        """
        last_index = -1
        last_key = None

        # look for the last key from all paths
        for path in output_paths:
            key = path[0] if isinstance(path, (list, ListConfig)) else path
            index = list(module._modules).index(str(key))
            if index > last_index:
                last_index = index
                last_key = key

        module._modules = self._slice_odict(module._modules, 0, last_index + 1)

        next_level_paths = []
        for path in output_paths:
            if isinstance(path, (list, ListConfig)) and path[0] == last_key and len(path) > 1:
                next_level_paths.append(path[1:])

        if len(next_level_paths) > 0:
            self._prune(module._modules[str(last_key)], next_level_paths)

    def _slice_odict(self, odict: OrderedDict, start: int, end: int):
        """Slice an OrderedDict in the same logic list,tuple... are sliced"""
        return OrderedDict([(k, v) for (k, v) in odict.items() if k in list(odict.keys())[start:end]])

    def replace_input_channels(self, in_channels: int, compute_new_weights_fn: Optional[Callable[[nn.Module, int], nn.Module]] = None):
        module = self._modules["0"]
        if isinstance(module, SupportsReplaceInputChannels):
            module.replace_input_channels(in_channels=in_channels, compute_new_weights_fn=compute_new_weights_fn)
        else:
            raise NotImplementedError(f"`{module.__class__.__name__}` does not support `replace_input_channels`")

    def get_input_channels(self) -> int:
        module = self._modules["0"]
        if isinstance(module, SupportsReplaceInputChannels):
            return module.get_input_channels()
        else:
            raise NotImplementedError(f"`{module.__class__.__name__}` does not support `get_input_channels`")

__init__(module, output_paths, prune=True)

Parameters:

Name Type Description Default
module nn.Module

The wrapped container module

required
output_paths list

a list of lists or keys containing the canonical paths to the outputs i.e. [3, [4, 'conv', 5], 7] will extract outputs of layers 3, 7 and 4->conv->5

required
Source code in src/super_gradients/modules/multi_output_modules.py
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
def __init__(self, module: nn.Module, output_paths: list, prune: bool = True):
    """
    :param module: The wrapped container module
    :param output_paths: a list of lists or keys containing the canonical paths to the outputs
    i.e. [3, [4, 'conv', 5], 7] will extract outputs of layers 3, 7 and 4->conv->5
    """
    super().__init__()
    self.output_paths = output_paths
    self._modules["0"] = module
    self._outputs_lists = {}

    for path in output_paths:
        child = self._get_recursive(module, path)
        child.register_forward_hook(hook=self.save_output_hook)

    # PRUNE THE MODULE TO SUPPORT ALL PROVIDED OUTPUT_PATHS BUT REMOVE ALL REDUNDANT LAYERS
    if prune:
        self._prune(module, output_paths)