cnn.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240
  1. ## @package cnn
  2. # Module caffe2.python.cnn
  3. from caffe2.python import brew, workspace
  4. from caffe2.python.model_helper import ModelHelper
  5. from caffe2.proto import caffe2_pb2
  6. import logging
  7. class CNNModelHelper(ModelHelper):
  8. """A helper model so we can write CNN models more easily, without having to
  9. manually define parameter initializations and operators separately.
  10. """
  11. def __init__(self, order="NCHW", name=None,
  12. use_cudnn=True, cudnn_exhaustive_search=False,
  13. ws_nbytes_limit=None, init_params=True,
  14. skip_sparse_optim=False,
  15. param_model=None):
  16. logging.warning(
  17. "[====DEPRECATE WARNING====]: you are creating an "
  18. "object from CNNModelHelper class which will be deprecated soon. "
  19. "Please use ModelHelper object with brew module. For more "
  20. "information, please refer to caffe2.ai and python/brew.py, "
  21. "python/brew_test.py for more information."
  22. )
  23. cnn_arg_scope = {
  24. 'order': order,
  25. 'use_cudnn': use_cudnn,
  26. 'cudnn_exhaustive_search': cudnn_exhaustive_search,
  27. }
  28. if ws_nbytes_limit:
  29. cnn_arg_scope['ws_nbytes_limit'] = ws_nbytes_limit
  30. super(CNNModelHelper, self).__init__(
  31. skip_sparse_optim=skip_sparse_optim,
  32. name="CNN" if name is None else name,
  33. init_params=init_params,
  34. param_model=param_model,
  35. arg_scope=cnn_arg_scope,
  36. )
  37. self.order = order
  38. self.use_cudnn = use_cudnn
  39. self.cudnn_exhaustive_search = cudnn_exhaustive_search
  40. self.ws_nbytes_limit = ws_nbytes_limit
  41. if self.order != "NHWC" and self.order != "NCHW":
  42. raise ValueError(
  43. "Cannot understand the CNN storage order %s." % self.order
  44. )
  45. def ImageInput(self, blob_in, blob_out, use_gpu_transform=False, **kwargs):
  46. return brew.image_input(
  47. self,
  48. blob_in,
  49. blob_out,
  50. order=self.order,
  51. use_gpu_transform=use_gpu_transform,
  52. **kwargs
  53. )
  54. def VideoInput(self, blob_in, blob_out, **kwargs):
  55. return brew.video_input(
  56. self,
  57. blob_in,
  58. blob_out,
  59. **kwargs
  60. )
  61. def PadImage(self, blob_in, blob_out, **kwargs):
  62. # TODO(wyiming): remove this dummy helper later
  63. self.net.PadImage(blob_in, blob_out, **kwargs)
  64. def ConvNd(self, *args, **kwargs):
  65. return brew.conv_nd(
  66. self,
  67. *args,
  68. use_cudnn=self.use_cudnn,
  69. order=self.order,
  70. cudnn_exhaustive_search=self.cudnn_exhaustive_search,
  71. ws_nbytes_limit=self.ws_nbytes_limit,
  72. **kwargs
  73. )
  74. def Conv(self, *args, **kwargs):
  75. return brew.conv(
  76. self,
  77. *args,
  78. use_cudnn=self.use_cudnn,
  79. order=self.order,
  80. cudnn_exhaustive_search=self.cudnn_exhaustive_search,
  81. ws_nbytes_limit=self.ws_nbytes_limit,
  82. **kwargs
  83. )
  84. def ConvTranspose(self, *args, **kwargs):
  85. return brew.conv_transpose(
  86. self,
  87. *args,
  88. use_cudnn=self.use_cudnn,
  89. order=self.order,
  90. cudnn_exhaustive_search=self.cudnn_exhaustive_search,
  91. ws_nbytes_limit=self.ws_nbytes_limit,
  92. **kwargs
  93. )
  94. def GroupConv(self, *args, **kwargs):
  95. return brew.group_conv(
  96. self,
  97. *args,
  98. use_cudnn=self.use_cudnn,
  99. order=self.order,
  100. cudnn_exhaustive_search=self.cudnn_exhaustive_search,
  101. ws_nbytes_limit=self.ws_nbytes_limit,
  102. **kwargs
  103. )
  104. def GroupConv_Deprecated(self, *args, **kwargs):
  105. return brew.group_conv_deprecated(
  106. self,
  107. *args,
  108. use_cudnn=self.use_cudnn,
  109. order=self.order,
  110. cudnn_exhaustive_search=self.cudnn_exhaustive_search,
  111. ws_nbytes_limit=self.ws_nbytes_limit,
  112. **kwargs
  113. )
  114. def FC(self, *args, **kwargs):
  115. return brew.fc(self, *args, **kwargs)
  116. def PackedFC(self, *args, **kwargs):
  117. return brew.packed_fc(self, *args, **kwargs)
  118. def FC_Prune(self, *args, **kwargs):
  119. return brew.fc_prune(self, *args, **kwargs)
  120. def FC_Decomp(self, *args, **kwargs):
  121. return brew.fc_decomp(self, *args, **kwargs)
  122. def FC_Sparse(self, *args, **kwargs):
  123. return brew.fc_sparse(self, *args, **kwargs)
  124. def Dropout(self, *args, **kwargs):
  125. return brew.dropout(
  126. self, *args, order=self.order, use_cudnn=self.use_cudnn, **kwargs
  127. )
  128. def LRN(self, *args, **kwargs):
  129. return brew.lrn(
  130. self, *args, order=self.order, use_cudnn=self.use_cudnn, **kwargs
  131. )
  132. def Softmax(self, *args, **kwargs):
  133. return brew.softmax(self, *args, use_cudnn=self.use_cudnn, **kwargs)
  134. def SpatialBN(self, *args, **kwargs):
  135. return brew.spatial_bn(self, *args, order=self.order, **kwargs)
  136. def SpatialGN(self, *args, **kwargs):
  137. return brew.spatial_gn(self, *args, order=self.order, **kwargs)
  138. def InstanceNorm(self, *args, **kwargs):
  139. return brew.instance_norm(self, *args, order=self.order, **kwargs)
  140. def Relu(self, *args, **kwargs):
  141. return brew.relu(
  142. self, *args, order=self.order, use_cudnn=self.use_cudnn, **kwargs
  143. )
  144. def PRelu(self, *args, **kwargs):
  145. return brew.prelu(self, *args, **kwargs)
  146. def Concat(self, *args, **kwargs):
  147. return brew.concat(self, *args, order=self.order, **kwargs)
  148. def DepthConcat(self, *args, **kwargs):
  149. """The old depth concat function - we should move to use concat."""
  150. print("DepthConcat is deprecated. use Concat instead.")
  151. return self.Concat(*args, **kwargs)
  152. def Sum(self, *args, **kwargs):
  153. return brew.sum(self, *args, **kwargs)
  154. def Transpose(self, *args, **kwargs):
  155. return brew.transpose(self, *args, use_cudnn=self.use_cudnn, **kwargs)
  156. def Iter(self, *args, **kwargs):
  157. return brew.iter(self, *args, **kwargs)
  158. def Accuracy(self, *args, **kwargs):
  159. return brew.accuracy(self, *args, **kwargs)
  160. def MaxPool(self, *args, **kwargs):
  161. return brew.max_pool(
  162. self, *args, use_cudnn=self.use_cudnn, order=self.order, **kwargs
  163. )
  164. def MaxPoolWithIndex(self, *args, **kwargs):
  165. return brew.max_pool_with_index(self, *args, order=self.order, **kwargs)
  166. def AveragePool(self, *args, **kwargs):
  167. return brew.average_pool(
  168. self, *args, use_cudnn=self.use_cudnn, order=self.order, **kwargs
  169. )
  170. @property
  171. def XavierInit(self):
  172. return ('XavierFill', {})
  173. def ConstantInit(self, value):
  174. return ('ConstantFill', dict(value=value))
  175. @property
  176. def MSRAInit(self):
  177. return ('MSRAFill', {})
  178. @property
  179. def ZeroInit(self):
  180. return ('ConstantFill', {})
  181. def AddWeightDecay(self, weight_decay):
  182. return brew.add_weight_decay(self, weight_decay)
  183. @property
  184. def CPU(self):
  185. device_option = caffe2_pb2.DeviceOption()
  186. device_option.device_type = caffe2_pb2.CPU
  187. return device_option
  188. @property
  189. def GPU(self, gpu_id=0):
  190. device_option = caffe2_pb2.DeviceOption()
  191. device_option.device_type = workspace.GpuDeviceType
  192. device_option.device_id = gpu_id
  193. return device_option