symbolic_opset13.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600
  1. # EDITING THIS FILE? READ THIS FIRST!
  2. # see Note [Edit Symbolic Files] in symbolic_helper.py
  3. # This file exports ONNX ops for opset 13
  4. import torch
  5. import torch._C._onnx as _C_onnx
  6. from torch.onnx import symbolic_helper
  7. from torch.onnx import symbolic_opset9 as opset9
  8. from torch.onnx import symbolic_opset11 as opset11
  9. from torch.onnx import utils
  10. @symbolic_helper.parse_args("v", "i", "none")
  11. def softmax(g, input, dim, dtype=None):
  12. softmax = g.op("Softmax", input, axis_i=dim)
  13. if dtype and dtype.node().kind() != "prim::Constant":
  14. parsed_dtype = symbolic_helper._get_const(dtype, "i", "dtype")
  15. softmax = g.op(
  16. "Cast", softmax, to_i=symbolic_helper.scalar_type_to_onnx[parsed_dtype]
  17. )
  18. return softmax
  19. @symbolic_helper.parse_args("v", "i", "none")
  20. def log_softmax(g, input, dim, dtype=None):
  21. return_op = g.op("LogSoftmax", input, axis_i=dim)
  22. if dtype and dtype.node().kind() != "prim::Constant":
  23. parsed_dtype = symbolic_helper._get_const(dtype, "i", "dtype")
  24. return_op = g.op(
  25. "Cast", return_op, to_i=symbolic_helper.scalar_type_to_onnx[parsed_dtype]
  26. )
  27. return return_op
  28. @symbolic_helper.parse_args("v", "v", "i")
  29. def frobenius_norm(g, self, dim=None, keepdim=False):
  30. dim_val = symbolic_helper._maybe_get_const(dim, "is")
  31. if not symbolic_helper._is_value(dim_val) and len(dim_val) == 0:
  32. return g.op("ReduceL2", self, keepdims_i=0)
  33. sqr = g.op("Mul", self, self)
  34. sumsqr = symbolic_helper._reducesum_helper(g, sqr, dim, keepdims_i=keepdim)
  35. return g.op("Sqrt", sumsqr)
  36. @symbolic_helper.parse_args("v", "v", "i", "i")
  37. def split(g, self, split_size_or_sizes, dim, _outputs=None):
  38. if not symbolic_helper._is_split_static(split_size_or_sizes, _outputs):
  39. split_out = g.op("SplitToSequence", self, split_size_or_sizes, axis_i=dim)
  40. if _outputs is None:
  41. return split_out
  42. # Convert to multiple slice nodes iff number of splits and number of outputs are statically known.
  43. if (
  44. symbolic_helper._is_packed_list(split_size_or_sizes)
  45. and len(symbolic_helper._unpack_list(split_size_or_sizes)) == _outputs
  46. ):
  47. split_sizes = [
  48. symbolic_helper._unsqueeze_helper(g, v, [0])
  49. for v in symbolic_helper._unpack_list(split_size_or_sizes)
  50. ]
  51. start = g.op("Constant", value_t=torch.tensor([0], dtype=torch.long))
  52. axis = g.op("Constant", value_t=torch.tensor([dim], dtype=torch.long))
  53. res = []
  54. for i in range(_outputs):
  55. end = g.op(
  56. "Add", start, split_sizes[i]
  57. ) # split_sizes is a list of same length as _outputs
  58. res.append(g.op("Slice", self, start, end, axis))
  59. start = end
  60. return res
  61. return [
  62. g.op(
  63. "SequenceAt",
  64. split_out,
  65. g.op("Constant", value_t=torch.tensor([i], dtype=torch.long)),
  66. )
  67. for i in range(_outputs)
  68. ]
  69. split_val = split_size_or_sizes.node()["value"]
  70. if split_val.dim() > 0:
  71. return g.op("Split", self, split_size_or_sizes, axis_i=dim, outputs=_outputs)
  72. split_size = symbolic_helper._get_const(split_size_or_sizes, "i", "split_size")
  73. size = symbolic_helper._get_tensor_dim_size(self, dim)
  74. if size is None:
  75. if _outputs is not None:
  76. size = split_size * _outputs
  77. else:
  78. raise RuntimeError("Unknown dimension size not supported")
  79. splits = [split_size] * (size // split_size)
  80. leftover = size % split_size
  81. if leftover:
  82. splits.append(leftover)
  83. splits = g.op("Constant", value_t=torch.tensor(splits))
  84. return g.op("Split", self, splits, axis_i=dim, outputs=_outputs)
  85. def split_with_sizes(g, self, split_sizes, dim, _outputs=None):
  86. return split(g, self, split_sizes, dim, _outputs)
  87. def unsafe_split(g, self, split_size_or_sizes, dim, _outputs=None):
  88. return split(g, self, split_size_or_sizes, dim, _outputs)
  89. def unsafe_split_with_sizes(g, self, split_sizes, dim, _outputs=None):
  90. return split_with_sizes(g, self, split_sizes, dim, _outputs)
  91. @symbolic_helper.parse_args("v", "i", "i")
  92. def unbind(g, self, dim=0, _outputs=None):
  93. if _outputs is None:
  94. return g.op(
  95. "SplitToSequence",
  96. self,
  97. g.op("Constant", value_t=torch.tensor(1, dtype=torch.long)),
  98. axis_i=dim,
  99. keepdims_i=0,
  100. )
  101. splits = g.op("Constant", value_t=torch.tensor([1] * _outputs))
  102. outputs = g.op("Split", self, splits, axis_i=dim, outputs=_outputs)
  103. outputs = [outputs] if _outputs == 1 else outputs
  104. squeezed_outputs = [
  105. g.op("Squeeze", out, g.op("Constant", value_t=torch.tensor([dim])))
  106. for out in outputs
  107. ]
  108. return squeezed_outputs
  109. # Emitted from `torch.nonzero(x, as_tuple=True)`
  110. def nonzero_numpy(g, input, _outputs=None):
  111. return unbind(g, opset9.nonzero(g, input), 1, _outputs=_outputs)
  112. @symbolic_helper.parse_args("v", "v", "v", "i")
  113. def where(g, condition, self=None, other=None, _outputs=None):
  114. # Assumes that torch.where's first argument takes only Bool and Byte tensors.
  115. if condition.type().scalarType() != "Bool":
  116. condition = g.op(
  117. "Cast", condition, to_i=symbolic_helper.cast_pytorch_to_onnx["Bool"]
  118. )
  119. if self is None:
  120. condition = opset9.nonzero(g, condition)
  121. return symbolic_helper._unbind_helper(
  122. g, condition, g.op("Constant", value_t=torch.tensor(1)), _outputs
  123. )
  124. return g.op("Where", condition, self, other)
  125. @symbolic_helper.parse_args("v", "v", "v", "i", "i", "i")
  126. def fake_quantize_per_channel_affine(
  127. g, inputs, scale, zero_point, axis, quant_min=-128, quant_max=127
  128. ):
  129. # NOTE: (0, 127) is allowed as special case. PyTorch restricts activations to be in the range (0, 127).
  130. # https://github.com/pytorch/pytorch/blob/b34b192d6b97325c9f78e5995c48c8498ede34bd/torch/ao/quantization/observer.py#L1422
  131. if (quant_min, quant_max) not in [(0, 255), (-128, 127), (0, 127)]:
  132. raise RuntimeError(
  133. "For (quant_min, quant_max), ONNX allows only (0, 127), (0, 255) and (-128, 127). "
  134. "Got ({}, {})".format(quant_min, quant_max)
  135. )
  136. # ONNX defines zero_point to be int8 or uint8
  137. if quant_min == 0:
  138. zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.UINT8)
  139. else:
  140. zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.INT8)
  141. quantized = g.op("QuantizeLinear", inputs, scale, zero_point, axis_i=axis)
  142. if (quant_min, quant_max) == (0, 127):
  143. quantized = g.op(
  144. "Clip",
  145. quantized,
  146. opset9.unused(g),
  147. g.op("Constant", value_t=torch.tensor(127, dtype=torch.uint8)),
  148. )
  149. return g.op("DequantizeLinear", quantized, scale, zero_point, axis_i=axis)
  150. @symbolic_helper.parse_args("v", "v", "v", "i", "i")
  151. def fake_quantize_per_tensor_affine(
  152. g, inputs, scale, zero_point, quant_min=-128, quant_max=127
  153. ):
  154. # NOTE: (0, 127) is allowed as special case. PyTorch restricts activations to be in the range (0, 127).
  155. # https://github.com/pytorch/pytorch/blob/b34b192d6b97325c9f78e5995c48c8498ede34bd/torch/ao/quantization/observer.py#L1422
  156. if (quant_min, quant_max) not in [(0, 255), (-128, 127), (0, 127)]:
  157. raise RuntimeError(
  158. "For (quant_min, quant_max), ONNX allows only (0, 127), (0, 255) and (-128, 127). "
  159. "Got ({}, {})".format(quant_min, quant_max)
  160. )
  161. if quant_min == 0:
  162. zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.UINT8)
  163. else:
  164. zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.INT8)
  165. if scale.type().scalarType() != "Float":
  166. scale = g.op("Cast", scale, to_i=_C_onnx.TensorProtoDataType.FLOAT)
  167. quantized = g.op("QuantizeLinear", inputs, scale, zero_point)
  168. if (quant_min, quant_max) == (0, 127):
  169. quantized = g.op(
  170. "Clip",
  171. quantized,
  172. opset9.unused(g),
  173. g.op("Constant", value_t=torch.tensor(127, dtype=torch.uint8)),
  174. )
  175. return g.op("DequantizeLinear", quantized, scale, zero_point)
  176. def _reduce_op_symbolic(onnx_op_name):
  177. def symbolic(g, self, dim=None, keepdim=None):
  178. self = opset9._maybe_cast_reduce_op_input(g, self)
  179. if dim is None:
  180. # all-reduce path
  181. return symbolic_helper._handle_reduce_dim_none(g, self, onnx_op_name)
  182. else:
  183. keepdim = symbolic_helper._get_const(keepdim, "i", "keepdim")
  184. return g.op(onnx_op_name, self, dim, keepdims_i=keepdim)
  185. return symbolic
  186. def _reduce_with_dtype(onnx_op, name):
  187. symbolic = _reduce_op_symbolic(onnx_op)
  188. @opset9.overload_by_arg_count
  189. def reduce(g, *args, **kwargs):
  190. @symbolic_helper.parse_args("v", "none")
  191. def reduce_nodim(g, self, dtype):
  192. if dtype.node().kind() == "onnx::Constant":
  193. dtype = symbolic_helper._get_const(dtype, "i", "dtype")
  194. self = g.op(
  195. "Cast", self, to_i=symbolic_helper.scalar_type_to_onnx[dtype]
  196. )
  197. elif dtype.node().kind() != "prim::Constant":
  198. return symbolic_helper._unimplemented(name, "dtype")
  199. return symbolic(g, self)
  200. @symbolic_helper.parse_args("v", "v", "i", "none")
  201. def reduce_dim(g, self, dim, keepdim, dtype):
  202. if dtype.node().kind() == "onnx::Constant":
  203. dtype = symbolic_helper._get_const(dtype, "i", "dtype")
  204. self = g.op(
  205. "Cast", self, to_i=symbolic_helper.scalar_type_to_onnx[dtype]
  206. )
  207. elif dtype.node().kind() != "prim::Constant":
  208. return symbolic_helper._unimplemented(name, "dtype")
  209. return symbolic(g, self, dim, keepdim)
  210. return reduce_nodim, reduce_dim
  211. return reduce
  212. # TODO(justinchuby): Rename the op to avoid colliding with the builtin sum.
  213. sum = _reduce_with_dtype("ReduceSum", "sum")
  214. @symbolic_helper.parse_args("v", "i", "i", "i")
  215. def unsafe_chunk(g, self, chunks, dim, _outputs=None):
  216. if _outputs is None:
  217. return g.op(
  218. "SplitToSequence",
  219. self,
  220. g.op("Constant", value_t=torch.tensor(1, dtype=torch.long)),
  221. axis_i=dim,
  222. keepdims_i=0,
  223. )
  224. size = symbolic_helper._get_tensor_dim_size(self, dim)
  225. if size is None:
  226. return symbolic_helper._unimplemented("unsafe_chunk", "unknown dimension size")
  227. split_size = (size + chunks - 1) // chunks
  228. splits = [split_size] * (size // split_size)
  229. leftover = size % split_size
  230. if leftover:
  231. splits.append(leftover)
  232. # TODO: So far we don"t have a module using this method. We"ll keep
  233. # this as a constant unless we see a request of dynamics in any
  234. # user's modules.
  235. splits = g.op("Constant", value_t=torch.tensor(splits, dtype=torch.long))
  236. return g.op("Split", self, splits, axis_i=dim, outputs=_outputs)
  237. def repeat_interleave(g, self, repeats, dim=None, output_size=None):
  238. input = self
  239. final_dim = dim
  240. # if dim is None flatten
  241. # By default, use the flattened input array, and return a flat output array
  242. if symbolic_helper._is_none(dim):
  243. input = symbolic_helper._reshape_helper(
  244. g, self, g.op("Constant", value_t=torch.tensor([-1]))
  245. )
  246. dim = 0
  247. else:
  248. dim = symbolic_helper._maybe_get_scalar(dim)
  249. repeats_dim = symbolic_helper._get_tensor_rank(repeats)
  250. repeats_sizes = symbolic_helper._get_tensor_sizes(repeats)
  251. input_sizes = symbolic_helper._get_tensor_sizes(input)
  252. if repeats_dim is None:
  253. raise RuntimeError(
  254. "Unsupported: ONNX export of repeat_interleave for unknown " "repeats rank."
  255. )
  256. if repeats_sizes is None:
  257. raise RuntimeError(
  258. "Unsupported: ONNX export of repeat_interleave for unknown " "repeats size."
  259. )
  260. if input_sizes is None:
  261. raise RuntimeError(
  262. "Unsupported: ONNX export of repeat_interleave for unknown " "input size."
  263. )
  264. # Handle cases where dim is negative
  265. if dim < 0:
  266. dim += len(input_sizes)
  267. output_sizes = input_sizes.copy()
  268. for idx, input_size in enumerate(input_sizes):
  269. if input_size is None:
  270. output_sizes[idx], input_sizes[idx] = 0, -1
  271. cond_dynamic_repeats = repeats_dim == 1 and repeats_sizes[0] is None
  272. # If input size is dynamic or repeats vector is dynamic
  273. if output_sizes[dim] == 0 or cond_dynamic_repeats:
  274. reps = symbolic_helper._size_helper(g, input, dim)
  275. reps = opset11.unsqueeze(g, reps, 0)
  276. # Check if repeats vector is a single integer value
  277. # or a single dimension tensor with non-dynamic values
  278. if repeats_dim == 0 or (repeats_dim == 1 and repeats_sizes[0] == 1):
  279. if not symbolic_helper._is_tensor(repeats):
  280. repeats = g.op("Constant", value_t=torch.LongTensor(repeats))
  281. repeats = g.op("Expand", repeats, reps)
  282. # Check if repeats is dynamic
  283. # As repeats is dynamic, we use a where node as a substitute for the if statement
  284. # If repests_dim = 1, expand repeats otherwise use original tensor
  285. elif cond_dynamic_repeats:
  286. repeat_dim = symbolic_helper._size_helper(
  287. g, repeats, g.op("Constant", value_t=torch.LongTensor([0]))
  288. )
  289. repeat_cond = g.op(
  290. "Equal", repeat_dim, g.op("Constant", value_t=torch.LongTensor([1]))
  291. )
  292. repeats = where(g, repeat_cond, g.op("Expand", repeats, reps), repeats)
  293. # There are cases when the repeats are 1-d tensor with multiple repeats, but dim
  294. # provided along one of the dynamic axes provided. A simple example would be
  295. # input.shape -> [1, 1, *] where * represents the dynamic axes, and dim = 2
  296. # Now, repeat interleaving can be performed in pytorch when the value of * matches
  297. # with the number of elements in repeat, for example if * -> 2, number of repeats
  298. # should be 2 as well.
  299. else:
  300. return opset9.repeat_interleave(g, self, repeats, final_dim)
  301. reps_like = g.op(
  302. "ConstantOfShape",
  303. g.op("Shape", repeats),
  304. value_t=torch.tensor([1], dtype=torch.long),
  305. )
  306. r_splits = split(g, repeats, reps_like, 0)
  307. i_splits = split(g, input, reps_like, dim)
  308. output_sizes[dim], input_sizes[dim] = -1, 1
  309. # Create a loop to iterate over each value along the dimension
  310. # and perform individual interleaving using the repeats tensor
  311. # Loop is of the following pattern
  312. # input (trip_count, cond)
  313. # int trip_count = ...;
  314. # bool cond = ...;
  315. # for (int i=0; i < trip_count && cond; ++i) {
  316. # cond = ...;
  317. # }
  318. # Loop conditions
  319. loop_condition = g.op("Constant", value_t=torch.tensor(1))
  320. loop_condition = g.op("Cast", loop_condition, to_i=9)
  321. loop_len = reps
  322. # Create an empty sequence to store final expansions
  323. final_splits = g.op("SequenceEmpty")
  324. loop = g.op("Loop", loop_len, loop_condition, final_splits)
  325. # Loop inputs
  326. loop_block = utils._add_block(loop.node())
  327. block_input_iter = utils._add_input_to_block(loop_block)
  328. cond = utils._add_input_to_block(loop_block)
  329. final_splits = utils._add_input_to_block(loop_block)
  330. r_split = loop_block.op("SequenceAt", r_splits, block_input_iter)
  331. i_split = loop_block.op("SequenceAt", i_splits, block_input_iter)
  332. i_split = opset11.unsqueeze(loop_block, i_split, dim + 1)
  333. r_concat = [
  334. loop_block.op("Constant", value_t=torch.LongTensor(input_sizes[: dim + 1])),
  335. r_split,
  336. loop_block.op("Constant", value_t=torch.LongTensor(input_sizes[dim + 1 :])),
  337. ]
  338. r_concat = loop_block.op("Concat", *r_concat, axis_i=0)
  339. i_split = opset9.expand(loop_block, i_split, r_concat, None)
  340. i_split = symbolic_helper._reshape_helper(
  341. loop_block, i_split, g.op("Constant", value_t=torch.LongTensor(output_sizes))
  342. )
  343. final_splits = loop_block.op("SequenceInsert", final_splits, i_split)
  344. # Loop outputs
  345. cond_out = loop_block.op("Cast", loop_condition, to_i=9)
  346. utils._add_output_to_block(loop_block, cond_out)
  347. utils._add_output_to_block(loop_block, final_splits)
  348. loop_out = loop.node().output()
  349. loop_out = g.op("ConcatFromSequence", loop_out, axis_i=dim)
  350. return loop_out
  351. @symbolic_helper.parse_args("v", "i", "i", "i")
  352. def diagonal(g, self, offset, dim1, dim2):
  353. dim1_size = opset9.size(
  354. g, self, dim=g.op("Constant", value_t=torch.LongTensor([dim1]))
  355. )
  356. dim2_size = opset9.size(
  357. g, self, dim=g.op("Constant", value_t=torch.LongTensor([dim2]))
  358. )
  359. # Create appropriate mask
  360. mask_shape = g.op("Concat", dim1_size, dim2_size, axis_i=0)
  361. mask = opset9.zeros(g, mask_shape, None, None, None)
  362. mask = g.op("EyeLike", mask, k_i=offset)
  363. # dim1 and dim2 appended as a dimension at the end of the shape
  364. rank = symbolic_helper._get_tensor_rank(self)
  365. if rank is not None:
  366. axes = list(range(rank))
  367. axes.remove(dim1)
  368. axes.remove(dim2)
  369. self = g.op("Transpose", self, perm_i=axes + [dim1, dim2])
  370. else:
  371. return symbolic_helper._unimplemented("diagonal", "unknown input rank")
  372. # Multiply input and mask to calculate values along diagonal
  373. # The mask consists of one values where diagonal values are to be calculated
  374. # For example:
  375. # [[1.1, 1.2, 1.3], * [[1, 0, 0] = [[1.1, 0, 0],
  376. # [2.1, 2.2, 2.3], [0, 1, 0] [0, 2.2, 0],
  377. # [3.1, 3.2, 3.3]] [0, 0, 1]] [0, 0, 3.3]]
  378. result = g.op("Mul", self, mask)
  379. result = symbolic_helper._reducesum_helper(g, result, axes_i=[-1], keepdims_i=0)
  380. # Calculate gather indices based on offset and dims
  381. # If offset is greater than zero, set offset to zero as this aids in
  382. # calculation of selection window
  383. offset_op = g.op("Constant", value_t=torch.LongTensor([offset]))
  384. if offset >= 0:
  385. diag_size = g.op(
  386. "Max",
  387. g.op("Min", dim1_size, g.op("Sub", dim2_size, offset_op)),
  388. g.op("Constant", value_t=torch.LongTensor([0])),
  389. )
  390. offset = 0
  391. else:
  392. diag_size = g.op(
  393. "Max",
  394. g.op("Min", g.op("Add", dim1_size, offset_op), dim2_size),
  395. g.op("Constant", value_t=torch.LongTensor([0])),
  396. )
  397. diag_size = g.op("Concat", diag_size, axis_i=0)
  398. # Calculate which diagonal values to select
  399. # For example, in cases with offsets:
  400. # [[0, 1.1, 0]
  401. # [0, 0, 2.2]]
  402. # we need to select the last two columns, so we create a tensor
  403. # with all columns that are to be selected
  404. # So in this example, it is [1, 2]
  405. select_window_ones_fill = opset9.ones(g, diag_size, 4, None, None)
  406. select_window = g.op(
  407. "CumSum",
  408. select_window_ones_fill,
  409. g.op("Constant", value_t=torch.LongTensor([0])),
  410. )
  411. select_window = g.op(
  412. "Add",
  413. select_window,
  414. g.op("Constant", value_t=torch.LongTensor([abs(offset) - 1])),
  415. )
  416. gather_shape = [
  417. opset9.size(g, result, dim=g.op("Constant", value_t=torch.LongTensor([axis])))
  418. for axis in list(range(rank))[:-2]
  419. ]
  420. gather_shape.append(diag_size)
  421. gather_shape = g.op("Concat", *gather_shape, axis_i=0)
  422. gather_indices = opset9.zeros(g, gather_shape, 4, None, None)
  423. # There might be cases where offset value is greater than number of rows/columns
  424. # and might cause the diagonal to overrun and as a result of this, diag_size would be zero.
  425. # For example, if
  426. # offset = 9, dim1_size = 2 (columns), dim2_size = 4 (rows)
  427. # diag_size = max(min(2, (4-9)), 0) = 0, based on calculation above
  428. # Cases with diagonal overrun always result in diag_size = max(0, -ve value) = 0
  429. # In cases without diagonal overrun, we select the appropriate rows/columns along which we
  430. # are calculating diagonal values. In cases with diagonal overrun, we return a tensor which has
  431. # the dimension of the row/column where overrun occurred as 0-dim, as we are essentially
  432. # returning an empty tensor
  433. overrun_cond = g.op(
  434. "Not",
  435. g.op(
  436. "Equal",
  437. diag_size,
  438. g.op("Constant", value_t=torch.tensor(0, dtype=torch.int64)),
  439. ),
  440. )
  441. if_op = g.op("If", overrun_cond)
  442. if_node = if_op.node()
  443. if_block = utils._add_block(if_node)
  444. gather_indices_if_block = if_block.op("Add", gather_indices, select_window)
  445. gather_indices_if_block = symbolic_helper._unsqueeze_helper(
  446. if_block, gather_indices_if_block, [rank - 1]
  447. )
  448. final_non_overrun_ = if_block.op(
  449. "GatherND", result, gather_indices_if_block, batch_dims_i=rank - 2
  450. )
  451. utils._add_output_to_block(if_block, final_non_overrun_)
  452. else_block = utils._add_block(if_node)
  453. final_overrun_ = opset9.zeros(else_block, gather_shape, 6, None, None)
  454. utils._add_output_to_block(else_block, final_overrun_)
  455. return if_op
  456. class Quantized:
  457. """
  458. https://github.com/pytorch/pytorch/wiki/PyTorch-ONNX-exporter#quantized-model-export
  459. """
  460. domain = "quantized"
  461. @staticmethod
  462. def linear(g, q_input, q_weight, bias, op_scale, op_zero_point):
  463. input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input)
  464. weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight)
  465. q_bias = symbolic_helper.requantize_bias_helper(
  466. g, bias, input_scale, weight_scale, axis
  467. )
  468. bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias)
  469. output = opset9.linear(g, input, weight, bias)
  470. return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
  471. @staticmethod
  472. def conv2d(
  473. g,
  474. q_input,
  475. q_weight,
  476. bias,
  477. stride,
  478. padding,
  479. dilation,
  480. groups,
  481. op_scale,
  482. op_zero_point,
  483. ):
  484. input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input)
  485. weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight)
  486. q_bias = symbolic_helper.requantize_bias_helper(
  487. g, bias, input_scale, weight_scale, axis
  488. )
  489. bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias)
  490. output = opset9.conv2d(
  491. g, input, weight, bias, stride, padding, dilation, groups
  492. )
  493. return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
  494. @staticmethod
  495. def conv2d_relu(
  496. g,
  497. q_input,
  498. q_weight,
  499. bias,
  500. stride,
  501. padding,
  502. dilation,
  503. groups,
  504. op_scale,
  505. op_zero_point,
  506. ):
  507. input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input)
  508. weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight)
  509. q_bias = symbolic_helper.requantize_bias_helper(
  510. g, bias, input_scale, weight_scale, axis
  511. )
  512. bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias)
  513. output = opset9.conv2d(
  514. g, input, weight, bias, stride, padding, dilation, groups
  515. )
  516. output = opset9.relu(g, output)
  517. return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)