Skip to content

Sampling

make_downsample_module(in_channels, stride, downsample_mode)

Factory method for creating down-sampling modules.

Parameters:

Name Type Description Default
downsample_mode Union[str, DownSampleMode]

see DownSampleMode for supported options.

required

Returns:

Type Description

nn.Module

Source code in V3_1/src/super_gradients/modules/sampling.py
36
37
38
39
40
41
42
43
44
45
46
47
def make_downsample_module(in_channels: int, stride: int, downsample_mode: Union[str, DownSampleMode]):
    """
    Factory method for creating down-sampling modules.
    :param downsample_mode: see DownSampleMode for supported options.
    :return: nn.Module
    """
    downsample_mode = downsample_mode.value if isinstance(downsample_mode, DownSampleMode) else downsample_mode
    if downsample_mode == DownSampleMode.ANTI_ALIAS.value:
        return AntiAliasDownsample(in_channels, stride)
    if downsample_mode == DownSampleMode.MAX_POOL.value:
        return nn.MaxPool2d(kernel_size=stride, stride=stride)
    raise NotImplementedError(f"DownSample mode: `{downsample_mode}` is not supported.")

make_upsample_module(scale_factor, upsample_mode, align_corners=None)

Factory method for creating upsampling modules.

Parameters:

Name Type Description Default
scale_factor int

upsample scale factor

required
upsample_mode Union[str, UpsampleMode]

see UpsampleMode for supported options.

required

Returns:

Type Description

nn.Module

Source code in V3_1/src/super_gradients/modules/sampling.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
def make_upsample_module(scale_factor: int, upsample_mode: Union[str, UpsampleMode], align_corners: Optional[bool] = None):
    """
    Factory method for creating upsampling modules.
    :param scale_factor: upsample scale factor
    :param upsample_mode: see UpsampleMode for supported options.
    :return: nn.Module
    """
    upsample_mode = upsample_mode.value if isinstance(upsample_mode, UpsampleMode) else upsample_mode

    if upsample_mode == UpsampleMode.NEAREST.value:
        # Prevent ValueError when passing align_corners with nearest mode.
        module = nn.Upsample(scale_factor=scale_factor, mode=upsample_mode)

    elif upsample_mode in [UpsampleMode.BILINEAR.value, UpsampleMode.BICUBIC.value]:
        module = nn.Upsample(scale_factor=scale_factor, mode=upsample_mode, align_corners=align_corners)

    elif upsample_mode == UpsampleMode.PIXEL_SHUFFLE.value:
        module = PixelShuffle(upscale_factor=scale_factor)

    elif upsample_mode == UpsampleMode.NN_PIXEL_SHUFFLE.value:
        module = nn.PixelShuffle(upscale_factor=scale_factor)
    else:
        raise NotImplementedError(f"Upsample mode: `{upsample_mode}` is not supported.")
    return module