NumericUtils.h 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. #pragma once
  2. #ifdef __HIPCC__
  3. #include <hip/hip_runtime.h>
  4. #endif
  5. #include <cmath>
  6. #include <complex>
  7. #include <type_traits>
  8. #include <c10/util/BFloat16.h>
  9. #include <c10/util/Half.h>
  10. #include <c10/macros/Macros.h>
  11. namespace at {
  12. // std::isnan isn't performant to use on integral types; it will
  13. // (uselessly) convert to floating point and then do the test.
  14. // This function is.
  15. template <typename T,
  16. typename std::enable_if<std::is_integral<T>::value, int>::type = 0>
  17. inline C10_HOST_DEVICE bool _isnan(T /*val*/) {
  18. return false;
  19. }
  20. template <typename T,
  21. typename std::enable_if<std::is_floating_point<T>::value, int>::type = 0>
  22. inline C10_HOST_DEVICE bool _isnan(T val) {
  23. #if defined(__CUDACC__) || defined(__HIPCC__)
  24. return ::isnan(val);
  25. #else
  26. return std::isnan(val);
  27. #endif
  28. }
  29. template <typename T,
  30. typename std::enable_if<c10::is_complex<T>::value, int>::type = 0>
  31. inline bool _isnan(T val) {
  32. return std::isnan(val.real()) || std::isnan(val.imag());
  33. }
  34. template <typename T,
  35. typename std::enable_if<std::is_same<T, at::Half>::value, int>::type = 0>
  36. inline C10_HOST_DEVICE bool _isnan(T val) {
  37. return at::_isnan(static_cast<float>(val));
  38. }
  39. template <typename T,
  40. typename std::enable_if<std::is_same<T, at::BFloat16>::value, int>::type = 0>
  41. inline C10_HOST_DEVICE bool _isnan(at::BFloat16 val) {
  42. return at::_isnan(static_cast<float>(val));
  43. }
  44. inline C10_HOST_DEVICE bool _isnan(at::BFloat16 val) {
  45. return at::_isnan(static_cast<float>(val));
  46. }
  47. // std::isinf isn't performant to use on integral types; it will
  48. // (uselessly) convert to floating point and then do the test.
  49. // This function is.
  50. template <typename T,
  51. typename std::enable_if<std::is_integral<T>::value, int>::type = 0>
  52. inline C10_HOST_DEVICE bool _isinf(T /*val*/) {
  53. return false;
  54. }
  55. template <typename T,
  56. typename std::enable_if<std::is_floating_point<T>::value, int>::type = 0>
  57. inline C10_HOST_DEVICE bool _isinf(T val) {
  58. #if defined(__CUDACC__) || defined(__HIPCC__)
  59. return ::isinf(val);
  60. #else
  61. return std::isinf(val);
  62. #endif
  63. }
  64. inline C10_HOST_DEVICE bool _isinf(at::Half val) {
  65. return at::_isinf(static_cast<float>(val));
  66. }
  67. inline C10_HOST_DEVICE bool _isinf(at::BFloat16 val) {
  68. return at::_isinf(static_cast<float>(val));
  69. }
  70. template <typename T>
  71. C10_HOST_DEVICE inline T exp(T x) {
  72. static_assert(!std::is_same<T, double>::value, "this template must be used with float or less precise type");
  73. #if defined(__CUDA_ARCH__) || defined(__HIP_ARCH__)
  74. // use __expf fast approximation for peak bandwidth
  75. return __expf(x);
  76. #else
  77. return ::exp(x);
  78. #endif
  79. }
  80. template <>
  81. C10_HOST_DEVICE inline double exp<double>(double x) {
  82. return ::exp(x);
  83. }
  84. template <typename T>
  85. C10_HOST_DEVICE inline T log(T x) {
  86. static_assert(!std::is_same<T, double>::value, "this template must be used with float or less precise type");
  87. #if defined(__CUDA_ARCH__) || defined(__HIP_ARCH__)
  88. // use __logf fast approximation for peak bandwidth
  89. return __logf(x);
  90. #else
  91. return ::log(x);
  92. #endif
  93. }
  94. template <>
  95. C10_HOST_DEVICE inline double log<double>(double x) {
  96. return ::log(x);
  97. }
  98. template <typename T>
  99. C10_HOST_DEVICE inline T tan(T x) {
  100. static_assert(!std::is_same<T, double>::value, "this template must be used with float or less precise type");
  101. #if defined(__CUDA_ARCH__) || defined(__HIP_ARCH__)
  102. // use __tanf fast approximation for peak bandwidth
  103. return __tanf(x);
  104. #else
  105. return ::tan(x);
  106. #endif
  107. }
  108. template <>
  109. C10_HOST_DEVICE inline double tan<double>(double x) {
  110. return ::tan(x);
  111. }
  112. } // namespace at