_builtins.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. import math
  2. import cmath
  3. import warnings
  4. import torch
  5. import torch.backends.cudnn as cudnn
  6. from ..nn.modules.utils import _single, _pair, _triple, _quadruple, _list_with_default
  7. from collections import OrderedDict
  8. from typing import Dict, Optional
  9. _builtin_table: Optional[Dict[int, str]] = None
  10. _modules_containing_builtins = (torch, torch._C._nn, torch._C._fft, torch._C._linalg, torch._C._sparse, torch._C._special) # type: ignore[attr-defined] # noqa: B950
  11. _builtin_ops = [
  12. # Pairs of (function, op_name)
  13. (_pair, "aten::_pair"),
  14. (_quadruple, "aten::_quadruple"),
  15. (_single, "aten::_single"),
  16. (_triple, "aten::_triple"),
  17. (_list_with_default, "aten::list_with_default"),
  18. (OrderedDict, "aten::dict"),
  19. (dict, "aten::dict"),
  20. (cudnn.is_acceptable, "aten::cudnn_is_acceptable"),
  21. (math.ceil, "aten::ceil"),
  22. (math.copysign, "aten::copysign"),
  23. (math.erf, "aten::erf"),
  24. (math.erfc, "aten::erfc"),
  25. (math.exp, "aten::exp"),
  26. (math.expm1, "aten::expm1"),
  27. (math.fabs, "aten::fabs"),
  28. (math.floor, "aten::floor"),
  29. (math.gamma, "aten::gamma"),
  30. (math.lgamma, "aten::lgamma"),
  31. (math.log, "aten::log"),
  32. (math.log10, "aten::log10"),
  33. (math.log1p, "aten::log1p"),
  34. (math.pow, "aten::pow"),
  35. (math.sqrt, "aten::sqrt"),
  36. (math.isnan, "aten::isnan"),
  37. (math.asinh, "aten::asinh"),
  38. (math.atanh, "aten::atanh"),
  39. (math.cosh, "aten::cosh"),
  40. (math.sinh, "aten::sinh"),
  41. (math.tanh, "aten::tanh"),
  42. (math.acos, "aten::acos"),
  43. (math.asin, "aten::asin"),
  44. (math.atan, "aten::atan"),
  45. (math.atan2, "aten::atan2"),
  46. (math.cos, "aten::cos"),
  47. (math.sin, "aten::sin"),
  48. (math.tan, "aten::tan"),
  49. (math.asinh, "aten::asinh"),
  50. (math.atanh, "aten::atanh"),
  51. (math.acosh, "aten::acosh"),
  52. (math.fmod, "aten::fmod"),
  53. (math.modf, "aten::modf"),
  54. (math.factorial, "aten::factorial"),
  55. (math.frexp, "aten::frexp"),
  56. (math.isinf, "aten::isinf"),
  57. (math.degrees, "aten::degrees"),
  58. (math.radians, "aten::radians"),
  59. (cmath.isnan, "aten::isnan"),
  60. (cmath.isfinite, "aten::isfinite"),
  61. (cmath.isinf, "aten::isinf"),
  62. (cmath.phase, "aten::angle"),
  63. (cmath.rect, "aten::polar"),
  64. (cmath.log, "aten::log"),
  65. (cmath.log10, "aten::log10"),
  66. (cmath.sqrt, "aten::sqrt"),
  67. (cmath.exp, "aten::exp"),
  68. (cmath.sin, "aten::sin"),
  69. (cmath.tan, "aten::tan"),
  70. (cmath.cos, "aten::cos"),
  71. (cmath.asin, "aten::asin"),
  72. (cmath.acos, "aten::acos"),
  73. (cmath.atan, "aten::atan"),
  74. (cmath.sinh, "aten::sinh"),
  75. (cmath.cosh, "aten::cosh"),
  76. (cmath.tanh, "aten::tanh"),
  77. (cmath.asinh, "aten::asinh"),
  78. (cmath.acosh, "aten::acosh"),
  79. (cmath.atanh, "aten::atanh"),
  80. (math.ldexp, "aten::ldexp"),
  81. (torch._assert, "aten::_assert"),
  82. (torch.autograd.grad, "aten::grad"),
  83. (torch.autograd.backward, "aten::backward"),
  84. (torch._C._infer_size, "aten::_infer_size"),
  85. (torch.nn.functional._no_grad_embedding_renorm_, "aten::_no_grad_embedding_renorm_"), # type: ignore[attr-defined]
  86. (torch.nn.functional.assert_int_or_pair, "aten::_assert_int_or_pair"),
  87. (torch.nn.init._no_grad_fill_, "aten::_no_grad_fill_"),
  88. (torch.nn.init._no_grad_normal_, "aten::_no_grad_normal_"),
  89. (torch.nn.init._no_grad_uniform_, "aten::_no_grad_uniform_"),
  90. (torch.nn.init._no_grad_zero_, "aten::_no_grad_zero_"),
  91. (torch._C._get_tracing_state, "aten::_get_tracing_state"),
  92. (warnings.warn, "aten::warn"),
  93. (torch._VF.stft, "aten::stft"), # type: ignore[attr-defined]
  94. (torch._VF.istft, "aten::istft"), # type: ignore[attr-defined]
  95. (torch._VF.cdist, "aten::cdist"), # type: ignore[attr-defined]
  96. (torch._VF.norm, "aten::norm"), # type: ignore[attr-defined]
  97. (torch._VF.unique_dim, "aten::unique_dim"),
  98. (torch._VF.unique_consecutive, "aten::unique_consecutive"), # type: ignore[attr-defined]
  99. (torch._VF.nuclear_norm, "aten::nuclear_norm"),
  100. (torch._VF.frobenius_norm, "aten::frobenius_norm"),
  101. (torch._VF.tensordot, "aten::tensordot"), # type: ignore[attr-defined]
  102. ]
  103. # ops in torch.functional are bound to torch
  104. # in these cases, we want to resolve the function to their python implementation
  105. # instead looking up a builtin "aten::" schema
  106. def _gen_torch_functional_registered_ops():
  107. # eventually ops should encompass all of torch/functional.py, (torch.functional.__all__)
  108. # but we are currently only able to compile some of the functions. additionally,
  109. # some functions directly map to their aten:: implementations.
  110. # TODO: add support for more ops
  111. ops = ["stft", "istft", "lu", "cdist", "norm", "unique", "unique_consecutive", "tensordot"]
  112. return set(getattr(torch.functional, name) for name in ops)
  113. _functional_registered_ops = _gen_torch_functional_registered_ops()
  114. def _is_special_functional_bound_op(fn):
  115. return fn in _functional_registered_ops
  116. # lazily built to ensure the correct initialization order
  117. def _get_builtin_table():
  118. global _builtin_table
  119. if _builtin_table is not None:
  120. return _builtin_table
  121. _builtin_table = {}
  122. def register_all(mod):
  123. for name in dir(mod):
  124. v = getattr(mod, name)
  125. if callable(v) and not _is_special_functional_bound_op(v) and v is not torch.no_grad and v is not torch.autocast:
  126. _builtin_ops.append((v, "aten::" + name))
  127. for mod in _modules_containing_builtins:
  128. register_all(mod)
  129. _builtin_ops.append((math.gcd, "aten::gcd"))
  130. _builtin_ops.append((math.isfinite, "aten::isfinite"))
  131. _builtin_ops.append((math.remainder, "aten::mathremainder")) # type: ignore[attr-defined]
  132. import torch.distributed.autograd as dist_autograd
  133. if dist_autograd.is_available():
  134. _builtin_ops.append((dist_autograd.get_gradients, "aten::get_gradients"))
  135. _builtin_ops.append((dist_autograd.backward, "aten::dist_backward"))
  136. # populate the _builtin_table from _builtin_ops
  137. for builtin, aten_op in _builtin_ops:
  138. _builtin_table[id(builtin)] = aten_op
  139. return _builtin_table
  140. def _register_builtin(fn, op):
  141. _get_builtin_table()[id(fn)] = op
  142. def _find_builtin(fn):
  143. return _get_builtin_table().get(id(fn))