Bases: nn.Module
Equivalent to nn.PixelShuffle.
nn.PixelShuffle module is translated to DepthToSpace
layer in ONNX, some compilation frameworks (i.e tflite),
doesn't support this layer. In that case this module should be used, it's translated to
reshape / transpose / reshape operations in ONNX graph.
Source code in V3_1/src/super_gradients/modules/pixel_shuffle.py
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22 | class PixelShuffle(nn.Module):
"""
Equivalent to nn.PixelShuffle.
nn.PixelShuffle module is translated to `DepthToSpace` layer in ONNX, some compilation frameworks (i.e tflite),
doesn't support this layer. In that case this module should be used, it's translated to
reshape / transpose / reshape operations in ONNX graph.
"""
def __init__(self, upscale_factor: int):
super().__init__()
self.scale = upscale_factor
def forward(self, x: torch.Tensor):
b, c, h, w = x.size()
x = x.reshape(b, torch.div(c, self.scale * self.scale, rounding_mode="trunc"), self.scale, self.scale, h, w)
x = x.permute(0, 1, 4, 2, 5, 3)
x = x.reshape(b, torch.div(c, self.scale * self.scale, rounding_mode="trunc"), h * self.scale, w * self.scale)
return x
|