symbolic_opset14.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. """This file exports ONNX ops for opset 14.
  2. Note [ONNX operators that are added/updated in opset 14]
  3. ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  4. New operators:
  5. HardSwish, Trilu
  6. Updated operators:
  7. Reshape
  8. Add, Sub, Mul, Div
  9. GRU, LSTM, RNN
  10. BatchNorm, Cumsum, Relu
  11. """
  12. # EDITING THIS FILE? READ THIS FIRST!
  13. # see Note [Edit Symbolic Files] in symbolic_helper.py
  14. import torch
  15. from torch.onnx import symbolic_helper
  16. from torch.onnx._globals import GLOBALS
  17. @symbolic_helper.parse_args("v")
  18. def hardswish(g, self):
  19. return g.op("HardSwish", self)
  20. @symbolic_helper.parse_args("v", "i")
  21. def tril(g, self, diagonal, out=None):
  22. k = g.op("Constant", value_t=torch.tensor(diagonal, dtype=torch.int64))
  23. return g.op("Trilu", self, k, upper_i=0)
  24. @symbolic_helper.parse_args("v", "i")
  25. def triu(g, self, diagonal, out=None):
  26. k = g.op("Constant", value_t=torch.tensor(diagonal, dtype=torch.int64))
  27. return g.op("Trilu", self, k, upper_i=1)
  28. @symbolic_helper.parse_args("v", "v")
  29. def reshape(g, self, shape):
  30. # NOTE: Due to bug in ORT https://github.com/microsoft/onnxruntime/issues/10664
  31. # Reshape export cannot utilize the new allowzero attribute introduced in opset 14.
  32. return symbolic_helper._reshape_helper(g, self, shape, allowzero=0)
  33. @symbolic_helper.parse_args("v", "v", "v", "v", "v", "i", "f", "f", "i")
  34. def batch_norm(
  35. g,
  36. input,
  37. weight,
  38. bias,
  39. running_mean,
  40. running_var,
  41. training,
  42. momentum,
  43. eps,
  44. cudnn_enabled,
  45. ):
  46. if (
  47. torch.is_autocast_enabled()
  48. and not symbolic_helper.args_have_same_dtype(
  49. [input, weight, bias, running_mean, running_var]
  50. )
  51. and GLOBALS.export_onnx_opset_version < 15
  52. ):
  53. return symbolic_helper._onnx_opset_unsupported_detailed(
  54. "BatchNormalization",
  55. 14,
  56. 15,
  57. "All input tensors must have the same `dtype`."
  58. " Turn off Autocast or export using opset version 15.",
  59. )
  60. symbolic_helper.check_training_mode(training, "batch_norm")
  61. weight, bias, running_mean, running_var = symbolic_helper._batchnorm_helper(
  62. g, input, weight, bias, running_mean, running_var
  63. )
  64. out = g.op(
  65. "BatchNormalization",
  66. input,
  67. weight,
  68. bias,
  69. running_mean,
  70. running_var,
  71. epsilon_f=eps,
  72. momentum_f=1 - momentum,
  73. training_mode_i=0 if not training else 1,
  74. outputs=1 if not training else 3,
  75. )
  76. if not training:
  77. return out
  78. else:
  79. res, new_running_mean, new_running_var = out
  80. new_running_mean.setType(running_mean.type())
  81. new_running_var.setType(running_var.type())
  82. return res
  83. class Quantized:
  84. """
  85. https://github.com/pytorch/pytorch/wiki/PyTorch-ONNX-exporter#quantized-model-export
  86. """
  87. domain = "quantized"
  88. @staticmethod
  89. def hardswish(g, x, op_scale, op_zero_point):
  90. x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
  91. output = hardswish(g, x)
  92. return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)