utils.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474
  1. import copy
  2. import logging
  3. from collections import defaultdict
  4. import numpy as np
  5. from caffe2.python import core, utils
  6. from caffe2.python.fb import hardcode_scale_zp # type: ignore[import]
  7. logger = logging.getLogger(__name__)
  8. logger.setLevel(logging.DEBUG)
  9. def pairwise(iterable):
  10. "s -> (s0,s1), (s1,s2), (s2, s3), ..."
  11. from itertools import tee
  12. a, b = tee(iterable)
  13. next(b, None)
  14. return zip(a, b)
  15. def blob_uses(net, blob):
  16. u = []
  17. for i, op in enumerate(net.op):
  18. if blob in op.input or blob in op.control_input:
  19. u.append(i)
  20. return u
  21. def fuse_first_bn(net, params, removed_tensors, begin_op_index):
  22. net = copy.deepcopy(net)
  23. params = copy.deepcopy(params)
  24. for i, conv in enumerate(net.op[begin_op_index:], begin_op_index):
  25. if conv.type not in ["Conv", "ConvTranspose"]:
  26. continue
  27. uses = blob_uses(net, conv.output[0])
  28. if len(uses) == 0:
  29. continue
  30. j = uses[0]
  31. bn = net.op[j]
  32. if bn.type != "SpatialBN" or (len(uses) > 1 and conv.output[0] != bn.output[0]):
  33. if bn.type == "SpatialBN":
  34. logger.debug("Can't fuse if more than one user {}".format(uses))
  35. # Can't fuse if more than one user unless SpatialBN is inplace
  36. # An example of inplace SpatialBN where we want to allow multiple uses:
  37. # x = Conv(...)
  38. # ... // no interferring use or def of x (will be checked below)
  39. # x = SpatialBN(x, ...)
  40. # ...
  41. # z = Foo(..., x, ...)
  42. # ...
  43. # w = Boo(..., x, ...)
  44. # Here, we still want to fuse Conv and SpatialBN
  45. continue
  46. # There shouldn't be any def of conv.output[0] and any use or def of bn.output[0] between conv and bn
  47. if any(
  48. blob in net.op[k].input or blob in net.op[k].output
  49. for blob in [conv.output[0], bn.output[0]]
  50. for k in range(i + 1, j)
  51. ):
  52. logger.debug(
  53. "Can't fuse because of the following interferring uses or defs:"
  54. )
  55. for k in range(i, j + 1):
  56. logger.debug(net.op[k])
  57. continue
  58. # else, can fuse
  59. fused_conv = copy.deepcopy(conv)
  60. fused_conv.output[0] = bn.output[0]
  61. conv_weight = params[conv.input[1]]
  62. if len(conv.input) > 2:
  63. conv_bias = params[conv.input[2]]
  64. else:
  65. conv_bias = np.zeros(len(params[bn.input[2]])).astype(np.float32)
  66. bn_scale = params[bn.input[1]]
  67. bn_bias = params[bn.input[2]]
  68. bn_running_mean = params[bn.input[3]]
  69. bn_running_var = params[bn.input[4]]
  70. # First, BN computation can be phrased as follows:
  71. # (X - running_mean) * (1.0 / sqrt(running_var + eps)) *
  72. # bn_scale + bias
  73. # Thus, we can rewrite bn_scale as:
  74. # X * bn_scale * 1.0 / (sqrt(running_var + eps)) + (bias -
  75. # running_mean * (1.0 / sqrt(running_var + eps)) * bn_scale)
  76. # Thus, can just have the affine transform
  77. # X * A + B
  78. # where
  79. # A = bn_scale * 1.0 / (sqrt(running_var + eps))
  80. # B = (bias - running_mean * (1.0 / sqrt(running_var + eps))
  81. # * bn_scale)
  82. eps = 1.0e-5
  83. for arg in bn.arg:
  84. if arg.name == "epsilon":
  85. eps = arg.f
  86. A = bn_scale * 1.0 / (np.sqrt(bn_running_var + eps))
  87. B = bn_bias - bn_running_mean * A
  88. # This identity should hold if we have correctly fused
  89. # np.testing.assert_array_equal(
  90. # params[conv.output[0]] * A + B,
  91. # params[bn.output[0]])
  92. # Now, we have that the computation made is the following:
  93. # ((X `conv` W) + b) * A + B
  94. # Then, we can simply fuse this as follows:
  95. # (X `conv` (W * A)) + b * A + B
  96. # which is simply
  97. # (X `conv` Q) + C
  98. # where
  99. # Q = W * A
  100. # C = b * A + B
  101. # For ConvTranspose, from the view of convolutions as a
  102. # Toepeliz multiplication, we have W_ = W^T, so the weights
  103. # are laid out as (R, S, K, K) (vs (S, R, K, K) for a Conv),
  104. # so the weights broadcast slightly differently. Remember, our
  105. # BN scale 'B' is of size (S,)
  106. A_ = (
  107. A.reshape((-1,) + tuple([1] * (conv_weight.ndim - 1)))
  108. if conv.type == "Conv"
  109. else A.reshape((1, -1) + tuple([1] * (conv_weight.ndim - 2)))
  110. )
  111. C = conv_bias * A + B
  112. Q = conv_weight * A_
  113. assert params[conv.input[1]].shape == Q.shape
  114. if len(conv.input) > 2:
  115. assert params[conv.input[2]].shape == C.shape
  116. else:
  117. assert bn_bias.shape == C.shape
  118. params[conv.input[1]] = Q
  119. if len(conv.input) > 2:
  120. params[conv.input[2]] = C
  121. else:
  122. params[bn.input[2]] = C
  123. fused_conv.input.append(bn.input[2])
  124. new_ops = net.op[:i] + [fused_conv] + net.op[i + 1 : j] + net.op[j + 1 :]
  125. del net.op[:]
  126. removed_tensors.append(bn.input[1])
  127. if len(conv.input) > 2:
  128. removed_tensors.append(bn.input[2])
  129. removed_tensors.append(bn.input[3])
  130. removed_tensors.append(bn.input[4])
  131. del params[bn.input[1]]
  132. if len(conv.input) > 2:
  133. del params[bn.input[2]]
  134. del params[bn.input[3]]
  135. del params[bn.input[4]]
  136. net.op.extend(new_ops)
  137. return net, params, removed_tensors, i + 1
  138. return net, params, removed_tensors, None
  139. def fuse_bn(net, params, ignore_failure):
  140. # Run until we hit a fixed point
  141. removed_tensors = []
  142. begin_op_index = 0
  143. while True:
  144. (next_net, next_params, removed_tensors, begin_op_index) = fuse_first_bn(
  145. net, params, removed_tensors, begin_op_index
  146. )
  147. if begin_op_index is None:
  148. if any(op.type == "SpatialBN" for op in next_net.op) and not ignore_failure:
  149. raise Exception(
  150. "Model contains SpatialBN op after fusion: %s", next_net
  151. )
  152. return (next_net, next_params, removed_tensors)
  153. net, params, removed_tensors = (next_net, next_params, removed_tensors)
  154. def fuse_first_scale(net, params, removed_tensors):
  155. net = copy.deepcopy(net)
  156. params = copy.deepcopy(params)
  157. for ((i, current), (j, next_)) in pairwise(enumerate(net.op)):
  158. if next_.input[0] != current.output[0]:
  159. continue
  160. if (
  161. current.type != "SpatialBN"
  162. or next_.type != "Mul"
  163. or len(net.op) <= j + 1
  164. or net.op[j + 1].type != "Add"
  165. ):
  166. continue
  167. # else, can fuse
  168. bn = current
  169. mul = next_
  170. add = net.op[j + 1]
  171. fused_bn = copy.deepcopy(bn)
  172. fused_bn.output[0] = add.output[0]
  173. bn_scale = params[bn.input[1]]
  174. mul_scale = params[mul.input[1]]
  175. bn_bias = params[bn.input[2]]
  176. add_bias = params[add.input[1]]
  177. params[bn.input[1]] = bn_scale * mul_scale
  178. params[bn.input[2]] = mul_scale * bn_bias + add_bias
  179. new_ops = net.op[:i] + [fused_bn] + net.op[j + 2 :]
  180. del net.op[:]
  181. removed_tensors.append(mul.input[1])
  182. removed_tensors.append(add.input[1])
  183. del params[mul.input[1]]
  184. del params[add.input[1]]
  185. net.op.extend(new_ops)
  186. break
  187. return net, params, removed_tensors
  188. def fuse_scale(net, params, ignore_failure):
  189. # Run until we hit a fixed point
  190. removed_tensors = []
  191. while True:
  192. (next_net, next_params, removed_tensors) = fuse_first_scale(
  193. net, params, removed_tensors
  194. )
  195. if len(next_net.op) == len(net.op):
  196. return (next_net, next_params, removed_tensors)
  197. net, params, removed_tensors = (next_net, next_params, removed_tensors)
  198. def fuse_first_relu(net, begin_op_index, ignore_op_with_output=None):
  199. net = copy.deepcopy(net)
  200. for i, conv in enumerate(net.op[begin_op_index:], begin_op_index):
  201. if conv.type not in ["Conv", "ConvTranspose", "Sum", "SpatialBN"]:
  202. continue
  203. uses = blob_uses(net, conv.output[0])
  204. if (
  205. len(uses) == 0
  206. or ignore_op_with_output
  207. and conv.output[0] in ignore_op_with_output
  208. ):
  209. continue
  210. j = uses[0]
  211. relu = net.op[j]
  212. if relu.type != "Relu" or len(uses) > 1 and conv.output[0] != relu.output[0]:
  213. # Can't fuse if more than one user unless Relu is inplace
  214. if relu.type == "Relu":
  215. logger.debug("Can't fuse if more than one user {}".format(uses))
  216. continue
  217. # There shouldn't be any def of conv.output[0] and any use or def of relu.output[0] between conv and relu
  218. if any(
  219. blob in net.op[k].input or blob in net.op[k].output
  220. for blob in [conv.output[0], relu.output[0]]
  221. for k in range(i + 1, j)
  222. ):
  223. logger.debug(
  224. "Can't fuse because of the following interferring uses or defs:"
  225. )
  226. for k in range(i, j + 1):
  227. logger.debug(net.op[k])
  228. continue
  229. # else, can fuse
  230. fused_conv = copy.deepcopy(conv)
  231. fused_conv.type = conv.type + "Relu"
  232. fused_conv.output[0] = relu.output[0]
  233. new_ops = net.op[:i] + [fused_conv] + net.op[i + 1 : j] + net.op[j + 1 :]
  234. del net.op[:]
  235. net.op.extend(new_ops)
  236. return net, i + 1
  237. return net, None
  238. def fuse_relu(net, ignore_failure, ignore_op_with_output=None):
  239. # Run until we hit a fixed point
  240. begin_op_index = 0
  241. while True:
  242. next_net, begin_op_index = fuse_first_relu(
  243. net, begin_op_index, ignore_op_with_output
  244. )
  245. if begin_op_index is None:
  246. if any(op.type == "Relu" for op in next_net.op) and not ignore_failure:
  247. raise Exception("Model contains Relu op after fusion: %s", next_net)
  248. return next_net
  249. net = next_net
  250. def last_producer(ops, blob):
  251. for (i, op) in reversed(list(enumerate(ops))):
  252. if op.output[0] == blob:
  253. return i
  254. raise ValueError("Failed to find last producer of blob, %s", blob)
  255. def swap_first_concat_relu(net, ignore_op_with_output=None):
  256. net = copy.deepcopy(net)
  257. for ((i, current), (j, next_)) in pairwise(enumerate(net.op)):
  258. if next_.input[0] != current.output[0]:
  259. continue
  260. if current.type != "Concat" or next_.type != "Relu":
  261. continue
  262. if ignore_op_with_output and current.output[0] in ignore_op_with_output:
  263. continue
  264. # else, can swap
  265. concat = copy.deepcopy(current)
  266. relu = copy.deepcopy(next_)
  267. pre_ops = copy.deepcopy(net.op[:i])
  268. post_ops = copy.deepcopy(net.op[j + 1 :])
  269. # Delete the Relu after Concat
  270. concat.output[0] = relu.output[0]
  271. # Insert Relu after each op that produces inputs to Concat
  272. for blob in concat.input:
  273. k = last_producer(pre_ops, blob)
  274. producer = pre_ops[k]
  275. assert producer.output[0] == blob
  276. producer.output[0] = blob + "_pre_relu"
  277. new_relu = copy.deepcopy(relu)
  278. new_relu.input[0] = producer.output[0]
  279. new_relu.output[0] = blob
  280. pre_ops = pre_ops[: k + 1] + [new_relu] + pre_ops[k + 1 :]
  281. new_ops = pre_ops + [concat] + post_ops
  282. del net.op[:]
  283. net.op.extend(new_ops)
  284. break
  285. return net
  286. def swap_concat_relu(net, ignore_op_with_output=None):
  287. # Run until we hit a fixed point
  288. while True:
  289. next_net = swap_first_concat_relu(net, ignore_op_with_output)
  290. if len(next_net.op) == len(net.op):
  291. return next_net
  292. net = next_net
  293. def add_version_to_conv_bias(net, init_net):
  294. """
  295. In architectures such as FPN (https://arxiv.org/abs/1612.03144), few Conv
  296. ops share the same weight and bias and are run at different scales of
  297. the input. Since 'bias_scale = input_scale * weight_scale', sharing the
  298. same bias blob among multiple Conv ops means that we need different bias
  299. scale for each of the ops. To achieve this, we just duplicate those bias
  300. blobs that are used by multiple Conv ops before performing int8 rewrite.
  301. """
  302. bias_count = defaultdict(int)
  303. for op in net._net.op:
  304. if "Conv" in op.type and len(op.input) >= 3:
  305. bias_count[op.input[2]] += 1
  306. bias_fill_op = {}
  307. for op in init_net._net.op:
  308. if bias_count[op.output[0]] > 1:
  309. bias_fill_op[op.output[0]] = op
  310. bias_version = defaultdict(int)
  311. for op in net._net.op:
  312. if "Conv" in op.type and len(op.input) >= 3:
  313. bias = op.input[2]
  314. if bias_count[bias] <= 1:
  315. continue
  316. version = bias_version[bias]
  317. bias_version[bias] += 1
  318. if version == 0:
  319. continue
  320. new_bias = bias + "_v" + str(version)
  321. fill_op = copy.deepcopy(bias_fill_op[bias])
  322. fill_op.output[0] = new_bias
  323. init_net._net.op.extend([fill_op])
  324. op.input[2] = new_bias
  325. net._net.external_input.append(new_bias)
  326. def add_quantization_param_args_(op, q_param):
  327. op.arg.extend(
  328. [
  329. utils.MakeArgument("Y_scale", q_param.scale),
  330. utils.MakeArgument("Y_zero_point", q_param.zero_point),
  331. ]
  332. )
  333. def choose_quantization_params(tensor_min, tensor_max, preserve_sparsity=False):
  334. if tensor_min < 0 and tensor_max > 0 and preserve_sparsity:
  335. symmetric_qmin = -(255 // 2 + 1)
  336. symmetric_qmax = 255 // 2
  337. max_scale = max(
  338. abs(tensor_min / symmetric_qmin), abs(tensor_max / symmetric_qmax)
  339. )
  340. tensor_min = max_scale * symmetric_qmin
  341. tensor_max = max_scale * symmetric_qmax
  342. q_param = hardcode_scale_zp.choose_quantization_params(tensor_min, tensor_max)
  343. if tensor_min < 0 and tensor_max > 0 and preserve_sparsity:
  344. q_param = hardcode_scale_zp.QuantizationParam(q_param.scale, 128)
  345. return q_param
  346. def add_quantization_param_args(op, tensor, preserve_sparsity=False):
  347. tensor_min = 0 if tensor.size == 0 else tensor.min()
  348. tensor_max = 0 if tensor.size == 0 else tensor.max()
  349. q_param = choose_quantization_params(tensor_min, tensor_max, preserve_sparsity)
  350. add_quantization_param_args_(op, q_param)
  351. return q_param
  352. def create_int8_given_tensor_fill(tensor, out_blob_name, preserve_sparsity=False):
  353. """
  354. Create Int8GivenTensorFill op that quantizes the given tensor and outputs
  355. an Int8Tensor with out_blob_name.
  356. """
  357. op = core.CreateOperator("Int8GivenTensorFill", [], out_blob_name)
  358. q_param = add_quantization_param_args(op, tensor, preserve_sparsity)
  359. quantized_tensor = (
  360. np.around(tensor / q_param.scale).astype(np.int32) + q_param.zero_point
  361. )
  362. quantized_tensor = np.maximum(0, np.minimum(quantized_tensor, 255))
  363. op.arg.extend(
  364. [
  365. utils.MakeArgument("values", quantized_tensor.astype(np.uint8).tobytes()),
  366. utils.MakeArgument("shape", quantized_tensor.shape),
  367. ]
  368. )
  369. return op, q_param
  370. def create_int8_bias_tensor_fill(tensor, out_blob_name, x_q_param, w_q_param):
  371. """
  372. Similar to create_int8_given_tensor_fill, but for bias blobs to be stored
  373. as int32.
  374. """
  375. scale = x_q_param.scale * w_q_param.scale
  376. quantized_tensor = np.around(tensor / scale).astype(np.int32)
  377. quantized_tensor.reshape(-1)
  378. op = core.CreateOperator("Int8GivenIntTensorFill", [], out_blob_name)
  379. op.arg.extend(
  380. [
  381. utils.MakeArgument("values", quantized_tensor),
  382. utils.MakeArgument("shape", quantized_tensor.shape),
  383. ]
  384. )
  385. q_param = hardcode_scale_zp.QuantizationParam(scale, 0)
  386. add_quantization_param_args_(op, q_param)
  387. return op