symbolic_opset8.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353
  1. """
  2. Note [ONNX operators that are added/updated from opset 8 to opset 9]
  3. ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  4. New operators:
  5. Compress
  6. ConstantOfShape
  7. EyeLike
  8. MaxUnpool
  9. OneHot
  10. Sinh
  11. Cosh
  12. Asinh
  13. Acosh
  14. Atanh
  15. Shrink
  16. IsNaN
  17. Sign
  18. Erf
  19. Scatter
  20. Where
  21. NonZero
  22. TfIdfVectorizer
  23. MeanVarianceNormalization
  24. Updated operators:
  25. BatchNormalization: removed spatial attribute.
  26. Greater, Less, Constant, MatMul, PRelu, Gemm, Flatten: more data types{integers} supported.
  27. Cast: more data types{string} supported.
  28. Upsample: moved scales from attribute to input.
  29. Scan
  30. """
  31. import warnings
  32. import torch
  33. from torch.onnx import symbolic_helper
  34. from torch.onnx import symbolic_opset9 as opset9
  35. block_listed_operators = [
  36. "nonzero",
  37. "where",
  38. "scatter",
  39. "scatter_add",
  40. "erf",
  41. "sign",
  42. "isnan",
  43. "gather",
  44. "arange",
  45. "masked_fill",
  46. "index_fill",
  47. "index_copy",
  48. "repeat_interleave",
  49. "isnan",
  50. "any",
  51. "all",
  52. ]
  53. for block_listed_op in block_listed_operators:
  54. vars()[block_listed_op] = symbolic_helper._block_list_in_opset(block_listed_op)
  55. def _interpolate(name, dim, interpolate_mode):
  56. def symbolic_fn(g, input, output_size, *args):
  57. scales, align_corners = symbolic_helper._get_interpolate_attributes(
  58. g, interpolate_mode, args
  59. )
  60. symbolic_helper._interpolate_warning(interpolate_mode)
  61. align_corners = symbolic_helper._maybe_get_scalar(align_corners)
  62. if align_corners:
  63. return symbolic_helper._unimplemented(name, "align_corners == True")
  64. output_size = symbolic_helper._maybe_get_const(output_size, "is")
  65. if symbolic_helper._is_value(output_size):
  66. return symbolic_helper._unimplemented(
  67. name, "torch._C.Value (output_size) indexing"
  68. )
  69. if scales is None:
  70. scales = [
  71. 1.0
  72. if i < 2
  73. else float(output_size[-(dim - i)])
  74. / float(input.type().sizes()[-(dim - i)])
  75. for i in range(0, dim)
  76. ]
  77. return g.op("Upsample", input, mode_s=interpolate_mode, scales_f=scales)
  78. return symbolic_fn
  79. upsample_nearest1d = _interpolate("upsample_nearest1d", 3, "nearest")
  80. upsample_nearest2d = _interpolate("upsample_nearest2d", 4, "nearest")
  81. upsample_nearest3d = _interpolate("upsample_nearest3d", 5, "nearest")
  82. upsample_linear1d = _interpolate("upsample_linear1d", 3, "linear")
  83. upsample_bilinear2d = _interpolate("upsample_bilinear2d", 4, "linear")
  84. upsample_trilinear3d = _interpolate("upsample_trilinear3d", 5, "linear")
  85. def __interpolate(
  86. g, input, size, scale_factor, mode, align_corners, recompute_scale_factor, antialias
  87. ):
  88. align_corners = symbolic_helper._maybe_get_const(align_corners, "b")
  89. if not symbolic_helper._is_none(align_corners) and align_corners:
  90. return symbolic_helper._unimplemented("interpolate", "align_corners == True")
  91. if not symbolic_helper._is_none(scale_factor) and symbolic_helper._is_value(
  92. scale_factor
  93. ):
  94. return symbolic_helper._unimplemented(
  95. "interpolate", "dynamic scales in opset 8"
  96. )
  97. if not symbolic_helper._is_none(size) and symbolic_helper._is_value(size):
  98. return symbolic_helper._unimplemented("interpolate", "dynamic size in opset 8")
  99. scales, mode = symbolic_helper._interpolate_get_scales_and_mode(
  100. g, input, size, scale_factor, mode, align_corners
  101. )
  102. return g.op("Upsample", input, mode_s=mode, scales_f=scales)
  103. # NOTE: We should create a wrapper for this kind of operation, after resolving the shape/type propagation
  104. # issue for "cast" operators. Some symbolic functions depend on shape information of input tensor, which
  105. # is lost after casting.
  106. def _try_cast_integer_to_float(g, *args):
  107. floating_scalar_types = ["Half", "Float", "Double"]
  108. old_type = None
  109. # Cast the input tensor to Float if its scalarType is known and is not floating number.
  110. # If casting is performed, return the old scalarType, otherwise return None.
  111. arg0_type = args[0].type().scalarType()
  112. if arg0_type is not None:
  113. old_type = arg0_type
  114. if old_type not in floating_scalar_types:
  115. # TODO(justinchuby): Remove the type ignore hint once _cast_Float is
  116. # properly defined.
  117. # NOTE: _cast_Float is generated programmatically so we need to make the
  118. # type checker happy with ignore[attr-defined].
  119. args = tuple(opset9._cast_Float(g, arg, False) for arg in args) # type: ignore[attr-defined]
  120. else:
  121. return (None,) + args
  122. else:
  123. warnings.warn(
  124. "Only floating datatype is supported for these operators: "
  125. "{Greater, Less, MatMul, PRelu, Gemm, Flatten}. This might cause "
  126. "the onnx model to be incorrect, if inputs have integer datatypes."
  127. )
  128. return (old_type,) + args
  129. def _cast_to_type(g, input, to_type):
  130. if to_type is None:
  131. return input
  132. return getattr(opset9, "_cast_{}".format(to_type))(g, input, False)
  133. def _comparison_operator(g, input, other, op_name):
  134. other = symbolic_helper._maybe_get_scalar(other)
  135. other = symbolic_helper._if_scalar_type_as(g, other, input)
  136. _, input, other = _try_cast_integer_to_float(g, input, other)
  137. return g.op(op_name, input, other)
  138. # NOTE: For symbolics {gt, lt, bmm, matmul, prelu, mm, addmm, view, flatten},
  139. # integer input type not supported in opset8. Cast to float if possible.
  140. def gt(g, input, other):
  141. return _comparison_operator(g, input, other, "Greater")
  142. def lt(g, input, other):
  143. return _comparison_operator(g, input, other, "Less")
  144. def bmm(g, self, other):
  145. if symbolic_helper._try_get_scalar_type(self):
  146. old_type, self, other = _try_cast_integer_to_float(g, self, other)
  147. return _cast_to_type(g, g.op("MatMul", self, other), old_type)
  148. else:
  149. return g.op("MatMul", self, other)
  150. def matmul(g, self, other):
  151. return bmm(g, self, other)
  152. def prelu(g, self, weight):
  153. self_rank = symbolic_helper._get_tensor_rank(self)
  154. if self_rank is not None and self_rank > 2:
  155. weight = g.op("Unsqueeze", weight, axes_i=list(range(1, self_rank - 1)))
  156. if symbolic_helper._try_get_scalar_type(self):
  157. old_type, self, weight = _try_cast_integer_to_float(g, self, weight)
  158. return _cast_to_type(g, g.op("PRelu", self, weight), old_type)
  159. else:
  160. return g.op("PRelu", self, weight)
  161. def mm(g, self, other):
  162. # Create a dummy C tensor. Only needed for API purposes, the value is
  163. # since beta = 0
  164. ty = symbolic_helper._try_get_scalar_type(self, other).lower()
  165. C = g.constant(0, [1], ty)
  166. if symbolic_helper._try_get_scalar_type(self):
  167. old_type, self, other, C = _try_cast_integer_to_float(g, self, other, C)
  168. return _cast_to_type(
  169. g, g.op("Gemm", self, other, C, beta_f=0.0, alpha_f=1.0), old_type
  170. )
  171. else:
  172. return g.op("Gemm", self, other, C, beta_f=0.0, alpha_f=1.0)
  173. @symbolic_helper.parse_args("v", "v", "v", "t", "t")
  174. def addmm(g, self, mat1, mat2, beta, alpha):
  175. if symbolic_helper._try_get_scalar_type(self):
  176. old_type, self, mat1, mat2 = _try_cast_integer_to_float(g, self, mat1, mat2)
  177. return _cast_to_type(
  178. g,
  179. g.op(
  180. "Gemm",
  181. mat1,
  182. mat2,
  183. self,
  184. beta_f=symbolic_helper._scalar(beta),
  185. alpha_f=symbolic_helper._scalar(alpha),
  186. ),
  187. old_type,
  188. )
  189. else:
  190. return g.op(
  191. "Gemm",
  192. mat1,
  193. mat2,
  194. self,
  195. beta_f=symbolic_helper._scalar(beta),
  196. alpha_f=symbolic_helper._scalar(alpha),
  197. )
  198. def flatten(g, input, start_dim, end_dim):
  199. start_dim_i = symbolic_helper._get_const(start_dim, "i", "start_dim")
  200. end_dim_i = symbolic_helper._get_const(end_dim, "i", "end_dim")
  201. dim = input.type().dim()
  202. if end_dim_i < 0:
  203. end_dim_i = dim + end_dim_i
  204. # use ONNX's Flatten operator for cases where the output shape is 2D
  205. if start_dim_i == 1 and end_dim_i == dim - 1:
  206. if symbolic_helper._try_get_scalar_type(input):
  207. old_type, input = _try_cast_integer_to_float(g, input)
  208. return _cast_to_type(
  209. g, g.op("Flatten", input, axis_i=start_dim_i), old_type
  210. )
  211. else:
  212. return g.op("Flatten", input, axis_i=start_dim_i)
  213. if start_dim_i == 0 and end_dim_i == dim - 2:
  214. if symbolic_helper._try_get_scalar_type(input):
  215. old_type, input = _try_cast_integer_to_float(g, input)
  216. return _cast_to_type(
  217. g, g.op("Flatten", input, axis_i=end_dim_i + 1), old_type
  218. )
  219. else:
  220. return g.op("Flatten", input, axis_i=end_dim_i + 1)
  221. return opset9.flatten(g, input, start_dim, end_dim)
  222. def _constant_fill(g, sizes, dtype, const_value):
  223. if dtype is None:
  224. dtype = symbolic_helper.ScalarType.FLOAT
  225. if not symbolic_helper.scalar_type_to_pytorch_type[dtype].is_floating_point:
  226. result = g.op(
  227. "ConstantFill",
  228. sizes,
  229. dtype_i=symbolic_helper.cast_pytorch_to_onnx["Float"],
  230. input_as_shape_i=1,
  231. value_f=const_value,
  232. )
  233. return symbolic_helper._cast_func_template(
  234. symbolic_helper.scalar_type_to_onnx[dtype], g, result, None
  235. )
  236. else:
  237. return g.op(
  238. "ConstantFill",
  239. sizes,
  240. dtype_i=symbolic_helper.scalar_type_to_onnx[dtype],
  241. input_as_shape_i=1,
  242. value_f=const_value,
  243. )
  244. @symbolic_helper.parse_args("v", "i", "v", "v", "v", "v")
  245. def empty(g, sizes, dtype, layout, device, pin_memory=False, memory_format=None):
  246. return zeros(g, sizes, dtype, layout, device, pin_memory)
  247. @symbolic_helper.parse_args("v", "i", "v", "v", "v", "v")
  248. def empty_like(g, input, dtype, layout, device, pin_memory=False, memory_format=None):
  249. return zeros_like(g, input, dtype, layout, device, pin_memory)
  250. @symbolic_helper.parse_args("v", "i", "v", "v", "v")
  251. def zeros(g, sizes, dtype, layout, device, pin_memory=False):
  252. # NOTE: no way to set device and layout in ONNX, so we ignore it
  253. return _constant_fill(g, sizes, dtype, 0)
  254. @symbolic_helper.parse_args("v", "i", "v", "v", "v", "v")
  255. def zeros_like(g, input, dtype, layout, device, pin_memory=False, memory_format=None):
  256. shape = g.op("Shape", input)
  257. return _constant_fill(g, shape, dtype, 0)
  258. @symbolic_helper.parse_args("v", "i", "v", "v", "v")
  259. def ones(g, sizes, dtype, layout, device, pin_memory=False):
  260. return _constant_fill(g, sizes, dtype, 1)
  261. @symbolic_helper.parse_args("v", "i", "v", "v", "v", "v")
  262. def ones_like(g, input, dtype, layout, device, pin_memory=False, memory_format=None):
  263. shape = g.op("Shape", input)
  264. return _constant_fill(g, shape, dtype, 1)
  265. def full(g, sizes, value, dtype, layout, device, pin_memory=False):
  266. const_value = symbolic_helper._maybe_get_const(value, "t")
  267. if symbolic_helper._is_value(const_value):
  268. tmp = zeros(g, sizes, dtype, layout, device)
  269. return opset9.add(g, tmp, value, g.op("Constant", value_t=torch.tensor(1)))
  270. else:
  271. dtype = symbolic_helper._get_const(dtype, "i", "dtype")
  272. return _constant_fill(g, sizes, dtype, const_value)
  273. @symbolic_helper.parse_args("v", "f", "i", "v", "v", "v", "v")
  274. def full_like(
  275. g, input, fill_value, dtype, layout, device, pin_memory=False, memory_format=None
  276. ):
  277. shape = g.op("Shape", input)
  278. return _constant_fill(g, shape, dtype, fill_value)
  279. def repeat(g, self, repeats):
  280. if not symbolic_helper._is_value(repeats):
  281. repeats = g.op("Constant", value_t=torch.LongTensor(repeats))
  282. if symbolic_helper._is_packed_list(repeats):
  283. repeat_size_len = len(symbolic_helper._unpack_list(repeats))
  284. else:
  285. const_repeats = symbolic_helper._maybe_get_const(repeats, "is")
  286. repeat_size_len = len(const_repeats)
  287. if self.isCompleteTensor():
  288. sizes = self.type().sizes()
  289. diff_dims = repeat_size_len - len(sizes)
  290. if diff_dims > 0:
  291. self = opset9.view(
  292. g, self, g.op("Constant", value_t=torch.tensor([1] * diff_dims + sizes))
  293. )
  294. return g.op("Tile", self, repeats)