resnet.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434
  1. ## @package resnet
  2. # Module caffe2.python.models.resnet
  3. from caffe2.python import brew
  4. import logging
  5. '''
  6. Utility for creating ResNe(X)t
  7. "Deep Residual Learning for Image Recognition" by He, Zhang et. al. 2015
  8. "Aggregated Residual Transformations for Deep Neural Networks" by Xie et. al. 2016
  9. '''
  10. class ResNetBuilder():
  11. '''
  12. Helper class for constructing residual blocks.
  13. '''
  14. def __init__(
  15. self,
  16. model,
  17. prev_blob,
  18. no_bias,
  19. is_test,
  20. bn_epsilon=1e-5,
  21. bn_momentum=0.9,
  22. ):
  23. self.model = model
  24. self.comp_count = 0
  25. self.comp_idx = 0
  26. self.prev_blob = prev_blob
  27. self.is_test = is_test
  28. self.bn_epsilon = bn_epsilon
  29. self.bn_momentum = bn_momentum
  30. self.no_bias = 1 if no_bias else 0
  31. def add_conv(
  32. self,
  33. in_filters,
  34. out_filters,
  35. kernel,
  36. stride=1,
  37. group=1,
  38. pad=0,
  39. ):
  40. self.comp_idx += 1
  41. self.prev_blob = brew.conv(
  42. self.model,
  43. self.prev_blob,
  44. 'comp_%d_conv_%d' % (self.comp_count, self.comp_idx),
  45. in_filters,
  46. out_filters,
  47. weight_init=("MSRAFill", {}),
  48. kernel=kernel,
  49. stride=stride,
  50. group=group,
  51. pad=pad,
  52. no_bias=self.no_bias,
  53. )
  54. return self.prev_blob
  55. def add_relu(self):
  56. self.prev_blob = brew.relu(
  57. self.model,
  58. self.prev_blob,
  59. self.prev_blob, # in-place
  60. )
  61. return self.prev_blob
  62. def add_spatial_bn(self, num_filters):
  63. self.prev_blob = brew.spatial_bn(
  64. self.model,
  65. self.prev_blob,
  66. 'comp_%d_spatbn_%d' % (self.comp_count, self.comp_idx),
  67. num_filters,
  68. epsilon=self.bn_epsilon,
  69. momentum=self.bn_momentum,
  70. is_test=self.is_test,
  71. )
  72. return self.prev_blob
  73. '''
  74. Add a "bottleneck" component as described in He et. al. Figure 3 (right)
  75. '''
  76. def add_bottleneck(
  77. self,
  78. input_filters, # num of feature maps from preceding layer
  79. base_filters, # num of filters internally in the component
  80. output_filters, # num of feature maps to output
  81. stride=1,
  82. group=1,
  83. spatial_batch_norm=True,
  84. ):
  85. self.comp_idx = 0
  86. shortcut_blob = self.prev_blob
  87. # 1x1
  88. self.add_conv(
  89. input_filters,
  90. base_filters,
  91. kernel=1,
  92. stride=1,
  93. )
  94. if spatial_batch_norm:
  95. self.add_spatial_bn(base_filters)
  96. self.add_relu()
  97. # 3x3 (note the pad, required for keeping dimensions)
  98. self.add_conv(
  99. base_filters,
  100. base_filters,
  101. kernel=3,
  102. stride=stride,
  103. group=group,
  104. pad=1,
  105. )
  106. if spatial_batch_norm:
  107. self.add_spatial_bn(base_filters)
  108. self.add_relu()
  109. # 1x1
  110. last_conv = self.add_conv(base_filters, output_filters, kernel=1)
  111. if spatial_batch_norm:
  112. last_conv = self.add_spatial_bn(output_filters)
  113. # Summation with input signal (shortcut)
  114. # When the number of feature maps mismatch between the input
  115. # and output (this usually happens when the residual stage
  116. # changes), we need to do a projection for the short cut
  117. if output_filters != input_filters:
  118. shortcut_blob = brew.conv(
  119. self.model,
  120. shortcut_blob,
  121. 'shortcut_projection_%d' % self.comp_count,
  122. input_filters,
  123. output_filters,
  124. weight_init=("MSRAFill", {}),
  125. kernel=1,
  126. stride=stride,
  127. no_bias=self.no_bias,
  128. )
  129. if spatial_batch_norm:
  130. shortcut_blob = brew.spatial_bn(
  131. self.model,
  132. shortcut_blob,
  133. 'shortcut_projection_%d_spatbn' % self.comp_count,
  134. output_filters,
  135. epsilon=self.bn_epsilon,
  136. momentum=self.bn_momentum,
  137. is_test=self.is_test,
  138. )
  139. self.prev_blob = brew.sum(
  140. self.model, [shortcut_blob, last_conv],
  141. 'comp_%d_sum_%d' % (self.comp_count, self.comp_idx)
  142. )
  143. self.comp_idx += 1
  144. self.add_relu()
  145. # Keep track of number of high level components if this ResNetBuilder
  146. self.comp_count += 1
  147. return output_filters
  148. def add_simple_block(
  149. self,
  150. input_filters,
  151. num_filters,
  152. down_sampling=False,
  153. spatial_batch_norm=True
  154. ):
  155. self.comp_idx = 0
  156. shortcut_blob = self.prev_blob
  157. # 3x3
  158. self.add_conv(
  159. input_filters,
  160. num_filters,
  161. kernel=3,
  162. stride=(1 if down_sampling is False else 2),
  163. pad=1
  164. )
  165. if spatial_batch_norm:
  166. self.add_spatial_bn(num_filters)
  167. self.add_relu()
  168. last_conv = self.add_conv(num_filters, num_filters, kernel=3, pad=1)
  169. if spatial_batch_norm:
  170. last_conv = self.add_spatial_bn(num_filters)
  171. # Increase of dimensions, need a projection for the shortcut
  172. if (num_filters != input_filters):
  173. shortcut_blob = brew.conv(
  174. self.model,
  175. shortcut_blob,
  176. 'shortcut_projection_%d' % self.comp_count,
  177. input_filters,
  178. num_filters,
  179. weight_init=("MSRAFill", {}),
  180. kernel=1,
  181. stride=(1 if down_sampling is False else 2),
  182. no_bias=self.no_bias,
  183. )
  184. if spatial_batch_norm:
  185. shortcut_blob = brew.spatial_bn(
  186. self.model,
  187. shortcut_blob,
  188. 'shortcut_projection_%d_spatbn' % self.comp_count,
  189. num_filters,
  190. epsilon=1e-3,
  191. is_test=self.is_test,
  192. )
  193. self.prev_blob = brew.sum(
  194. self.model, [shortcut_blob, last_conv],
  195. 'comp_%d_sum_%d' % (self.comp_count, self.comp_idx)
  196. )
  197. self.comp_idx += 1
  198. self.add_relu()
  199. # Keep track of number of high level components if this ResNetBuilder
  200. self.comp_count += 1
  201. def create_resnet_32x32(
  202. model, data, num_input_channels, num_groups, num_labels, is_test=False
  203. ):
  204. '''
  205. Create residual net for smaller images (sec 4.2 of He et. al (2015))
  206. num_groups = 'n' in the paper
  207. '''
  208. # conv1 + maxpool
  209. brew.conv(
  210. model, data, 'conv1', num_input_channels, 16, kernel=3, stride=1
  211. )
  212. brew.spatial_bn(
  213. model, 'conv1', 'conv1_spatbn', 16, epsilon=1e-3, is_test=is_test
  214. )
  215. brew.relu(model, 'conv1_spatbn', 'relu1')
  216. # Number of blocks as described in sec 4.2
  217. filters = [16, 32, 64]
  218. builder = ResNetBuilder(model, 'relu1', no_bias=0, is_test=is_test)
  219. prev_filters = 16
  220. for groupidx in range(0, 3):
  221. for blockidx in range(0, 2 * num_groups):
  222. builder.add_simple_block(
  223. prev_filters if blockidx == 0 else filters[groupidx],
  224. filters[groupidx],
  225. down_sampling=(True if blockidx == 0 and
  226. groupidx > 0 else False))
  227. prev_filters = filters[groupidx]
  228. # Final layers
  229. brew.average_pool(
  230. model, builder.prev_blob, 'final_avg', kernel=8, stride=1
  231. )
  232. brew.fc(model, 'final_avg', 'last_out', 64, num_labels)
  233. softmax = brew.softmax(model, 'last_out', 'softmax')
  234. return softmax
  235. RESNEXT_BLOCK_CONFIG = {
  236. 18: (2, 2, 2, 2),
  237. 34: (3, 4, 6, 3),
  238. 50: (3, 4, 6, 3),
  239. 101: (3, 4, 23, 3),
  240. 152: (3, 8, 36, 3),
  241. 200: (3, 24, 36, 3),
  242. }
  243. RESNEXT_STRIDES = [1, 2, 2, 2]
  244. logging.basicConfig()
  245. log = logging.getLogger("resnext_builder")
  246. log.setLevel(logging.DEBUG)
  247. # The conv1 and final_avg kernel/stride args provide a basic mechanism for
  248. # adapting resnet50 for different sizes of input images.
  249. def create_resnext(
  250. model,
  251. data,
  252. num_input_channels,
  253. num_labels,
  254. num_layers,
  255. num_groups,
  256. num_width_per_group,
  257. label=None,
  258. is_test=False,
  259. no_loss=False,
  260. no_bias=1,
  261. conv1_kernel=7,
  262. conv1_stride=2,
  263. final_avg_kernel=7,
  264. log=None,
  265. bn_epsilon=1e-5,
  266. bn_momentum=0.9,
  267. ):
  268. if num_layers not in RESNEXT_BLOCK_CONFIG:
  269. log.error("{}-layer is invalid for resnext config".format(num_layers))
  270. num_blocks = RESNEXT_BLOCK_CONFIG[num_layers]
  271. strides = RESNEXT_STRIDES
  272. num_filters = [64, 256, 512, 1024, 2048]
  273. if num_layers in [18, 34]:
  274. num_filters = [64, 64, 128, 256, 512]
  275. # the number of features before the last FC layer
  276. num_features = num_filters[-1]
  277. # conv1 + maxpool
  278. conv_blob = brew.conv(
  279. model,
  280. data,
  281. 'conv1',
  282. num_input_channels,
  283. num_filters[0],
  284. weight_init=("MSRAFill", {}),
  285. kernel=conv1_kernel,
  286. stride=conv1_stride,
  287. pad=3,
  288. no_bias=no_bias
  289. )
  290. bn_blob = brew.spatial_bn(
  291. model,
  292. conv_blob,
  293. 'conv1_spatbn_relu',
  294. num_filters[0],
  295. epsilon=bn_epsilon,
  296. momentum=bn_momentum,
  297. is_test=is_test
  298. )
  299. relu_blob = brew.relu(model, bn_blob, bn_blob)
  300. max_pool = brew.max_pool(model, relu_blob, 'pool1', kernel=3, stride=2, pad=1)
  301. # Residual blocks...
  302. builder = ResNetBuilder(model, max_pool, no_bias=no_bias,
  303. is_test=is_test, bn_epsilon=1e-5, bn_momentum=0.9)
  304. inner_dim = num_groups * num_width_per_group
  305. # 4 different kinds of residual blocks
  306. for residual_idx in range(4):
  307. residual_num = num_blocks[residual_idx]
  308. residual_stride = strides[residual_idx]
  309. dim_in = num_filters[residual_idx]
  310. for blk_idx in range(residual_num):
  311. dim_in = builder.add_bottleneck(
  312. dim_in,
  313. inner_dim,
  314. num_filters[residual_idx + 1], # dim out
  315. stride=residual_stride if blk_idx == 0 else 1,
  316. group=num_groups,
  317. )
  318. inner_dim *= 2
  319. # Final layers
  320. final_avg = brew.average_pool(
  321. model,
  322. builder.prev_blob,
  323. 'final_avg',
  324. kernel=final_avg_kernel,
  325. stride=1,
  326. global_pooling=True,
  327. )
  328. # Final dimension of the "image" is reduced to 7x7
  329. last_out = brew.fc(
  330. model, final_avg, 'last_out_L{}'.format(num_labels), num_features, num_labels
  331. )
  332. if no_loss:
  333. return last_out
  334. # If we create model for training, use softmax-with-loss
  335. if (label is not None):
  336. (softmax, loss) = model.SoftmaxWithLoss(
  337. [last_out, label],
  338. ["softmax", "loss"],
  339. )
  340. return (softmax, loss)
  341. else:
  342. # For inference, we just return softmax
  343. return brew.softmax(model, last_out, "softmax")
  344. # The conv1 and final_avg kernel/stride args provide a basic mechanism for
  345. # adapting resnet50 for different sizes of input images.
  346. def create_resnet50(
  347. model,
  348. data,
  349. num_input_channels,
  350. num_labels,
  351. label=None,
  352. is_test=False,
  353. no_loss=False,
  354. no_bias=0,
  355. conv1_kernel=7,
  356. conv1_stride=2,
  357. final_avg_kernel=7,
  358. ):
  359. # resnet50 is a special case for ResNeXt50-1x64d
  360. return create_resnext(
  361. model,
  362. data,
  363. num_input_channels,
  364. num_labels,
  365. num_layers=50,
  366. num_groups=1,
  367. num_width_per_group=64,
  368. label=label,
  369. is_test=is_test,
  370. no_loss=no_loss,
  371. no_bias=no_bias,
  372. conv1_kernel=conv1_kernel,
  373. conv1_stride=conv1_stride,
  374. final_avg_kernel=final_avg_kernel,
  375. )