shufflenet.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. # Module caffe2.python.models.shufflenet
  2. from caffe2.python import brew
  3. """
  4. Utilitiy for creating ShuffleNet
  5. "ShuffleNet V2: Practical Guidelines for EfficientCNN Architecture Design" by Ma et. al. 2018
  6. """
  7. OUTPUT_CHANNELS = {
  8. '0.5x': [24, 48, 96, 192, 1024],
  9. '1.0x': [24, 116, 232, 464, 1024],
  10. '1.5x': [24, 176, 352, 704, 1024],
  11. '2.0x': [24, 244, 488, 976, 2048],
  12. }
  13. class ShuffleNetV2Builder():
  14. def __init__(
  15. self,
  16. model,
  17. data,
  18. num_input_channels,
  19. num_labels,
  20. num_groups=2,
  21. width='1.0x',
  22. is_test=False,
  23. detection=False,
  24. bn_epsilon=1e-5,
  25. ):
  26. self.model = model
  27. self.prev_blob = data
  28. self.num_input_channels = num_input_channels
  29. self.num_labels = num_labels
  30. self.num_groups = num_groups
  31. self.output_channels = OUTPUT_CHANNELS[width]
  32. self.stage_repeats = [3, 7, 3]
  33. self.is_test = is_test
  34. self.detection = detection
  35. self.bn_epsilon = bn_epsilon
  36. def create(self):
  37. in_channels = self.output_channels[0]
  38. self.prev_blob = brew.conv(self.model, self.prev_blob, 'stage1_conv',
  39. self.num_input_channels, in_channels,
  40. weight_init=("MSRAFill", {}),
  41. kernel=3, stride=2)
  42. self.prev_blob = brew.max_pool(self.model, self.prev_blob,
  43. 'stage1_pool', kernel=3, stride=2)
  44. # adds stage#{2,3,4}; see table 5 of the ShufflenetV2 paper.
  45. for idx, (out_channels, n_repeats) in enumerate(zip(
  46. self.output_channels[1:4], self.stage_repeats
  47. )):
  48. prefix = 'stage{}_stride{}'.format(idx + 2, 2)
  49. self.add_spatial_ds_unit(prefix, in_channels, out_channels)
  50. in_channels = out_channels
  51. for i in range(n_repeats):
  52. prefix = 'stage{}_stride{}_repeat{}'.format(
  53. idx + 2, 1, i + 1
  54. )
  55. self.add_basic_unit(prefix, in_channels)
  56. self.last_conv = brew.conv(self.model, self.prev_blob, 'conv5',
  57. in_channels, self.output_channels[4],
  58. kernel=1)
  59. self.avg_pool = self.model.AveragePool(self.last_conv, 'avg_pool',
  60. kernel=7)
  61. self.last_out = brew.fc(self.model,
  62. self.avg_pool,
  63. 'last_out_L{}'.format(self.num_labels),
  64. self.output_channels[4],
  65. self.num_labels)
  66. # spatial down sampling unit with stride=2
  67. def add_spatial_ds_unit(self, prefix, in_channels, out_channels, stride=2):
  68. right = left = self.prev_blob
  69. out_channels = out_channels // 2
  70. # Enlarge the receptive field for detection task
  71. if self.detection:
  72. left = self.add_detection_unit(left, prefix + '_left_detection',
  73. in_channels, in_channels)
  74. left = self.add_dwconv3x3_bn(left, prefix + 'left_dwconv',
  75. in_channels, stride)
  76. left = self.add_conv1x1_bn(left, prefix + '_left_conv1', in_channels,
  77. out_channels)
  78. if self.detection:
  79. right = self.add_detection_unit(right, prefix + '_right_detection',
  80. in_channels, in_channels)
  81. right = self.add_conv1x1_bn(right, prefix + '_right_conv1',
  82. in_channels, out_channels)
  83. right = self.add_dwconv3x3_bn(right, prefix + '_right_dwconv',
  84. out_channels, stride)
  85. right = self.add_conv1x1_bn(right, prefix + '_right_conv2',
  86. out_channels, out_channels)
  87. self.prev_blob = brew.concat(self.model, [right, left],
  88. prefix + '_concat')
  89. self.prev_blob = self.model.net.ChannelShuffle(
  90. self.prev_blob, prefix + '_ch_shuffle',
  91. group=self.num_groups, kernel=1
  92. )
  93. # basic unit with stride=1
  94. def add_basic_unit(self, prefix, in_channels, stride=1):
  95. in_channels = in_channels // 2
  96. left = prefix + '_left'
  97. right = prefix + '_right'
  98. self.model.net.Split(self.prev_blob, [left, right])
  99. if self.detection:
  100. right = self.add_detection_unit(right, prefix + '_right_detection',
  101. in_channels, in_channels)
  102. right = self.add_conv1x1_bn(right, prefix + '_right_conv1',
  103. in_channels, in_channels)
  104. right = self.add_dwconv3x3_bn(right, prefix + '_right_dwconv',
  105. in_channels, stride)
  106. right = self.add_conv1x1_bn(right, prefix + '_right_conv2',
  107. in_channels, in_channels)
  108. self.prev_blob = brew.concat(self.model, [right, left],
  109. prefix + '_concat')
  110. self.prev_blob = self.model.net.ChannelShuffle(
  111. self.prev_blob, prefix + '_ch_shuffle',
  112. group=self.num_groups, kernel=1
  113. )
  114. # helper functions to create net's units
  115. def add_detection_unit(self, prev_blob, prefix, in_channels, out_channels,
  116. kernel=3, pad=1):
  117. out_blob = brew.conv(self.model, prev_blob, prefix + '_conv',
  118. in_channels, out_channels, kernel=kernel,
  119. weight_init=("MSRAFill", {}),
  120. group=in_channels, pad=pad)
  121. out_blob = brew.spatial_bn(self.model, out_blob, prefix + '_bn',
  122. out_channels, epsilon=self.bn_epsilon,
  123. is_test=self.is_test)
  124. return out_blob
  125. def add_conv1x1_bn(self, prev_blob, blob, in_channels, out_channels):
  126. prev_blob = brew.conv(self.model, prev_blob, blob, in_channels,
  127. out_channels, kernel=1,
  128. weight_init=("MSRAFill", {}))
  129. prev_blob = brew.spatial_bn(self.model, prev_blob, prev_blob + '_bn',
  130. out_channels,
  131. epsilon=self.bn_epsilon,
  132. is_test=self.is_test)
  133. prev_blob = brew.relu(self.model, prev_blob, prev_blob)
  134. return prev_blob
  135. def add_dwconv3x3_bn(self, prev_blob, blob, channels, stride):
  136. prev_blob = brew.conv(self.model, prev_blob, blob, channels,
  137. channels, kernel=3,
  138. weight_init=("MSRAFill", {}),
  139. stride=stride, group=channels, pad=1)
  140. prev_blob = brew.spatial_bn(self.model, prev_blob,
  141. prev_blob + '_bn',
  142. channels,
  143. epsilon=self.bn_epsilon,
  144. is_test=self.is_test)
  145. return prev_blob
  146. def create_shufflenet(
  147. model,
  148. data,
  149. num_input_channels,
  150. num_labels,
  151. label=None,
  152. is_test=False,
  153. no_loss=False,
  154. ):
  155. builder = ShuffleNetV2Builder(model, data, num_input_channels,
  156. num_labels,
  157. is_test=is_test)
  158. builder.create()
  159. if no_loss:
  160. return builder.last_out
  161. if (label is not None):
  162. (softmax, loss) = model.SoftmaxWithLoss(
  163. [builder.last_out, label],
  164. ["softmax", "loss"],
  165. )
  166. return (softmax, loss)