channelshuffle.py 1.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. from .module import Module
  2. from .. import functional as F
  3. from torch import Tensor
  4. class ChannelShuffle(Module):
  5. r"""Divide the channels in a tensor of shape :math:`(*, C , H, W)`
  6. into g groups and rearrange them as :math:`(*, C \frac g, g, H, W)`,
  7. while keeping the original tensor shape.
  8. Args:
  9. groups (int): number of groups to divide channels in.
  10. Examples::
  11. >>> channel_shuffle = nn.ChannelShuffle(2)
  12. >>> input = torch.randn(1, 4, 2, 2)
  13. >>> print(input)
  14. [[[[1, 2],
  15. [3, 4]],
  16. [[5, 6],
  17. [7, 8]],
  18. [[9, 10],
  19. [11, 12]],
  20. [[13, 14],
  21. [15, 16]],
  22. ]]
  23. >>> output = channel_shuffle(input)
  24. >>> print(output)
  25. [[[[1, 2],
  26. [3, 4]],
  27. [[9, 10],
  28. [11, 12]],
  29. [[5, 6],
  30. [7, 8]],
  31. [[13, 14],
  32. [15, 16]],
  33. ]]
  34. """
  35. __constants__ = ['groups']
  36. groups: int
  37. def __init__(self, groups: int) -> None:
  38. super(ChannelShuffle, self).__init__()
  39. self.groups = groups
  40. def forward(self, input: Tensor) -> Tensor:
  41. return F.channel_shuffle(input, self.groups)
  42. def extra_repr(self) -> str:
  43. return 'groups={}'.format(self.groups)