symbolic_opset11.py 46 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248
  1. """This file exports ONNX ops for opset 11."""
  2. import sys
  3. import warnings
  4. from typing import Tuple, Union
  5. import torch
  6. from torch import _C
  7. from torch.onnx import symbolic_helper
  8. from torch.onnx import symbolic_opset9 as opset9
  9. from torch.onnx import symbolic_opset10 as opset10
  10. from torch.onnx import utils
  11. from torch.onnx._globals import GLOBALS
  12. # EDITING THIS FILE? READ THIS FIRST!
  13. # see Note [Edit Symbolic Files] in symbolic_helper.py
  14. # This file exports ONNX ops for opset 11
  15. @symbolic_helper.parse_args("v", "f", "f")
  16. def hardtanh(g, self, min_val, max_val):
  17. dtype = self.type().scalarType()
  18. if dtype is None:
  19. dtype = symbolic_helper.ScalarType.FLOAT
  20. else:
  21. dtype = symbolic_helper.scalar_type_to_onnx.index(
  22. symbolic_helper.cast_pytorch_to_onnx[dtype]
  23. )
  24. min_val = g.op(
  25. "Constant",
  26. value_t=torch.tensor(
  27. min_val, dtype=symbolic_helper.scalar_type_to_pytorch_type[dtype]
  28. ),
  29. )
  30. max_val = g.op(
  31. "Constant",
  32. value_t=torch.tensor(
  33. max_val, dtype=symbolic_helper.scalar_type_to_pytorch_type[dtype]
  34. ),
  35. )
  36. return opset9.op_with_optional_float_cast(
  37. g, "Clip", self, min_val, max_val, opset_before=12
  38. )
  39. def clamp(g, self, min, max):
  40. dtype = self.type().scalarType()
  41. def _cast_if_not_none(tensor, dtype):
  42. if tensor is not None and not symbolic_helper._is_none(tensor):
  43. return g.op(
  44. "Cast", tensor, to_i=symbolic_helper.cast_pytorch_to_onnx[dtype]
  45. )
  46. else:
  47. return tensor
  48. if dtype is not None:
  49. min = _cast_if_not_none(min, dtype)
  50. max = _cast_if_not_none(max, dtype)
  51. if symbolic_helper._is_none(min):
  52. return clamp_max(g, self, max)
  53. elif symbolic_helper._is_none(max):
  54. return clamp_min(g, self, min)
  55. else:
  56. if (
  57. symbolic_helper._get_tensor_rank(min) == 0
  58. and symbolic_helper._get_tensor_rank(max) == 0
  59. ):
  60. return opset9.op_with_optional_float_cast(
  61. g, "Clip", self, min, max, opset_before=12
  62. )
  63. else:
  64. return clamp_max(g, clamp_min(g, self, min), max)
  65. @symbolic_helper.parse_args("v", "v")
  66. def clamp_min(g, self, min):
  67. dtype = self.type().scalarType()
  68. min = g.op("Cast", min, to_i=symbolic_helper.cast_pytorch_to_onnx[dtype])
  69. if symbolic_helper._get_tensor_rank(min) == 0:
  70. max = opset9.unused(g)
  71. return opset9.op_with_optional_float_cast(
  72. g, "Clip", self, min, max, opset_before=12
  73. )
  74. else:
  75. return opset9.op_with_optional_float_cast(g, "Max", self, min, opset_before=12)
  76. @symbolic_helper.parse_args("v", "v")
  77. def clamp_max(g, self, max):
  78. dtype = self.type().scalarType()
  79. max = g.op("Cast", max, to_i=symbolic_helper.cast_pytorch_to_onnx[dtype])
  80. if symbolic_helper._get_tensor_rank(max) == 0:
  81. min = opset9.unused(g)
  82. return opset9.op_with_optional_float_cast(
  83. g, "Clip", self, min, max, opset_before=12
  84. )
  85. else:
  86. return opset9.op_with_optional_float_cast(g, "Min", self, max, opset_before=12)
  87. def relu6(g, input):
  88. relu = opset9.op_with_optional_float_cast(g, "Relu", input, opset_before=14)
  89. dtype = input.type().scalarType()
  90. if dtype is None:
  91. dtype = symbolic_helper.ScalarType.FLOAT
  92. else:
  93. dtype = symbolic_helper.scalar_type_to_onnx.index(
  94. symbolic_helper.cast_pytorch_to_onnx[dtype]
  95. )
  96. min_val = g.op(
  97. "Constant",
  98. value_t=torch.tensor(
  99. 0, dtype=symbolic_helper.scalar_type_to_pytorch_type[dtype]
  100. ),
  101. )
  102. max_val = g.op(
  103. "Constant",
  104. value_t=torch.tensor(
  105. 6, dtype=symbolic_helper.scalar_type_to_pytorch_type[dtype]
  106. ),
  107. )
  108. return clamp(g, relu, min_val, max_val)
  109. # Opset 11 gather accepts negative indices
  110. @symbolic_helper.parse_args("v", "i", "v")
  111. def select(g, self, dim, index):
  112. return g.op("Gather", self, index, axis_i=dim)
  113. def index_put(g, self, indices_list_value, values, accumulate=False):
  114. if symbolic_helper._is_packed_list(indices_list_value):
  115. indices_list = symbolic_helper._unpack_list(indices_list_value)
  116. else:
  117. indices_list = [indices_list_value]
  118. if symbolic_helper.is_caffe2_aten_fallback():
  119. args = [self] + indices_list + [values, accumulate]
  120. return g.at("index_put", *args)
  121. accumulate = symbolic_helper._parse_arg(accumulate, "b")
  122. if len(indices_list) == 0:
  123. return values
  124. if len(indices_list) > 1:
  125. for idx_ in range(len(indices_list)):
  126. if indices_list[idx_].type().scalarType() == "Bool": # type: ignore[attr-defined]
  127. # TODO(justinchuby): Remove type ignore after #81112 is checked in.
  128. indices_list[idx_] = g.op("NonZero", indices_list[idx_])
  129. index = indices_list[0]
  130. for ind in indices_list[1:]:
  131. index = opset9.add(g, index, ind)
  132. broadcast_index_shape = g.op("Shape", index)
  133. indices_list = [
  134. symbolic_helper._unsqueeze_helper(
  135. g, opset9.expand(g, ind, broadcast_index_shape, None), [-1]
  136. )
  137. for ind in indices_list
  138. ]
  139. index = g.op("Concat", *indices_list, axis_i=-1)
  140. else:
  141. # Replace index_put node with masked_scatter or masked_fill
  142. # when inputs to the index_put node contains a single boolean input.
  143. #
  144. # index_put -> masked_fill
  145. # * input index contains single tensor of Bool type (e.g.: %24 <- %23).
  146. # * input value contains single element (e.g.: %18).
  147. #
  148. # Torch IR
  149. # %mask : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = aten::clone(%0, %6)
  150. # %16 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) =
  151. # aten::to(%8, %26, %27, %11, %12, %28, %29, %15)
  152. # %18 : Float(requires_grad=0, device=cpu) = prim::Constant[value={1}]()
  153. # %23 : Bool(8, strides=[1], device=cpu) = aten::view(%16, %22)
  154. # %24 : Tensor?[] = prim::ListConstruct(%23)
  155. # %25 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) =
  156. # aten::index_put(%mask, %24, %18, %30)
  157. # return (%25)
  158. #
  159. #
  160. # index_put -> masked_scatter
  161. # * input index contains single tensor of Bool type (e.g.: %32 <- %31).
  162. # * input value contains multiple elements (e.g.: %28).
  163. #
  164. # Torch IR
  165. # %mask : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = aten::clone(%0, %6)
  166. # %28 : Float(8, strides=[1], requires_grad=0, device=cpu)
  167. # = prim::Constant[value= 1 1 1 1 1 1 1 1 [ CPUFloatType{8} ]]()
  168. # %15 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu)
  169. # = aten::ne(%mask, %some_const)
  170. # %23 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu)
  171. # = aten::to(%15, %34, %35, %18, %19, %36, %37, %22)
  172. # %38 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]()
  173. # %30 : int[] = prim::Constant[value=[-1]]()
  174. # %31 : Bool(8, strides=[1], device=cpu) = aten::view(%23, %30)
  175. # %32 : Tensor?[] = prim::ListConstruct(%31)
  176. # %33 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu)
  177. # = aten::index_put(%mask, %32, %28, %38)
  178. # return (%33)
  179. index = indices_list[0]
  180. bool_inp = index
  181. if bool_inp.type() is not None and bool_inp.type().scalarType() == "Bool": # type: ignore[attr-defined]
  182. # TODO(justinchuby): Remove type ignore after #81112 is checked in.
  183. rank = symbolic_helper._get_tensor_rank(values)
  184. if rank is not None and rank == 0:
  185. return opset9.masked_fill(g, self, bool_inp, values)
  186. return masked_scatter(g, self, bool_inp, values)
  187. broadcast_index_shape = g.op("Shape", index)
  188. index = symbolic_helper._unsqueeze_helper(g, index, [-1])
  189. sub_data_shape = symbolic_helper._slice_helper(
  190. g, g.op("Shape", self), axes=[0], starts=[len(indices_list)], ends=[sys.maxsize]
  191. )
  192. values_shape = g.op("Concat", broadcast_index_shape, sub_data_shape, axis_i=0)
  193. # Check if values is a singular value and expand accordingly
  194. rank = symbolic_helper._get_tensor_rank(values)
  195. if rank is not None and rank == 0:
  196. values = opset9.expand(g, values, values_shape, None)
  197. values = symbolic_helper._reshape_helper(g, values, values_shape)
  198. dtype = self.type().scalarType()
  199. if dtype is not None and dtype != values.type().scalarType():
  200. values = g.op("Cast", values, to_i=symbolic_helper.cast_pytorch_to_onnx[dtype])
  201. dtype = symbolic_helper.scalar_type_to_onnx.index(
  202. symbolic_helper.cast_pytorch_to_onnx[dtype]
  203. )
  204. dtype = symbolic_helper.scalar_type_to_pytorch_type[dtype]
  205. if accumulate:
  206. zeros = g.op(
  207. "ConstantOfShape",
  208. g.op("Shape", self),
  209. value_t=torch.tensor([0], dtype=dtype),
  210. )
  211. result = g.op("ScatterND", zeros, index, values)
  212. result = add(g, self, result)
  213. else:
  214. result = g.op("ScatterND", self, index, values)
  215. return result
  216. @symbolic_helper.parse_args("v", "i")
  217. def pixel_shuffle(g, self, upscale_factor):
  218. rank = symbolic_helper._get_tensor_rank(self)
  219. if rank is not None and rank != 4:
  220. return symbolic_helper._unimplemented("pixel_shuffle", "only support 4d input")
  221. return g.op("DepthToSpace", self, blocksize_i=upscale_factor, mode_s="CRD")
  222. def _interpolate(name, dim, interpolate_mode):
  223. return symbolic_helper._interpolate_helper(name, dim, interpolate_mode)
  224. upsample_nearest1d = _interpolate("upsample_nearest1d", 3, "nearest")
  225. upsample_nearest2d = _interpolate("upsample_nearest2d", 4, "nearest")
  226. upsample_nearest3d = _interpolate("upsample_nearest3d", 5, "nearest")
  227. upsample_linear1d = _interpolate("upsample_linear1d", 3, "linear")
  228. upsample_bilinear2d = _interpolate("upsample_bilinear2d", 4, "linear")
  229. upsample_trilinear3d = _interpolate("upsample_trilinear3d", 5, "linear")
  230. upsample_bicubic2d = _interpolate("upsample_bicubic2d", 4, "cubic")
  231. @symbolic_helper.quantized_args(True, False, False, False, False, False, False)
  232. def __interpolate(
  233. g, input, size, scale_factor, mode, align_corners, recompute_scale_factor, antialias
  234. ):
  235. return symbolic_helper.__interpolate_helper(
  236. g, input, size, scale_factor, mode, align_corners, recompute_scale_factor
  237. )
  238. @symbolic_helper.parse_args("v", "i", "v", "v")
  239. def gather(g, self, dim, index, sparse_grad=False):
  240. if symbolic_helper._maybe_get_const(sparse_grad, "i"):
  241. return symbolic_helper._unimplemented("gather", "sparse_grad == True")
  242. if symbolic_helper.is_caffe2_aten_fallback():
  243. return g.at("gather", self, dim, index, sparse_grad)
  244. return g.op("GatherElements", self, index, axis_i=dim)
  245. @symbolic_helper.parse_args("v", "i", "v", "v")
  246. def scatter(g, self, dim, index, src):
  247. if symbolic_helper.is_caffe2_aten_fallback():
  248. return g.at("scatter", self, dim, index, src, overload_name="src")
  249. src_type = src.type().scalarType()
  250. src = symbolic_helper._maybe_get_scalar(src)
  251. if symbolic_helper._is_value(src):
  252. return g.op("ScatterElements", self, index, src, axis_i=dim)
  253. else:
  254. # Check if scalar "src" has same type as self (PyTorch allows different
  255. # type for scalar src (but not when src is tensor)). If not, insert Cast node.
  256. if self.type().scalarType() != src_type:
  257. src = g.op(
  258. "Cast",
  259. src,
  260. to_i=symbolic_helper.cast_pytorch_to_onnx[self.type().scalarType()],
  261. )
  262. return g.op(
  263. "ScatterElements", self, index, opset9.expand_as(g, src, index), axis_i=dim
  264. )
  265. @symbolic_helper.parse_args("v", "i", "none")
  266. def cumsum(g, self, dim, dtype=None):
  267. dim_tensor = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.int))
  268. if dtype and dtype.node().kind() != "prim::Constant":
  269. parsed_dtype = symbolic_helper._get_const(dtype, "i", "dtype")
  270. cast = g.op(
  271. "Cast", self, to_i=symbolic_helper.scalar_type_to_onnx[parsed_dtype]
  272. )
  273. else:
  274. cast = self
  275. csum = g.op("CumSum", cast, dim_tensor)
  276. return csum
  277. def masked_select(g, self, mask):
  278. index = opset9.nonzero(g, opset9.expand_as(g, mask, self))
  279. return g.op("GatherND", self, index)
  280. def masked_scatter(g, self, mask, source):
  281. index = opset9.nonzero(g, opset9.expand_as(g, mask, self))
  282. # NOTE: source can have more elements than needed.
  283. # It could also have arbitrary shape.
  284. # This is not supported by ONNX::ScatterND, so we need to flatten and slice source tensor.
  285. source = symbolic_helper._reshape_helper(g, source, torch.LongTensor([-1]))
  286. source = symbolic_helper._slice_helper(
  287. g,
  288. source,
  289. axes=torch.LongTensor([0]),
  290. starts=torch.LongTensor([0]),
  291. ends=opset9.size(g, index, torch.LongTensor([0])),
  292. dynamic_slice=True,
  293. )
  294. return g.op("ScatterND", self, index, source)
  295. def _len(g, self):
  296. if (
  297. symbolic_helper._is_tensor_list(self)
  298. or self.node().kind() == "onnx::SplitToSequence"
  299. ):
  300. return g.op("SequenceLength", self)
  301. sz_0 = size(g, self, g.op("Constant", value_t=torch.LongTensor([0])))
  302. return symbolic_helper._squeeze_helper(g, sz_0, [0])
  303. def __getitem_(g, self, i):
  304. if symbolic_helper._is_tensor_list(self):
  305. # SequenceAt requires that the input be a List of Tensors
  306. return g.op("SequenceAt", self, i)
  307. else:
  308. from torch.onnx.symbolic_opset9 import __getitem_ as getitem
  309. return getitem(g, self, i)
  310. def _set_item(g, tensor_list, i, v):
  311. tensor_list = g.op("SequenceErase", tensor_list, i)
  312. return g.op("SequenceInsert", tensor_list, v, i)
  313. def append(g, self, tensor):
  314. return g.op("SequenceInsert", self, tensor)
  315. def add(g, self, other, alpha=None):
  316. if symbolic_helper._is_value(self) and symbolic_helper._is_tensor_list(self):
  317. tensor_list_node = other.node()
  318. if tensor_list_node.kind() != "prim::ListConstruct":
  319. return symbolic_helper._unimplemented(
  320. "add", "does not support adding dynamic tensor list to another"
  321. )
  322. tensors = symbolic_helper._unpack_list(other)
  323. l = self
  324. for t in tensors:
  325. l = g.op("SequenceInsert", l, t)
  326. return l
  327. return opset9.add(g, self, other, alpha)
  328. def insert(g, self, pos, tensor):
  329. return g.op("SequenceInsert", self, tensor, pos)
  330. def pop(g, tensor_list, dim):
  331. return g.op("SequenceErase", tensor_list, dim)
  332. def Delete(g, tensor_list, dim):
  333. return g.op("SequenceErase", tensor_list, dim)
  334. def cat(g, tensor_list, dim):
  335. if symbolic_helper._is_packed_list(tensor_list):
  336. return opset9.cat(g, tensor_list, dim)
  337. else:
  338. dim = symbolic_helper._get_const(dim, "i", "dim")
  339. return g.op("ConcatFromSequence", tensor_list, axis_i=dim)
  340. def stack(g, tensor_list, dim):
  341. if symbolic_helper._is_packed_list(tensor_list):
  342. return opset9.stack(g, tensor_list, dim)
  343. else:
  344. dim = symbolic_helper._get_const(dim, "i", "dim")
  345. return g.op("ConcatFromSequence", tensor_list, axis_i=dim, new_axis_i=1)
  346. @symbolic_helper.parse_args("v", "i", "i", "i")
  347. def _unique2(g, self, sorted, return_inverse, return_counts):
  348. u, indices, inverse_indices, counts = g.op(
  349. "Unique", self, sorted_i=sorted, outputs=4
  350. )
  351. return u, inverse_indices, counts
  352. def _avg_pool(name, tuple_fn):
  353. @symbolic_helper.quantized_args(True, False, False, False, False, False, False)
  354. @symbolic_helper.parse_args("v", "is", "is", "is", "i", "i", "none")
  355. def symbolic_fn(
  356. g,
  357. input: _C.Value,
  358. kernel_size: Tuple[int, ...],
  359. stride: Tuple[int, ...],
  360. padding: Union[int, Tuple[int, ...]],
  361. ceil_mode: int,
  362. count_include_pad: int,
  363. divisor_override=None,
  364. ):
  365. padding = symbolic_helper._avgpool_helper(
  366. tuple_fn, padding, kernel_size, stride, divisor_override, name
  367. )
  368. if not stride:
  369. stride = kernel_size
  370. if count_include_pad:
  371. input = g.op(
  372. "Pad",
  373. input,
  374. g.op("Constant", value_t=torch.tensor(((0,) * 2 + padding) * 2)),
  375. mode_s="constant",
  376. )
  377. padding = (0,) * len(padding)
  378. output = g.op(
  379. "AveragePool",
  380. input,
  381. kernel_shape_i=tuple_fn(kernel_size),
  382. strides_i=tuple_fn(stride),
  383. pads_i=padding * 2,
  384. ceil_mode_i=ceil_mode,
  385. )
  386. return output
  387. return symbolic_fn
  388. avg_pool1d = _avg_pool("avg_pool1d", torch.nn.modules.utils._single)
  389. avg_pool2d = _avg_pool("avg_pool2d", torch.nn.modules.utils._pair)
  390. avg_pool3d = _avg_pool("avg_pool3d", torch.nn.modules.utils._triple)
  391. @symbolic_helper.parse_args("v", "i", "i", "i", "i")
  392. def unique_dim(g, self, dim, sorted, return_inverse, return_counts):
  393. u, indices, inverse_indices, counts = g.op(
  394. "Unique", self, axis_i=dim, sorted_i=sorted, outputs=4
  395. )
  396. return u, inverse_indices, counts
  397. @symbolic_helper.parse_args("v", "v", "i", "i", "i", "none")
  398. def topk(g, self, k, dim, largest, sorted, out=None):
  399. return symbolic_helper._topk_helper(
  400. g, self, k, dim, largest=largest, sorted=sorted, out=out
  401. )
  402. @symbolic_helper.parse_args("v", "i", "i", "none")
  403. def sort(g, self, dim, decending, out=None):
  404. return symbolic_helper._sort_helper(g, self, dim, decending=decending, out=out)
  405. def round(g, self):
  406. return g.op("Round", self)
  407. def remainder(g, input, other):
  408. if symbolic_helper._is_fp(input) or symbolic_helper._is_fp(other):
  409. return opset9.remainder(g, input, other)
  410. return g.op("Mod", input, other, fmod_i=0)
  411. @symbolic_helper.parse_args("v", "v", "i", "i")
  412. def split(g, self, split_size_or_sizes, dim, _outputs=None):
  413. if not symbolic_helper._is_split_static(split_size_or_sizes, _outputs):
  414. split_out = g.op("SplitToSequence", self, split_size_or_sizes, axis_i=dim)
  415. if _outputs is None:
  416. return split_out
  417. # Convert to multiple slice nodes iff number of splits and number of outputs are statically known.
  418. if (
  419. symbolic_helper._is_packed_list(split_size_or_sizes)
  420. and len(symbolic_helper._unpack_list(split_size_or_sizes)) == _outputs
  421. ):
  422. split_sizes = [
  423. symbolic_helper._unsqueeze_helper(g, v, [0])
  424. for v in symbolic_helper._unpack_list(split_size_or_sizes)
  425. ]
  426. start = g.op("Constant", value_t=torch.tensor([0], dtype=torch.long))
  427. axis = g.op("Constant", value_t=torch.tensor([dim], dtype=torch.long))
  428. res = []
  429. for i in range(_outputs):
  430. end = g.op(
  431. "Add", start, split_sizes[i]
  432. ) # split_sizes is a list of same length as _outputs
  433. res.append(g.op("Slice", self, start, end, axis))
  434. start = end
  435. return res
  436. return [
  437. g.op(
  438. "SequenceAt",
  439. split_out,
  440. g.op("Constant", value_t=torch.tensor([i], dtype=torch.long)),
  441. )
  442. for i in range(_outputs)
  443. ]
  444. else:
  445. return opset9.split(g, self, split_size_or_sizes, dim, _outputs)
  446. @symbolic_helper.parse_args("v", "v", "i", "i")
  447. def split_with_sizes(g, self, split_sizes, dim, _outputs=None):
  448. return split(g, self, split_sizes, dim, _outputs)
  449. @symbolic_helper.parse_args("v", "i", "i")
  450. def unbind(g, self, dim=0, _outputs=None):
  451. if _outputs is None:
  452. return g.op(
  453. "SplitToSequence",
  454. self,
  455. g.op("Constant", value_t=torch.tensor(1, dtype=torch.long)),
  456. axis_i=dim,
  457. keepdims_i=0,
  458. )
  459. else:
  460. return opset9.unbind(g, self, dim, _outputs)
  461. # Generate paddings in ONNX order based on pad in pytorch.
  462. # Args:
  463. # input: the input tensor.
  464. # pad: the paddings in pytorch.
  465. # The order is dim_n_begin, dim_n_end, dim_n-1_begin, dim_n-1_end, ..., dim_m_begin, dim_m_end,
  466. # where m is in range [0, n].
  467. def _prepare_onnx_paddings(g, input, pad):
  468. if (
  469. not symbolic_helper._is_packed_list(pad)
  470. and symbolic_helper._is_list(pad)
  471. and symbolic_helper._is_scalar_list(pad)
  472. ):
  473. pad = g.op("ConcatFromSequence", pad, axis_i=0, new_axis_i=1)
  474. # The desired order of paddings is
  475. # dim_0_begin, dim_1_begin, ... , dim_0_end, ..., dim_n_end.
  476. # n is the dimension of input.
  477. # Assume zero-dimensions in the beginning, pad the "pad" sequence with zeros in the beginning
  478. pad_len = opset9.size(g, pad, g.op("Constant", value_t=torch.tensor([0])))
  479. # Set extension = [0] * (dim * 2 - len(pad))
  480. rank = symbolic_helper._get_tensor_rank(input)
  481. if rank is None:
  482. rank = g.op("Size", g.op("Shape", input))
  483. else:
  484. rank = g.op("Constant", value_t=torch.tensor(rank, dtype=torch.int64))
  485. extension = g.op(
  486. "Sub",
  487. g.op("Mul", rank, g.op("Constant", value_t=torch.tensor(2, dtype=torch.int64))),
  488. pad_len,
  489. )
  490. # Concat pad with extension: paddings = [dim_n_begin, dim_n_end, dim_n-1_begin, dim_n-1_end, 0, 0, ... ]
  491. # Currently ONNX only supports int64 type for Pad
  492. pad = g.op("Cast", pad, to_i=symbolic_helper.cast_pytorch_to_onnx["Long"])
  493. paddings = g.op(
  494. "Concat",
  495. pad,
  496. g.op(
  497. "ConstantOfShape", extension, value_t=torch.tensor([0], dtype=torch.int64)
  498. ),
  499. axis_i=0,
  500. )
  501. # Reshape and reverse order and collate first beginnings and then ends
  502. # paddings = [[..., 0, dim_n-1_begin, dim_n_begin],
  503. # [..., 0, dim_n-1_end, dim_n_end]]
  504. # Reshape back to 1-D paddings = [..., 0, dim_n - 1_begin, dim_n_begin, ..., 0, dim_n - 1_end, dim_n_end]
  505. paddings = symbolic_helper._reshape_helper(
  506. g, paddings, g.op("Constant", value_t=torch.tensor([-1, 2]))
  507. )
  508. paddings = g.op("Transpose", opset10.flip(g, paddings, [0]), perm_i=[1, 0])
  509. paddings = symbolic_helper._reshape_helper(
  510. g, paddings, g.op("Constant", value_t=torch.tensor([-1]))
  511. )
  512. padding_c = g.op(
  513. "Cast", paddings, to_i=symbolic_helper.cast_pytorch_to_onnx["Long"]
  514. )
  515. return padding_c
  516. def constant_pad_nd(g, input, padding, value=None):
  517. mode = "constant"
  518. value = symbolic_helper._maybe_get_scalar(value)
  519. value = symbolic_helper._if_scalar_type_as(g, value, input)
  520. pad = _prepare_onnx_paddings(g, input, padding)
  521. return g.op("Pad", input, pad, value, mode_s=mode)
  522. def reflection_pad(g, input, padding):
  523. mode = "reflect"
  524. paddings = _prepare_onnx_paddings(g, input, padding)
  525. return g.op("Pad", input, paddings, mode_s=mode)
  526. def replication_pad(g, input, padding):
  527. mode = "edge"
  528. paddings = _prepare_onnx_paddings(g, input, padding)
  529. return g.op("Pad", input, paddings, mode_s=mode)
  530. reflection_pad1d = reflection_pad
  531. reflection_pad2d = reflection_pad
  532. reflection_pad3d = reflection_pad
  533. replication_pad1d = replication_pad
  534. replication_pad2d = replication_pad
  535. replication_pad3d = replication_pad
  536. def pad(g, input, pad, mode, value):
  537. mode = symbolic_helper._parse_arg(mode, "s")
  538. if mode == "replicate":
  539. return replication_pad(g, input, pad)
  540. elif mode == "reflect":
  541. return reflection_pad(g, input, pad)
  542. elif mode == "constant":
  543. return constant_pad_nd(g, input, pad, value)
  544. elif mode == "circular":
  545. return opset9._pad_circular(g, input, pad)
  546. else:
  547. raise RuntimeError(f"Unrecognized padding mode {mode}")
  548. def linalg_det(g, self):
  549. return g.op("Det", self)
  550. def logdet(g, input):
  551. return opset9.log(g, linalg_det(g, input))
  552. def arange(g, *args):
  553. def _get_arange_dtype(dtype):
  554. dtype = symbolic_helper._maybe_get_const(dtype, "i")
  555. return dtype
  556. if len(args) == 2 or len(args) == 5:
  557. if len(args) == 2:
  558. # aten::arange(Scalar end, Tensor out)
  559. dtype = None
  560. else:
  561. # aten::arange(Scalar end, ScalarType dtype, Layout, Device, bool pin_memory)
  562. dtype = _get_arange_dtype(args[1])
  563. type, end, start, step = symbolic_helper._arange_cast_helper(
  564. g, end=args[0], dtype=dtype
  565. )
  566. start_default = g.op(
  567. "Constant",
  568. value_t=torch.tensor(
  569. 0, dtype=symbolic_helper.scalar_type_to_pytorch_type[type]
  570. ),
  571. )
  572. delta_default = g.op(
  573. "Constant",
  574. value_t=torch.tensor(
  575. 1, dtype=symbolic_helper.scalar_type_to_pytorch_type[type]
  576. ),
  577. )
  578. arange_tensor = g.op("Range", start_default, end, delta_default)
  579. elif len(args) == 4 or len(args) == 7:
  580. if len(args) == 4:
  581. # aten::arange(Scalar start, Scalar end, Scalar step, Tensor out)
  582. dtype = None
  583. else:
  584. # aten::arange(Scalar start, Scalar end, Scalar step, ScalarType dtype, Layout, Device, bool pin_memory)
  585. dtype = _get_arange_dtype(args[3])
  586. type, end, start, step = symbolic_helper._arange_cast_helper(
  587. g, start=args[0], end=args[1], step=args[2], dtype=dtype
  588. )
  589. arange_tensor = g.op("Range", start, end, step)
  590. elif len(args) == 6:
  591. # aten::arange(Scalar start, Scalar end, ScalarType dtype, Layout, Device, bool pin_memory)
  592. dtype = _get_arange_dtype(args[2])
  593. type, end, start, step = symbolic_helper._arange_cast_helper(
  594. g, start=args[0], end=args[1], dtype=dtype
  595. )
  596. delta_default = g.op(
  597. "Constant",
  598. value_t=torch.tensor(
  599. 1, dtype=symbolic_helper.scalar_type_to_pytorch_type[type]
  600. ),
  601. )
  602. arange_tensor = g.op("Range", start, end, delta_default)
  603. else:
  604. raise NotImplementedError(
  605. "Unknown aten::arange signature taking " + str(len(args)) + " arguments."
  606. )
  607. return arange_tensor
  608. @symbolic_helper.parse_args("v", "i")
  609. def _dim_arange(g, like, dim):
  610. like_shape = g.op("Shape", like)
  611. stop = g.op(
  612. "Gather", like_shape, g.op("Constant", value_t=torch.tensor(dim)), axis_i=0
  613. )
  614. if symbolic_helper.is_caffe2_aten_fallback():
  615. return g.op("_caffe2::Range", stop)
  616. return arange(g, stop, 4, None, None, None)
  617. def size(g, self, dim=None):
  618. if dim is None:
  619. return g.op("Shape", self)
  620. return symbolic_helper._size_helper(g, self, dim)
  621. def squeeze(g, self, dim=None):
  622. if dim is None:
  623. return g.op("Squeeze", self)
  624. # dim as a tensor
  625. if not symbolic_helper._is_constant(dim):
  626. return symbolic_helper._squeeze_helper(g, self, [dim])
  627. dim = symbolic_helper._get_const(dim, "i", "dim")
  628. input_rank = symbolic_helper._get_tensor_rank(self)
  629. adjusted_dim = dim
  630. if input_rank is not None and dim < 0:
  631. adjusted_dim += input_rank
  632. dim_size = symbolic_helper._get_tensor_dim_size(self, adjusted_dim)
  633. if (dim < 0 and input_rank is None) or dim_size is None:
  634. # If onnx shape inference is not on, export always as dynamic.
  635. # Because we cannot tell if observed static shape is also static at runtime.
  636. # create "cond" node (condition is shape[i]==1)
  637. dim_constant = g.op("Constant", value_t=torch.tensor([dim]))
  638. size = symbolic_helper._size_helper(g, self, dim_constant)
  639. const_one = g.op("Constant", value_t=torch.ones(1, dtype=torch.int64))
  640. cond = g.op("Equal", size, const_one)
  641. # create the "If" node and add the "then" and "else" blocks to it.
  642. if_node_outputs = g.op("If", cond)
  643. if_node = if_node_outputs.node()
  644. if_block = utils._add_block(if_node)
  645. squeeze_ = symbolic_helper._squeeze_helper(if_block, self, [dim])
  646. utils._add_output_to_block(if_block, squeeze_)
  647. else_block = utils._add_block(if_node)
  648. identity_ = else_block.op("Identity", self)
  649. utils._add_output_to_block(else_block, identity_)
  650. return if_node_outputs
  651. # For static input shape
  652. dim = adjusted_dim
  653. if dim_size > 1:
  654. warnings.warn(
  655. "This model contains a squeeze operation on dimension "
  656. + str(dim)
  657. + ". The size of "
  658. + "this dimension in the given input is "
  659. + str(dim_size)
  660. + ". The model will "
  661. + "be exported without the squeeze node. If the model is intended to be used with dynamic "
  662. + "input shapes, please export with dynamic_axes argument."
  663. )
  664. return self
  665. return symbolic_helper._squeeze_helper(g, self, [dim])
  666. def unsqueeze(g, self, dim):
  667. if symbolic_helper._is_constant(dim):
  668. dim = symbolic_helper._get_const(dim, "i", "dim")
  669. return symbolic_helper._unsqueeze_helper(g, self, [dim])
  670. def mm(g, self, other):
  671. return g.op("Gemm", self, other, beta_f=0.0, alpha_f=1.0)
  672. def index(g, self, index):
  673. if symbolic_helper.is_caffe2_aten_fallback():
  674. return g.at("index", self, index, overload_name="Tensor")
  675. if symbolic_helper._is_packed_list(index):
  676. indices = symbolic_helper._unpack_list(index)
  677. else:
  678. indices = [index]
  679. # Handle single mask index.
  680. if len(indices) == 1:
  681. index = indices[0]
  682. if not symbolic_helper._is_none(index) and (
  683. index.type().scalarType() == "Bool" or index.type().scalarType() == "Byte"
  684. ):
  685. index = opset9.nonzero(g, index)
  686. return g.op("GatherND", self, index)
  687. return opset9.index(g, self, index)
  688. def index_fill(g, self, dim, index, value):
  689. dim_value = symbolic_helper._parse_arg(dim, "i")
  690. if symbolic_helper.is_caffe2_aten_fallback():
  691. return g.at(
  692. "index_fill",
  693. self,
  694. index,
  695. value,
  696. overload_name="int_Scalar",
  697. dim_i=dim_value,
  698. )
  699. expanded_index_shape, expanded_index = symbolic_helper._index_fill_reshape_helper(
  700. g, self, dim, index
  701. )
  702. value = symbolic_helper._maybe_get_scalar(value)
  703. value = symbolic_helper._if_scalar_type_as(g, value, self)
  704. expanded_value = opset9.expand(g, value, expanded_index_shape, None)
  705. return scatter(g, self, dim, expanded_index, expanded_value)
  706. def index_copy(g, self, dim, index, source):
  707. dim_value = symbolic_helper._parse_arg(dim, "i")
  708. if symbolic_helper.is_caffe2_aten_fallback():
  709. return g.at("index_copy", self, index, source, dim_i=dim_value)
  710. expanded_index_shape, expanded_index = symbolic_helper._index_fill_reshape_helper(
  711. g, self, dim, index
  712. )
  713. return scatter(g, self, dim, expanded_index, source)
  714. def __rshift_(g, self, other):
  715. # make sure to cast other to self's type
  716. # (when self is long, make sure that other is not float)
  717. if other.type().scalarType() != self.type().scalarType():
  718. other = g.op(
  719. "Cast",
  720. other,
  721. to_i=symbolic_helper.cast_pytorch_to_onnx[self.type().scalarType()],
  722. )
  723. if self.type().scalarType() == "Byte":
  724. return g.op("BitShift", self, other, direction_s="RIGHT")
  725. two = g.op("Constant", value_t=torch.tensor(2, dtype=torch.float32))
  726. # exponent (same type as self) has to be float or double in onnx::Pow
  727. if not symbolic_helper._is_fp(self):
  728. other = g.op("Cast", other, to_i=symbolic_helper.cast_pytorch_to_onnx["Float"])
  729. two_pow = g.op("Pow", two, other)
  730. two_pow = g.op(
  731. "Cast",
  732. two_pow,
  733. to_i=symbolic_helper.cast_pytorch_to_onnx[self.type().scalarType()],
  734. )
  735. rshift = g.op("Div", self, two_pow)
  736. return rshift
  737. def __lshift_(g, self, other):
  738. # make sure to cast other to self's type
  739. # (when self is long, make sure that other is not float)
  740. if other.type().scalarType() != self.type().scalarType():
  741. other = g.op(
  742. "Cast",
  743. other,
  744. to_i=symbolic_helper.cast_pytorch_to_onnx[self.type().scalarType()],
  745. )
  746. if self.type().scalarType() == "Byte":
  747. return g.op("BitShift", self, other, direction_s="LEFT")
  748. two = g.op("Constant", value_t=torch.tensor(2, dtype=torch.float32))
  749. # exponent (same type as self) has to be float or double in onnx::Pow
  750. if not symbolic_helper._is_fp(self):
  751. other = g.op("Cast", other, to_i=symbolic_helper.cast_pytorch_to_onnx["Float"])
  752. two_pow = g.op("Pow", two, other)
  753. two_pow = g.op(
  754. "Cast",
  755. two_pow,
  756. to_i=symbolic_helper.cast_pytorch_to_onnx[self.type().scalarType()],
  757. )
  758. lshift = g.op("Mul", self, two_pow)
  759. return lshift
  760. def _get_im2col_indices_along_dim(
  761. g, input_d, kernel_size_d, dilation_d, padding_d, stride_d
  762. ):
  763. # Input is always 4-D (N, C, H, W)
  764. # Calculate indices of sliding blocks along spatial dimension
  765. # Slide kernel over input each dim d:
  766. # each dimension d ranges from 0 to input[d]+2xpadding[d]-dilation[d]x(kernel_size[d]-1)
  767. # with steps = stride
  768. blocks_d = g.op(
  769. "Add", input_d, g.op("Constant", value_t=torch.tensor(padding_d * 2))
  770. )
  771. blocks_d = g.op(
  772. "Sub",
  773. blocks_d,
  774. g.op("Constant", value_t=torch.tensor(dilation_d * (kernel_size_d - 1))),
  775. )
  776. # Stride kernel over input and find starting indices along dim d
  777. blocks_d_indices = g.op(
  778. "Range",
  779. g.op("Constant", value_t=torch.tensor(0)),
  780. blocks_d,
  781. g.op("Constant", value_t=torch.tensor(stride_d)),
  782. )
  783. # Apply dilation on kernel and find its indices along dim d
  784. kernel_grid = torch.arange(0, kernel_size_d * dilation_d, dilation_d)
  785. kernel_grid = g.op("Constant", value_t=kernel_grid.unsqueeze(0))
  786. # Broadcast and add kernel staring positions (indices) with
  787. # kernel_grid along dim d, to get block indices along dim d
  788. blocks_d_indices = symbolic_helper._unsqueeze_helper(
  789. g, blocks_d_indices, [0]
  790. ) # Reshape to [1, -1]
  791. kernel_mask = symbolic_helper._reshape_helper(
  792. g, kernel_grid, g.op("Constant", value_t=torch.tensor([-1, 1]))
  793. )
  794. block_mask = g.op("Add", blocks_d_indices, kernel_mask)
  795. return block_mask
  796. def _get_im2col_padded_input(g, input, padding_h, padding_w):
  797. # Input is always 4-D tensor (N, C, H, W)
  798. # Padding tensor has the following format: (padding_h, padding_w)
  799. # Reshape the padding to follow ONNX format: (dim1_begin, dim2_begin,...,dim1_end, dim2_end,...)
  800. pad = g.op("Constant", value_t=torch.LongTensor([0, 0, padding_h, padding_w] * 2))
  801. return g.op("Pad", input, pad)
  802. def _get_im2col_output_shape(g, input, kernel_h, kernel_w):
  803. batch_dim = size(g, input, g.op("Constant", value_t=torch.tensor(0)))
  804. channel_dim = size(g, input, g.op("Constant", value_t=torch.tensor(1)))
  805. channel_unfolded = g.op(
  806. "Mul", channel_dim, g.op("Constant", value_t=torch.tensor(kernel_h * kernel_w))
  807. )
  808. return g.op(
  809. "Concat",
  810. symbolic_helper._unsqueeze_helper(g, batch_dim, [0]),
  811. symbolic_helper._unsqueeze_helper(g, channel_unfolded, [0]),
  812. g.op("Constant", value_t=torch.tensor([-1])),
  813. axis_i=0,
  814. )
  815. @symbolic_helper.parse_args("v", "is", "is", "is", "is")
  816. def im2col(g, input, kernel_size, dilation, padding, stride):
  817. # Input is always 4-D tensor (N, C, H, W)
  818. # All other args are int[2]
  819. input_h = size(g, input, g.op("Constant", value_t=torch.tensor(2)))
  820. input_w = size(g, input, g.op("Constant", value_t=torch.tensor(3)))
  821. stride_h, stride_w = stride[0], stride[1]
  822. padding_h, padding_w = padding[0], padding[1]
  823. dilation_h, dilation_w = dilation[0], dilation[1]
  824. kernel_h, kernel_w = kernel_size[0], kernel_size[1]
  825. blocks_row_indices = _get_im2col_indices_along_dim(
  826. g, input_h, kernel_h, dilation_h, padding_h, stride_h
  827. )
  828. blocks_col_indices = _get_im2col_indices_along_dim(
  829. g, input_w, kernel_w, dilation_w, padding_w, stride_w
  830. )
  831. output_shape = _get_im2col_output_shape(g, input, kernel_h, kernel_w)
  832. padded_input = _get_im2col_padded_input(g, input, padding_h, padding_w)
  833. # For a 4D matrix of size (1, 1, 3, 3) as below with kernel_size=2, stride=1, and dilation=1
  834. # [[[[1., 2., 3.,],
  835. # [4., 5., 6.,],
  836. # [7., 8., 9.,]]]]
  837. # First gather indices along rows (dim=2) with blocks_row_indices = [[0,1], [1,2]] to get:
  838. # [[[[[1., 2., 3.],
  839. # [4., 5., 6.]],
  840. # [[4., 5., 6.],
  841. # [7., 8., 9.]]]]]
  842. # And then gather along cols (dim=4) with blocks_row_indices = [[0,1], [1,2]] to get:
  843. # [[[[[[1., 2.],
  844. # [4., 5.]],
  845. # [[2., 3.],
  846. # [5., 6]]],
  847. # [[[4., 5.],
  848. # [7., 8.]],
  849. # [[5., 6.],
  850. # [8., 9.]]]]]]
  851. # Transpose dims 3 (depth) and 4 (rows), and then reshape to output shape (1, 1, 4, 4) to get:
  852. # [[[1., 2., 4., 5.],
  853. # [2., 3., 5., 6.],
  854. # [4., 5., 7., 8.],
  855. # [5., 6., 8., 9.]]]
  856. output = g.op("Gather", padded_input, blocks_row_indices, axis_i=2)
  857. output = g.op("Gather", output, blocks_col_indices, axis_i=4)
  858. output = g.op("Transpose", output, perm_i=[0, 1, 2, 4, 3, 5])
  859. return symbolic_helper._reshape_helper(g, output, output_shape)
  860. def narrow(g, input, dim, start, length):
  861. end = g.op("Add", start, length)
  862. return symbolic_helper._slice_helper(
  863. g, input, axes=dim, starts=start, ends=end, dynamic_slice=True
  864. )
  865. @symbolic_helper.quantized_args(True, False, False)
  866. @symbolic_helper.parse_args("v", "i", "i")
  867. def flatten(g, input, start_dim, end_dim):
  868. dim = symbolic_helper._get_tensor_rank(input)
  869. if dim == 1:
  870. return input
  871. # use ONNX's Flatten operator for cases where the output shape is 2D
  872. if start_dim == 1:
  873. if end_dim == -1 or (dim is not None and end_dim == dim - 1):
  874. return g.op("Flatten", input, axis_i=start_dim)
  875. elif start_dim == 0:
  876. if end_dim == -2 or (dim is not None and end_dim == dim - 2):
  877. return g.op("Flatten", input, axis_i=end_dim + 1)
  878. if dim is None:
  879. return symbolic_helper._unimplemented(
  880. "dim",
  881. "ONNX and PyTorch use different strategies to split the input. "
  882. "Input rank must be known at export time.",
  883. )
  884. # if end_dim is negative add dim
  885. if end_dim < 0:
  886. end_dim = dim + end_dim
  887. return symbolic_helper._flatten_helper(g, input, start_dim, end_dim, dim)
  888. @symbolic_helper.parse_args("v", "f", "is", "i", "v")
  889. def linalg_vector_norm(g, self, ord, dim, keepdim, dtype):
  890. if ord == 0:
  891. if dim is None:
  892. self = symbolic_helper._reshape_helper(
  893. g, self, g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64))
  894. )
  895. keepdim = None
  896. cond_op = g.op(
  897. "Not", g.op("Equal", self, g.op("Constant", value_t=torch.LongTensor([0])))
  898. )
  899. cond_op = g.op(
  900. "Cast", cond_op, to_i=symbolic_helper.cast_pytorch_to_onnx["Long"]
  901. )
  902. return symbolic_helper._reducesum_helper(
  903. g, cond_op, axes_i=dim, keepdims_i=keepdim
  904. )
  905. else:
  906. return opset9.linalg_vector_norm(g, self, ord, dim, keepdim, dtype)
  907. @symbolic_helper.parse_args("v", "v", "v", "i", "i", "i", "v", "i", "i")
  908. def embedding_bag(
  909. g,
  910. embedding_matrix,
  911. indices,
  912. offsets,
  913. scale_grad_by_freq,
  914. mode,
  915. sparse,
  916. per_sample_weights,
  917. include_last_offset,
  918. padding_idx,
  919. ):
  920. if scale_grad_by_freq and GLOBALS.training_mode:
  921. return symbolic_helper._onnx_unsupported(
  922. "embedding_bag with scale_grad_by_freq for training mode"
  923. )
  924. if padding_idx is not None and padding_idx >= 0:
  925. raise RuntimeError("embedding_bag with padding_idx")
  926. loop_condition = g.op("Constant", value_t=torch.tensor(1))
  927. loop_condition = g.op("Cast", loop_condition, to_i=9)
  928. zero = g.op("Constant", value_t=torch.tensor([0]))
  929. indices_len = symbolic_helper._unsqueeze_helper(
  930. g,
  931. symbolic_helper._size_helper(
  932. g, indices, g.op("Constant", value_t=torch.tensor(0))
  933. ),
  934. [0],
  935. )
  936. if not include_last_offset:
  937. offsets = [offsets, indices_len]
  938. offsets = g.op("Concat", *offsets, axis_i=0)
  939. # Offsets holds the starting index position of each bag. So we create a list of the indices slices (determined by
  940. # offsets) and gather those indices in indices_row. Then we use this subset of indices to gather from embeddings.
  941. # The embeddings output is a loop scan output, so we can avoid creating a sequence and inserting elements in.
  942. offsets_starts = symbolic_helper._slice_helper(
  943. g, offsets, axes=[0], starts=[0], ends=[sys.maxsize], steps=[1]
  944. )
  945. offsets_ends = symbolic_helper._slice_helper(
  946. g, offsets, axes=[0], starts=[1], ends=[sys.maxsize], steps=[1]
  947. )
  948. loop_len = symbolic_helper._size_helper(
  949. g, offsets_ends, g.op("Constant", value_t=torch.tensor(0))
  950. )
  951. loop = g.op("Loop", loop_len, loop_condition)
  952. loop_block = utils._add_block(loop.node())
  953. block_input_iter = utils._add_input_to_block(loop_block)
  954. cond = utils._add_input_to_block(loop_block)
  955. indices_start = loop_block.op("Gather", offsets_starts, block_input_iter, axis_i=0)
  956. indices_end = loop_block.op("Gather", offsets_ends, block_input_iter, axis_i=0)
  957. indices_start = symbolic_helper._unsqueeze_helper(loop_block, indices_start, [0])
  958. indices_end = symbolic_helper._unsqueeze_helper(loop_block, indices_end, [0])
  959. indices_row = loop_block.op("Slice", indices, indices_start, indices_end, zero)
  960. embeddings = loop_block.op("Gather", embedding_matrix, indices_row, axis_i=0)
  961. if not symbolic_helper._is_none(per_sample_weights):
  962. per_sample_weights_row = loop_block.op(
  963. "Slice", per_sample_weights, indices_start, indices_end, zero
  964. )
  965. per_sample_weights_row = symbolic_helper._unsqueeze_helper(
  966. loop_block, per_sample_weights_row, [1]
  967. )
  968. embeddings = loop_block.op("Mul", embeddings, per_sample_weights_row)
  969. if mode == 0:
  970. embeddings = symbolic_helper._reducesum_helper(
  971. loop_block, embeddings, axes_i=[0], keepdims_i=0
  972. )
  973. elif mode == 1:
  974. embeddings = loop_block.op("ReduceMean", embeddings, axes_i=[0], keepdims_i=0)
  975. else:
  976. embeddings = loop_block.op("ReduceMax", embeddings, axes_i=[0], keepdims_i=0)
  977. cond_out = loop_block.op("Cast", loop_condition, to_i=9)
  978. utils._add_output_to_block(loop_block, cond_out)
  979. utils._add_output_to_block(loop_block, embeddings)
  980. # aten::embedding_bag returns a tuple of 4 elements: output, offset2bag, bag_size, max_indices.
  981. # But the last three outputs are not used in torch.nn.EmbeddingBag or torch.nn.functional.embedding_bag.
  982. return loop.node().output(), None, None, None
  983. @symbolic_helper.parse_args("v", "v", "f", "f")
  984. def embedding_renorm(g, weight, indices, max_norm, norm_type):
  985. unique_indices = g.op("Unique", indices)
  986. partial_weight = g.op("Gather", weight, unique_indices)
  987. norm_type = int(norm_type)
  988. if norm_type == 1:
  989. norm_type = "ReduceL1"
  990. elif norm_type == 2:
  991. norm_type = "ReduceL2"
  992. else:
  993. raise RuntimeError(
  994. f"Unsupported: ONNX export of embedding_renorm with norm: {norm_type}. "
  995. "Only 1. and 2. are supported."
  996. )
  997. partial_weight_norm = g.op(norm_type, partial_weight, axes_i=[1], keepdims_i=1)
  998. # https://github.com/pytorch/pytorch/blob/0a07488ed2c47765e337e290bd138c0e6e459cbd/aten/src/ATen/native/Embedding.cpp#L177
  999. # Add 1e-7 to prevent division by zero.
  1000. partial_weight_norm_ = g.op(
  1001. "Add", partial_weight_norm, g.op("Constant", value_t=torch.tensor(1e-7))
  1002. )
  1003. max_norm = torch.tensor(max_norm)
  1004. scales = g.op("Div", max_norm, partial_weight_norm_)
  1005. partial_weight_renorm = g.op("Mul", partial_weight, scales)
  1006. partial_weight_renorm = g.op(
  1007. "Where",
  1008. g.op("Greater", partial_weight_norm, max_norm),
  1009. partial_weight_renorm,
  1010. partial_weight,
  1011. )
  1012. return g.op(
  1013. "ScatterND",
  1014. weight,
  1015. symbolic_helper._unsqueeze_helper(g, unique_indices, [1]),
  1016. partial_weight_renorm,
  1017. )
  1018. def chunk(g, self, chunks, dim):
  1019. # Calculate chunk size for dynamic chunk
  1020. dim_size = g.op("Gather", g.op("Shape", self), dim, axis_i=0)
  1021. chunk_size_s = g.op(
  1022. "Sub", chunks, g.op("Constant", value_t=torch.tensor([1], dtype=torch.long))
  1023. )
  1024. chunk_size = g.op("Div", g.op("Add", dim_size, chunk_size_s), chunks)
  1025. # Create splits vector
  1026. chunk_vec = [
  1027. opset9.expand(g, chunk_size, chunk_size_s, None),
  1028. g.op("Sub", dim_size, g.op("Mul", chunk_size, chunk_size_s)),
  1029. ]
  1030. chunk_vec = g.op("Concat", *chunk_vec, axis_i=0)
  1031. return split(g, self, chunk_vec, dim)
  1032. def normal(g, loc, scale, seed):
  1033. # If you can sample from a given distribution with mean 0 and variance 1, then you can easily sample from a
  1034. # scale-location transformation of that distribution, which has mean μ and variance σ's square. If x is a sample
  1035. # from a mean 0 and variance 1 distribution then
  1036. # σx+μ
  1037. # is a sample with mean μ and variance σ's square.
  1038. result = opset9.mul(g, scale, g.op("RandomNormalLike", loc))
  1039. return add(g, result, loc)
  1040. class Prim:
  1041. domain = "prim"
  1042. @staticmethod
  1043. def ConstantChunk(g, self, chunks, dim):
  1044. input_shape = g.op("Shape", self)
  1045. axis = g.op("Constant", value_t=torch.tensor([dim], dtype=torch.long))
  1046. input_shape_dim = g.op("Gather", input_shape, axis, axis_i=0)
  1047. start = g.op("Constant", value_t=torch.tensor([0], dtype=torch.long))
  1048. chunk_size = g.op("Constant", value_t=torch.tensor([chunks], dtype=torch.long))
  1049. chunk_size_minus_1 = g.op(
  1050. "Constant", value_t=torch.tensor([chunks - 1], dtype=torch.long)
  1051. )
  1052. input_shape_dim_shift = g.op("Add", input_shape_dim, chunk_size_minus_1)
  1053. chunk_dim = g.op("Div", input_shape_dim_shift, chunk_size)
  1054. res = []
  1055. for i in range(chunks):
  1056. index = g.op("Constant", value_t=torch.tensor([i + 1], dtype=torch.long))
  1057. end = g.op("Mul", chunk_dim, index)
  1058. res.append(g.op("Slice", self, start, end, axis))
  1059. start = end
  1060. return res