hp_emblookup_codegen.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530
  1. import argparse
  2. import sys
  3. sizeof = {"float": 4, "at::Half": 2, "uint8_t": 1}
  4. def unroll(uf, IndexType, InType, OutType, use_weights, isa, fused, use_offsets):
  5. def compute(regid, InType, use_weights, isa, prefetch):
  6. code = []
  7. if InType == "float":
  8. code.append(
  9. " vop%d = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (%d)), vop%d);" # noqa
  10. % (regid, regid, regid)
  11. )
  12. elif InType == "at::Half":
  13. code.append(
  14. " vop%d = _mm256_fmadd_ps(\n"
  15. " vwgt,\n"
  16. " _mm256_cvtph_ps(\n"
  17. " _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (%d)))),\n" # noqa
  18. " vop%d);" % (regid, regid, regid)
  19. )
  20. elif InType == "uint8_t":
  21. code.append(
  22. " vop%d = _mm256_fmadd_ps(\n"
  23. " vwgt,\n"
  24. " _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(\n"
  25. " _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (%d))))),\n" # noqa
  26. " _mm256_add_ps(vop%d, vbio));" % (regid, regid, regid)
  27. )
  28. else:
  29. assert False
  30. if prefetch:
  31. code.append(
  32. " _mm_prefetch(\n"
  33. " reinterpret_cast<const char*>(&ip_next_T0[%d]), _MM_HINT_T0);"
  34. % (regid)
  35. )
  36. else:
  37. code.append(
  38. " // skip unnecessary prefetch of (&ip_next_T0[%d])" % (regid)
  39. )
  40. return code
  41. code = []
  42. code.append(" // unrolling " + str(uf) + " times")
  43. if use_offsets:
  44. code.append(
  45. " for ("
  46. + IndexType
  47. + " rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {"
  48. )
  49. else:
  50. code.append(
  51. " for ("
  52. + IndexType
  53. + " rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {"
  54. )
  55. code.append(" " + OutType + "* op = &out[rangeIndex * block_size];")
  56. for i in range(0, uf):
  57. j = 8 * i
  58. code.append(" __m256 vop" + str(j) + " = _mm256_setzero_ps();")
  59. # inner loop
  60. if use_offsets:
  61. code.append(
  62. " if (dataInd != offsets[rangeIndex] - offsets[0]) {\n"
  63. + " return false;\n"
  64. + " }"
  65. )
  66. code.append("""\
  67. int64_t end_offset = offsets[rangeIndex + 1];
  68. int64_t length = end_offset - offsets[rangeIndex];""")
  69. code.append(
  70. " for ("
  71. + "int64_t"
  72. + " start = dataInd; dataInd < end_offset - offsets[0];\n ++dataInd) {" # noqa
  73. )
  74. else:
  75. code.append(
  76. " if (dataInd + lengths[rangeIndex] > index_size) {\n"
  77. + " return false;\n"
  78. + " }"
  79. )
  80. code.append(
  81. " for ("
  82. + IndexType
  83. + " start = dataInd; dataInd < start + lengths[rangeIndex];\n ++dataInd) {" # noqa
  84. )
  85. code.append(" const " + IndexType + " idx = indices[dataInd];")
  86. code.append(
  87. " if (idx < 0 || idx >= data_size) {\n"
  88. + " return false;\n"
  89. + " }"
  90. )
  91. if InType == "uint8_t":
  92. code.append(" " + OutType + " wgt = 1.f;")
  93. code.append(" " + OutType + " bio;")
  94. code.append(" if (weights) {")
  95. code.append(
  96. " wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];" # noqa
  97. )
  98. code.append(" }")
  99. if fused:
  100. code.append(
  101. " const float* scale_bias = reinterpret_cast<const float*>(\n"
  102. " &input[idx * fused_block_size + block_size]);"
  103. )
  104. code.append(" bio = wgt * scale_bias[1];")
  105. code.append(" wgt = wgt * scale_bias[0];")
  106. else:
  107. code.append(" bio = wgt * scale_bias[2 * idx + 1];")
  108. code.append(" wgt = wgt * scale_bias[2 * idx];")
  109. code.append(" __m256 vbio = _mm256_set1_ps(bio);")
  110. else:
  111. code.append(" " + OutType + " wgt = 1.f;")
  112. code.append(" if (weights) {")
  113. code.append(
  114. " wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];" # noqa
  115. )
  116. code.append(" }")
  117. code.append(" __m256 vwgt = _mm256_set1_ps(wgt);")
  118. code.append(" const {}* ip = &input[idx * fused_block_size];".format(InType))
  119. code.append(
  120. " const {} next_T0 = (dataInd < index_size - prefdist_T0)\n"
  121. " ? (dataInd + prefdist_T0)\n : dataInd;".format(
  122. IndexType
  123. )
  124. )
  125. code.append(" const " + IndexType + " idx_pref_T0 = indices[next_T0];")
  126. code.append(
  127. " if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {\n"
  128. + " return false;\n"
  129. + " }"
  130. )
  131. code.append(
  132. " const {}* ip_next_T0 = "
  133. "&input[idx_pref_T0 * fused_block_size];".format(InType)
  134. )
  135. for i in range(0, uf):
  136. j = 8 * i
  137. cachelinesize = 64
  138. byteoffset = sizeof[InType] * j
  139. prefetch = (byteoffset % cachelinesize) == 0
  140. code.extend(compute(j, InType, use_weights, isa, prefetch))
  141. code.append(" }")
  142. if use_offsets:
  143. code.append(" if (!normalize_by_lengths || length == 0) {")
  144. else:
  145. code.append(" if (!normalize_by_lengths || lengths[rangeIndex] == 0) {")
  146. for i in range(0, uf):
  147. j = 8 * i
  148. code.append(" _mm256_storeu_ps(&op[" + str(j) + "], vop" + str(j) + ");")
  149. code.append(" } else {")
  150. # inv of length
  151. if use_offsets:
  152. code.append(" __m256 vlen_inv = _mm256_set1_ps(1.0f / length);")
  153. else:
  154. code.append(" __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);")
  155. for i in range(0, uf):
  156. j = 8 * i
  157. code.append(
  158. " _mm256_storeu_ps(&op["
  159. + str(j)
  160. + "], _mm256_mul_ps("
  161. + "vop"
  162. + str(j)
  163. + ", vlen_inv));"
  164. )
  165. code.append(" }")
  166. code.append(" }")
  167. return code
  168. def generic(IndexType, InType, OutType, use_weights, isa, fused, use_offsets):
  169. def compute(InType, use_weights, isa):
  170. code = []
  171. if InType == "float":
  172. code.append(
  173. " _mm256_storeu_ps(\n"
  174. " &op[j],\n"
  175. " _mm256_fmadd_ps(\n"
  176. " vwgt, _mm256_loadu_ps(&ip[j]), _mm256_loadu_ps(&op[j])));" # noqa
  177. )
  178. elif InType == "at::Half":
  179. code.append(
  180. " _mm256_storeu_ps(\n"
  181. " &op[j],\n"
  182. " _mm256_fmadd_ps(\n"
  183. " vwgt,\n"
  184. " _mm256_cvtph_ps(_mm_loadu_si128(\n"
  185. " reinterpret_cast<const __m128i*>(&ip[j]))),\n"
  186. " _mm256_loadu_ps(&op[j])));"
  187. )
  188. elif InType == "uint8_t":
  189. code.append(
  190. " _mm256_storeu_ps(\n"
  191. " &op[j],\n"
  192. " _mm256_fmadd_ps(\n"
  193. " vwgt,\n"
  194. " _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(\n" # noqa
  195. " reinterpret_cast<const __m128i*>(&ip[j])))),\n"
  196. " _mm256_add_ps(_mm256_loadu_ps(&op[j]), vbio)));"
  197. )
  198. else:
  199. assert False
  200. code.append(
  201. " _mm_prefetch(\n"
  202. " reinterpret_cast<const char*>(&ip_next_T0[j]), _MM_HINT_T0);"
  203. )
  204. return code
  205. code = []
  206. if InType == "at::Half":
  207. code.append(" alignas(64) at::Half vtmp1[8] = {0};")
  208. if use_offsets:
  209. code.append(
  210. " for ("
  211. + IndexType
  212. + " rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {"
  213. )
  214. else:
  215. code.append(
  216. " for ("
  217. + IndexType
  218. + " rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {"
  219. )
  220. code.append(" " + OutType + "* op = &out[rangeIndex * block_size];")
  221. # initialize to 0
  222. code.append(" int64_t j = 0;")
  223. code.append(" for (; j + 8 <= block_size; j += 8) {")
  224. code.append(" _mm256_storeu_ps(op + j, _mm256_setzero_ps());")
  225. code.append(" }")
  226. code.append(" for (; j < block_size; j++) {")
  227. code.append(" op[j] = 0.0f;")
  228. code.append(" }")
  229. # inner loop
  230. if use_offsets:
  231. code.append(
  232. " if (dataInd != offsets[rangeIndex] - offsets[0]) {\n"
  233. + " return false;\n"
  234. + " }"
  235. )
  236. code.append("""\
  237. int64_t end_offset = offsets[rangeIndex + 1];
  238. int64_t length = end_offset - offsets[rangeIndex];""")
  239. code.append(
  240. " for ("
  241. + "int64_t"
  242. + " start = dataInd; dataInd < end_offset - offsets[0];\n ++dataInd) {" # noqa
  243. )
  244. else:
  245. code.append(
  246. " if (dataInd + lengths[rangeIndex] > index_size) {\n"
  247. + " return false;\n"
  248. + " }"
  249. )
  250. code.append(
  251. " for ("
  252. + IndexType
  253. + " start = dataInd; dataInd < start + lengths[rangeIndex];\n ++dataInd) {" # noqa
  254. )
  255. code.append(" const " + IndexType + " idx = indices[dataInd];")
  256. code.append(
  257. " if (idx < 0 || idx >= data_size) {\n"
  258. + " return false;\n"
  259. + " }"
  260. )
  261. if InType == "uint8_t":
  262. code.append(" " + OutType + " wgt = 1.f;")
  263. code.append(" " + OutType + " bio;")
  264. code.append(" if (weights) {")
  265. code.append(
  266. " wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];" # noqa
  267. )
  268. code.append(" }")
  269. if fused:
  270. code.append(
  271. " const float* scale_bias = reinterpret_cast<const float*>(\n"
  272. " &input[idx * fused_block_size + block_size]);"
  273. )
  274. code.append(" bio = wgt * scale_bias[1];")
  275. code.append(" wgt = wgt * scale_bias[0];")
  276. else:
  277. code.append(" bio = wgt * scale_bias[2 * idx + 1];")
  278. code.append(" wgt = wgt * scale_bias[2 * idx];")
  279. code.append(" __m256 vbio = _mm256_set1_ps(bio);")
  280. else:
  281. code.append(" " + OutType + " wgt = 1.f;")
  282. code.append(" if (weights) {")
  283. code.append(
  284. " wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];" # noqa
  285. )
  286. code.append(" }")
  287. code.append(" __m256 vwgt = _mm256_set1_ps(wgt);")
  288. code.append(" const {}* ip = &input[idx * fused_block_size];".format(InType))
  289. code.append(
  290. " const {} next_T0 = (dataInd < index_size - prefdist_T0)\n"
  291. " ? (dataInd + prefdist_T0)\n : dataInd;".format(
  292. IndexType
  293. )
  294. )
  295. code.append(" const " + IndexType + " idx_pref_T0 = indices[next_T0];")
  296. code.append(
  297. " if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {\n"
  298. + " return false;\n"
  299. + " }"
  300. )
  301. code.append(
  302. " const {}* ip_next_T0 = "
  303. "&input[idx_pref_T0 * fused_block_size];".format(InType)
  304. )
  305. # compute and store main loop
  306. code.append(" j = 0;")
  307. code.append(" for (; j + 8 <= block_size; j += 8) {")
  308. code.extend(compute(InType, use_weights, isa))
  309. code.append(" }")
  310. # leftover
  311. code.append(" for (; j < block_size; j++) {")
  312. if InType == "float":
  313. code.append(" op[j] = std::fma(wgt, ip[j], op[j]);")
  314. elif InType == "at::Half":
  315. code.append(" vtmp1[0] = ip[j];")
  316. code.append(
  317. " __m256 vtmp2 =\n"
  318. " _mm256_cvtph_ps(*(reinterpret_cast<const __m128i*>(vtmp1)));"
  319. )
  320. code.append(" op[j] = std::fma(wgt, ((float*)(&vtmp2))[0], op[j]);")
  321. elif InType == "uint8_t":
  322. code.append(" op[j] = std::fma(wgt, (float)ip[j], bio + op[j]);")
  323. else:
  324. assert False
  325. code.append(" }")
  326. code.append(" }")
  327. if use_offsets:
  328. code.append(" if (normalize_by_lengths && length) {")
  329. code.append(" float len_inv = 1.0f / length;")
  330. else:
  331. code.append(" if (normalize_by_lengths && lengths[rangeIndex]) {")
  332. code.append(" float len_inv = 1.0f / lengths[rangeIndex];")
  333. code.append(" __m256 vlen_inv = _mm256_set1_ps(len_inv);")
  334. code.append(" j = 0;")
  335. code.append(" for (; j + 8 <= block_size; j += 8) {")
  336. code.append(
  337. " _mm256_storeu_ps(\n"
  338. " &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));"
  339. )
  340. code.append(" }")
  341. code.append(" for (; j < block_size; j++) {")
  342. code.append(" op[j] = len_inv * op[j];")
  343. code.append(" }")
  344. code.append(" }")
  345. code.append(" }")
  346. return code
  347. # start main code
  348. parser = argparse.ArgumentParser()
  349. parser.add_argument("-f", "--filename", help="file name")
  350. parser.add_argument("--fused", action="store_true")
  351. parser.add_argument("--use-offsets", action="store_true")
  352. opts = parser.parse_args()
  353. if opts.filename:
  354. filename = opts.filename
  355. elif opts.fused:
  356. if opts.use_offsets:
  357. filename = "embedding_lookup_fused_8bit_rowwise_idx_avx2.cc"
  358. else:
  359. filename = "embedding_lookup_fused_8bit_rowwise_avx2.cc"
  360. else:
  361. if opts.use_offsets:
  362. filename = "embedding_lookup_idx_avx2.cc"
  363. else:
  364. filename = "embedding_lookup_avx2.cc"
  365. options = [
  366. ["int32_t", "int", "float", "float", "float", "float"],
  367. ["int64_t", "int64_t", "float", "float", "float", "float"],
  368. ["int32_t", "int", "half", "at::Half", "float", "float"],
  369. ["int64_t", "int64_t", "half", "at::Half", "float", "float"],
  370. ["int32_t", "int", "uint8_t", "uint8_t", "float", "float"],
  371. ["int64_t", "int64_t", "uint8_t", "uint8_t", "float", "float"],
  372. ]
  373. code = []
  374. # includes
  375. code.append("//// --------------------------")
  376. code.append("//// ATTENTION:")
  377. code.append("//// THIS CODE IS AUTOGENERATED")
  378. code.append("//// BY {}".format(sys.argv[0]))
  379. code.append("//// DO NOT MODIFY!!!")
  380. code.append("//// --------------------------\n")
  381. code.append("#include <c10/util/Half.h>")
  382. code.append("#include <immintrin.h>")
  383. code.append("namespace caffe2 {\n")
  384. for o in options:
  385. [IndexTypeName, IndexType, InTypeName, InType, OutTypeName, OutType] = o
  386. prefix = "Fused8BitRowwise" if opts.fused else ""
  387. code.append("template <bool IS_WEIGHT_POSITIONAL>")
  388. if opts.use_offsets:
  389. fn_base = "{}EmbeddingLookupIdx_{}_{}_{}".format(
  390. prefix, IndexTypeName, InTypeName, OutTypeName
  391. )
  392. else:
  393. fn_base = "{}EmbeddingLookup_{}_{}_{}".format(
  394. prefix, IndexTypeName, InTypeName, OutTypeName
  395. )
  396. suffix = "__avx2_fma"
  397. fn = "static bool " + fn_base + suffix
  398. code.append(fn + "(")
  399. args = []
  400. args.append(" const int64_t block_size,")
  401. args.append(" const int64_t output_size,")
  402. args.append(" const int64_t index_size,")
  403. args.append(" const int64_t data_size,")
  404. args.append(" const " + InType + "* input,")
  405. args.append(" const " + IndexType + "* indices,")
  406. if opts.use_offsets:
  407. args.append(" const " + IndexType + "* offsets,")
  408. else:
  409. args.append(" const int* lengths,")
  410. args.append(" const float* weights,")
  411. if not opts.fused:
  412. args.append(" const float* scale_bias,")
  413. args.append(" bool normalize_by_lengths,")
  414. args.append(" " + OutType + "* out) {")
  415. code += args
  416. code.append(" const " + IndexType + " prefdist_T0 = 16;")
  417. # block_size is the number of elements and fused_block_size is the size of
  418. # an entire row, including scale and bias.
  419. offset = (8 // sizeof[InType]) if opts.fused else 0
  420. code.append(
  421. " const {} fused_block_size = block_size + {};".format(IndexType, offset)
  422. )
  423. if opts.use_offsets:
  424. code.append(" int64_t dataInd = 0;")
  425. else:
  426. code.append(" " + IndexType + " dataInd = 0;")
  427. # code.append("printf(\"calling " + fn + "\\n\");");
  428. code.append(" if (block_size == 128) {")
  429. code += unroll(16, IndexType, InType, OutType, True, "AVX2", opts.fused, opts.use_offsets)
  430. code.append(" } else if (block_size == 64) {")
  431. code += unroll(8, IndexType, InType, OutType, True, "AVX2", opts.fused, opts.use_offsets)
  432. code.append(" } else if (block_size == 32) {")
  433. code += unroll(4, IndexType, InType, OutType, True, "AVX2", opts.fused, opts.use_offsets)
  434. code.append(" } else if (block_size == 16) {")
  435. code += unroll(2, IndexType, InType, OutType, True, "AVX2", opts.fused, opts.use_offsets)
  436. code.append(" } else {")
  437. code.append(" // generic code")
  438. code += generic(IndexType, InType, OutType, True, "AVX2", opts.fused, opts.use_offsets)
  439. code.append(" }")
  440. code.append(" return dataInd == index_size;")
  441. code.append("}")
  442. for is_weight_positional in ["false", "true"]:
  443. code.append("bool " + fn_base + "_" + is_weight_positional + suffix + "(")
  444. code += args
  445. # Resolve the Lint warnings: Limit of 80 characters in one line.
  446. extra_space = "\n "
  447. ret_string = " return " + fn_base + suffix + "<" + is_weight_positional + ">("
  448. if len(ret_string) <= 80:
  449. code.append(ret_string)
  450. else:
  451. code.append(" return " + fn_base + suffix + "<" + extra_space + is_weight_positional + ">(")
  452. code.append(" block_size,")
  453. code.append(" output_size,")
  454. code.append(" index_size,")
  455. code.append(" data_size,")
  456. code.append(" input,")
  457. code.append(" indices,")
  458. if opts.use_offsets:
  459. code.append(" offsets,")
  460. else:
  461. code.append(" lengths,")
  462. code.append(" weights,")
  463. if not opts.fused:
  464. code.append(" scale_bias,")
  465. code.append(" normalize_by_lengths,")
  466. code.append(" out);")
  467. code.append("}")
  468. code.append("")
  469. code.append("} // namespace caffe2")
  470. with open(filename, "w") as fout:
  471. for c in code:
  472. # print(c, file = fout)
  473. fout.write(c + "\n")
  474. print("Created " + filename)