symbolic_opset10.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644
  1. import sys
  2. import warnings
  3. from typing import Sequence
  4. import torch
  5. import torch._C._onnx as _C_onnx
  6. import torch.onnx
  7. from torch import _C
  8. # This import monkey-patches graph manipulation methods on Graph, used for the
  9. # ONNX symbolics
  10. from torch.onnx import _patch_torch # noqa: F401
  11. from torch.onnx import symbolic_helper
  12. from torch.onnx import symbolic_opset9 as opset9
  13. from torch.onnx._globals import GLOBALS
  14. # EDITING THIS FILE? READ THIS FIRST!
  15. # see Note [Edit Symbolic Files] in symbolic_helper.py
  16. # This file exports ONNX ops for opset 10
  17. # Opset 10 is supported by ONNX release 1.5.0
  18. # release on 04/24/19
  19. def div(g, self, other, *args):
  20. if len(args) == 0:
  21. return opset9.true_divide(g, self, other)
  22. else:
  23. return _div_rounding_mode(g, self, other, *args)
  24. @symbolic_helper.parse_args("v", "v", "s")
  25. def _div_rounding_mode(g, self, other, rounding_mode):
  26. if rounding_mode == "floor":
  27. return _floor_divide(g, self, other)
  28. else:
  29. return opset9._div_rounding_mode(g, self, other, rounding_mode)
  30. def _floor_divide(g, self, other):
  31. if symbolic_helper._is_fp(self) or symbolic_helper._is_fp(other):
  32. out = opset9.true_divide(g, self, other)
  33. return g.op("Floor", out)
  34. else:
  35. # Integer division does trunction rounding
  36. div = g.op("Div", self, other)
  37. # Division is negative if: self < 0 != other < 0
  38. zero = g.op("Constant", value_t=torch.tensor(0, dtype=torch.int64))
  39. negative = g.op("Xor", g.op("Less", self, zero), g.op("Less", other, zero))
  40. # For negative numbers with self % other != 0, subtract 1 to round down instead of up
  41. mod = g.op("Mod", self, other, fmod_i=0)
  42. fixup_mask = g.op("And", negative, g.op("Not", g.op("Equal", mod, zero)))
  43. one = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64))
  44. fixup = g.op("Sub", div, one)
  45. return g.op("Where", fixup_mask, fixup, div)
  46. @symbolic_helper.parse_args("v", "i", "i", "none")
  47. def sort(g, self, dim, decending, out=None):
  48. return symbolic_helper._sort_helper(g, self, dim, decending=decending, out=out)
  49. @symbolic_helper.parse_args("v", "v", "i", "i", "i", "none")
  50. def topk(g, self, k, dim, largest, sorted, out=None):
  51. return symbolic_helper._topk_helper(
  52. g, self, k, dim, largest=largest, sorted=sorted, out=out
  53. )
  54. def _max_pool(name, tuple_fn, ndims, return_indices):
  55. @symbolic_helper.quantized_args(True, False, False, False, False, False)
  56. @symbolic_helper.parse_args("v", "is", "is", "is", "is", "i")
  57. def symbolic_fn(g, input, kernel_size, stride, padding, dilation, ceil_mode):
  58. if not stride:
  59. stride = kernel_size
  60. kwargs = {
  61. "kernel_shape_i": tuple_fn(kernel_size),
  62. "pads_i": tuple_fn(padding) * 2,
  63. "strides_i": tuple_fn(stride),
  64. "ceil_mode_i": ceil_mode,
  65. }
  66. if set(tuple_fn(dilation)) != {1}:
  67. kwargs["dilations_i"] = tuple_fn(dilation)
  68. # easy but hacky way to get flattened indices values
  69. # to be used to convert the indices values to non-flattened.
  70. # In ONNX the indices are computed as a flatten 1-D tensor,
  71. # so the values in indices are in [0, N x C x D1 x ... x Dn).
  72. # To convert the indices to the same format used by Pytorch,
  73. # we first execute a maxpool with a kernel and stride of 1 on the same input.
  74. # This will result in a tensor of indices in which each index will have it's own value.
  75. # Using this tensor as a reference, we extract the first index of each axis and subtract
  76. # it from each index of this axis in the indices to convert.
  77. # This step will result in a tensor were each dimension has values of indices within
  78. # the dimension it is in.
  79. # For more information :
  80. # https://github.com/pytorch/pytorch/pull/16455#issuecomment-460776407
  81. if return_indices:
  82. r, indices = g.op("MaxPool", input, outputs=2, **kwargs)
  83. _, flattened_indices = g.op(
  84. "MaxPool",
  85. input,
  86. outputs=2,
  87. kernel_shape_i=[1 for _ in range(ndims)],
  88. strides_i=[1 for _ in range(ndims)],
  89. )
  90. # convert indices to have non-flattened indices values
  91. s = symbolic_helper._slice_helper(
  92. g,
  93. flattened_indices,
  94. axes=[2 + i for i in range(ndims)],
  95. starts=tuple_fn(0),
  96. ends=tuple_fn(1),
  97. )
  98. indices = opset9.sub(g, indices, s)
  99. return r, indices
  100. else:
  101. r = g.op("MaxPool", input, outputs=1, **kwargs)
  102. return r
  103. return symbolic_fn
  104. max_pool1d = _max_pool(
  105. "max_pool1d", torch.nn.modules.utils._single, 1, return_indices=False
  106. )
  107. max_pool2d = _max_pool(
  108. "max_pool2d", torch.nn.modules.utils._pair, 2, return_indices=False
  109. )
  110. max_pool3d = _max_pool(
  111. "max_pool3d", torch.nn.modules.utils._triple, 3, return_indices=False
  112. )
  113. max_pool1d_with_indices = _max_pool(
  114. "max_pool1d_with_indices", torch.nn.modules.utils._single, 1, return_indices=True
  115. )
  116. max_pool2d_with_indices = _max_pool(
  117. "max_pool2d_with_indices", torch.nn.modules.utils._pair, 2, return_indices=True
  118. )
  119. max_pool3d_with_indices = _max_pool(
  120. "max_pool3d_with_indices", torch.nn.modules.utils._triple, 3, return_indices=True
  121. )
  122. def _avg_pool(name, tuple_fn):
  123. @symbolic_helper.quantized_args(True, False, False, False, False, False, False)
  124. @symbolic_helper.parse_args("v", "is", "is", "is", "i", "i", "none")
  125. def symbolic_fn(
  126. g,
  127. input: _C.Value,
  128. kernel_size: Sequence[int],
  129. stride: Sequence[int],
  130. padding: Sequence[int],
  131. ceil_mode: int,
  132. count_include_pad: int,
  133. divisor_override=None,
  134. ):
  135. if not stride:
  136. stride = kernel_size
  137. padding = symbolic_helper._avgpool_helper(
  138. tuple_fn, padding, kernel_size, stride, divisor_override, name
  139. )
  140. if count_include_pad:
  141. input = opset9.op_with_optional_float_cast(
  142. g,
  143. "Pad",
  144. input,
  145. pads_i=((0,) * 2 + padding) * 2,
  146. mode_s="constant",
  147. value_f=0.0,
  148. opset_before=11,
  149. )
  150. padding = (0,) * len(padding)
  151. output = g.op(
  152. "AveragePool",
  153. input,
  154. kernel_shape_i=tuple_fn(kernel_size),
  155. strides_i=tuple_fn(stride),
  156. pads_i=padding * 2,
  157. ceil_mode_i=ceil_mode,
  158. )
  159. return output
  160. return symbolic_fn
  161. avg_pool1d = _avg_pool("avg_pool1d", torch.nn.modules.utils._single)
  162. avg_pool2d = _avg_pool("avg_pool2d", torch.nn.modules.utils._pair)
  163. avg_pool3d = _avg_pool("avg_pool3d", torch.nn.modules.utils._triple)
  164. def _interpolate(name, dim, interpolate_mode):
  165. @symbolic_helper.quantized_args(True, False, False)
  166. def symbolic_fn(g, input, output_size, *args):
  167. scales, align_corners = symbolic_helper._get_interpolate_attributes(
  168. g, interpolate_mode, args
  169. )
  170. symbolic_helper._interpolate_warning(interpolate_mode)
  171. align_corners = symbolic_helper._maybe_get_scalar(align_corners)
  172. if align_corners:
  173. return symbolic_helper._unimplemented(name, "align_corners == True")
  174. if scales is None:
  175. scales = symbolic_helper._interpolate_size_to_scales(
  176. g, input, output_size, dim
  177. )
  178. return g.op("Resize", input, scales, mode_s=interpolate_mode)
  179. return symbolic_fn
  180. upsample_nearest1d = _interpolate("upsample_nearest1d", 3, "nearest")
  181. upsample_nearest2d = _interpolate("upsample_nearest2d", 4, "nearest")
  182. upsample_nearest3d = _interpolate("upsample_nearest3d", 5, "nearest")
  183. upsample_linear1d = _interpolate("upsample_linear1d", 3, "linear")
  184. upsample_bilinear2d = _interpolate("upsample_bilinear2d", 4, "linear")
  185. upsample_trilinear3d = _interpolate("upsample_trilinear3d", 5, "linear")
  186. def __interpolate(
  187. g, input, size, scale_factor, mode, align_corners, recompute_scale_factor, antialias
  188. ):
  189. scales, mode = symbolic_helper._interpolate_get_scales_and_mode(
  190. g, input, size, scale_factor, mode, align_corners
  191. )
  192. return g.op("Resize", input, scales, mode_s=mode)
  193. def _slice(g, input, axes, starts, ends, steps=None, dynamic_slice=False):
  194. if dynamic_slice:
  195. starts = symbolic_helper._unsqueeze_helper(g, starts, [0])
  196. ends = symbolic_helper._unsqueeze_helper(g, ends, [0])
  197. if isinstance(axes, int):
  198. axes = g.op("Constant", value_t=torch.tensor(axes))
  199. axes = symbolic_helper._unsqueeze_helper(g, axes, [0])
  200. else:
  201. assert len(starts) == len(ends)
  202. assert len(starts) == len(axes)
  203. assert steps is None or len(starts) == len(steps)
  204. if (
  205. len(starts) == 1
  206. and starts[0] == 0
  207. and ends[0] == 9223372036854775807
  208. and (steps is None or (len(steps) == 1 and steps[0] == 1))
  209. ):
  210. return input
  211. axes = g.op("Constant", value_t=torch.tensor(axes))
  212. starts = g.op("Constant", value_t=torch.tensor(starts))
  213. ends = g.op("Constant", value_t=torch.tensor(ends))
  214. if steps is None:
  215. return g.op("Slice", input, starts, ends, axes)
  216. steps = g.op("Constant", value_t=torch.tensor(steps))
  217. return g.op("Slice", input, starts, ends, axes, steps)
  218. def slice(g, self, *args):
  219. if len(args) == 4:
  220. # aten::slice(Tensor self, int dim, int? start=None, int? end=None, int step=1) -> Tensor
  221. dim, start, end, step = args
  222. elif len(args) == 3:
  223. # aten::slice(t[] l, int? start=None, int? end=None, int step=1) -> t[]
  224. start, end, step = args
  225. dim = 0
  226. else:
  227. raise NotImplementedError("Unknown aten::slice signature")
  228. is_start_none = (
  229. start.node().kind() == "prim::Constant" and start.type().kind() == "NoneType"
  230. )
  231. is_end_none = (
  232. end.node().kind() == "prim::Constant" and end.type().kind() == "NoneType"
  233. )
  234. is_start_onnx_const = start.node().kind() == "onnx::Constant"
  235. is_end_onnx_const = end.node().kind() == "onnx::Constant"
  236. step = symbolic_helper._parse_arg(step, "i")
  237. if (
  238. (not is_start_none and not is_start_onnx_const)
  239. or (not isinstance(end, int) and not is_end_none and not is_end_onnx_const)
  240. or (not isinstance(dim, int) and dim.node().kind() != "onnx::Constant")
  241. ):
  242. dynamic_slice = True
  243. if is_start_none:
  244. start = g.op("Constant", value_t=torch.tensor(0))
  245. if is_end_none:
  246. end = g.op("Constant", value_t=torch.tensor(9223372036854775807))
  247. else:
  248. start = [0 if is_start_none else symbolic_helper._parse_arg(start, "i")]
  249. end = [
  250. 9223372036854775807 if is_end_none else symbolic_helper._parse_arg(end, "i")
  251. ]
  252. dim = [symbolic_helper._parse_arg(dim, "i")]
  253. dynamic_slice = False
  254. return symbolic_helper._slice_helper(
  255. g,
  256. self,
  257. axes=dim,
  258. starts=start,
  259. ends=end,
  260. steps=[step],
  261. dynamic_slice=dynamic_slice,
  262. )
  263. @symbolic_helper.parse_args("v", "is")
  264. def flip(g, input, dims):
  265. return symbolic_helper._slice_helper(
  266. g,
  267. input,
  268. axes=dims,
  269. starts=[-1] * len(dims),
  270. ends=[-9223372036854775807] * len(dims),
  271. steps=[-1] * len(dims),
  272. )
  273. def fmod(g, input, other):
  274. return g.op("Mod", input, other, fmod_i=1)
  275. @symbolic_helper.parse_args("v", "v", "v", "i", "i", "i", "v", "i", "i")
  276. def embedding_bag(
  277. g,
  278. embedding_matrix,
  279. indices,
  280. offsets,
  281. scale_grad_by_freq,
  282. mode,
  283. sparse,
  284. per_sample_weights,
  285. include_last_offset,
  286. padding_idx,
  287. ):
  288. if scale_grad_by_freq and GLOBALS.training_mode:
  289. return symbolic_helper._onnx_unsupported(
  290. "embedding_bag with scale_grad_by_freq for training mode"
  291. )
  292. if padding_idx is not None and padding_idx >= 0:
  293. raise RuntimeError("embedding_bag with padding_idx")
  294. warnings.warn(
  295. "Export of embedding_bag with dynamic input/offsets shape is not supported in opset 10. "
  296. "Please use opset 11 or higher to export model for dynamic input shape.'"
  297. )
  298. offsets_dim_0 = symbolic_helper._get_tensor_dim_size(offsets, 0)
  299. if offsets_dim_0 is not None:
  300. if include_last_offset:
  301. offset_len = offsets_dim_0 - 1
  302. offsets_extended = offsets
  303. else:
  304. offset_len = offsets_dim_0
  305. offsets_extended = [
  306. offsets,
  307. g.op("Constant", value_t=torch.tensor([sys.maxsize])),
  308. ]
  309. offsets_extended = g.op("Concat", *offsets_extended, axis_i=0)
  310. list_ = []
  311. for i in range(offset_len):
  312. start_ = symbolic_helper._unsqueeze_helper(
  313. g,
  314. opset9.select(g, offsets_extended, torch.tensor(0), torch.tensor(i)),
  315. [0],
  316. )
  317. end_ = symbolic_helper._unsqueeze_helper(
  318. g,
  319. opset9.select(
  320. g, offsets_extended, torch.tensor(0), torch.tensor(i + 1)
  321. ),
  322. [0],
  323. )
  324. axes_ = g.op("Constant", value_t=torch.tensor([0]))
  325. indices_row = g.op("Slice", indices, start_, end_, axes_)
  326. embeddings = g.op("Gather", embedding_matrix, indices_row)
  327. if not symbolic_helper._is_none(per_sample_weights):
  328. per_sample_weights_row = g.op(
  329. "Slice", per_sample_weights, start_, end_, axes_
  330. )
  331. per_sample_weights_row = symbolic_helper._unsqueeze_helper(
  332. g, per_sample_weights_row, [1]
  333. )
  334. embeddings = g.op("Mul", embeddings, per_sample_weights_row)
  335. if mode == 0:
  336. embeddings = symbolic_helper._reducesum_helper(
  337. g, embeddings, axes_i=[0], keepdims_i=0
  338. )
  339. elif mode == 1:
  340. embeddings = g.op("ReduceMean", embeddings, axes_i=[0], keepdims_i=0)
  341. else:
  342. embeddings = g.op("ReduceMax", embeddings, axes_i=[0], keepdims_i=0)
  343. embeddings = symbolic_helper._unsqueeze_helper(g, embeddings, [0])
  344. list_.append(embeddings)
  345. output = g.op("Concat", *list_, axis_i=0)
  346. # aten::embedding_bag returns a tuple of 4 elements: output, offset2bag, bag_size, max_indices.
  347. # But the last three outputs are not used in torch.nn.EmbeddingBag or torch.nn.functional.embedding_bag.
  348. return output, None, None, None
  349. else:
  350. return symbolic_helper._onnx_unsupported(
  351. "embedding_bag with unknown shape of offsets for opset 10 is not supported. "
  352. "please use opset 11 or higher."
  353. )
  354. @symbolic_helper.parse_args("v", "v", "v", "i", "i")
  355. def fake_quantize_per_tensor_affine(
  356. g, inputs, scale, zero_point, quant_min=-128, quant_max=127
  357. ):
  358. # NOTE: (0, 127) is a special case. PyTorch restricts activations to be in the range (0, 127).
  359. # https://github.com/pytorch/pytorch/blob/b34b192d6b97325c9f78e5995c48c8498ede34bd/torch/ao/quantization/observer.py#L1422
  360. if (quant_min, quant_max) == (0, 127):
  361. symbolic_helper._onnx_opset_unsupported_detailed(
  362. "fake_quantize_per_tensor_affine",
  363. 10,
  364. 13,
  365. "Quantize range (0, 127) not supported, requires opset 13 Clip",
  366. )
  367. if (quant_min, quant_max) not in [(0, 255), (-128, 127)]:
  368. raise RuntimeError(
  369. "For (quant_min, quant_max), ONNX allows only (0, 255) and (-128, 127). "
  370. "Got ({}, {})".format(quant_min, quant_max)
  371. )
  372. scale = symbolic_helper._maybe_get_scalar(scale)
  373. if scale is None:
  374. symbolic_helper._onnx_opset_unsupported_detailed(
  375. "fake_quantize_per_tensor_affine",
  376. 10,
  377. 13,
  378. "Non-constant scale not supported",
  379. )
  380. scale = scale.float().data # Avoid exporter generating double type
  381. if quant_min == 0:
  382. zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.UINT8)
  383. else:
  384. zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.INT8)
  385. return g.op(
  386. "DequantizeLinear",
  387. g.op("QuantizeLinear", inputs, scale, zero_point),
  388. scale,
  389. zero_point,
  390. )
  391. def isinf(g, input):
  392. return g.op("IsInf", opset9._cast_Double(g, input, False)) # type: ignore[attr-defined]
  393. def isfinite(g, input):
  394. from torch.onnx.symbolic_opset9 import __not_, __or_
  395. inf_node = isinf(g, input)
  396. nan_node = opset9.isnan(g, input)
  397. return __not_(g, __or_(g, inf_node, nan_node))
  398. def quantize_per_tensor(g, input, scale, zero_point, dtype):
  399. dtype = symbolic_helper._get_const(dtype, "i", "dtype")
  400. zero_point = g.op(
  401. "Cast", zero_point, to_i=symbolic_helper.scalar_type_to_onnx[dtype]
  402. )
  403. scale = g.op("Cast", scale, to_i=_C_onnx.TensorProtoDataType.FLOAT)
  404. return symbolic_helper.quantize_helper(g, input, scale, zero_point)
  405. def dequantize(g, input):
  406. return symbolic_helper.dequantize_helper(g, input)[0]
  407. @symbolic_helper.parse_args("v", "f", "f", "f")
  408. def nan_to_num(g, input, nan, posinf, neginf):
  409. # Cannot create a int type tensor with inf/nan values, so we simply
  410. # return the original tensor
  411. if not symbolic_helper._is_fp(input):
  412. return input
  413. input_dtype = symbolic_helper.pytorch_name_to_type[input.type().scalarType()]
  414. if nan is None:
  415. nan = 0.0
  416. nan_cond = opset9.isnan(g, input)
  417. nan_result = g.op(
  418. "Where",
  419. nan_cond,
  420. g.op("Constant", value_t=torch.tensor([nan], dtype=input_dtype)),
  421. input,
  422. )
  423. # For None values of posinf, neginf we use the greatest/lowest finite
  424. # value representable by input’s dtype.
  425. finfo = torch.finfo(input_dtype)
  426. if posinf is None:
  427. posinf = finfo.max
  428. posinf_cond = opset9.logical_and(
  429. g,
  430. isinf(g, nan_result),
  431. opset9.gt(g, nan_result, g.op("Constant", value_t=torch.LongTensor([0]))),
  432. )
  433. nan_posinf_result = g.op(
  434. "Where",
  435. posinf_cond,
  436. g.op("Constant", value_t=torch.tensor([posinf], dtype=input_dtype)),
  437. nan_result,
  438. )
  439. if neginf is None:
  440. neginf = finfo.min
  441. neginf_cond = opset9.logical_and(
  442. g,
  443. isinf(g, nan_posinf_result),
  444. opset9.lt(
  445. g, nan_posinf_result, g.op("Constant", value_t=torch.LongTensor([0]))
  446. ),
  447. )
  448. return g.op(
  449. "Where",
  450. neginf_cond,
  451. g.op("Constant", value_t=torch.tensor([neginf], dtype=input_dtype)),
  452. nan_posinf_result,
  453. )
  454. # https://github.com/pytorch/pytorch/wiki/PyTorch-ONNX-exporter#quantized-model-export
  455. class Quantized:
  456. """
  457. https://github.com/pytorch/pytorch/wiki/PyTorch-ONNX-exporter#quantized-model-export
  458. Support starts from opset 10 because `DequantizeLinear` and `QuantizeLinear` were introduced in opset version 10.
  459. """
  460. domain = "quantized"
  461. @staticmethod
  462. def linear(g, q_input, q_weight, bias, op_scale, op_zero_point):
  463. input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input)
  464. weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight)
  465. q_bias = symbolic_helper.requantize_bias_helper(
  466. g, bias, input_scale, weight_scale
  467. )
  468. bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias)
  469. output = opset9.linear(g, input, weight, bias)
  470. return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
  471. @staticmethod
  472. def add(g, x, y, op_scale, op_zero_point):
  473. x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
  474. y, _, _, _ = symbolic_helper.dequantize_helper(g, y)
  475. output = opset9.add(g, x, y)
  476. return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
  477. @staticmethod
  478. def add_relu(g, x, y, op_scale, op_zero_point):
  479. x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
  480. y, _, _, _ = symbolic_helper.dequantize_helper(g, y)
  481. output = opset9.add(g, x, y)
  482. output = opset9.relu(g, output)
  483. return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
  484. @staticmethod
  485. def mul(g, x, y, op_scale, op_zero_point):
  486. x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
  487. y, _, _, _ = symbolic_helper.dequantize_helper(g, y)
  488. output = opset9.mul(g, x, y)
  489. return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
  490. @staticmethod
  491. def hardswish(g, x, op_scale, op_zero_point):
  492. x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
  493. output = opset9.hardswish(g, x)
  494. return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
  495. @staticmethod
  496. def conv2d_relu(
  497. g,
  498. q_input,
  499. q_weight,
  500. bias,
  501. stride,
  502. padding,
  503. dilation,
  504. groups,
  505. op_scale,
  506. op_zero_point,
  507. ):
  508. input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input)
  509. weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight)
  510. q_bias = symbolic_helper.requantize_bias_helper(
  511. g, bias, input_scale, weight_scale
  512. )
  513. bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias)
  514. output = opset9.conv2d(
  515. g, input, weight, bias, stride, padding, dilation, groups
  516. )
  517. output = opset9.relu(g, output)
  518. return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
  519. @staticmethod
  520. def conv2d(
  521. g,
  522. q_input,
  523. q_weight,
  524. bias,
  525. stride,
  526. padding,
  527. dilation,
  528. groups,
  529. op_scale,
  530. op_zero_point,
  531. ):
  532. input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input)
  533. weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight)
  534. q_bias = symbolic_helper.requantize_bias_helper(
  535. g, bias, input_scale, weight_scale
  536. )
  537. bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias)
  538. output = opset9.conv2d(
  539. g, input, weight, bias, stride, padding, dilation, groups
  540. )
  541. return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
  542. @staticmethod
  543. @symbolic_helper.parse_args("v", "i", "v", "v")
  544. def cat(
  545. g,
  546. q_inputs: _C.Value,
  547. dim: int,
  548. op_scale: _C.Value,
  549. op_zero_point: _C.Value,
  550. ) -> _C.Value:
  551. unpacked_inputs = symbolic_helper._unpack_list(q_inputs)
  552. dequantized = [
  553. symbolic_helper.dequantize_helper(g, input)[0] for input in unpacked_inputs
  554. ]
  555. concatenated = g.op("Concat", *dequantized, axis_i=dim)
  556. return symbolic_helper.quantize_helper(g, concatenated, op_scale, op_zero_point)