Skip to content

Conv bn act block

ConvBNAct

Bases: nn.Module, SupportsReplaceInputChannels

Class for Convolution2d-Batchnorm2d-Activation layer. Default behaviour is Conv-BN-Act. To exclude Batchnorm module use use_normalization=False, to exclude activation use activation_type=None. For convolution arguments documentation see nn.Conv2d. For batchnorm arguments documentation see nn.BatchNorm2d.

Source code in src/super_gradients/modules/conv_bn_act_block.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
class ConvBNAct(nn.Module, SupportsReplaceInputChannels):
    """
    Class for Convolution2d-Batchnorm2d-Activation layer.
        Default behaviour is Conv-BN-Act. To exclude Batchnorm module use
        `use_normalization=False`, to exclude activation use `activation_type=None`.
    For convolution arguments documentation see `nn.Conv2d`.
    For batchnorm arguments documentation see `nn.BatchNorm2d`.
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: Union[int, Tuple[int, int]],
        padding: Union[int, Tuple[int, int]],
        activation_type: Type[nn.Module],
        stride: Union[int, Tuple[int, int]] = 1,
        dilation: Union[int, Tuple[int, int]] = 1,
        groups: int = 1,
        bias: bool = True,
        padding_mode: str = "zeros",
        use_normalization: bool = True,
        eps: float = 1e-5,
        momentum: float = 0.1,
        affine: bool = True,
        track_running_stats: bool = True,
        device=None,
        dtype=None,
        activation_kwargs=None,
    ):

        super().__init__()
        if activation_kwargs is None:
            activation_kwargs = {}

        self.seq = nn.Sequential()
        self.seq.add_module(
            "conv",
            nn.Conv2d(
                in_channels,
                out_channels,
                kernel_size=kernel_size,
                stride=stride,
                padding=padding,
                dilation=dilation,
                groups=groups,
                bias=bias,
                padding_mode=padding_mode,
            ),
        )

        if use_normalization:
            self.seq.add_module(
                "bn",
                nn.BatchNorm2d(out_channels, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats, device=device, dtype=dtype),
            )
        if activation_type is not None:
            self.seq.add_module("act", activation_type(**activation_kwargs))

    def forward(self, x):
        return self.seq(x)

    def replace_input_channels(self, in_channels: int, compute_new_weights_fn: Optional[Callable[[nn.Module, int], nn.Module]] = None):
        from super_gradients.modules.weight_replacement_utils import replace_conv2d_input_channels

        self.seq[0] = replace_conv2d_input_channels(conv=self.seq[0], in_channels=in_channels, fn=compute_new_weights_fn)

    def get_input_channels(self) -> int:
        return self.seq[0].in_channels