Skip to content

Head replacement utils

replace_num_classes_with_random_weights(module, num_classes)

Replace the number of classes in the module with random weights. This is useful for replacing the output layer of a detection/classification head. This implementation support Conv2d and Linear layers. Returned module will have the same device and dtype as the original module. Random weights are initialized with the same mean and std as the original weights.

Parameters:

Name Type Description Default
module Union[nn.Conv2d, nn.Linear, nn.Module]

(nn.Module) Module to replace the number of classes in.

required
num_classes int

New number of classes.

required

Returns:

Type Description
nn.Module

nn.Module

Source code in src/super_gradients/modules/head_replacement_utils.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
def replace_num_classes_with_random_weights(module: Union[nn.Conv2d, nn.Linear, nn.Module], num_classes: int) -> nn.Module:
    """
    Replace the number of classes in the module with random weights.
    This is useful for replacing the output layer of a detection/classification head.
    This implementation support Conv2d and Linear layers.
    Returned module will have the same device and dtype as the original module.
    Random weights are initialized with the same mean and std as the original weights.

    :param module: (nn.Module) Module to replace the number of classes in.
    :param num_classes: New number of classes.
    :return: nn.Module
    """
    if isinstance(module, nn.Conv2d):
        new_module = nn.Conv2d(
            module.in_channels,
            num_classes,
            kernel_size=module.kernel_size,
            stride=module.stride,
            padding=module.padding,
            dilation=module.dilation,
            groups=module.groups,
            bias=module.bias is not None,
            device=module.weight.device,
            dtype=module.weight.dtype,
        )
        torch.nn.init.normal_(new_module.weight, mean=module.weight.mean().item(), std=module.weight.std(dim=(0, 1, 2, 3)).item())
        if module.bias is not None:
            torch.nn.init.normal_(new_module.bias, mean=module.bias.mean().item(), std=module.bias.std(dim=0).item())

        return new_module
    elif isinstance(module, nn.Linear):
        new_module = nn.Linear(module.in_features, num_classes, device=module.weight.device, dtype=module.weight.dtype, bias=module.bias is not None)
        torch.nn.init.normal_(new_module.weight, mean=module.weight.mean().item(), std=module.weight.std(dim=(0, 1, 2)).item())
        if module.bias is not None:
            torch.nn.init.normal_(new_module.bias, mean=module.bias.mean().item(), std=module.bias.std(dim=0).item())

        return new_module
    else:
        raise ValueError(f"Module {module} does not support replacing the number of classes")