Head replacement utils
replace_num_classes_with_random_weights(module, num_classes)
Replace the number of classes in the module with random weights. This is useful for replacing the output layer of a detection/classification head. This implementation support Conv2d and Linear layers. Returned module will have the same device and dtype as the original module. Random weights are initialized with the same mean and std as the original weights.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
module |
Union[nn.Conv2d, nn.Linear, nn.Module]
|
(nn.Module) Module to replace the number of classes in. |
required |
num_classes |
int
|
New number of classes. |
required |
Returns:
Type | Description |
---|---|
nn.Module
|
nn.Module |
Source code in modules/head_replacement_utils.py
9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 |
|