symbolic_opset16.py 1.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. """This file exports ONNX ops for opset 16.
  2. Note [ONNX Operators that are added/updated in opset 16]
  3. ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  4. https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-16-of-the-default-onnx-operator-set
  5. New operators:
  6. GridSample https://github.com/onnx/onnx/pull/3557
  7. Updated operators:
  8. Identity
  9. If
  10. LeakyRelu
  11. Loop
  12. PRelu
  13. RoiAlign
  14. Scan
  15. ScatterElemenets
  16. ScatterND
  17. Where
  18. GreaterOrEqual
  19. LessOrEqual
  20. SequenceMap
  21. """
  22. # EDITING THIS FILE? READ THIS FIRST!
  23. # see Note [Edit Symbolic Files] in symbolic_helper.py
  24. from torch.nn.functional import (
  25. GRID_SAMPLE_INTERPOLATION_MODES,
  26. GRID_SAMPLE_PADDING_MODES,
  27. )
  28. from torch.onnx import symbolic_helper
  29. # note (mkozuki): Why `grid_sampler` instead of `grid_sample`?
  30. # Because `torch.nn.functional.grid_sample` calls `torch.grid_sampler`.
  31. @symbolic_helper.parse_args("v", "v", "i", "i", "b")
  32. def grid_sampler(g, input, grid, mode_enum, padding_mode_enum, align_corners):
  33. mode_s = {v: k for k, v in GRID_SAMPLE_INTERPOLATION_MODES.items()}[mode_enum] # type: ignore[call-arg]
  34. padding_mode_s = {v: k for k, v in GRID_SAMPLE_PADDING_MODES.items()}[padding_mode_enum] # type: ignore[call-arg]
  35. return g.op(
  36. "GridSample",
  37. input,
  38. grid,
  39. align_corners_i=int(align_corners),
  40. mode_s=mode_s,
  41. padding_mode_s=padding_mode_s,
  42. )