config.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351
  1. from torchgen.model import NativeFunctionsGroup
  2. from typing import Dict
  3. def func_name_base_str(g: NativeFunctionsGroup) -> str:
  4. return str(g.functional.func.name.name.base)
  5. is_hand_written_ops_ = frozenset(
  6. (
  7. "abs",
  8. "add",
  9. "addmm",
  10. "all",
  11. "any",
  12. "argmin",
  13. "bmm",
  14. "clamp",
  15. "clamp_min",
  16. "cumsum",
  17. "div",
  18. "fmod",
  19. "index_select",
  20. "leaky_relu",
  21. "linear",
  22. "log",
  23. "matmul",
  24. "mul",
  25. "narrow_copy",
  26. "nonzero",
  27. "pow",
  28. "remainder",
  29. "sigmoid",
  30. "sign",
  31. "sub",
  32. "tanh",
  33. )
  34. )
  35. def is_hand_written(g: NativeFunctionsGroup) -> bool:
  36. name_base = func_name_base_str(g)
  37. return name_base in is_hand_written_ops_
  38. def override_test_values(arg_map: Dict[str, str], op_name: str, index: int) -> None:
  39. assert index == 0 or index == 1
  40. if op_name == "addr":
  41. if index == 0:
  42. arg_map["self"] = "at::rand({6, 6})"
  43. arg_map["vec1"] = "at::rand({6})"
  44. arg_map["vec2"] = "at::rand({6})"
  45. else:
  46. arg_map["self"] = "at::rand({22, 22})"
  47. arg_map["vec1"] = "at::rand({22})"
  48. arg_map["vec2"] = "at::rand({22})"
  49. return
  50. if op_name == "mv":
  51. if index == 0:
  52. arg_map["self"] = "at::rand({6, 6})"
  53. arg_map["vec"] = "at::rand({6})"
  54. else:
  55. arg_map["self"] = "at::rand({22, 22})"
  56. arg_map["vec"] = "at::rand({22})"
  57. return
  58. if op_name == "addbmm":
  59. if index == 0:
  60. arg_map["self"] = "at::rand({6, 6})"
  61. else:
  62. arg_map["self"] = "at::rand({22, 22})"
  63. return
  64. if op_name == "cross":
  65. if index == 0:
  66. arg_map["self"] = "at::rand({3, 3, 3})"
  67. arg_map["other"] = "at::rand({3, 3, 3})"
  68. else:
  69. arg_map["self"] = "at::rand({22, 3, 22})"
  70. arg_map["other"] = "at::rand({22, 3, 22})"
  71. return
  72. if op_name == "take":
  73. if index == 0:
  74. arg_map["index"] = "at::randint(0, 216, {20}, torch::kInt64)"
  75. else:
  76. arg_map["index"] = "at::randint(0, 1000, {100}, torch::kInt64)"
  77. return
  78. if op_name == "take_along_dim":
  79. if index == 0:
  80. arg_map["indices"] = "at::argsort(self0, 1)"
  81. else:
  82. arg_map["indices"] = "at::argsort(self1, 1)"
  83. return
  84. if op_name == "masked_select":
  85. if index == 0:
  86. arg_map["mask"] = "at::randn({6, 6, 6}) > 0.5"
  87. else:
  88. arg_map["mask"] = "at::rand({22, 22, 22}) > 0.5"
  89. return
  90. if op_name == "orgqr":
  91. if index == 0:
  92. arg_map["input2"] = "at::rand({6, 6})"
  93. else:
  94. arg_map["input2"] = "at::rand({22, 22})"
  95. return
  96. if op_name == "ormqr":
  97. if index == 0:
  98. arg_map["input2"] = "at::rand({6, 6})"
  99. else:
  100. arg_map["input2"] = "at::rand({22, 22})"
  101. return
  102. if op_name == "quantile":
  103. if index == 0:
  104. arg_map["q"] = "at::rand({6})"
  105. arg_map["interpolation"] = '"linear"'
  106. else:
  107. arg_map["q"] = "at::rand({22})"
  108. arg_map["interpolation"] = '"linear"'
  109. return
  110. if op_name == "nanquantile":
  111. if index == 0:
  112. arg_map["q"] = "at::rand({6})"
  113. arg_map["interpolation"] = '"linear"'
  114. else:
  115. arg_map["q"] = "at::rand({22})"
  116. arg_map["interpolation"] = '"linear"'
  117. return
  118. if op_name == "multi_margin_loss":
  119. if index == 0:
  120. arg_map["self"] = "at::rand({6, 6})"
  121. arg_map["target"] = "at::randint(6, {6}, torch::kInt64)"
  122. arg_map["weight"] = "at::rand({6})"
  123. else:
  124. arg_map["self"] = "at::rand({22, 22})"
  125. arg_map["target"] = "at::randint(22, {22}, torch::kInt64)"
  126. arg_map["weight"] = "at::rand({22})"
  127. return
  128. if op_name == "multilabel_margin_loss":
  129. if index == 0:
  130. arg_map["self"] = "at::rand({6, 6})"
  131. arg_map["target"] = "at::randint(6, {6, 6}, torch::kInt64)"
  132. else:
  133. arg_map["self"] = "at::rand({22, 22})"
  134. arg_map["target"] = "at::randint(22, {22, 22}, torch::kInt64)"
  135. return
  136. if op_name == "nll_loss":
  137. if index == 0:
  138. arg_map["self"] = "at::rand({6, 6})"
  139. arg_map["target"] = "at::randint(6, {6}, torch::kInt64)"
  140. arg_map["weight"] = "at::rand({6})"
  141. else:
  142. arg_map["self"] = "at::rand({22, 22})"
  143. arg_map["target"] = "at::randint(22, {22}, torch::kInt64)"
  144. arg_map["weight"] = "at::rand({22})"
  145. return
  146. if op_name == "nll_loss2d":
  147. if index == 0:
  148. arg_map["self"] = "at::rand({6, 6, 6, 6})"
  149. arg_map["target"] = "at::randint(6, {6, 6, 6}, torch::kInt64)"
  150. arg_map["weight"] = "at::rand({6})"
  151. else:
  152. arg_map["self"] = "at::rand({22, 22, 22, 22})"
  153. arg_map["target"] = "at::randint(22, {22, 22, 22}, torch::kInt64)"
  154. arg_map["weight"] = "at::rand({22})"
  155. return
  156. if op_name in (
  157. "fft_fft",
  158. "fft_ifft",
  159. "fft_rfft",
  160. "fft_irfft",
  161. "fft_hfft",
  162. "fft_ihfft",
  163. ):
  164. arg_map["norm"] = '"forward"'
  165. return
  166. if op_name == "linalg_tensorinv":
  167. if index == 0:
  168. arg_map["self"] = "at::rand({6, 6, 6, 6})"
  169. arg_map["ind"] = "2"
  170. else:
  171. arg_map["self"] = "at::rand({22, 22, 22, 22})"
  172. arg_map["ind"] = "2"
  173. return
  174. if op_name == "addmv":
  175. if index == 0:
  176. arg_map["self"] = "at::rand({2})"
  177. arg_map["mat"] = "at::rand({2, 2})"
  178. arg_map["vec"] = "at::rand({2})"
  179. else:
  180. arg_map["self"] = "at::rand({35})"
  181. arg_map["mat"] = "at::rand({35, 35})"
  182. arg_map["vec"] = "at::rand({35})"
  183. return
  184. if op_name == "acosh":
  185. if index == 0:
  186. arg_map["self"] = "at::rand({2, 2, 2}) + at::ones({2, 2, 2})"
  187. else:
  188. arg_map["self"] = "at::rand({5, 5, 5}) + at::ones({5, 5, 5})"
  189. return
  190. if op_name == "adaptive_max_pool2d_backward":
  191. if index == 0:
  192. arg_map["grad_output"] = "at::randint(-3, 2, {2,2,2})"
  193. arg_map["self"] = "at::randint(-3, 2, {2,2,2})"
  194. arg_map["indices"] = "at::randint(0, 1, {2,2,2}, at::kLong)"
  195. else:
  196. arg_map["grad_output"] = "at::randint(-3, 3, {3,3,3})"
  197. arg_map["self"] = "at::randint(-3, 2, {3,3,3})"
  198. arg_map["indices"] = "at::randint(0, 1, {3,3,3}, at::kLong)"
  199. return
  200. if op_name == "adaptive_max_pool3d_backward":
  201. if index == 0:
  202. arg_map["grad_output"] = "at::randint(-3, 2, {2,2,2,2})"
  203. arg_map["self"] = "at::randint(-3, 2, {2,2,2,2})"
  204. arg_map["indices"] = "at::randint(0, 1, {2,2,2,2}, at::kLong)"
  205. else:
  206. arg_map["grad_output"] = "at::randint(-3, 3, {3,3,3,3})"
  207. arg_map["self"] = "at::randint(-3, 2, {3,3,3,3})"
  208. arg_map["indices"] = "at::randint(0, 1, {3,3,3,3}, at::kLong)"
  209. return
  210. if op_name == "gather":
  211. if index == 0:
  212. arg_map["self"] = "at::randint(1, 100, {2,2,2}, at::kInt)"
  213. arg_map["dim"] = "1"
  214. arg_map["index"] = "at::randint(0, 1, {2,2,2}, torch::kInt64)"
  215. arg_map["sparse_grad"] = "false"
  216. else:
  217. arg_map["self"] = "at::randint(1, 100, {5,5,5}, at::kInt)"
  218. arg_map["dim"] = "1"
  219. arg_map["index"] = "at::randint(0, 4, {5,5,5}, torch::kInt64)"
  220. arg_map["sparse_grad"] = "false"
  221. return
  222. if op_name == "gelu":
  223. if index == 0:
  224. arg_map["self"] = "at::rand({6, 6, 6})"
  225. arg_map["approximate"] = '"tanh"'
  226. else:
  227. arg_map["self"] = "at::rand({22, 22, 22})"
  228. arg_map["approximate"] = '"tanh"'
  229. return
  230. if op_name == "gelu_backward":
  231. if index == 0:
  232. arg_map["grad_output"] = "at::rand({6, 6, 6})"
  233. arg_map["self"] = "at::rand({6, 6, 6})"
  234. arg_map["approximate"] = '"tanh"'
  235. else:
  236. arg_map["grad_output"] = "at::rand({22, 22, 22})"
  237. arg_map["self"] = "at::rand({22, 22, 22})"
  238. arg_map["approximate"] = '"tanh"'
  239. return
  240. if op_name == "index_add":
  241. if index == 0:
  242. arg_map["self"] = "at::rand({2})"
  243. arg_map["dim"] = "0"
  244. arg_map["index"] = "at::randint(0, 1, {2}, at::kInt)"
  245. arg_map["source"] = "at::rand({2})"
  246. arg_map["alpha"] = "2"
  247. else:
  248. arg_map["self"] = "at::rand({16})"
  249. arg_map["dim"] = "0"
  250. arg_map["index"] = "at::randint(0, 10, {16}, at::kInt)"
  251. arg_map["source"] = "at::rand({16})"
  252. arg_map["alpha"] = "2"
  253. return
  254. if op_name == "index_copy":
  255. if index == 0:
  256. arg_map["self"] = "at::rand({2})"
  257. arg_map["dim"] = "0"
  258. arg_map["index"] = "at::randint(0, 1, {2}, at::kLong)"
  259. arg_map["source"] = "at::rand({2})"
  260. else:
  261. arg_map["self"] = "at::rand({32})"
  262. arg_map["dim"] = "0"
  263. arg_map["index"] = "at::randint(0, 10, {32}, at::kLong)"
  264. arg_map["source"] = "at::rand({32})"
  265. return
  266. if op_name == "linalg_cross":
  267. if index == 0:
  268. arg_map["self"] = "at::rand({6, 3, 6})"
  269. arg_map["other"] = "at::rand({6, 3, 6})"
  270. arg_map["dim"] = "1"
  271. else:
  272. arg_map["self"] = "at::rand({22, 3, 22})"
  273. arg_map["other"] = "at::rand({22, 3, 22})"
  274. arg_map["dim"] = "1"
  275. return
  276. if op_name == "nll_loss_backward":
  277. if index == 0:
  278. arg_map["grad_output"] = "at::rand({})"
  279. arg_map["self"] = "at::rand({6})"
  280. arg_map["target"] = "at::randint(0, 5, {6}, torch::kInt64)"
  281. arg_map["weight"] = "at::rand({6})"
  282. arg_map["reduction"] = "1"
  283. arg_map["ignore_index"] = "1"
  284. arg_map["total_weight"] = "at::rand({})"
  285. else:
  286. arg_map["grad_output"] = "at::rand({})"
  287. arg_map["self"] = "at::rand({36})"
  288. arg_map["target"] = "at::randint(0, 11, {36}, torch::kInt64)"
  289. arg_map["weight"] = "at::rand({36})"
  290. arg_map["reduction"] = "1"
  291. arg_map["ignore_index"] = "1"
  292. arg_map["total_weight"] = "at::rand({})"
  293. return
  294. if op_name in ["scatter", "scatter_add", "_scatter_reduce"]:
  295. if index == 0:
  296. arg_map["self"] = "at::randint(1, 100, {2,2,2}, torch::kInt64)"
  297. arg_map["index"] = "at::randint(0, 1, {2,2,2}, torch::kInt64)"
  298. arg_map["src"] = "at::randint(1, 100, {2,2,2}, torch::kInt64)"
  299. else:
  300. arg_map["self"] = "at::randint(1, 100, {5,5,5}, torch::kInt64)"
  301. arg_map["index"] = "at::randint(0, 1, {5,5,5}, torch::kInt64)"
  302. arg_map["src"] = "at::randint(1, 100, {5,5,5}, torch::kInt64)"
  303. if "reduce" in arg_map:
  304. arg_map["reduce"] = '"sum"' if op_name == "_scatter_reduce" else '"add"'
  305. return
  306. if op_name == "scatter_reduce":
  307. arg_map["reduce"] = '"mean"'
  308. if index == 0:
  309. arg_map["index"] = "at::randint(6, {6, 6, 6}, torch::kInt64)"
  310. else:
  311. arg_map["index"] = "at::randint(22, {22, 22, 22}, torch::kInt64)"
  312. return
  313. if op_name == "special_zeta":
  314. if index == 0:
  315. arg_map["self"] = "at::rand({2,2,2}, at::kDouble) + at::ones({2,2,2})"
  316. arg_map["other"] = "at::rand({2,2,2}, at::kDouble) + at::ones({2,2,2})"
  317. else:
  318. arg_map["self"] = "at::rand({5,5,5}, at::kDouble) + at::ones({5,5,5})"
  319. arg_map["other"] = "at::rand({5,5,5}, at::kDouble) + at::ones({5,5,5})"
  320. return
  321. if op_name == "_convert_indices_from_csr_to_coo":
  322. if index == 0:
  323. arg_map["crow_indices"] = "torch::tensor({1}, torch::kInt32)"
  324. arg_map["col_indices"] = "torch::tensor({0, 1, 0}, torch::kInt32)"
  325. arg_map["out_int32"] = "false"
  326. else:
  327. arg_map["crow_indices"] = "torch::tensor({0}, torch::kInt32)"
  328. arg_map[
  329. "col_indices"
  330. ] = "torch::tensor({0, 1, 0, 2, 1, 2, 0, 1, 0, 2, 1, 2}, torch::kInt32)"
  331. arg_map["out_int32"] = "false"
  332. return
  333. if op_name == "_convert_indices_from_coo_to_csr":
  334. if index == 0:
  335. arg_map["self"] = "at::randint(0, 3, {2}, at::kInt)"
  336. arg_map["size"] = "10"
  337. arg_map["out_int32"] = "false"
  338. else:
  339. arg_map["self"] = "at::randint(0, 3, {12}, at::kInt)"
  340. arg_map["size"] = "24"
  341. arg_map["out_int32"] = "false"
  342. return