Dispatch.h 61 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926
  1. #pragma once
  2. #include <ATen/core/DeprecatedTypeProperties.h>
  3. #include <c10/macros/Macros.h>
  4. #include <c10/util/Exception.h>
  5. #include <c10/util/Half.h>
  6. #include <c10/util/Metaprogramming.h>
  7. #include <c10/util/complex.h>
  8. #include <c10/util/string_view.h>
  9. #ifdef TEMPLATE_SELECTIVE_BUILD
  10. #include <ATen/selected_mobile_ops.h>
  11. #else
  12. namespace at {
  13. /**
  14. * The method should_include_kernel_dtype() returns true/false
  15. * based on whether the switching code for a specific dtype should be
  16. * included based on build time constants generated from tracing model
  17. * execution. This method will be implmeneted via code-generation and
  18. * included in this file when code-gen is ready.
  19. */
  20. inline constexpr bool should_include_kernel_dtype(
  21. const char* /*kernel_tag_str*/,
  22. at::ScalarType /*scalar_type*/
  23. ) {
  24. return true;
  25. }
  26. }
  27. #endif
  28. /**
  29. * In the Facebook internal build (using BUCK), this macro is enabled by
  30. * passing in -c pt.enable_record_kernel_dtype=1 when building the tracer
  31. * binary.
  32. */
  33. #if defined ENABLE_RECORD_KERNEL_FUNCTION_DTYPE
  34. namespace at {
  35. namespace detail {
  36. TORCH_API void record_kernel_function_dtype(std::string name);
  37. }
  38. }
  39. #define RECORD_KERNEL_FUNCTION_DTYPE(NAME, enum_type) \
  40. at::detail::record_kernel_function_dtype( \
  41. std::string(NAME) + "$" + toString(enum_type));
  42. #else
  43. #define RECORD_KERNEL_FUNCTION_DTYPE(NAME, enum_type)
  44. #endif
  45. #if defined __cpp_if_constexpr
  46. #define AT_PRIVATE_CASE_TYPE_USING_HINT(NAME, enum_type, type, HINT, ...) \
  47. case enum_type: { \
  48. if constexpr (!at::should_include_kernel_dtype(NAME, enum_type)) { \
  49. AT_ERROR("dtype '", toString(enum_type), "' not selected for kernel tag ", #NAME); \
  50. } \
  51. using HINT = type; \
  52. return __VA_ARGS__(); \
  53. }
  54. #else
  55. #define AT_PRIVATE_CASE_TYPE_USING_HINT(NAME, enum_type, type, HINT, ...) \
  56. case enum_type: { \
  57. at::guts::if_constexpr<(!at::should_include_kernel_dtype(NAME, enum_type))>( \
  58. [] { \
  59. AT_ERROR("dtype '" #enum_type "' not selected for kernel tag " #NAME); \
  60. } \
  61. ); \
  62. using HINT = type; \
  63. return __VA_ARGS__(); \
  64. }
  65. #endif \
  66. #define AT_PRIVATE_CASE_TYPE(NAME, enum_type, type, ...) \
  67. AT_PRIVATE_CASE_TYPE_USING_HINT(NAME, enum_type, type, scalar_t, __VA_ARGS__)
  68. // Workaround for C10_UNUSED because CUDA 10.1 and below fails to handle unused
  69. // attribute in the type aliasing context. Keep name long and verbose to avoid
  70. // macro collisions.
  71. #if defined(__CUDACC__) && defined(CUDA_VERSION) && CUDA_VERSION <= 10010
  72. #define C10_UNUSED_DISPATCH_CUDA_WORKAROUND
  73. #else
  74. #define C10_UNUSED_DISPATCH_CUDA_WORKAROUND C10_UNUSED
  75. #endif // defined(__CUDACC__) && defined(CUDA_VERSION) && CUDA_VERSION <= 10010
  76. #if defined __cpp_if_constexpr
  77. #define AT_QINT_PRIVATE_CASE_TYPE( \
  78. NAME, enum_type, type, underlying_enum, underlying_type, ...) \
  79. case enum_type: { \
  80. if constexpr (!at::should_include_kernel_dtype(NAME, enum_type)) { \
  81. AT_ERROR("dtype '", toString(enum_type), "' not selected for kernel tag ", #NAME); \
  82. } \
  83. using scalar_t = type; \
  84. using underlying_t C10_UNUSED_DISPATCH_CUDA_WORKAROUND = \
  85. scalar_t::underlying; \
  86. const auto& SCALAR_TYPE C10_UNUSED_DISPATCH_CUDA_WORKAROUND = enum_type; \
  87. const auto& UNDERLYING_TYPE C10_UNUSED_DISPATCH_CUDA_WORKAROUND = \
  88. toUnderlying(enum_type); \
  89. (void)SCALAR_TYPE; /* Suppress unused-var compiler warning */ \
  90. /* TODO: Use [[maybe-unused]] when C++17 becomes the standard */ \
  91. return __VA_ARGS__(); \
  92. }
  93. #else
  94. #define AT_QINT_PRIVATE_CASE_TYPE( \
  95. NAME, enum_type, type, underlying_enum, underlying_type, ...) \
  96. case enum_type: { \
  97. at::guts::if_constexpr<(!at::should_include_kernel_dtype(NAME, enum_type))>( \
  98. [] { \
  99. AT_ERROR("dtype '" #enum_type "' not selected for kernel tag " #NAME); \
  100. } \
  101. ); \
  102. using scalar_t = type; \
  103. using underlying_t C10_UNUSED_DISPATCH_CUDA_WORKAROUND = \
  104. scalar_t::underlying; \
  105. const auto& SCALAR_TYPE C10_UNUSED_DISPATCH_CUDA_WORKAROUND = enum_type; \
  106. const auto& UNDERLYING_TYPE C10_UNUSED_DISPATCH_CUDA_WORKAROUND = \
  107. toUnderlying(enum_type); \
  108. (void)SCALAR_TYPE; /* Suppress unused-var compiler warning */ \
  109. /* TODO: Use [[maybe-unused]] when C++17 becomes the standard */ \
  110. return __VA_ARGS__(); \
  111. }
  112. #endif
  113. #if defined __cpp_if_constexpr
  114. #define AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \
  115. NAME, enum_type, type, underlying_type, bitwidth, qmin, qmax, ...) \
  116. case enum_type: { \
  117. if constexpr (!at::should_include_kernel_dtype(NAME, enum_type)) { \
  118. AT_ERROR("dtype '", toString(enum_type), "' not selected for kernel tag ", #NAME); \
  119. } \
  120. using scalar_t = type; \
  121. using underlying_t C10_UNUSED_DISPATCH_CUDA_WORKAROUND = \
  122. scalar_t::underlying; \
  123. const auto& SCALAR_TYPE C10_UNUSED_DISPATCH_CUDA_WORKAROUND = enum_type; \
  124. const auto& UNDERLYING_TYPE C10_UNUSED_DISPATCH_CUDA_WORKAROUND = \
  125. toUnderlying(enum_type); \
  126. C10_UNUSED int bit_width = bitwidth; \
  127. C10_UNUSED int64_t quant_min = qmin; \
  128. C10_UNUSED int64_t quant_max = qmax; \
  129. (void)bit_width; /* Suppress unused variable warning */ \
  130. (void)quant_min; /* Suppress unused variable warning */ \
  131. (void)quant_max; /* Suppress unused variable warning */ \
  132. return __VA_ARGS__(); \
  133. }
  134. #else
  135. #define AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \
  136. NAME, enum_type, type, underlying_type, bitwidth, qmin, qmax, ...) \
  137. case enum_type: { \
  138. at::guts::if_constexpr<(!at::should_include_kernel_dtype(NAME, enum_type))>( \
  139. [] { \
  140. AT_ERROR("dtype '" #enum_type "' not selected for kernel tag " #NAME); \
  141. } \
  142. ); \
  143. using scalar_t = type; \
  144. using underlying_t C10_UNUSED_DISPATCH_CUDA_WORKAROUND = \
  145. scalar_t::underlying; \
  146. const auto& SCALAR_TYPE C10_UNUSED_DISPATCH_CUDA_WORKAROUND = enum_type; \
  147. const auto& UNDERLYING_TYPE C10_UNUSED_DISPATCH_CUDA_WORKAROUND = \
  148. toUnderlying(enum_type); \
  149. int bit_width = bitwidth; \
  150. int64_t quant_min = qmin; \
  151. int64_t quant_max = qmax; \
  152. (void)bit_width; /* Suppress unused variable warning */ \
  153. (void)quant_min; /* Suppress unused variable warning */ \
  154. (void)quant_max; /* Suppress unused variable warning */ \
  155. return __VA_ARGS__(); \
  156. }
  157. #endif
  158. namespace detail {
  159. inline at::ScalarType scalar_type(at::ScalarType s) {
  160. return s;
  161. }
  162. C10_DEPRECATED_MESSAGE(
  163. "passing at::DeprecatedTypeProperties to an AT_DISPATCH macro is deprecated, "
  164. "pass an at::ScalarType instead")
  165. inline at::ScalarType scalar_type(const at::DeprecatedTypeProperties& t) {
  166. return t.scalarType();
  167. }
  168. C10_DEPRECATED_MESSAGE(
  169. "AT_DISPATCH_ALL_TYPES_AND_HALF is deprecated, "
  170. "use AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, ...) instead")
  171. inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF() {}
  172. C10_DEPRECATED_MESSAGE(
  173. "AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX is deprecated, "
  174. "use AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(at::ScalarType::Half, ...) "
  175. "instead")
  176. inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {}
  177. } // namespace detail
  178. // The AT_DISPATCH_* family of macros provides the ability to
  179. // conveniently generate specializations of a kernel over all of the
  180. // dtypes we care about in PyTorch. We call it "dispatch" because
  181. // we are "dispatching" to the correct, dtype-specific kernel.
  182. //
  183. // A standard usage looks like:
  184. //
  185. // AT_DISPATCH_ALL_TYPES(self.scalar_type(), "op_name", [&] {
  186. // // Your code here, with 'scalar_t' now defined to
  187. // // be the dtype in question
  188. // })
  189. //
  190. // There are many variations of this macro, so it's important to
  191. // understand exactly /which/ dtypes you want to get instantiated, as
  192. // well as what the "default" set is.
  193. //
  194. // The default set of dtypes that are instantiated (e.g., by
  195. // AT_DISPATCH_ALL_TYPES) are floating point types (float, double),
  196. // and integral types (int32_t, int64_t, int16_t, int8_t, uint8_t),
  197. // but NOT booleans (bool), half-precision floats (Half) or
  198. // complex number (c10::complex<float>, c10::complex<double>).
  199. // This "cut" is somewhat historical (the default types are the
  200. // ones that TH historically supported), but it also reflects the
  201. // fact that the non-default types are "poorly" behaved (booleans
  202. // are NOT integers mod 2, half precision operations ~essentially
  203. // don't exist on CPU, complex numbers are an experimental application).
  204. //
  205. // Here are the questions you should generally ask to decide which
  206. // dispatch you want:
  207. //
  208. // 1. Is this an integral or floating point specific operation?
  209. // (If so, you'll want one of the FLOATING or INTEGRAL macros.)
  210. //
  211. // 2. Should half be supported? (If you're on CPU, the answer is almost
  212. // definitely no. If you do want support, use one of the AND_HALF
  213. // macros)
  214. //
  215. // Much rarer situations:
  216. //
  217. // 3. Should bool be supported? (You often have to write your kernel
  218. // differently if arithmetic operations are involved.) If so,
  219. // Use AT_DISPATCH_ALL_TYPES_AND along with ScalarType::Bool
  220. //
  221. // 4. Should complex be supported? The answer is almost always no,
  222. // unless you are working on "generic" code that should work on
  223. // all dtypes.
  224. //
  225. // Parameters:
  226. // -----------
  227. //
  228. // 1. The NAME argument is a "tag" that is used to trace and then
  229. // conditionally compile fragments of the case statements such
  230. // that the kernel functions are specialized only for the dtypes
  231. // that are needed. The NAME parameter *must* be a build time
  232. // cons char* (can't be std::string, etc...)
  233. //
  234. // Please ensure that the NAME is unique for every implementation
  235. // or you run the risk of over-including code for the kernel
  236. // functions. There is no risk of missing out on any code, so
  237. // it's mostly a risk of a Type-2 error, and not a Type-1 error.
  238. //
  239. // NB: the the_type variable is not used, but we have kept it for
  240. // backwards compatibility. It's probably not used by anyone though;
  241. // but we're just being safe (and it doesn't hurt.) Note we must
  242. // use it to shut up warnings about unused store.
  243. #define AT_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
  244. [&] { \
  245. const auto& the_type = TYPE; \
  246. /* don't use TYPE again in case it is an expensive or side-effect op */ \
  247. at::ScalarType _st = ::detail::scalar_type(the_type); \
  248. RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \
  249. switch (_st) { \
  250. AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Double, double, __VA_ARGS__) \
  251. AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Float, float, __VA_ARGS__) \
  252. default: \
  253. AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \
  254. } \
  255. }()
  256. #define AT_DISPATCH_FLOATING_TYPES_AND_HALF(TYPE, NAME, ...) \
  257. [&] { \
  258. const auto& the_type = TYPE; \
  259. /* don't use TYPE again in case it is an expensive or side-effect op */ \
  260. at::ScalarType _st = ::detail::scalar_type(the_type); \
  261. RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \
  262. switch (_st) { \
  263. AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Double, double, __VA_ARGS__) \
  264. AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Float, float, __VA_ARGS__) \
  265. AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Half, at::Half, __VA_ARGS__) \
  266. default: \
  267. AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \
  268. } \
  269. }()
  270. #define AT_DISPATCH_FLOATING_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \
  271. [&] { \
  272. const auto& the_type = TYPE; \
  273. /* don't use TYPE again in case it is an expensive or side-effect op */ \
  274. at::ScalarType _st = ::detail::scalar_type(the_type); \
  275. RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \
  276. switch (_st) { \
  277. AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Double, double, __VA_ARGS__) \
  278. AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Float, float, __VA_ARGS__) \
  279. AT_PRIVATE_CASE_TYPE(NAME, \
  280. SCALARTYPE, \
  281. decltype(c10::impl::ScalarTypeToCPPType<SCALARTYPE>::t), \
  282. __VA_ARGS__) \
  283. default: \
  284. AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
  285. } \
  286. }()
  287. #define AT_DISPATCH_FLOATING_TYPES_AND2( \
  288. SCALARTYPE1, SCALARTYPE2, TYPE, NAME, ...) \
  289. [&] { \
  290. const auto& the_type = TYPE; \
  291. /* don't use TYPE again in case it is an expensive or side-effect op */ \
  292. at::ScalarType _st = ::detail::scalar_type(the_type); \
  293. RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \
  294. switch (_st) { \
  295. AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Double, double, __VA_ARGS__) \
  296. AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Float, float, __VA_ARGS__) \
  297. AT_PRIVATE_CASE_TYPE( \
  298. NAME, \
  299. SCALARTYPE1, \
  300. decltype(c10::impl::ScalarTypeToCPPType<SCALARTYPE1>::t), \
  301. __VA_ARGS__) \
  302. AT_PRIVATE_CASE_TYPE( \
  303. NAME, \
  304. SCALARTYPE2, \
  305. decltype(c10::impl::ScalarTypeToCPPType<SCALARTYPE2>::t), \
  306. __VA_ARGS__) \
  307. default: \
  308. AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
  309. } \
  310. }()
  311. #define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(TYPE, NAME, ...) \
  312. [&] { \
  313. const auto& the_type = TYPE; \
  314. /* don't use TYPE again in case it is an expensive or side-effect op */ \
  315. at::ScalarType _st = ::detail::scalar_type(the_type); \
  316. RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \
  317. switch (_st) { \
  318. AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Double, double, __VA_ARGS__) \
  319. AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Float, float, __VA_ARGS__) \
  320. AT_PRIVATE_CASE_TYPE( \
  321. NAME, \
  322. at::ScalarType::ComplexDouble, \
  323. c10::complex<double>, \
  324. __VA_ARGS__) \
  325. AT_PRIVATE_CASE_TYPE( \
  326. NAME, \
  327. at::ScalarType::ComplexFloat, \
  328. c10::complex<float>, \
  329. __VA_ARGS__) \
  330. default: \
  331. AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \
  332. } \
  333. }()
  334. #define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1( \
  335. SCALARTYPE, TYPE, NAME, ...) \
  336. [&] { \
  337. const auto& the_type = TYPE; \
  338. /* don't use TYPE again in case it is an expensive or side-effect op */ \
  339. at::ScalarType _st = ::detail::scalar_type(the_type); \
  340. RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \
  341. switch (_st) { \
  342. AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Double, double, __VA_ARGS__) \
  343. AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Float, float, __VA_ARGS__) \
  344. AT_PRIVATE_CASE_TYPE( \
  345. NAME, \
  346. at::ScalarType::ComplexDouble, c10::complex<double>, __VA_ARGS__) \
  347. AT_PRIVATE_CASE_TYPE( \
  348. NAME, \
  349. at::ScalarType::ComplexFloat, c10::complex<float>, __VA_ARGS__) \
  350. AT_PRIVATE_CASE_TYPE( \
  351. NAME, \
  352. SCALARTYPE, \
  353. decltype(c10::impl::ScalarTypeToCPPType<SCALARTYPE>::t), \
  354. __VA_ARGS__) \
  355. default: \
  356. AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \
  357. } \
  358. }()
  359. #define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2( \
  360. SCALARTYPE1, SCALARTYPE2, TYPE, NAME, ...) \
  361. [&] { \
  362. const auto& the_type = TYPE; \
  363. /* don't use TYPE again in case it is an expensive or side-effect op */ \
  364. at::ScalarType _st = ::detail::scalar_type(the_type); \
  365. RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \
  366. switch (_st) { \
  367. AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Double, double, __VA_ARGS__) \
  368. AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Float, float, __VA_ARGS__) \
  369. AT_PRIVATE_CASE_TYPE( \
  370. NAME, \
  371. at::ScalarType::ComplexDouble, \
  372. c10::complex<double>, \
  373. __VA_ARGS__) \
  374. AT_PRIVATE_CASE_TYPE( \
  375. NAME, \
  376. at::ScalarType::ComplexFloat, \
  377. c10::complex<float>, \
  378. __VA_ARGS__) \
  379. AT_PRIVATE_CASE_TYPE( \
  380. NAME, \
  381. SCALARTYPE1, \
  382. decltype(c10::impl::ScalarTypeToCPPType<SCALARTYPE1>::t), \
  383. __VA_ARGS__) \
  384. AT_PRIVATE_CASE_TYPE( \
  385. NAME, \
  386. SCALARTYPE2, \
  387. decltype(c10::impl::ScalarTypeToCPPType<SCALARTYPE2>::t), \
  388. __VA_ARGS__) \
  389. default: \
  390. AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \
  391. } \
  392. }()
  393. #define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND3( \
  394. SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, TYPE, NAME, ...) \
  395. [&] { \
  396. const auto& the_type = TYPE; \
  397. /* don't use TYPE again in case it is an expensive or side-effect op */ \
  398. at::ScalarType _st = ::detail::scalar_type(the_type); \
  399. RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \
  400. switch (_st) { \
  401. AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Double, double, __VA_ARGS__) \
  402. AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Float, float, __VA_ARGS__) \
  403. AT_PRIVATE_CASE_TYPE( \
  404. NAME, \
  405. at::ScalarType::ComplexDouble, \
  406. c10::complex<double>, \
  407. __VA_ARGS__) \
  408. AT_PRIVATE_CASE_TYPE( \
  409. NAME, \
  410. at::ScalarType::ComplexFloat, \
  411. c10::complex<float>, \
  412. __VA_ARGS__) \
  413. AT_PRIVATE_CASE_TYPE( \
  414. NAME, \
  415. SCALARTYPE1, \
  416. decltype(c10::impl::ScalarTypeToCPPType<SCALARTYPE1>::t), \
  417. __VA_ARGS__) \
  418. AT_PRIVATE_CASE_TYPE( \
  419. NAME, \
  420. SCALARTYPE2, \
  421. decltype(c10::impl::ScalarTypeToCPPType<SCALARTYPE2>::t), \
  422. __VA_ARGS__) \
  423. AT_PRIVATE_CASE_TYPE( \
  424. NAME, \
  425. SCALARTYPE3, \
  426. decltype(c10::impl::ScalarTypeToCPPType<SCALARTYPE3>::t), \
  427. __VA_ARGS__) \
  428. default: \
  429. AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \
  430. } \
  431. }()
  432. #define AT_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
  433. [&] { \
  434. const auto& the_type = TYPE; \
  435. /* don't use TYPE again in case it is an expensive or side-effect op */ \
  436. at::ScalarType _st = ::detail::scalar_type(the_type); \
  437. RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \
  438. switch (_st) { \
  439. AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Byte, uint8_t, __VA_ARGS__) \
  440. AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Char, int8_t, __VA_ARGS__) \
  441. AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Int, int32_t, __VA_ARGS__) \
  442. AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Long, int64_t, __VA_ARGS__) \
  443. AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Short, int16_t, __VA_ARGS__) \
  444. default: \
  445. AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \
  446. } \
  447. }()
  448. #define AT_DISPATCH_INTEGRAL_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \
  449. [&] { \
  450. const auto& the_type = TYPE; \
  451. /* don't use TYPE again in case it is an expensive or side-effect op */ \
  452. at::ScalarType _st = ::detail::scalar_type(the_type); \
  453. RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \
  454. switch (_st) { \
  455. AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Byte, uint8_t, __VA_ARGS__) \
  456. AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Char, int8_t, __VA_ARGS__) \
  457. AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Int, int32_t, __VA_ARGS__) \
  458. AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Long, int64_t, __VA_ARGS__) \
  459. AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Short, int16_t, __VA_ARGS__) \
  460. AT_PRIVATE_CASE_TYPE(NAME, \
  461. SCALARTYPE, \
  462. decltype(c10::impl::ScalarTypeToCPPType<SCALARTYPE>::t), \
  463. __VA_ARGS__) \
  464. default: \
  465. AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \
  466. } \
  467. }()
  468. #define AT_DISPATCH_ALL_TYPES(TYPE, NAME, ...) \
  469. [&] { \
  470. const auto& the_type = TYPE; \
  471. /* don't use TYPE again in case it is an expensive or side-effect op */ \
  472. at::ScalarType _st = ::detail::scalar_type(the_type); \
  473. RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \
  474. switch (_st) { \
  475. AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Byte, uint8_t, __VA_ARGS__) \
  476. AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Char, int8_t, __VA_ARGS__) \
  477. AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Double, double, __VA_ARGS__) \
  478. AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Float, float, __VA_ARGS__) \
  479. AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Int, int32_t, __VA_ARGS__) \
  480. AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Long, int64_t, __VA_ARGS__) \
  481. AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Short, int16_t, __VA_ARGS__) \
  482. default: \
  483. AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \
  484. } \
  485. }()
  486. #define AT_DISPATCH_COMPLEX_TYPES(TYPE, NAME, ...) \
  487. [&] { \
  488. const auto& the_type = TYPE; \
  489. /* don't use TYPE again in case it is an expensive or side-effect op */ \
  490. at::ScalarType _st = ::detail::scalar_type(the_type); \
  491. RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \
  492. switch (_st) { \
  493. AT_PRIVATE_CASE_TYPE( \
  494. NAME, \
  495. at::ScalarType::ComplexFloat, \
  496. c10::complex<float>, \
  497. __VA_ARGS__) \
  498. AT_PRIVATE_CASE_TYPE( \
  499. NAME, \
  500. at::ScalarType::ComplexDouble, \
  501. c10::complex<double>, \
  502. __VA_ARGS__) \
  503. default: \
  504. AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \
  505. } \
  506. }()
  507. #define AT_DISPATCH_COMPLEX_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \
  508. [&] { \
  509. const auto& the_type = TYPE; \
  510. /* don't use TYPE again in case it is an expensive or side-effect op */ \
  511. at::ScalarType _st = ::detail::scalar_type(the_type); \
  512. RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \
  513. switch (_st) { \
  514. AT_PRIVATE_CASE_TYPE( \
  515. NAME, \
  516. at::ScalarType::ComplexFloat, \
  517. c10::complex<float>, \
  518. __VA_ARGS__) \
  519. AT_PRIVATE_CASE_TYPE( \
  520. NAME, \
  521. at::ScalarType::ComplexDouble, \
  522. c10::complex<double>, \
  523. __VA_ARGS__) \
  524. AT_PRIVATE_CASE_TYPE( \
  525. NAME, \
  526. SCALARTYPE, \
  527. decltype(c10::impl::ScalarTypeToCPPType<SCALARTYPE>::t), \
  528. __VA_ARGS__) \
  529. default: \
  530. AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \
  531. } \
  532. }()
  533. #define AT_DISPATCH_QINT_TYPES(TYPE, NAME, ...) \
  534. [&] { \
  535. const auto& the_type = TYPE; \
  536. /* don't use TYPE again in case it is an expensive or side-effect op */ \
  537. at::ScalarType _st = ::detail::scalar_type(the_type); \
  538. RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \
  539. switch (_st) { \
  540. AT_QINT_PRIVATE_CASE_TYPE( \
  541. NAME, at::kQInt8, at::qint8, at::kChar, int8_t, __VA_ARGS__) \
  542. AT_QINT_PRIVATE_CASE_TYPE( \
  543. NAME, at::kQUInt8, at::quint8, at::kByte, uint8_t, __VA_ARGS__) \
  544. AT_QINT_PRIVATE_CASE_TYPE( \
  545. NAME, at::kQInt32, at::qint32, at::kInt, int, __VA_ARGS__) \
  546. default: \
  547. AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
  548. } \
  549. }()
  550. #define AT_DISPATCH_QINT_BYTE_TYPES(TYPE, NAME, ...) \
  551. [&] { \
  552. const auto& the_type = TYPE; \
  553. /* don't use TYPE again in case it is an expensive or side-effect op */ \
  554. at::ScalarType _st = ::detail::scalar_type(the_type); \
  555. RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \
  556. switch (_st) { \
  557. AT_QINT_PRIVATE_CASE_TYPE( \
  558. NAME, at::kQInt8, at::qint8, at::kChar, int8_t, __VA_ARGS__) \
  559. AT_QINT_PRIVATE_CASE_TYPE( \
  560. NAME, at::kQUInt8, at::quint8, at::kByte, uint8_t, __VA_ARGS__) \
  561. default: \
  562. AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
  563. } \
  564. }()
  565. #define AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES(TYPE, NAME, ...) \
  566. [&] { \
  567. const auto& the_type = TYPE; \
  568. /* don't use TYPE again in case it is an expensive or side-effect op */ \
  569. at::ScalarType _st = ::detail::scalar_type(the_type); \
  570. RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \
  571. switch (_st) { \
  572. AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \
  573. NAME, at::kQInt8, at::qint8, int8_t, CHAR_BIT, SCHAR_MIN, SCHAR_MAX, __VA_ARGS__) \
  574. AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \
  575. NAME, at::kQUInt8, at::quint8, uint8_t, CHAR_BIT, 0, UCHAR_MAX, __VA_ARGS__) \
  576. AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \
  577. NAME, at::kQInt32, at::qint32, int, CHAR_BIT * sizeof(int), INT_MIN, INT_MAX, __VA_ARGS__) \
  578. AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \
  579. NAME, at::kQUInt4x2, at::quint4x2, uint8_t, 4, 0, 15, __VA_ARGS__) \
  580. AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \
  581. NAME, at::kQUInt2x4, at::quint2x4, uint8_t, 2, 0, 3, __VA_ARGS__) \
  582. default: \
  583. AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
  584. } \
  585. }()
  586. #define AT_DISPATCH_ALL_TYPES_AND_COMPLEX(TYPE, NAME, ...) \
  587. [&] { \
  588. const auto& the_type = TYPE; \
  589. /* don't use TYPE again in case it is an expensive or side-effect op*/ \
  590. at::ScalarType _st = ::detail::scalar_type(the_type); \
  591. RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \
  592. switch (_st) { \
  593. AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Byte, uint8_t, __VA_ARGS__) \
  594. AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Char, int8_t, __VA_ARGS__) \
  595. AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Double, double, __VA_ARGS__) \
  596. AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Float, float, __VA_ARGS__) \
  597. AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Int, int32_t, __VA_ARGS__) \
  598. AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Long, int64_t, __VA_ARGS__) \
  599. AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Short, int16_t, __VA_ARGS__) \
  600. AT_PRIVATE_CASE_TYPE(NAME, \
  601. at::ScalarType::ComplexFloat, c10::complex<float>, __VA_ARGS__) \
  602. AT_PRIVATE_CASE_TYPE(NAME, \
  603. at::ScalarType::ComplexDouble, c10::complex<double>, __VA_ARGS__) \
  604. default: \
  605. AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \
  606. } \
  607. }()
  608. #define AT_DISPATCH_ALL_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \
  609. [&] { \
  610. const auto& the_type = TYPE; \
  611. /* don't use TYPE again in case it is an expensive or side-effect op*/ \
  612. at::ScalarType _st = ::detail::scalar_type(the_type); \
  613. RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \
  614. switch (_st) { \
  615. AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Byte, uint8_t, __VA_ARGS__) \
  616. AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Char, int8_t, __VA_ARGS__) \
  617. AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Double, double, __VA_ARGS__) \
  618. AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Float, float, __VA_ARGS__) \
  619. AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Int, int32_t, __VA_ARGS__) \
  620. AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Long, int64_t, __VA_ARGS__) \
  621. AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Short, int16_t, __VA_ARGS__) \
  622. AT_PRIVATE_CASE_TYPE( \
  623. NAME, \
  624. SCALARTYPE, \
  625. decltype(c10::impl::ScalarTypeToCPPType<SCALARTYPE>::t), \
  626. __VA_ARGS__) \
  627. default: \
  628. AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \
  629. } \
  630. }()
  631. #define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(SCALARTYPE, TYPE, NAME, ...) \
  632. [&] { \
  633. const auto& the_type = TYPE; \
  634. /* don't use TYPE again in case it is an expensive or side-effect op*/ \
  635. at::ScalarType _st = ::detail::scalar_type(the_type); \
  636. RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \
  637. switch (_st) { \
  638. AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Byte, uint8_t, __VA_ARGS__) \
  639. AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Char, int8_t, __VA_ARGS__) \
  640. AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Double, double, __VA_ARGS__) \
  641. AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Float, float, __VA_ARGS__) \
  642. AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Int, int32_t, __VA_ARGS__) \
  643. AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Long, int64_t, __VA_ARGS__) \
  644. AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Short, int16_t, __VA_ARGS__) \
  645. AT_PRIVATE_CASE_TYPE( \
  646. NAME, \
  647. at::ScalarType::ComplexFloat, \
  648. c10::complex<float>, \
  649. __VA_ARGS__) \
  650. AT_PRIVATE_CASE_TYPE( \
  651. NAME, \
  652. at::ScalarType::ComplexDouble, \
  653. c10::complex<double>, \
  654. __VA_ARGS__) \
  655. AT_PRIVATE_CASE_TYPE( \
  656. NAME, \
  657. SCALARTYPE, \
  658. decltype(c10::impl::ScalarTypeToCPPType<SCALARTYPE>::t), \
  659. __VA_ARGS__) \
  660. default: \
  661. AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \
  662. } \
  663. }()
  664. #define AT_DISPATCH_ALL_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, TYPE, NAME, ...) \
  665. [&] { \
  666. const auto& the_type = TYPE; \
  667. /* don't use TYPE again in case it is an expensive or side-effect op*/ \
  668. at::ScalarType _st = ::detail::scalar_type(the_type); \
  669. RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \
  670. switch (_st) { \
  671. AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Byte, uint8_t, __VA_ARGS__) \
  672. AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Char, int8_t, __VA_ARGS__) \
  673. AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Double, double, __VA_ARGS__) \
  674. AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Float, float, __VA_ARGS__) \
  675. AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Int, int32_t, __VA_ARGS__) \
  676. AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Long, int64_t, __VA_ARGS__) \
  677. AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Short, int16_t, __VA_ARGS__) \
  678. AT_PRIVATE_CASE_TYPE( \
  679. NAME, \
  680. SCALARTYPE1, \
  681. decltype(c10::impl::ScalarTypeToCPPType<SCALARTYPE1>::t), \
  682. __VA_ARGS__) \
  683. AT_PRIVATE_CASE_TYPE( \
  684. NAME, \
  685. SCALARTYPE2, \
  686. decltype(c10::impl::ScalarTypeToCPPType<SCALARTYPE2>::t), \
  687. __VA_ARGS__) \
  688. default: \
  689. AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \
  690. } \
  691. }()
  692. #define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2( \
  693. SCALARTYPE1, SCALARTYPE2, TYPE, NAME, ...) \
  694. [&] { \
  695. const auto& the_type = TYPE; \
  696. /* don't use TYPE again in case it is an expensive or side-effect op*/ \
  697. at::ScalarType _st = ::detail::scalar_type(the_type); \
  698. RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \
  699. switch (_st) { \
  700. AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Byte, uint8_t, __VA_ARGS__) \
  701. AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Char, int8_t, __VA_ARGS__) \
  702. AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Double, double, __VA_ARGS__) \
  703. AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Float, float, __VA_ARGS__) \
  704. AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Int, int32_t, __VA_ARGS__) \
  705. AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Long, int64_t, __VA_ARGS__) \
  706. AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Short, int16_t, __VA_ARGS__) \
  707. AT_PRIVATE_CASE_TYPE( \
  708. NAME, at::ScalarType::ComplexFloat, c10::complex<float>, __VA_ARGS__) \
  709. AT_PRIVATE_CASE_TYPE( \
  710. NAME, at::ScalarType::ComplexDouble, c10::complex<double>, __VA_ARGS__) \
  711. AT_PRIVATE_CASE_TYPE( \
  712. NAME, \
  713. SCALARTYPE1, \
  714. decltype(c10::impl::ScalarTypeToCPPType<SCALARTYPE1>::t), \
  715. __VA_ARGS__) \
  716. AT_PRIVATE_CASE_TYPE( \
  717. NAME, \
  718. SCALARTYPE2, \
  719. decltype(c10::impl::ScalarTypeToCPPType<SCALARTYPE2>::t), \
  720. __VA_ARGS__) \
  721. default: \
  722. AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \
  723. } \
  724. }()
  725. #define AT_DISPATCH_ALL_TYPES_AND3( \
  726. SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, TYPE, NAME, ...) \
  727. [&] { \
  728. const auto& the_type = TYPE; \
  729. /* don't use TYPE again in case it is an expensive or side-effect op*/ \
  730. at::ScalarType _st = ::detail::scalar_type(the_type); \
  731. RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \
  732. switch (_st) { \
  733. AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Byte, uint8_t, __VA_ARGS__) \
  734. AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Char, int8_t, __VA_ARGS__) \
  735. AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Double, double, __VA_ARGS__) \
  736. AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Float, float, __VA_ARGS__) \
  737. AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Int, int32_t, __VA_ARGS__) \
  738. AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Long, int64_t, __VA_ARGS__) \
  739. AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Short, int16_t, __VA_ARGS__) \
  740. AT_PRIVATE_CASE_TYPE( \
  741. NAME, \
  742. SCALARTYPE1, \
  743. decltype(c10::impl::ScalarTypeToCPPType<SCALARTYPE1>::t), \
  744. __VA_ARGS__) \
  745. AT_PRIVATE_CASE_TYPE( \
  746. NAME, \
  747. SCALARTYPE2, \
  748. decltype(c10::impl::ScalarTypeToCPPType<SCALARTYPE2>::t), \
  749. __VA_ARGS__) \
  750. AT_PRIVATE_CASE_TYPE( \
  751. NAME, \
  752. SCALARTYPE3, \
  753. decltype(c10::impl::ScalarTypeToCPPType<SCALARTYPE3>::t), \
  754. __VA_ARGS__) \
  755. default: \
  756. AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \
  757. } \
  758. }()
  759. #define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3( \
  760. SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, TYPE, NAME, ...) \
  761. [&] { \
  762. const auto& the_type = TYPE; \
  763. /* don't use TYPE again in case it is an expensive or side-effect op*/ \
  764. at::ScalarType _st = ::detail::scalar_type(the_type); \
  765. RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \
  766. switch (_st) { \
  767. AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Byte, uint8_t, __VA_ARGS__) \
  768. AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Char, int8_t, __VA_ARGS__) \
  769. AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Double, double, __VA_ARGS__) \
  770. AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Float, float, __VA_ARGS__) \
  771. AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Int, int32_t, __VA_ARGS__) \
  772. AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Long, int64_t, __VA_ARGS__) \
  773. AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Short, int16_t, __VA_ARGS__) \
  774. AT_PRIVATE_CASE_TYPE( \
  775. NAME, at::ScalarType::ComplexFloat, c10::complex<float>, __VA_ARGS__) \
  776. AT_PRIVATE_CASE_TYPE( \
  777. NAME, at::ScalarType::ComplexDouble, c10::complex<double>, __VA_ARGS__) \
  778. AT_PRIVATE_CASE_TYPE( \
  779. NAME, \
  780. SCALARTYPE1, \
  781. decltype(c10::impl::ScalarTypeToCPPType<SCALARTYPE1>::t), \
  782. __VA_ARGS__) \
  783. AT_PRIVATE_CASE_TYPE( \
  784. NAME, \
  785. SCALARTYPE2, \
  786. decltype(c10::impl::ScalarTypeToCPPType<SCALARTYPE2>::t), \
  787. __VA_ARGS__) \
  788. AT_PRIVATE_CASE_TYPE( \
  789. NAME, \
  790. SCALARTYPE3, \
  791. decltype(c10::impl::ScalarTypeToCPPType<SCALARTYPE3>::t), \
  792. __VA_ARGS__) \
  793. default: \
  794. AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \
  795. } \
  796. }()
  797. #define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( \
  798. SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, TYPE, NAME, ...) \
  799. [&] { \
  800. const auto& the_type = TYPE; \
  801. /* don't use TYPE again in case it is an expensive or side-effect op*/ \
  802. at::ScalarType _st = ::detail::scalar_type(the_type); \
  803. RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \
  804. switch (_st) { \
  805. AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Byte, uint8_t, __VA_ARGS__) \
  806. AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Char, int8_t, __VA_ARGS__) \
  807. AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Double, double, __VA_ARGS__) \
  808. AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Float, float, __VA_ARGS__) \
  809. AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Int, int32_t, __VA_ARGS__) \
  810. AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Long, int64_t, __VA_ARGS__) \
  811. AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Short, int16_t, __VA_ARGS__) \
  812. AT_PRIVATE_CASE_TYPE( \
  813. NAME, \
  814. at::ScalarType::ComplexFloat, \
  815. c10::complex<float>, \
  816. __VA_ARGS__) \
  817. AT_PRIVATE_CASE_TYPE( \
  818. NAME, \
  819. at::ScalarType::ComplexDouble, \
  820. c10::complex<double>, \
  821. __VA_ARGS__) \
  822. AT_PRIVATE_CASE_TYPE( \
  823. NAME, \
  824. SCALARTYPE1, \
  825. decltype(c10::impl::ScalarTypeToCPPType<SCALARTYPE1>::t), \
  826. __VA_ARGS__) \
  827. AT_PRIVATE_CASE_TYPE( \
  828. NAME, \
  829. SCALARTYPE2, \
  830. decltype(c10::impl::ScalarTypeToCPPType<SCALARTYPE2>::t), \
  831. __VA_ARGS__) \
  832. AT_PRIVATE_CASE_TYPE( \
  833. NAME, \
  834. SCALARTYPE3, \
  835. decltype(c10::impl::ScalarTypeToCPPType<SCALARTYPE3>::t), \
  836. __VA_ARGS__) \
  837. AT_PRIVATE_CASE_TYPE( \
  838. NAME, \
  839. SCALARTYPE4, \
  840. decltype(c10::impl::ScalarTypeToCPPType<SCALARTYPE4>::t), \
  841. __VA_ARGS__) \
  842. default: \
  843. AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \
  844. } \
  845. }()
  846. #define AT_DISPATCH_INDEX_TYPES(TYPE, NAME, ...) \
  847. [&] { \
  848. const auto& the_index_type = TYPE; \
  849. /* don't use TYPE again in case it is an expensive or side-effect op */ \
  850. at::ScalarType _it = ::detail::scalar_type(the_index_type); \
  851. RECORD_KERNEL_FUNCTION_DTYPE(NAME, _it) \
  852. switch (_it) { \
  853. AT_PRIVATE_CASE_TYPE_USING_HINT(NAME, at::ScalarType::Int, int32_t, index_t, __VA_ARGS__) \
  854. AT_PRIVATE_CASE_TYPE_USING_HINT(NAME, at::ScalarType::Long, int64_t, index_t, __VA_ARGS__)\
  855. default: \
  856. AT_ERROR(#NAME, " not implemented for '", toString(_it), "'"); \
  857. } \
  858. }()
  859. // ----------------------------------------------------------------------------
  860. // DEPRECATED MACROS, DON'T USE THESE
  861. // ----------------------------------------------------------------------------
  862. #define AT_DISPATCH_ALL_TYPES_AND_HALF(TYPE, NAME, ...) \
  863. [&] { \
  864. detail::deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF(); \
  865. const auto& the_type = TYPE; \
  866. /* don't use TYPE again in case it is an expensive or side-effect op */ \
  867. at::ScalarType _st = ::detail::scalar_type(the_type); \
  868. RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \
  869. switch (_st) { \
  870. AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Byte, uint8_t, __VA_ARGS__) \
  871. AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Char, int8_t, __VA_ARGS__) \
  872. AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Double, double, __VA_ARGS__) \
  873. AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Float, float, __VA_ARGS__) \
  874. AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Int, int32_t, __VA_ARGS__) \
  875. AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Long, int64_t, __VA_ARGS__) \
  876. AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Short, int16_t, __VA_ARGS__) \
  877. AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Half, at::Half, __VA_ARGS__) \
  878. default: \
  879. AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \
  880. } \
  881. }()