Skip to content

Weight replacement utils

replace_conv2d_input_channels(conv, in_channels, fn=None)

Instantiate a new Conv2d module with same attributes as input Conv2d, except for the input channels.

Parameters:

Name Type Description Default
conv nn.Conv2d

Conv2d to replace the input channels in.

required
in_channels int

New number of input channels.

required
fn Optional[Callable[[nn.Conv2d, int], nn.Conv2d]]

(Optional) Function to instantiate the new Conv2d. By default, it will initialize the new weights with the same mean and std as the original weights.

None

Returns:

Type Description
nn.Module

Conv2d with new number of input channels.

Source code in latest/src/super_gradients/modules/weight_replacement_utils.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
def replace_conv2d_input_channels(conv: nn.Conv2d, in_channels: int, fn: Optional[Callable[[nn.Conv2d, int], nn.Conv2d]] = None) -> nn.Module:
    """Instantiate a new Conv2d module with same attributes as input Conv2d, except for the input channels.

    :param conv:        Conv2d to replace the input channels in.
    :param in_channels: New number of input channels.
    :param fn:          (Optional) Function to instantiate the new Conv2d.
                        By default, it will initialize the new weights with the same mean and std as the original weights.
    :return:            Conv2d with new number of input channels.
    """
    if fn:
        return fn(conv, in_channels)
    else:
        return replace_conv2d_input_channels_with_random_weights(conv=conv, in_channels=in_channels)

replace_conv2d_input_channels_with_random_weights(conv, in_channels)

Replace the input channels in the input Conv2d with random weights. Returned module will have the same device and dtype as the original module. Random weights are initialized with the same mean and std as the original weights.

Parameters:

Name Type Description Default
conv nn.Conv2d

Conv2d to replace the input channels in.

required
in_channels int

New number of input channels.

required

Returns:

Type Description
nn.Conv2d

Conv2d with new number of input channels.

Source code in latest/src/super_gradients/modules/weight_replacement_utils.py
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
def replace_conv2d_input_channels_with_random_weights(conv: nn.Conv2d, in_channels: int) -> nn.Conv2d:
    """
    Replace the input channels in the input Conv2d with random weights.
    Returned module will have the same device and dtype as the original module.
    Random weights are initialized with the same mean and std as the original weights.

    :param conv:        Conv2d to replace the input channels in.
    :param in_channels: New number of input channels.
    :return:            Conv2d with new number of input channels.
    """

    if in_channels % conv.groups != 0:
        raise ValueError(
            f"Incompatible number of input channels ({in_channels}) with the number of groups ({conv.groups})."
            f"The number of input channels must be divisible by the number of groups."
        )

    new_conv = nn.Conv2d(
        in_channels,
        conv.out_channels,
        kernel_size=conv.kernel_size,
        stride=conv.stride,
        padding=conv.padding,
        dilation=conv.dilation,
        groups=conv.groups,
        bias=conv.bias is not None,
        device=conv.weight.device,
        dtype=conv.weight.dtype,
    )

    if in_channels <= conv.in_channels:
        new_conv.weight.data = conv.weight.data[:, :in_channels, ...]
    else:
        new_conv.weight.data[:, : conv.in_channels, ...] = conv.weight.data

        # Pad the remaining channels with random weights
        torch.nn.init.normal_(new_conv.weight.data[:, conv.in_channels :, ...], mean=conv.weight.mean().item(), std=conv.weight.std().item())

    if conv.bias is not None:
        torch.nn.init.normal_(new_conv.bias, mean=conv.bias.mean().item(), std=conv.bias.std().item())

    return new_conv