symbolic_opset12.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389
  1. import sys
  2. import warnings
  3. import torch
  4. from torch.onnx import symbolic_helper
  5. from torch.onnx import symbolic_opset9 as opset9
  6. from torch.onnx import utils
  7. # EDITING THIS FILE? READ THIS FIRST!
  8. # see Note [Edit Symbolic Files] in symbolic_helper.py
  9. # This file exports ONNX ops for opset 12
  10. def einsum_helper(g, equation, tensors):
  11. if not tensors:
  12. raise RuntimeError("Einsum inputs are empty.")
  13. # ONNX does not support bool for Einsum inputs.
  14. if tensors[0].type().scalarType() == "Bool":
  15. tensors = [
  16. g.op("Cast", tensor, to_i=symbolic_helper.cast_pytorch_to_onnx["Long"])
  17. for tensor in tensors
  18. ]
  19. return g.op(
  20. "Cast",
  21. g.op("Einsum", *tensors, equation_s=equation),
  22. to_i=symbolic_helper.cast_pytorch_to_onnx["Bool"],
  23. )
  24. else:
  25. return g.op("Einsum", *tensors, equation_s=equation)
  26. @symbolic_helper.parse_args("s", "v")
  27. def einsum(g, equation, tensor_list):
  28. tensors = symbolic_helper._unpack_list(tensor_list)
  29. return einsum_helper(g, equation, tensors)
  30. @symbolic_helper.parse_args("v", "v")
  31. def outer(g, input, other):
  32. # make sure to cast other to self's type
  33. if other.type().scalarType() != input.type().scalarType():
  34. other = g.op(
  35. "Cast",
  36. other,
  37. to_i=symbolic_helper.cast_pytorch_to_onnx[input.type().scalarType()],
  38. )
  39. return einsum_helper(g, "i,j->ij", [input, other])
  40. @symbolic_helper.parse_args("v", "f", "i")
  41. def dropout(g, input, p, train):
  42. symbolic_helper.check_training_mode(train, "dropout")
  43. # in eval mode, dropout is non-op - if the node's train param is set to False, dropout is non-op
  44. if not train:
  45. return input
  46. warnings.warn(
  47. "Dropout is a training op and should not be exported in inference mode. "
  48. "For inference, make sure to call eval() on the model and to export it with param training=False."
  49. )
  50. p = g.op("Constant", value_t=torch.tensor(p))
  51. t = g.op("Constant", value_t=torch.tensor(True))
  52. r, _ = g.op("Dropout", input, p, t, outputs=2)
  53. return r
  54. def nll_loss(g, self, target, weight, reduction, ignore_index):
  55. # none reduction : onnx::Constant[value={0}]
  56. # mean reduction : onnx::Constant[value={1}]
  57. # sum reduction : onnx::Constant[value={2}]
  58. reduction = symbolic_helper._maybe_get_const(reduction, "i")
  59. reduction_vals = ["none", "mean", "sum"]
  60. reduction = reduction_vals[reduction]
  61. # in onnx NegativeLogLikelihoodLoss specification, ignore_index is optional without default value.
  62. # therefore we need to set ignore_index attribute even if it is not specified (e.g. ignore_index=-100).
  63. ignore_index = symbolic_helper._maybe_get_const(ignore_index, "i")
  64. if weight.node().mustBeNone():
  65. nllloss = g.op(
  66. "NegativeLogLikelihoodLoss",
  67. self,
  68. target,
  69. reduction_s=reduction,
  70. ignore_index_i=ignore_index,
  71. )
  72. else:
  73. nllloss = g.op(
  74. "NegativeLogLikelihoodLoss",
  75. self,
  76. target,
  77. weight,
  78. reduction_s=reduction,
  79. ignore_index_i=ignore_index,
  80. )
  81. return nllloss
  82. def nll_loss2d(g, self, target, weight, reduction, ignore_index):
  83. return nll_loss(g, self, target, weight, reduction, ignore_index)
  84. def nll_loss_nd(g, self, target, weight, reduction, ignore_index):
  85. return nll_loss(g, self, target, weight, reduction, ignore_index)
  86. def cross_entropy_loss(
  87. g, self, target, weight, reduction, ignore_index, label_smoothing
  88. ):
  89. # none reduction : onnx::Constant[value={0}]
  90. # mean reduction : onnx::Constant[value={1}]
  91. # sum reduction : onnx::Constant[value={2}]
  92. reduction = symbolic_helper._maybe_get_const(reduction, "i")
  93. reduction_vals = ["none", "mean", "sum"]
  94. reduction = reduction_vals[reduction]
  95. label_smoothing = symbolic_helper._maybe_get_const(label_smoothing, "f")
  96. if label_smoothing > 0.0:
  97. raise RuntimeError("Unsupported: ONNX does not support label_smoothing")
  98. # in onnx SoftmaxCrossEntropyLoss specification, ignore_index is optional without default value.
  99. # therefore we need to set ignore_index attribute even if it is not specified (e.g. ignore_index=-100).
  100. ignore_index = symbolic_helper._maybe_get_const(ignore_index, "i")
  101. if weight.node().mustBeNone():
  102. celoss = g.op(
  103. "SoftmaxCrossEntropyLoss",
  104. self,
  105. target,
  106. reduction_s=reduction,
  107. ignore_index_i=ignore_index,
  108. )
  109. else:
  110. celoss = g.op(
  111. "SoftmaxCrossEntropyLoss",
  112. self,
  113. target,
  114. weight,
  115. reduction_s=reduction,
  116. ignore_index_i=ignore_index,
  117. )
  118. return celoss
  119. @symbolic_helper.parse_args("v", "v", "v", "v", "i")
  120. def binary_cross_entropy_with_logits(g, input, target, weight, pos_weight, reduction):
  121. p = g.op("Constant", value_t=torch.tensor([1]))
  122. sig_x = opset9.sigmoid(g, input)
  123. log_sig_x = opset9.log(g, sig_x)
  124. sub_1_x = opset9.sub(g, p, sig_x)
  125. sub_1_y = opset9.sub(g, p, target)
  126. log_1_x = opset9.log(g, sub_1_x)
  127. if pos_weight is None or symbolic_helper._is_none(pos_weight):
  128. output = opset9.neg(
  129. g,
  130. opset9.add(
  131. g, opset9.mul(g, target, log_sig_x), opset9.mul(g, sub_1_y, log_1_x)
  132. ),
  133. )
  134. else:
  135. output = opset9.neg(
  136. g,
  137. opset9.add(
  138. g,
  139. opset9.mul(g, opset9.mul(g, target, log_sig_x), pos_weight),
  140. opset9.mul(g, sub_1_y, log_1_x),
  141. ),
  142. )
  143. if weight is not None and not symbolic_helper._is_none(weight):
  144. output = opset9.mul(g, weight, output)
  145. reduction = symbolic_helper._maybe_get_const(reduction, "i")
  146. if reduction == 0:
  147. return output
  148. elif reduction == 1:
  149. return g.op("ReduceMean", output, keepdims_i=0)
  150. elif reduction == 2:
  151. return g.op("ReduceSum", output, keepdims_i=0)
  152. else:
  153. return symbolic_helper._onnx_unsupported(
  154. "binary_cross_entropy_with_logits with reduction other than none, mean, or sum"
  155. )
  156. def celu(g, self, alpha):
  157. alpha = symbolic_helper._maybe_get_const(alpha, "f")
  158. # if the input is of type double cast it to float
  159. if self.type().scalarType() == "Double":
  160. self = g.op("Cast", self, to_i=symbolic_helper.cast_pytorch_to_onnx["Float"])
  161. out = g.op("Celu", self, alpha_f=alpha)
  162. return g.op("Cast", out, to_i=symbolic_helper.cast_pytorch_to_onnx["Double"])
  163. return g.op("Celu", self, alpha_f=alpha)
  164. def argmax(g, input, dim, keepdim):
  165. if symbolic_helper._is_none(dim):
  166. flattened = symbolic_helper._reshape_helper(
  167. g, input, g.op("Constant", value_t=torch.tensor([-1]))
  168. )
  169. return g.op(
  170. "ArgMax", flattened, axis_i=0, keepdims_i=False, select_last_index_i=False
  171. )
  172. else:
  173. dim = symbolic_helper._parse_arg(dim, "i")
  174. keepdim = symbolic_helper._parse_arg(keepdim, "i")
  175. return g.op(
  176. "ArgMax", input, axis_i=dim, keepdims_i=keepdim, select_last_index_i=False
  177. )
  178. def argmin(g, input, dim, keepdim):
  179. if symbolic_helper._is_none(dim):
  180. flattened = symbolic_helper._reshape_helper(
  181. g, input, g.op("Constant", value_t=torch.tensor([-1]))
  182. )
  183. return g.op(
  184. "ArgMin", flattened, axis_i=0, keepdims_i=False, select_last_index_i=False
  185. )
  186. else:
  187. dim = symbolic_helper._parse_arg(dim, "i")
  188. keepdim = symbolic_helper._parse_arg(keepdim, "i")
  189. return g.op(
  190. "ArgMin", input, axis_i=dim, keepdims_i=keepdim, select_last_index_i=False
  191. )
  192. def pow(g, self, exponent):
  193. return g.op("Pow", self, exponent)
  194. def ge(g, input, other):
  195. return g.op("GreaterOrEqual", input, other)
  196. def le(g, input, other):
  197. return g.op("LessOrEqual", input, other)
  198. @symbolic_helper.parse_args("v", "i", "v", "v")
  199. def unfold(g, input, dimension, size, step):
  200. const_size = symbolic_helper._maybe_get_const(size, "i")
  201. const_step = symbolic_helper._maybe_get_const(step, "i")
  202. if not symbolic_helper._is_value(const_size) and not symbolic_helper._is_value(
  203. const_step
  204. ):
  205. return opset9.unfold(g, input, dimension, const_size, const_step)
  206. if symbolic_helper.is_caffe2_aten_fallback():
  207. return g.at("unfold", input, dimension_i=dimension, size_i=size, step_i=step)
  208. sizedim = symbolic_helper._get_tensor_dim_size(input, dimension)
  209. if sizedim is not None:
  210. low_start = g.op("Constant", value_t=torch.tensor(0))
  211. low_end = g.op("Constant", value_t=torch.tensor(sizedim))
  212. hi_end = g.op("Constant", value_t=torch.tensor(sizedim + 1))
  213. low_indices = g.op("Range", low_start, low_end, step)
  214. hi_indices = g.op("Range", size, hi_end, step)
  215. low_size = symbolic_helper._size_helper(
  216. g, low_indices, g.op("Constant", value_t=torch.tensor(0))
  217. )
  218. hi_size = symbolic_helper._size_helper(
  219. g, hi_indices, g.op("Constant", value_t=torch.tensor(0))
  220. )
  221. ndim = symbolic_helper._get_tensor_rank(input)
  222. perm = list(range(0, ndim))
  223. perm.append(perm.pop(dimension))
  224. unsqueeze_list = []
  225. loop_condition = g.op("Constant", value_t=torch.tensor(1))
  226. loop_condition = g.op("Cast", loop_condition, to_i=9)
  227. loop_len = g.op("Min", low_size, hi_size)
  228. loop = g.op("Loop", loop_len, loop_condition)
  229. loop_block = utils._add_block(loop.node())
  230. block_input_iter = utils._add_input_to_block(loop_block)
  231. cond = utils._add_input_to_block(loop_block)
  232. starts = loop_block.op("Gather", low_indices, block_input_iter)
  233. ends = loop_block.op("Gather", hi_indices, block_input_iter)
  234. axes = loop_block.op("Constant", value_t=torch.tensor([2]))
  235. starts = symbolic_helper._unsqueeze_helper(loop_block, starts, [0])
  236. ends = symbolic_helper._unsqueeze_helper(loop_block, ends, [0])
  237. stack = loop_block.op("Slice", input, starts, ends, axes)
  238. unsqueeze = symbolic_helper._unsqueeze_helper(
  239. loop_block, loop_block.op("Transpose", stack, perm_i=perm), [dimension]
  240. )
  241. unsqueeze_list.append(unsqueeze)
  242. concat = loop_block.op("Concat", *unsqueeze_list, axis_i=0)
  243. cond_out = loop_block.op("Cast", loop_condition, to_i=9)
  244. utils._add_output_to_block(loop_block, cond_out)
  245. utils._add_output_to_block(loop_block, concat)
  246. loop_output = loop.node().output()
  247. perm = [0, 1, 2, 3, 4]
  248. perm[0], perm[dimension + 1] = perm[dimension + 1], perm[0]
  249. transpose = g.op("Transpose", loop_output, perm_i=perm)
  250. squeeze = symbolic_helper._squeeze_helper(g, transpose, [0])
  251. return squeeze
  252. else:
  253. return symbolic_helper._unimplemented("Unfold", "input size not accessible")
  254. @symbolic_helper.parse_args("v", "v", "is", "is", "v")
  255. def tensordot(g, input_a, input_b, dims_a, dims_b, out=None):
  256. if out is not None:
  257. symbolic_helper._unimplemented(
  258. "Tensordot", "Out parameter is not supported for tensordot."
  259. )
  260. dim_count_a = symbolic_helper._get_tensor_rank(input_a)
  261. if dim_count_a is None:
  262. raise RuntimeError(
  263. "Unsupported: ONNX export of tensordot for tensor(input_a) of unknown rank."
  264. )
  265. dim_count_b = symbolic_helper._get_tensor_rank(input_b)
  266. if dim_count_b is None:
  267. raise RuntimeError(
  268. "Unsupported: ONNX export of tensordot for tensor(input_b) of unknown rank."
  269. )
  270. dims_a = [
  271. (dims_a[i] + dim_count_a) if (dims_a[i] < 0) else dims_a[i]
  272. for i in range(len(dims_a))
  273. ]
  274. dims_b = [
  275. (dims_b[i] + dim_count_b) if (dims_b[i] < 0) else dims_b[i]
  276. for i in range(len(dims_b))
  277. ]
  278. left_dims_a = [i for i in range(dim_count_a) if (i not in dims_a)]
  279. left_dims_b = [i for i in range(dim_count_b) if (i not in dims_b)]
  280. new_input_a = opset9.permute(g, input_a, left_dims_a + dims_a)
  281. new_input_b = opset9.permute(g, input_b, dims_b + left_dims_b)
  282. input_shape = g.op("Shape", new_input_a)
  283. left_sizes_a = symbolic_helper._slice_helper(
  284. g, input_shape, axes=[0], starts=[0], ends=[len(left_dims_a)]
  285. )
  286. shape_sizes = [
  287. left_sizes_a,
  288. g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)),
  289. ]
  290. output_a = opset9._reshape_from_tensor(g, new_input_a, shape_sizes)
  291. input_shape = g.op("Shape", output_a)
  292. slices = symbolic_helper._slice_helper(
  293. g, input_shape, axes=[0], starts=[-1], ends=[sys.maxsize]
  294. )
  295. shape_sizes = [
  296. g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)),
  297. slices,
  298. ]
  299. output_a = opset9._reshape_from_tensor(g, new_input_a, shape_sizes)
  300. input_shape = g.op("Shape", new_input_b)
  301. left_sizes_b = symbolic_helper._slice_helper(
  302. g, input_shape, axes=[0], starts=[len(dims_b)], ends=[sys.maxsize]
  303. )
  304. slices = symbolic_helper._slice_helper(
  305. g, input_shape, axes=[0], starts=[0], ends=[len(dims_b)]
  306. )
  307. shape_sizes = [
  308. slices,
  309. g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)),
  310. ]
  311. output_b = opset9._reshape_from_tensor(g, new_input_b, shape_sizes)
  312. input_shape = g.op("Shape", output_b)
  313. slices = symbolic_helper._slice_helper(
  314. g, input_shape, axes=[0], starts=[-1], ends=[sys.maxsize]
  315. )
  316. shape_sizes = [
  317. g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)),
  318. slices,
  319. ]
  320. output_b = opset9._reshape_from_tensor(g, new_input_b, shape_sizes)
  321. output = einsum(g, "ij,jk->ik", g.op("prim::ListConstruct", *[output_a, output_b]))
  322. shape_sizes = [left_sizes_a, left_sizes_b]
  323. return opset9._reshape_from_tensor(g, output, shape_sizes)