symbolic_caffe2.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318
  1. import importlib
  2. from inspect import getmembers, isfunction
  3. from torch.onnx import symbolic_helper
  4. from torch.onnx import symbolic_opset9 as opset9
  5. from torch.onnx import symbolic_registry
  6. def register_quantized_ops(domain: str, version: int):
  7. # Register all the non-quantized ops
  8. symbolic_registry.register_version("", version)
  9. # Register all quantized ops
  10. module = importlib.import_module("torch.onnx.symbolic_caffe2")
  11. symbolic_registry._symbolic_versions["caffe2"] = module
  12. quant_version_ops = getmembers(symbolic_registry._symbolic_versions["caffe2"])
  13. for op in quant_version_ops:
  14. if isfunction(op[1]) and not symbolic_registry.is_registered_op(
  15. op[0], domain, version
  16. ):
  17. aten_q_ops = [
  18. "relu",
  19. "_empty_affine_quantized",
  20. "dequantize",
  21. "quantize_per_tensor",
  22. "upsample_nearest2d",
  23. "avg_pool2d",
  24. "reshape",
  25. "slice",
  26. "cat",
  27. "max_pool2d",
  28. "sigmoid",
  29. ]
  30. if op[0] in aten_q_ops:
  31. symbolic_registry.register_op(op[0], op[1], "", version)
  32. symbolic_registry.register_op(op[0], op[1], domain, version)
  33. def _permute_helper(g, input, axes):
  34. quant_args = {
  35. "axes_i": axes,
  36. "Y_scale_f": input.node()["Y_scale"],
  37. "Y_zero_point_i": input.node()["Y_zero_point"],
  38. }
  39. output = g.op("_caffe2::Int8Transpose", input, **quant_args)
  40. symbolic_helper._quantized_ops.add(output)
  41. return output
  42. def nchw2nhwc(g, input):
  43. axes = [0, 2, 3, 1]
  44. return _permute_helper(g, input, axes)
  45. def nhwc2nchw(g, input):
  46. axes = [0, 3, 1, 2]
  47. return _permute_helper(g, input, axes)
  48. def linear_prepack(g, weight, bias):
  49. # Mapping to a dummy caffe2 prepack node.
  50. # During the onnx -> c2 conversion we can look up original weight and bias
  51. # from this node
  52. output = g.op("_caffe2::WeightPrepack", weight, bias)
  53. symbolic_helper._quantized_ops.add(output)
  54. return output
  55. @symbolic_helper.parse_args("v", "v", "v", "f", "i")
  56. def linear(g, input, weight, bias, scale, zero_point):
  57. kwargs = {
  58. "Y_scale_f": scale,
  59. "Y_zero_point_i": zero_point,
  60. }
  61. output = g.op("_caffe2::Int8FC", input, weight, bias, **kwargs)
  62. symbolic_helper._quantized_ops.add(output)
  63. return output
  64. def conv_prepack(g, input, weight, bias, stride, padding, dilation, groups):
  65. # Mapping to a dummy caffe2 prepack node.
  66. # During the onnx -> c2 conversion we can look up original weight and bias
  67. # from this node
  68. output = g.op("_caffe2::WeightPrepack", input, weight, bias)
  69. symbolic_helper._quantized_ops.add(output)
  70. return output
  71. @symbolic_helper.parse_args("v", "v", "v", "is", "is", "is", "i", "f", "i")
  72. def conv2d(
  73. g, input, weight, bias, stride, padding, dilation, groups, scale, zero_point
  74. ):
  75. kernel_size = weight.node()["shape"][1:3]
  76. kwargs = {
  77. "strides_i": stride,
  78. "pads_i": padding + padding,
  79. "dilations_i": dilation,
  80. "group_i": groups,
  81. "kernels_i": kernel_size,
  82. "order_s": "NHWC",
  83. "Y_scale_f": scale,
  84. "Y_zero_point_i": zero_point,
  85. }
  86. output = g.op("_caffe2::Int8Conv", input, weight, bias, **kwargs)
  87. symbolic_helper._quantized_ops.add(output)
  88. return output
  89. @symbolic_helper.parse_args("v", "v", "v", "is", "is", "is", "i", "f", "i")
  90. def conv2d_relu(
  91. g, input, weight, bias, stride, padding, dilation, groups, scale, zero_point
  92. ):
  93. kernel_size = weight.node()["shape"][1:3]
  94. kwargs = {
  95. "strides_i": stride,
  96. "pads_i": padding + padding,
  97. "dilations_i": dilation,
  98. "group_i": groups,
  99. "kernels_i": kernel_size,
  100. "order_s": "NHWC",
  101. "Y_scale_f": scale,
  102. "Y_zero_point_i": zero_point,
  103. }
  104. output = g.op("_caffe2::Int8ConvRelu", input, weight, bias, **kwargs)
  105. symbolic_helper._quantized_ops.add(output)
  106. return output
  107. @symbolic_helper.parse_args("v", "v", "f", "i")
  108. def add(g, input_a, input_b, scale, zero_point):
  109. kwargs = {
  110. "Y_scale_f": scale,
  111. "Y_zero_point_i": zero_point,
  112. }
  113. output = g.op("_caffe2::Int8Add", input_a, input_b, **kwargs)
  114. symbolic_helper._quantized_ops.add(output)
  115. return output
  116. @symbolic_helper.parse_args("v")
  117. def relu(g, input):
  118. if input not in symbolic_helper._quantized_ops:
  119. return opset9.relu(g, input)
  120. kwargs = {
  121. "Y_scale_f": input.node()["Y_scale"],
  122. "Y_zero_point_i": input.node()["Y_zero_point"],
  123. }
  124. output = g.op("_caffe2::Int8Relu", input, **kwargs)
  125. symbolic_helper._quantized_ops.add(output)
  126. return output
  127. @symbolic_helper.parse_args("v", "f", "i", "t")
  128. def quantize_per_tensor(g, input, scale, zero_point, dtype):
  129. kwargs = {
  130. "Y_scale_f": scale,
  131. "Y_zero_point_i": zero_point,
  132. }
  133. output = g.op("_caffe2::Int8Quantize", input, **kwargs)
  134. symbolic_helper._quantized_ops.add(output)
  135. return output
  136. @symbolic_helper.parse_args("v")
  137. def dequantize(g, input):
  138. return g.op("_caffe2::Int8Dequantize", input)
  139. @symbolic_helper.parse_args("v", "t", "t", "t", "t", "t", "t", "t")
  140. def _empty_affine_quantized(
  141. g, input, shape, scale, zero_point, dtype, pin_memory, memory_format, layout
  142. ):
  143. return input
  144. def upsample_nearest2d(
  145. g, input, output_size, align_corners=None, scales_h=None, scales_w=None
  146. ):
  147. if input not in symbolic_helper._quantized_ops:
  148. return opset9.upsample_nearest2d(g, input, output_size, align_corners)
  149. output_size = symbolic_helper._parse_arg(output_size, "is")
  150. kwargs = {
  151. "output_size_i": output_size,
  152. "Y_scale_f": input.node()["Y_scale"],
  153. "Y_zero_point_i": input.node()["Y_zero_point"],
  154. }
  155. input = nchw2nhwc(g, input)
  156. output = g.op("_caffe2::Int8ResizeNearest", input, **kwargs)
  157. output = nhwc2nchw(g, output)
  158. symbolic_helper._quantized_ops.add(output)
  159. return output
  160. @symbolic_helper.parse_args("v", "is", "is", "is", "is", "i")
  161. def max_pool2d(g, input, kernel_size, stride, padding, dilation, ceil_mode):
  162. if input not in symbolic_helper._quantized_ops:
  163. return opset9.max_pool2d(
  164. g, input, kernel_size, stride, padding, dilation, ceil_mode
  165. )
  166. kwargs = {
  167. "strides_i": stride,
  168. "pads_i": padding + padding,
  169. "kernel_i": kernel_size[0],
  170. "order_s": "NHWC",
  171. "Y_scale_f": input.node()["Y_scale"],
  172. "Y_zero_point_i": input.node()["Y_zero_point"],
  173. }
  174. input = nchw2nhwc(g, input)
  175. output = g.op("_caffe2::Int8MaxPool", input, **kwargs)
  176. output = nhwc2nchw(g, output)
  177. symbolic_helper._quantized_ops.add(output)
  178. return output
  179. @symbolic_helper.parse_args("v", "is", "is", "is", "i", "i", "none")
  180. def avg_pool2d(
  181. g,
  182. input,
  183. kernel_size,
  184. stride,
  185. padding,
  186. ceil_mode,
  187. count_include_pad,
  188. divisor_override=None,
  189. ):
  190. if input not in symbolic_helper._quantized_ops:
  191. return opset9.avg_pool2d(
  192. g,
  193. input,
  194. kernel_size,
  195. stride,
  196. padding,
  197. ceil_mode,
  198. count_include_pad,
  199. divisor_override,
  200. )
  201. kwargs = {
  202. "strides_i": stride,
  203. "pads_i": padding + padding,
  204. "kernel_i": kernel_size[0],
  205. "order_s": "NHWC",
  206. "Y_scale_f": input.node()["Y_scale"],
  207. "Y_zero_point_i": input.node()["Y_zero_point"],
  208. }
  209. input = nchw2nhwc(g, input)
  210. output = g.op("_caffe2::Int8AveragePool", input, **kwargs)
  211. output = nhwc2nchw(g, output)
  212. symbolic_helper._quantized_ops.add(output)
  213. return output
  214. def reshape(g, input, shape):
  215. if input not in symbolic_helper._quantized_ops:
  216. return opset9.reshape(g, input, shape)
  217. kwargs = {
  218. "Y_scale_f": input.node()["Y_scale"],
  219. "Y_zero_point_i": input.node()["Y_zero_point"],
  220. }
  221. output = g.op("_caffe2::Int8Reshape", input, shape, **kwargs)
  222. symbolic_helper._quantized_ops.add(output)
  223. return output
  224. @symbolic_helper.parse_args("v", "v", "v", "v", "i")
  225. def slice(g, input, dim, start, end, step):
  226. if input not in symbolic_helper._quantized_ops:
  227. return opset9.slice(g, input, dim, start, end, step)
  228. if step != 1:
  229. raise RuntimeError("ONNX quantized slice export only works for step 1.")
  230. start = symbolic_helper._parse_arg(start, "i")
  231. end = symbolic_helper._parse_arg(end, "i")
  232. dim = symbolic_helper._parse_arg(dim, "i")
  233. kwargs = {
  234. "start_idx_i": start,
  235. "end_idx_i": end,
  236. "dim_i": dim,
  237. "Y_scale_f": input.node()["Y_scale"],
  238. "Y_zero_point_i": input.node()["Y_zero_point"],
  239. }
  240. output = g.op("_caffe2::Int8Slice", input, **kwargs)
  241. symbolic_helper._quantized_ops.add(output)
  242. return output
  243. def cat(g, tensor_list, dim, scale=None, zero_point=None):
  244. tensors = symbolic_helper._unpack_list(tensor_list)
  245. input = tensors[0]
  246. if input not in symbolic_helper._quantized_ops:
  247. return opset9.cat(g, tensor_list, dim)
  248. dim = symbolic_helper._parse_arg(dim, "i")
  249. kwargs = {
  250. "Y_scale_f": tensors[0].node()["Y_scale"],
  251. "Y_zero_point_i": tensors[0].node()["Y_zero_point"],
  252. }
  253. output = g.op("_caffe2::Int8Concat", *tensors, axis_i=dim, **kwargs)
  254. symbolic_helper._quantized_ops.add(output)
  255. return output
  256. @symbolic_helper.parse_args("v")
  257. def sigmoid(g, input):
  258. if input not in symbolic_helper._quantized_ops:
  259. return opset9.sigmoid(g, input)
  260. # Caffe2 expects the output scale to be 1/2^8
  261. # and output zero_point to be 0 (quint8 type)
  262. out_scale = 1.0 / 256
  263. zero_point = 0
  264. kwargs = {
  265. "Y_scale_f": out_scale,
  266. "Y_zero_point_i": zero_point,
  267. }
  268. output = g.op("_caffe2::Int8Sigmoid", input, **kwargs)
  269. symbolic_helper._quantized_ops.add(output)
  270. return output