Multi output modules
MultiOutputModule
Bases: nn.Module
, SupportsReplaceInputChannels
This module wraps around a container nn.Module (such as Module, Sequential and ModuleList) and allows to extract multiple output from its inner modules on each forward call() (as a list of output tensors) note: the default output of the wrapped module will not be added to the output list by default. To get the default output in the outputs list, explicitly include its path in the @output_paths parameter
i.e. for module: Sequential( (0): Sequential( (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU6(inplace=True) ) ===================================>> (1): InvertedResidual( (conv): Sequential( (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False) (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU6(inplace=True) ===================================>> (3): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False) (4): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) ) and paths: [0, [1, 'conv', 2]] the output are marked with arrows
Source code in src/super_gradients/modules/multi_output_modules.py
9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 |
|
__init__(module, output_paths, prune=True)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
module |
nn.Module
|
The wrapped container module |
required |
output_paths |
list
|
a list of lists or keys containing the canonical paths to the outputs i.e. [3, [4, 'conv', 5], 7] will extract outputs of layers 3, 7 and 4->conv->5 |
required |
Source code in src/super_gradients/modules/multi_output_modules.py
39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
|