fc.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230
  1. ## @package fc
  2. # Module caffe2.python.layers.fc
  3. from caffe2.python.helpers.arg_scope import get_current_scope
  4. from caffe2.python import schema
  5. from caffe2.python.layers.layers import ModelLayer
  6. from caffe2.python.layers.sampling_trainable_mixin import SamplingTrainableMixin
  7. import math
  8. import numpy as np
  9. def get_fc_predictor_version(fc_version):
  10. assert fc_version in ["fp32", "fp16"], (
  11. "Only support fp32 and fp16 for the fully connected layer "
  12. "in the predictor net, the provided FC precision is {}".format(fc_version)
  13. )
  14. return fc_version
  15. class FC(SamplingTrainableMixin, ModelLayer):
  16. def __init__(self, model, input_record, output_dims, weight_init=None,
  17. bias_init=None, weight_optim=None, bias_optim=None, name='fc',
  18. weight_reg=None, bias_reg=None, clip_param=None,
  19. max_fc_size=None, axis=1, transposed=False,
  20. uniform_weight_init_scale_numerator=1.0,
  21. **kwargs):
  22. super(FC, self).__init__(model, name, input_record, **kwargs)
  23. assert isinstance(input_record, schema.Scalar), (
  24. "Incorrect input type {}".format(input_record))
  25. assert len(input_record.field_types()[0].shape) > 0, (
  26. "FC expects limited dimensions of the input tensor")
  27. assert axis >= 1, "axis {} should >= 1.".format(axis)
  28. self.axis = axis
  29. input_dims = np.prod(input_record.field_types()[0].shape[axis - 1:])
  30. assert input_dims > 0, (
  31. "FC expects input dimensions > 0, got {}".format(input_dims))
  32. self.clip_args = None
  33. if (clip_param is not None):
  34. assert len(clip_param) == 2, (
  35. 'clip_param must be a tuple / list '
  36. 'of length 2 and in the form of (clip_min, clip max)'
  37. )
  38. clip_min, clip_max = clip_param
  39. assert clip_min is not None or clip_max is not None, (
  40. 'clip_min, and clip_max in clip_param cannot both be None'
  41. )
  42. assert (
  43. (clip_min is None or clip_max is None) or clip_min < clip_max
  44. ), (
  45. 'clip_param = [clip_min, clip_max] must have clip_min < clip_max'
  46. )
  47. self.clip_args = {}
  48. if clip_min is not None:
  49. self.clip_args['min'] = clip_min
  50. if clip_max is not None:
  51. self.clip_args['max'] = clip_max
  52. if uniform_weight_init_scale_numerator is None:
  53. uniform_weight_init_scale_numerator = 1.0
  54. scale = math.sqrt(uniform_weight_init_scale_numerator / input_dims)
  55. weight_init = weight_init if weight_init else (
  56. 'UniformFill', {'min': -scale, 'max': scale})
  57. bias_init = bias_init if bias_init else (
  58. 'UniformFill', {'min': -scale, 'max': scale})
  59. self.output_dim_vec = FC.calculate_fc_output_dims(
  60. max_fc_size, input_dims, output_dims)
  61. self.transposed = transposed
  62. if self.output_dim_vec is None or len(self.output_dim_vec) == 1:
  63. weight_shape = [input_dims, output_dims] if transposed else [output_dims, input_dims]
  64. self.w = self.create_param(param_name='w',
  65. shape=weight_shape,
  66. initializer=weight_init,
  67. optimizer=weight_optim,
  68. regularizer=weight_reg)
  69. self.b = self.create_param(param_name='b',
  70. shape=[output_dims, ],
  71. initializer=bias_init,
  72. optimizer=bias_optim,
  73. regularizer=bias_reg)
  74. else:
  75. self.w_vec = []
  76. self.b_vec = []
  77. for idx, output_dim in enumerate(self.output_dim_vec):
  78. weight_shape = [input_dims, output_dim] if transposed else [output_dim, input_dims]
  79. self.w_vec.append(self.create_param(param_name='w_sub_{}'.format(idx),
  80. shape=weight_shape,
  81. initializer=weight_init,
  82. optimizer=weight_optim,
  83. regularizer=weight_reg))
  84. self.b_vec.append(self.create_param(param_name='b_sub_{}'.format(idx),
  85. shape=[output_dim, ],
  86. initializer=weight_init,
  87. optimizer=weight_optim,
  88. regularizer=weight_reg))
  89. if axis == 1:
  90. output_shape = (output_dims, )
  91. else:
  92. output_shape = list(input_record.field_types()[0].shape)[0: axis - 1]
  93. output_shape = tuple(output_shape + [output_dims])
  94. self.output_schema = schema.Scalar(
  95. (np.float32, output_shape),
  96. self.get_next_blob_reference('output')
  97. )
  98. @staticmethod
  99. def calculate_fc_output_dims(max_fc_size, input_dim, output_dim):
  100. if not max_fc_size or max_fc_size < 0:
  101. return None
  102. assert max_fc_size >= input_dim, "Currently we split along the output " \
  103. "dimension. So we need max_fc_size >= input_dim. But, max_fc_size: " \
  104. "{}, input_dim: {}".format(max_fc_size, input_dim)
  105. output_dim_allowed = int(np.floor(max_fc_size / input_dim))
  106. num_fc = int(np.floor((output_dim - 1) / output_dim_allowed) + 1)
  107. output_dim_vec = [output_dim_allowed] * (num_fc - 1)
  108. output_dim_vec.append(output_dim - sum(output_dim_vec))
  109. return output_dim_vec
  110. def _insert_fc_ops(self, net, params, outputs, version):
  111. """
  112. Args:
  113. net: the caffe2 net to insert operator
  114. params: weight and bias for FC
  115. outputs: the output blobs
  116. version: support fp32 and fp16 for now.
  117. """
  118. if version == "fp32":
  119. if self.transposed:
  120. return net.FCTransposed(
  121. self.input_record.field_blobs() + params,
  122. outputs,
  123. axis=self.axis,
  124. **self.kwargs
  125. )
  126. else:
  127. return net.FC(
  128. self.input_record.field_blobs() + params,
  129. outputs,
  130. axis=self.axis,
  131. **self.kwargs
  132. )
  133. elif version == "fp16":
  134. return net.FbFCPacked(
  135. self.input_record.field_blobs() + params,
  136. outputs,
  137. axis=self.axis,
  138. **self.kwargs
  139. )
  140. else:
  141. raise Exception("unsupported FC type version {}".format(version))
  142. def _add_ops(self, net, params, version):
  143. """
  144. Args:
  145. params : the weight and bias,
  146. passed by either add_ops or add_train_ops function
  147. version : fp16 or fp32, might support in8 in the future.
  148. """
  149. if self.clip_args is not None:
  150. clipped_params = [net.NextScopedBlob(
  151. 'clipped_%s' % str(p)) for p in params]
  152. for p, cp in zip(params, clipped_params):
  153. net.Clip([p], [cp], **self.clip_args)
  154. params = clipped_params
  155. if self.output_dim_vec is None or len(self.output_dim_vec) == 1:
  156. self._insert_fc_ops(net, params, self.output_schema.field_blobs(), version)
  157. else:
  158. w_vec = params[:int(len(params) / 2)]
  159. b_vec = params[int(len(params) / 2):]
  160. assert len(w_vec) == len(b_vec)
  161. output_blob_vec = []
  162. for i in range(len(self.output_dim_vec)):
  163. output_blob = net.NextScopedBlob(
  164. 'output_sub_{}'.format(i))
  165. insert_ret = self._insert_fc_ops(
  166. net, [w_vec[i], b_vec[i]], [output_blob], version
  167. )
  168. output_blob_vec.append(insert_ret)
  169. net.Concat(output_blob_vec,
  170. self.output_schema.field_blobs() +
  171. [self.output_schema.field_blobs()[0] + "_concat_dims"])
  172. def add_ops(self, net):
  173. """Both the predict net and the eval net will call this function
  174. """
  175. version_info = get_current_scope().get(
  176. get_fc_predictor_version.__name__, {'fc_version': 'fp32'}
  177. )
  178. predictor_fc_fp_version = version_info['fc_version']
  179. self._add_ops(net, self.param_blobs, predictor_fc_fp_version)
  180. def add_train_ops(self, net):
  181. # use the train_param_blobs to be consistent with the SamplingTrain unittest
  182. self._add_ops(net, self.train_param_blobs, "fp32")
  183. def get_fp16_compatible_parameters(self):
  184. if self.output_dim_vec is None or len(self.output_dim_vec) == 1:
  185. return [self.w]
  186. else:
  187. return self.w_vec
  188. @property
  189. def param_blobs(self):
  190. if self.output_dim_vec is None or len(self.output_dim_vec) == 1:
  191. return [self.w, self.b]
  192. else:
  193. return self.w_vec + self.b_vec