Weight replacement utils
replace_conv2d_input_channels(conv, in_channels, fn=None)
Instantiate a new Conv2d module with same attributes as input Conv2d, except for the input channels.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
conv |
nn.Conv2d
|
Conv2d to replace the input channels in. |
required |
in_channels |
int
|
New number of input channels. |
required |
fn |
Optional[Callable[[nn.Conv2d, int], nn.Conv2d]]
|
(Optional) Function to instantiate the new Conv2d. By default, it will initialize the new weights with the same mean and std as the original weights. |
None
|
Returns:
Type | Description |
---|---|
nn.Module
|
Conv2d with new number of input channels. |
Source code in latest/src/super_gradients/modules/weight_replacement_utils.py
9 10 11 12 13 14 15 16 17 18 19 20 21 |
|
replace_conv2d_input_channels_with_random_weights(conv, in_channels)
Replace the input channels in the input Conv2d with random weights. Returned module will have the same device and dtype as the original module. Random weights are initialized with the same mean and std as the original weights.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
conv |
nn.Conv2d
|
Conv2d to replace the input channels in. |
required |
in_channels |
int
|
New number of input channels. |
required |
Returns:
Type | Description |
---|---|
nn.Conv2d
|
Conv2d with new number of input channels. |
Source code in latest/src/super_gradients/modules/weight_replacement_utils.py
24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
|