symbolic_opset15.py 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. """This file exports ONNX ops for opset 15.
  2. Note [ONNX operators that are added/updated in opset 15]
  3. ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  4. https://github.com/onnx/onnx/blob/master/docs/Changelog.md#version-15-of-the-default-onnx-operator-set
  5. New operators:
  6. Bernoulli
  7. CastLike
  8. Optional
  9. OptionalGetElement
  10. OptionalHasElement
  11. Updated operators:
  12. BatchNormalization https://github.com/onnx/onnx/pull/3545
  13. Backwards compatible
  14. TODO: test coverage for mixed types inputs.
  15. Pow https://github.com/onnx/onnx/pull/3412
  16. Backwards compatible
  17. TODO: bfloat16 support.
  18. Shape https://github.com/onnx/onnx/pull/3580
  19. Backwards compatible
  20. TODO: optional start/end attribute.
  21. """
  22. # EDITING THIS FILE? READ THIS FIRST!
  23. # see Note [Edit Symbolic Files] in symbolic_helper.py
  24. import torch
  25. from torch import _C
  26. from torch.onnx import symbolic_helper
  27. from torch.onnx import symbolic_opset9 as opset9
  28. def __is_(g, self, other):
  29. if symbolic_helper._is_none(other):
  30. if isinstance(self.type(), _C.OptionalType):
  31. none = g.op("OptionalHasElement", self)
  32. return g.op("Not", none)
  33. else:
  34. return g.op("Constant", value_t=torch.BoolTensor([0]))
  35. return opset9.eq(g, self, other)
  36. @opset9.wrap_logical_op_with_negation
  37. def __isnot_(g, self, other):
  38. return __is_(g, self, other)
  39. class Prim:
  40. domain = "prim"
  41. @staticmethod
  42. def unchecked_cast(g, self):
  43. # exists to refine the type of the Value
  44. # if x is Optional[Tensor], unchecked_cast will cast
  45. # x to Tensor, so the rest of the graph knows that x is a Tensor.
  46. if isinstance(self.type(), _C.OptionalType):
  47. return g.op("OptionalGetElement", self)
  48. return self