symbolic_opset7.py 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. """
  2. Note [ONNX operators that are added/updated from opset 7 to opset 8]
  3. ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  4. New operators:
  5. Expand
  6. Updated operators:
  7. Min, Max, Sum, Mean: supports multidirectional broadcasting.
  8. MaxPool: added optional indices output.
  9. Scan
  10. """
  11. import warnings
  12. from torch.onnx import symbolic_helper
  13. from torch.onnx import symbolic_opset9 as opset9
  14. block_listed_operators = [
  15. "scan",
  16. "expand",
  17. "expand_as",
  18. "meshgrid",
  19. "adaptive_max_pool1d",
  20. "adaptive_max_pool2d",
  21. "adaptive_max_pool3d",
  22. "max_pool1d_with_indices",
  23. "max_pool2d_with_indices",
  24. "max_pool3d_with_indices",
  25. ]
  26. # NOTE: max, min, sum, mean: broadcasting is not supported in opset 7.
  27. # torch.max (same for torch.min) actually has two interfaces smashed together:
  28. # torch.max(x, dim, keepdim) and torch.max(x, y)
  29. def max(g, self, dim_or_y=None, keepdim=None):
  30. # torch.max(input, other)
  31. if keepdim is None and dim_or_y is not None:
  32. warnings.warn(
  33. "Multidirectional broadcasting is not supported in opset 7. "
  34. "This might cause the onnx model to be incorrect, if inputs to max operators "
  35. "have different shapes"
  36. )
  37. return opset9.max(g, self, dim_or_y, keepdim)
  38. def min(g, self, dim_or_y=None, keepdim=None):
  39. # torch.min(input, other)
  40. if keepdim is None and dim_or_y is not None:
  41. warnings.warn(
  42. "Multidirectional broadcasting is not supported in opset 7. "
  43. "This might cause the onnx model to be incorrect, if inputs to min operators "
  44. "have different shapes"
  45. )
  46. return opset9.min(g, self, dim_or_y, keepdim)
  47. for block_listed_op in block_listed_operators:
  48. vars()[block_listed_op] = symbolic_helper._block_list_in_opset(block_listed_op)