Skip to content

Sampling

make_downsample_module(in_channels, stride, downsample_mode)

Factory method for creating down-sampling modules.

Parameters:

Name Type Description Default
downsample_mode Union[str, DownSampleMode]

see DownSampleMode for supported options.

required

Returns:

Type Description

nn.Module

Source code in src/super_gradients/modules/sampling.py
83
84
85
86
87
88
89
90
91
92
93
94
def make_downsample_module(in_channels: int, stride: int, downsample_mode: Union[str, DownSampleMode]):
    """
    Factory method for creating down-sampling modules.
    :param downsample_mode: see DownSampleMode for supported options.
    :return: nn.Module
    """
    downsample_mode = downsample_mode.value if isinstance(downsample_mode, DownSampleMode) else downsample_mode
    if downsample_mode == DownSampleMode.ANTI_ALIAS.value:
        return AntiAliasDownsample(in_channels, stride)
    if downsample_mode == DownSampleMode.MAX_POOL.value:
        return nn.MaxPool2d(kernel_size=stride, stride=stride)
    raise NotImplementedError(f"DownSample mode: `{downsample_mode}` is not supported.")

make_upsample_module(scale_factor, upsample_mode, align_corners=None)

Factory method for creating upsampling modules.

Parameters:

Name Type Description Default
scale_factor int

upsample scale factor

required
upsample_mode Union[str, UpsampleMode]

see UpsampleMode for supported options.

required

Returns:

Type Description

nn.Module

Source code in src/super_gradients/modules/sampling.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
def make_upsample_module(scale_factor: int, upsample_mode: Union[str, UpsampleMode], align_corners: Optional[bool] = None):
    """
    Factory method for creating upsampling modules.
    :param scale_factor: upsample scale factor
    :param upsample_mode: see UpsampleMode for supported options.
    :return: nn.Module
    """
    upsample_mode = upsample_mode.value if isinstance(upsample_mode, UpsampleMode) else upsample_mode

    if upsample_mode == UpsampleMode.NEAREST.value:
        # Prevent ValueError when passing align_corners with nearest mode.
        module = nn.Upsample(scale_factor=scale_factor, mode=upsample_mode)

    elif upsample_mode in [UpsampleMode.BILINEAR.value, UpsampleMode.BICUBIC.value]:
        module = nn.Upsample(scale_factor=scale_factor, mode=upsample_mode, align_corners=align_corners)

    elif upsample_mode == UpsampleMode.PIXEL_SHUFFLE.value:
        module = PixelShuffle(upscale_factor=scale_factor)

    elif upsample_mode == UpsampleMode.NN_PIXEL_SHUFFLE.value:
        module = nn.PixelShuffle(upscale_factor=scale_factor)
    else:
        raise NotImplementedError(f"Upsample mode: `{upsample_mode}` is not supported.")
    return module

make_upsample_module_with_explicit_channels(in_channels, out_channels, scale_factor, upsample_mode, align_corners=None)

Factory method for creating upsampling module with explicit control of in/out channels. Unlike make_upsample_module, this method allows to specify number of desired output channels which is useful for upsampling using pixel shuffle and transposed convolutions.

Parameters:

Name Type Description Default
in_channels int

Number of input channels

required
out_channels int

Number of output channels

required
scale_factor int

Upsample scale factor

required
upsample_mode UpsampleMode

The desired mode of upsampling.

required
align_corners Optional[bool]

See nn.Upsample for details.

None

Returns:

Type Description
nn.Module

Created upsampling module.

Source code in src/super_gradients/modules/sampling.py
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
def make_upsample_module_with_explicit_channels(
    in_channels: int, out_channels: int, scale_factor: int, upsample_mode: UpsampleMode, align_corners: Optional[bool] = None
) -> nn.Module:
    """
    Factory method for creating upsampling module with explicit control of in/out channels.
    Unlike `make_upsample_module`, this method allows to specify number of desired output channels
    which is useful for upsampling using pixel shuffle and transposed convolutions.

    :param in_channels:   Number of input channels
    :param out_channels:  Number of output channels
    :param scale_factor:  Upsample scale factor
    :param upsample_mode: The desired mode of upsampling.
    :param align_corners: See `nn.Upsample` for details.
    :return:              Created upsampling module.
    """
    projection_before_upsample = None

    if upsample_mode == UpsampleMode.NEAREST:
        # Prevent ValueError when passing align_corners with nearest mode.
        module = nn.Upsample(scale_factor=scale_factor, mode=upsample_mode)
        if in_channels != out_channels:
            projection_before_upsample = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False)

    elif upsample_mode in [UpsampleMode.BILINEAR.value, UpsampleMode.BICUBIC.value]:
        module = nn.Upsample(scale_factor=scale_factor, mode=upsample_mode, align_corners=align_corners)

    elif upsample_mode == UpsampleMode.PIXEL_SHUFFLE:
        if in_channels != out_channels * scale_factor**2:
            projection_before_upsample = nn.Conv2d(in_channels, out_channels * scale_factor**2, kernel_size=1, stride=1, padding=0, bias=False)
        module = PixelShuffle(upscale_factor=scale_factor)

    elif upsample_mode == UpsampleMode.NN_PIXEL_SHUFFLE:
        if in_channels != out_channels * scale_factor**2:
            projection_before_upsample = nn.Conv2d(in_channels, out_channels * scale_factor**2, kernel_size=1, stride=1, padding=0, bias=False)
        module = nn.PixelShuffle(upscale_factor=scale_factor)

    elif upsample_mode == UpsampleMode.CONV_TRANSPOSE:
        module = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=scale_factor, stride=scale_factor)
    else:
        raise NotImplementedError(f"Upsample mode: `{upsample_mode}` is not supported.")

    if projection_before_upsample is not None:
        module = nn.Sequential(projection_before_upsample, module)

    return module