decompositions.py 41 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291
  1. import torch
  2. from torch import Tensor
  3. from torch._decomp import register_decomposition
  4. from enum import Enum
  5. from typing import Tuple, Optional, List, Callable
  6. import torch.nn.functional as F
  7. import functools
  8. from torch.utils._pytree import tree_map, tree_flatten
  9. import torch._prims.utils as utils
  10. from torch._prims.wrappers import out_wrapper_multi
  11. # None of these functions are publicly accessible; get at them
  12. # from torch._decomps
  13. __all__: List[str] = []
  14. aten = torch.ops.aten
  15. class Reduction(Enum):
  16. NONE = 0
  17. MEAN = 1
  18. SUM = 2
  19. # This wraps a decomposition and performs various type promotion logic within it, depending on the strategy provided
  20. # We're currently re-using ELEMENTWISE_TYPE_PROMOTION_KIND, although some of the usages are on non-elementwise ops
  21. # Will need to validate the non-elementwise uses
  22. def type_casts(f: Callable, type_promotion: utils.ELEMENTWISE_TYPE_PROMOTION_KIND):
  23. @functools.wraps(f)
  24. def inner(*args, **kwargs):
  25. flat_args = [x for x in tree_flatten((args, kwargs))[0] if isinstance(x, Tensor)]
  26. computation_dtype, result_dtype = utils.elementwise_dtypes(*flat_args,
  27. type_promotion_kind=type_promotion)
  28. # TODO: pretty sure this is not quite right
  29. def increase_prec(x):
  30. if isinstance(x, Tensor):
  31. return x.to(computation_dtype)
  32. else:
  33. return x
  34. def decrease_prec(x):
  35. if isinstance(x, Tensor):
  36. return x.to(result_dtype)
  37. else:
  38. return x
  39. r = f(*tree_map(increase_prec, args), **tree_map(increase_prec, kwargs))
  40. return tree_map(decrease_prec, r)
  41. return inner
  42. pw_cast_for_opmath = functools.partial(type_casts, type_promotion=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT)
  43. reduction_complex_to_real = functools.partial(type_casts, type_promotion=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT)
  44. pw_cast_for_int_to_real = functools.partial(type_casts, type_promotion=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
  45. # This expands x until x.dim() == dim. Might be useful as an operator
  46. def _unsqueeze_to_dim(x: Tensor, dim: int):
  47. for _ in range(dim - x.dim()):
  48. x = x.unsqueeze(-1)
  49. return x
  50. @register_decomposition(aten.tanh_backward)
  51. @pw_cast_for_opmath
  52. def tanh_backward(out_grad: Tensor, y: Tensor):
  53. return out_grad * (1 - y * y).conj_physical()
  54. @register_decomposition(aten.sigmoid_backward)
  55. @pw_cast_for_opmath
  56. def sigmoid_backward(out_grad: Tensor, y: Tensor):
  57. return out_grad * (y * (1 - y)).conj_physical()
  58. @register_decomposition(aten.softplus_backward)
  59. @pw_cast_for_opmath
  60. def softplus_backward(out_grad: Tensor, x: Tensor, beta: float, threshold: float):
  61. z = (x * beta).exp()
  62. return torch.where((x * beta) > threshold, out_grad, out_grad * z / (z + 1.0))
  63. @register_decomposition(aten.elu)
  64. @pw_cast_for_opmath
  65. def elu(
  66. self: Tensor, alpha: float = 1, scale: float = 1, input_scale: float = 1
  67. ) -> Tensor:
  68. negcoef = alpha * scale
  69. poscoef = scale
  70. negiptcoef = input_scale
  71. return torch.where(
  72. self > 0, self * poscoef, (torch.exp(self * negiptcoef) - 1) * negcoef
  73. )
  74. @register_decomposition(aten.elu_backward)
  75. @pw_cast_for_opmath
  76. def elu_backward(
  77. grad_output: Tensor,
  78. alpha: float,
  79. scale: float,
  80. input_scale: float,
  81. is_result: bool,
  82. self_or_result: Tensor,
  83. ):
  84. negcoef = alpha * scale
  85. poscoef = scale
  86. negiptcoef = input_scale
  87. if is_result:
  88. return torch.where(
  89. self_or_result <= 0,
  90. grad_output * negiptcoef * (self_or_result + negcoef),
  91. self_or_result * poscoef,
  92. )
  93. else:
  94. return torch.where(
  95. self_or_result <= 0,
  96. grad_output * negiptcoef * negcoef * torch.exp(self_or_result * negiptcoef),
  97. grad_output * poscoef,
  98. )
  99. @register_decomposition(aten.hardsigmoid)
  100. @pw_cast_for_opmath
  101. def hardsigmoid(self: Tensor) -> Tensor:
  102. return torch.clamp(torch.clamp(self + 3, min=0), max=6) / 6
  103. @register_decomposition(aten.hardsigmoid_backward)
  104. @pw_cast_for_opmath
  105. def hardsigmoid_backward(grad_output: Tensor, self: Tensor):
  106. return torch.where(
  107. (self > -3.0) & (self < 3.0),
  108. grad_output * (1.0 / 6.0),
  109. grad_output.new_zeros(()),
  110. )
  111. @register_decomposition(aten.hardtanh)
  112. @pw_cast_for_opmath
  113. def hardtanh(self: Tensor, min_val: float = -1, max_val: float = 1) -> Tensor:
  114. return torch.clamp(self, min_val, max_val)
  115. @register_decomposition(aten.hardtanh_backward)
  116. @pw_cast_for_opmath
  117. def hardtanh_backward(
  118. grad_output: Tensor, self: Tensor, min_val: float, max_val: float
  119. ):
  120. return torch.where(
  121. (self <= min_val) | (self >= max_val), grad_output.new_zeros(()), grad_output
  122. )
  123. @register_decomposition(aten.hardshrink_backward)
  124. @pw_cast_for_opmath
  125. def hardshrink_backward(grad_out: Tensor, self: Tensor, lambd: float):
  126. return torch.where(
  127. (self >= -lambd) & (self <= lambd), grad_out.new_zeros(()), grad_out
  128. )
  129. @register_decomposition(aten.hardswish)
  130. @pw_cast_for_opmath
  131. def hardswish(self: Tensor) -> Tensor:
  132. return self * torch.clamp(torch.clamp(self + 3, min=0), max=6) / 6
  133. @register_decomposition(aten.hardswish_backward)
  134. @pw_cast_for_opmath
  135. def hardswish_backward(grad_output: Tensor, self: Tensor) -> Tensor:
  136. return torch.where(
  137. self < -3,
  138. grad_output.new_zeros(()),
  139. torch.where(self <= 3, grad_output * ((self / 3) + 0.5), grad_output),
  140. )
  141. @register_decomposition(aten.threshold_backward)
  142. @pw_cast_for_opmath
  143. def threshold_backward(grad_output: Tensor, self: Tensor, threshold: float):
  144. return torch.where(self <= threshold, grad_output.new_zeros(()), grad_output)
  145. @register_decomposition(aten.leaky_relu)
  146. @pw_cast_for_opmath
  147. def leaky_relu(self: Tensor, negative_slope: float = 0.01) -> Tensor:
  148. return torch.where(self > 0, self, self * negative_slope)
  149. @register_decomposition(aten.leaky_relu_backward)
  150. @pw_cast_for_opmath
  151. def leaky_relu_backward(
  152. grad_output: Tensor, self: Tensor, negative_slope: float, self_is_result: bool
  153. ):
  154. return torch.where(self > 0, grad_output, grad_output * negative_slope)
  155. @register_decomposition(aten.gelu)
  156. @pw_cast_for_opmath
  157. def gelu(self: Tensor, approximate: str = 'none') -> Tensor:
  158. M_SQRT2 = 1.41421356237309504880
  159. M_SQRT1_2 = 0.70710678118654752440
  160. M_2_SQRTPI = 1.12837916709551257390
  161. if approximate == 'tanh':
  162. kBeta = M_SQRT2 * M_2_SQRTPI * 0.5
  163. kKappa = 0.044715
  164. x_cube = self * self * self
  165. inner = kBeta * (self + kKappa * x_cube)
  166. return 0.5 * self * (1 + torch.tanh(inner))
  167. else:
  168. kAlpha = M_SQRT1_2
  169. return self * 0.5 * (1 + torch.erf(self * kAlpha))
  170. @register_decomposition(aten.gelu_backward)
  171. @pw_cast_for_opmath
  172. def gelu_backward(grad: Tensor, self: Tensor, approximate: str = "none"):
  173. M_SQRT2 = 1.41421356237309504880
  174. M_SQRT1_2 = 0.70710678118654752440
  175. M_2_SQRTPI = 1.12837916709551257390
  176. if approximate == 'tanh':
  177. kBeta = M_SQRT2 * M_2_SQRTPI * 0.5
  178. kKappa = 0.044715
  179. x_sq = self * self
  180. x_cube = x_sq * self
  181. inner = kBeta * (self + kKappa * x_cube)
  182. tanh_inner = torch.tanh(inner)
  183. left = 0.5 * self
  184. right = 1 + tanh_inner
  185. left_derivative = 0.5 * right
  186. tanh_derivative = 1 - tanh_inner * tanh_inner
  187. inner_derivative = kBeta * (1 + 3 * kKappa * x_sq)
  188. right_derivative = left * tanh_derivative * inner_derivative
  189. return grad * (left_derivative + right_derivative)
  190. else:
  191. kAlpha = M_SQRT1_2
  192. kBeta = M_2_SQRTPI * M_SQRT1_2 * 0.5
  193. cdf = 0.5 * (1 + torch.erf(self * kAlpha))
  194. pdf = kBeta * torch.exp(self * self * -0.5)
  195. return grad * (cdf + self * pdf)
  196. @register_decomposition(aten.mish_backward)
  197. @pw_cast_for_opmath
  198. def mish_backward(grad_output: Tensor, input: Tensor):
  199. input_tanh_softplus = torch.tanh(F.softplus(input))
  200. input_sigmoid = torch.sigmoid(input)
  201. out = input * input_sigmoid * (1 - input_tanh_softplus * input_tanh_softplus)
  202. return grad_output * (input_tanh_softplus + out)
  203. @register_decomposition(aten.silu)
  204. @pw_cast_for_opmath
  205. def silu(self: Tensor) -> Tensor:
  206. return self * torch.sigmoid(self)
  207. @register_decomposition(aten.silu_backward)
  208. @pw_cast_for_opmath
  209. def silu_backward(grad_output: Tensor, self: Tensor) -> Tensor:
  210. sigmoid = 1 / (1 + torch.exp(-self))
  211. return grad_output * sigmoid * (1 + self * (1 - sigmoid))
  212. @register_decomposition(aten.softshrink_backward)
  213. def softshrink_backward(grad_output: Tensor, self: Tensor, lambd: float) -> Tensor:
  214. return torch.where(
  215. (self >= -lambd) & (self <= lambd), grad_output.new_zeros(()), grad_output
  216. )
  217. @register_decomposition(aten.prelu_backward)
  218. @pw_cast_for_opmath
  219. def prelu_backward(
  220. grad_output: Tensor, self: Tensor, weight: Tensor
  221. ) -> Tuple[Tensor, Tensor]:
  222. # Logic is more complicated than I would like. Basically, weight can either
  223. # be a scalar or a vector of size [C], and in the forward pass it's
  224. # broadcast against [N, C, ...]. So now, we need to do the corresponding
  225. # reduction, which is harder than we'd like...
  226. cur_weight = weight
  227. for _ in range(2, grad_output.dim()):
  228. cur_weight = cur_weight.unsqueeze(-1)
  229. input_grad = torch.where(self > 0, grad_output, cur_weight * grad_output)
  230. weight_grad_collector = torch.where(
  231. self > 0, grad_output.new_zeros(()), self * grad_output
  232. )
  233. out = weight_grad_collector.sum_to_size(cur_weight.shape)
  234. while out.dim() > weight.dim():
  235. out = out.squeeze(-1)
  236. return (input_grad, out)
  237. @register_decomposition(aten.rrelu_with_noise_backward)
  238. @pw_cast_for_opmath
  239. def rrelu_with_noise_backward(
  240. grad_output: Tensor,
  241. self: Tensor,
  242. noise: Tensor,
  243. lower: float,
  244. upper: float,
  245. training: bool,
  246. self_is_result: bool,
  247. ) -> Tensor:
  248. if training and upper - lower > 1e-6:
  249. return grad_output.mul(noise)
  250. else:
  251. negative_slope = (lower + upper) / 2
  252. return aten.leaky_relu_backward(grad_output, self, negative_slope, self_is_result)
  253. @register_decomposition(aten.log_sigmoid_backward)
  254. @pw_cast_for_opmath
  255. def log_sigmoid_backward(grad_output: Tensor, self: Tensor, buffer: Tensor) -> Tensor:
  256. in_negative = self < 0
  257. max_deriv = torch.where(in_negative, 1, 0)
  258. sign = torch.where(in_negative, 1, -1)
  259. z = torch.exp(-torch.abs(self))
  260. return grad_output * (max_deriv - sign * (z / (1 + z)))
  261. # CPU has a special formula that uses buffer, but disabled for convenience sake
  262. # return (max_deriv - sign * (buffer / (1 + buffer))) * grad_output
  263. def apply_loss_reduction(loss: Tensor, reduction: int):
  264. if reduction == Reduction.MEAN.value:
  265. return torch.mean(loss)
  266. elif reduction == Reduction.SUM.value:
  267. return torch.sum(loss)
  268. else:
  269. return loss
  270. def to_real_dtype(dtype: torch.dtype):
  271. if dtype == torch.complex32:
  272. return torch.float16
  273. elif dtype == torch.complex64:
  274. return torch.float32
  275. elif dtype == torch.complex128:
  276. return torch.float64
  277. # TODO: None of these loss castings are quite correct, see
  278. # https://github.com/pytorch/pytorch/issues/76870. Also, the ATen kernels
  279. # perform the pointwise portion in opmath, but don't maintain it between the
  280. # pointwise portion and the reduction
  281. @register_decomposition(aten.l1_loss)
  282. def l1_loss(
  283. self: Tensor, target: Tensor, reduction: int = Reduction.MEAN.value
  284. ) -> Tensor:
  285. loss = (self - target).abs()
  286. # PyTorch semantics result in the output of l1_loss having the corresponding
  287. # real dtype to self. This may not happen without explicit casting if say
  288. # self: complex64 and target: float64, which results in loss: float64
  289. float_type = to_real_dtype(self.dtype)
  290. return apply_loss_reduction(loss, reduction).to(float_type)
  291. @register_decomposition(aten.l1_loss_backward)
  292. @pw_cast_for_opmath
  293. def l1_loss_backward(
  294. grad_output: Tensor,
  295. self: Tensor,
  296. target: Tensor,
  297. reduction: int = Reduction.MEAN.value,
  298. ):
  299. sign = torch.sign(self - target)
  300. norm = sign / self.numel() if reduction == Reduction.MEAN.value else sign
  301. return grad_output * norm
  302. @register_decomposition(aten.mse_loss)
  303. @pw_cast_for_opmath
  304. def mse_loss(
  305. self: Tensor, target: Tensor, reduction: int = Reduction.MEAN.value
  306. ) -> Tensor:
  307. loss = (self - target) ** 2
  308. return apply_loss_reduction(loss, reduction)
  309. @register_decomposition(aten.mse_loss_backward)
  310. @pw_cast_for_opmath
  311. def mse_loss_backward(
  312. grad_output: Tensor, input: Tensor, target: Tensor, reduction: int
  313. ):
  314. norm = 2.0 / input.numel() if reduction == Reduction.MEAN.value else 2.0
  315. return norm * (input - target) * grad_output
  316. @register_decomposition(aten.huber_loss)
  317. @pw_cast_for_opmath
  318. def huber_loss(
  319. self: Tensor,
  320. target: Tensor,
  321. reduction: int = Reduction.MEAN.value,
  322. delta: float = 1.0,
  323. ) -> Tensor:
  324. assert delta > 0, "huber_loss does not support non-positive values for delta."
  325. z = (self - target).abs()
  326. loss = torch.where(z < delta, 0.5 * z * z, delta * (z - 0.5 * delta))
  327. return apply_loss_reduction(loss, reduction)
  328. @register_decomposition(aten.huber_loss_backward)
  329. @pw_cast_for_opmath
  330. def huber_loss_backward(
  331. grad_output: Tensor, self: Tensor, target: Tensor, reduction: int, delta: float
  332. ):
  333. norm = 1.0 / self.numel() if reduction == Reduction.MEAN.value else 1.0
  334. x = self - target
  335. return torch.where(
  336. x < -delta,
  337. -norm * grad_output * delta,
  338. torch.where(x > delta, norm * grad_output * delta, norm * x * grad_output),
  339. )
  340. def _nll_loss_backward(
  341. grad_output: Tensor,
  342. self: Tensor,
  343. target: Tensor,
  344. weight: Optional[Tensor],
  345. reduction: int,
  346. ignore_index: int,
  347. total_weight: Tensor,
  348. ) -> Tensor:
  349. channel_dim = 0 if self.dim() < 2 else 1
  350. if reduction == Reduction.MEAN.value:
  351. grad_output = grad_output / total_weight
  352. target = target.unsqueeze(channel_dim)
  353. grad_input = torch.zeros_like(self)
  354. grad_input = torch.scatter(grad_input, channel_dim, target, -1.0)
  355. if grad_input.dim() > grad_output.dim() > 0:
  356. grad_output = grad_output.unsqueeze(channel_dim)
  357. if weight is not None:
  358. new_shape = [1 for _ in range(self.dim())]
  359. new_shape[channel_dim] = weight.shape[0]
  360. weight = weight.reshape(new_shape)
  361. grad_output = grad_output * weight
  362. has_ignore_index = ignore_index >= 0
  363. if has_ignore_index:
  364. ignore_index_mask = target != ignore_index
  365. grad_output = grad_output * ignore_index_mask
  366. return grad_input * grad_output
  367. @register_decomposition(aten.nll_loss_backward)
  368. def nll_loss_backward(
  369. grad_output: Tensor,
  370. self: Tensor,
  371. target: Tensor,
  372. weight: Optional[Tensor],
  373. reduction: int,
  374. ignore_index: int,
  375. total_weight: Tensor,
  376. ) -> Tensor:
  377. assert 0 <= self.dim() <= 2, "input tensor should be 1D or 2D"
  378. assert (
  379. target.dim() <= 1
  380. ), "0D or 1D target tensor expected, multi-target not supported"
  381. no_batch_dim = self.dim() == 1 and target.dim() == 0
  382. assert no_batch_dim or (
  383. self.shape[0] == target.shape[0]
  384. ), f"size mismatch (got input: {self.shape}, target: {target.shape})"
  385. assert total_weight.numel() == 1, (
  386. "expected total_weight to be a single element tensor, got: ",
  387. f"{total_weight.shape} ({total_weight.numel()} elements)",
  388. )
  389. assert (
  390. weight is None or weight.numel() == self.shape[-1]
  391. ), "weight tensor should be defined either for all or no classes"
  392. if reduction == Reduction.NONE.value and self.dim() == 2:
  393. assert grad_output.dim() == 1 and grad_output.shape[0] == self.shape[0], (
  394. f"Expected a tensor of dimension 1 and tensor.size[0] == {self.shape[0]} but "
  395. f"got: dimension {grad_output.dim()} and tensor.size[0] == {grad_output.shape[0]}"
  396. )
  397. else:
  398. assert (
  399. grad_output.dim() <= 1 and grad_output.numel() == 1
  400. ), f"Expected a single element grad_output tensor, but got: {grad_output.shape}"
  401. return _nll_loss_backward(grad_output, self, target, weight, reduction, ignore_index, total_weight)
  402. @register_decomposition(aten.nll_loss2d_backward)
  403. def nll_loss2d_backward(
  404. grad_output: Tensor,
  405. self: Tensor,
  406. target: Tensor,
  407. weight: Optional[Tensor],
  408. reduction: int,
  409. ignore_index: int,
  410. total_weight: Tensor,
  411. ) -> Tensor:
  412. assert (
  413. self.dim() == 4
  414. ), f"only batches of spatial inputs supported (4D tensors), but got input of dimension: {self.dim()}"
  415. assert (
  416. target.dim() == 3
  417. ), f"only batches of spatial targets supported (3D tensors) but got targets of dimension: {target.dim()}"
  418. assert(
  419. self.shape[0] == target.shape[0] and self.shape[2] == target.shape[1] and self.shape[3] == target.shape[2]
  420. ), f"size mismatch (got input: {self.shape}, target: {target.shape}"
  421. assert (
  422. total_weight.numel() == 1
  423. ), (
  424. "expected total_weight to be a single element tensor, "
  425. f"got: {total_weight.shape} ( {total_weight.numel()}, elements)"
  426. )
  427. return _nll_loss_backward(grad_output, self, target, weight, reduction, ignore_index, total_weight)
  428. @register_decomposition(aten.binary_cross_entropy)
  429. @pw_cast_for_opmath
  430. def binary_cross_entropy(
  431. self: Tensor,
  432. target: Tensor,
  433. weight: Optional[Tensor] = None,
  434. reduction: int = Reduction.MEAN.value,
  435. ) -> Tensor:
  436. # We cannot currently model this without introducing data-dependent control flow
  437. # TORCH_CHECK(
  438. # (input_val >= 0) && (input_val <= 1),
  439. # "all elements of input should be between 0 and 1"
  440. # )
  441. loss = (target - 1) * torch.maximum(
  442. torch.log(1 - self), self.new_full((), -100)
  443. ) - target * torch.maximum(torch.log(self), self.new_full((), -100))
  444. if weight is not None:
  445. loss = loss * weight
  446. return apply_loss_reduction(loss, reduction)
  447. @register_decomposition(aten.binary_cross_entropy_backward)
  448. @pw_cast_for_opmath
  449. def binary_cross_entropy_backward(
  450. grad_output: Tensor,
  451. self: Tensor,
  452. target: Tensor,
  453. weight: Optional[Tensor] = None,
  454. reduction: int = Reduction.MEAN.value,
  455. ) -> Tensor:
  456. EPSILON = 1e-12
  457. result = grad_output * (self - target) / torch.clamp(self * (1 - self), min=EPSILON)
  458. if weight is not None:
  459. result = result * weight
  460. if reduction == Reduction.MEAN.value:
  461. result = result / self.numel()
  462. return result
  463. @register_decomposition(aten._euclidean_dist)
  464. def _euclidean_dist(x1: Tensor, x2: Tensor) -> Tensor:
  465. x1_norm = x1.pow(2).sum(-1, True)
  466. x1_pad = torch.ones_like(x1_norm, memory_format=torch.contiguous_format)
  467. x2_norm = x2.pow(2).sum(-1, True)
  468. x2_pad = torch.ones_like(x2_norm, memory_format=torch.contiguous_format)
  469. x1_ = torch.cat([x1.mul(-2), x1_norm, x1_pad], -1)
  470. x2_ = torch.cat([x2, x2_pad, x2_norm], -1)
  471. result = x1_.matmul(x2_.mT)
  472. return result.clamp_min(0).sqrt()
  473. @register_decomposition(aten.slice_backward)
  474. def slice_backward(
  475. grad_output: Tensor,
  476. input_sizes: List[int],
  477. dim: int,
  478. start: int,
  479. end: int,
  480. step: int,
  481. ):
  482. grad_input = grad_output.new_zeros(input_sizes)
  483. return torch.slice_scatter(grad_input, grad_output, dim, start, end, step)
  484. @register_decomposition(aten.select_backward)
  485. def select_backward(grad_output: Tensor, input_sizes: List[int], dim: int, index: int):
  486. grad_input = grad_output.new_zeros(input_sizes)
  487. return torch.select_scatter(grad_input, grad_output, dim, index)
  488. @register_decomposition(aten.diagonal_backward)
  489. def diagonal_backward(
  490. grad_output: Tensor, input_sizes: List[int], offset: int, dim1: int, dim2: int
  491. ):
  492. grad_input = grad_output.new_zeros(input_sizes)
  493. return torch.diagonal_scatter(grad_input, grad_output, offset, dim1, dim2)
  494. @register_decomposition(aten._softmax_backward_data)
  495. @pw_cast_for_opmath
  496. def _softmax_backward_data(
  497. grad_output: Tensor, output: Tensor, dim: int, input_dtype: int
  498. ):
  499. new_grad = grad_output * output
  500. return new_grad - output * torch.sum(new_grad, dim=dim, keepdim=True)
  501. @register_decomposition(aten._log_softmax_backward_data)
  502. @pw_cast_for_opmath
  503. def _log_softmax_backward_data(
  504. grad_output: Tensor, output: Tensor, dim: int, input_dtype: int
  505. ):
  506. grad_input = grad_output - torch.exp(output) * torch.sum(
  507. grad_output, dim=dim, keepdim=True
  508. )
  509. return grad_input
  510. # TODO: the type annotations on arguments are not quite right
  511. @register_decomposition(aten.im2col_backward)
  512. def im2col_backward(
  513. grad_output: Tensor,
  514. input_size: List[int],
  515. kernel_size: List[int],
  516. dilation: List[int],
  517. padding: List[int],
  518. stride: List[int],
  519. ) -> Tensor:
  520. return F.fold(grad_output, input_size, kernel_size, dilation, padding, stride) # type: ignore[arg-type]
  521. @register_decomposition(aten.col2im_backward)
  522. def col2im_backward(
  523. grad_output: Tensor,
  524. kernel_size: List[int],
  525. dilation: List[int],
  526. padding: List[int],
  527. stride: List[int],
  528. ) -> Tensor:
  529. return F.unfold(grad_output, kernel_size, dilation, padding, stride) # type: ignore[arg-type]
  530. @register_decomposition(aten.masked_fill.Scalar)
  531. def masked_fill_Scalar(self: Tensor, mask: Tensor, value: float) -> Tensor:
  532. return torch.where(mask, utils.dtype_to_type(self.dtype)(value), self)
  533. @register_decomposition(aten.masked_fill.Tensor)
  534. def masked_fill_Tensor(self: Tensor, mask: Tensor, value: Tensor) -> Tensor:
  535. return torch.where(mask, value, self)
  536. @register_decomposition(aten.native_dropout_backward)
  537. @pw_cast_for_opmath
  538. def native_dropout_backward(grad_output: Tensor, mask: Tensor, scale: float):
  539. return grad_output * (mask.type_as(grad_output) * scale)
  540. @register_decomposition(aten.logit)
  541. @pw_cast_for_int_to_real
  542. def logit(self: Tensor, eps: Optional[float] = None) -> Tensor:
  543. if eps is None:
  544. eps = -1.0
  545. lo = eps
  546. hi = 1 - eps
  547. self = torch.clamp(self, lo, hi)
  548. return (self / (1 - self)).log()
  549. @register_decomposition(aten.logit_backward)
  550. @pw_cast_for_opmath
  551. def logit_backward(
  552. grad_output: Tensor, self: Tensor, eps: Optional[float] = None
  553. ) -> Tensor:
  554. if eps is not None:
  555. lo = eps
  556. hi = 1.0 - lo
  557. return torch.where(
  558. torch.logical_and(self >= lo, self <= hi),
  559. grad_output / (self * (1.0 - self)),
  560. self.new_zeros(()),
  561. )
  562. else:
  563. return torch.where(
  564. torch.logical_and(self >= 0.0, self <= 1.0),
  565. grad_output / (self * (1.0 - self)),
  566. self.new_full((), float("nan")),
  567. )
  568. @register_decomposition(aten.native_dropout)
  569. @pw_cast_for_opmath
  570. def native_dropout(input: Tensor, p: float, train: Optional[bool]):
  571. if train:
  572. bool_mask = torch.rand_like(input) < p
  573. res = bool_mask * input * float(1.0 / p)
  574. return (res, bool_mask)
  575. else:
  576. return (input, torch.ones_like(input, dtype=torch.bool))
  577. # TODO: Correct the type promotion semantics
  578. @register_decomposition(aten._softmax)
  579. @pw_cast_for_opmath
  580. def _softmax(x: Tensor, dim: int, half_to_float: bool):
  581. x_max = torch.max(x, dim, keepdim=True)[0]
  582. unnormalized = torch.exp(x - x_max)
  583. return unnormalized / torch.sum(unnormalized, dim, keepdim=True)
  584. # TODO: Correct the type promotion semantics
  585. @register_decomposition(aten._log_softmax)
  586. @pw_cast_for_opmath
  587. def _log_softmax(x: Tensor, dim: int, half_to_float: bool):
  588. x_max = torch.max(x, dim, keepdim=True)[0]
  589. shifted = x - x_max
  590. shifted_logsumexp = torch.log(torch.sum(torch.exp(shifted), dim, keepdim=True))
  591. return shifted - shifted_logsumexp
  592. @register_decomposition(aten.addcdiv)
  593. @pw_cast_for_opmath
  594. def addcdiv(self: Tensor, tensor1: Tensor, tensor2: Tensor, value: float = 1):
  595. return self + value * (tensor1 / tensor2)
  596. # Remove special case when https://github.com/pytorch/pytorch/pull/72949 is landed.
  597. @register_decomposition(aten.addcmul)
  598. @pw_cast_for_opmath
  599. def addcmul(self: Tensor, tensor1: Tensor, tensor2: Tensor, value: float = 1):
  600. if self.is_floating_point() or self.is_complex():
  601. return self + value * tensor1 * tensor2
  602. else:
  603. return self + int(value) * tensor1 * tensor2
  604. @register_decomposition(aten.rsub.Tensor)
  605. def rsub_Tensor(self: Tensor, other: Tensor, alpha: float = 1) -> Tensor:
  606. return torch.sub(other, self, alpha=alpha)
  607. @register_decomposition(aten.rsub.Scalar)
  608. def rsub_Scalar(self: Tensor, other: float, alpha: float = 1) -> Tensor:
  609. return torch.sub(other, self, alpha=alpha)
  610. @register_decomposition(aten.embedding)
  611. def embedding(
  612. weight: Tensor,
  613. indices: Tensor,
  614. padding_idx: int = -1,
  615. scale_grad_by_freq: bool = False,
  616. sparse: bool = False,
  617. ) -> Tensor:
  618. assert weight.dim() == 2, "'weight' must be 2-D"
  619. # TODO: Assert not ported over yet
  620. # auto indices_arg = TensorArg(indices, "indices", 1);
  621. # checkScalarTypes("embedding", indices_arg, {kLong, kInt});
  622. if indices.dim() == 1:
  623. return weight.index_select(0, indices)
  624. size = list(indices.shape)
  625. for d in weight.shape[1:]:
  626. size.append(d)
  627. return weight.index_select(0, indices.reshape(-1)).view(size)
  628. # TODO: Correct the type promotion semantics
  629. @register_decomposition(aten.embedding_dense_backward)
  630. def embedding_dense_backward(
  631. grad_output: Tensor,
  632. indices: Tensor,
  633. num_weights: int,
  634. padding_idx: int,
  635. scale_grad_by_freq: bool,
  636. ):
  637. numel = indices.numel()
  638. grad = grad_output.view(numel, grad_output.size(-1))
  639. grad_weight = grad_output.new_zeros((num_weights, grad_output.shape[-1]))
  640. indices_rank1 = indices.view(numel)
  641. if scale_grad_by_freq:
  642. counts = indices.new_zeros((num_weights,))
  643. ones = indices.new_ones((numel,))
  644. counts = counts.index_put([indices_rank1], ones, accumulate=True)
  645. grad_weights_scale = counts[indices_rank1]
  646. grad = grad / grad_weights_scale.unsqueeze(1)
  647. skip_padding = (indices_rank1 != padding_idx).unsqueeze(1)
  648. skip_padding = skip_padding.expand_as(grad)
  649. zero_grad = torch.full_like(grad, 0)
  650. return grad_weight.index_put(
  651. [indices_rank1], torch.where(skip_padding, grad, zero_grad), accumulate=True
  652. )
  653. def prod(x: List[int]):
  654. r = 1
  655. for i in x:
  656. r *= i
  657. return r
  658. @register_decomposition(aten.split_with_sizes)
  659. def split_with_sizes(
  660. self: Tensor, split_sizes: List[int], dim: int = 0
  661. ) -> List[Tensor]:
  662. num_splits = len(split_sizes)
  663. splits = []
  664. start_idx = 0
  665. for i in range(num_splits):
  666. length = split_sizes[i]
  667. splits.append(self.narrow(dim, start_idx, length))
  668. start_idx += length
  669. return splits
  670. @register_decomposition(aten.split.Tensor)
  671. def split(self: Tensor, split_size: int, dim: int = 0) -> List[Tensor]:
  672. input_sizes = self.shape
  673. dim_size = input_sizes[dim]
  674. if split_size == 0:
  675. assert dim_size == 0
  676. return [self]
  677. chunks = (dim_size + split_size - 1) // split_size
  678. split_sizes = [split_size for i in range(chunks)]
  679. split_sizes[chunks - 1] = split_size - (split_size * chunks - dim_size)
  680. return torch.split(self, split_sizes, dim)
  681. # TODO: this doesn't appear to have enough precision in bfloat16
  682. @register_decomposition(aten.addmm)
  683. @pw_cast_for_opmath
  684. def addmm(self: Tensor, mat1: Tensor, mat2: Tensor, beta: int = 1, alpha: int = 1):
  685. if not self.is_floating_point() and not self.is_complex():
  686. beta = int(beta)
  687. alpha = int(alpha)
  688. out = alpha * torch.mm(mat1, mat2)
  689. if beta == 0:
  690. return out
  691. return beta * self + out
  692. # TODO: Correct the type promotion semantics
  693. @register_decomposition(aten.native_layer_norm)
  694. @pw_cast_for_opmath
  695. def native_layer_norm(
  696. input: Tensor,
  697. normalized_shape: List[int],
  698. weight: Optional[Tensor],
  699. bias: Optional[Tensor],
  700. eps: float,
  701. ) -> Tuple[Tensor, Tensor, Tensor]:
  702. input_shape = input.shape
  703. input_ndim = input.dim()
  704. axis = input_ndim - len(normalized_shape)
  705. M = prod(input_shape[:axis]) # type: ignore[arg-type]
  706. # Hmm... not sure how I get around this...
  707. # Basically, native_batch_norm doesn't support 0-entry tensors, while
  708. # native_layer_norm does (and is tested by OpInfos!)
  709. if M > 0:
  710. input_reshaped = input.view(1, M, -1)
  711. else:
  712. return (input, input.new_zeros((0,)), input.new_zeros((0,)))
  713. # Unlike Batch Normalization, which applies scalar scale and bias for each
  714. # entire channel/plane with the affine option, Layer Normalization applies
  715. # per-element scale and bias. E.g. For input {N, C, H, W}, weight for
  716. # batchnorm has shape {C} while weight for layernorm has shape {H, W} or {W}.
  717. out, mean, rstd = aten.native_batch_norm(
  718. input_reshaped,
  719. weight=None,
  720. bias=None,
  721. running_mean=None,
  722. running_var=None,
  723. training=True,
  724. momentum=0.0,
  725. eps=eps,
  726. )
  727. out = out.view(input_shape)
  728. if weight is not None:
  729. out = out * weight
  730. if bias is not None:
  731. out = out + bias
  732. stat_shape = list(input_shape[:axis])
  733. for _ in range(axis, input.dim()):
  734. stat_shape.append(1)
  735. mean = mean.view(stat_shape)
  736. rstd = rstd.view(stat_shape)
  737. return (out, mean, rstd)
  738. # TODO: Correct the type promotion semantics
  739. @register_decomposition(aten.native_layer_norm_backward)
  740. @pw_cast_for_opmath
  741. def native_layer_norm_backward(
  742. grad_out: Tensor,
  743. input: Tensor,
  744. normalized_shape: List[int],
  745. mean: Tensor,
  746. rstd: Tensor,
  747. weight: Optional[Tensor],
  748. bias: Optional[Tensor],
  749. output_mask: List[bool],
  750. ) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]:
  751. input_shape = input.shape
  752. input_ndim = input.dim()
  753. axis = input_ndim - len(normalized_shape)
  754. inner_dims = input_shape[axis:]
  755. outer_dims = input_shape[:axis]
  756. inner_dim_indices: List[int] = []
  757. outer_dim_indices: List[int] = []
  758. for i in range(input_ndim):
  759. if i >= axis:
  760. inner_dim_indices.append(i)
  761. else:
  762. outer_dim_indices.append(i)
  763. N = prod(inner_dims) # type: ignore[arg-type]
  764. M = prod(outer_dims) # type: ignore[arg-type]
  765. if M <= 0 or N <= 0:
  766. return (
  767. input.new_zeros(input_shape),
  768. input.new_zeros(input_shape[axis:]),
  769. input.new_zeros(input_shape[axis:]),
  770. )
  771. x_hat = (input - mean) * rstd
  772. if weight is not None:
  773. grad_x_hat = grad_out * weight
  774. else:
  775. grad_x_hat = grad_out
  776. a = grad_x_hat * N
  777. b = torch.sum(grad_x_hat, inner_dim_indices, True)
  778. c1 = torch.mul(grad_x_hat, x_hat)
  779. c2 = torch.sum(c1, inner_dim_indices, True)
  780. c3 = torch.mul(x_hat, c2)
  781. inner = a - b - c3
  782. if output_mask[0]:
  783. d_input: Optional[Tensor] = (rstd / N) * inner
  784. else:
  785. d_input = None
  786. if output_mask[1] and weight is not None:
  787. if len(outer_dim_indices) > 0:
  788. d_weight: Optional[Tensor] = torch.sum(
  789. grad_out * x_hat, outer_dim_indices, False
  790. )
  791. else:
  792. d_weight = grad_out * x_hat
  793. else:
  794. d_weight = None
  795. if output_mask[2] and bias is not None:
  796. if len(outer_dim_indices) > 0:
  797. d_bias: Optional[Tensor] = torch.sum(grad_out, outer_dim_indices, False)
  798. else:
  799. d_bias = grad_out
  800. else:
  801. d_bias = None
  802. return (d_input, d_weight, d_bias)
  803. # TODO: Correct the type promotion semantics
  804. @register_decomposition(aten.native_batch_norm)
  805. @pw_cast_for_opmath
  806. def native_batch_norm(
  807. input: Tensor,
  808. weight: Optional[Tensor],
  809. bias: Optional[Tensor],
  810. running_mean: Optional[Tensor],
  811. running_var: Optional[Tensor],
  812. training: bool,
  813. momentum: float,
  814. eps: float,
  815. ) -> Tuple[Tensor, Tensor, Tensor]:
  816. reduction_dims = [0] + list(range(2, input.dim()))
  817. if training:
  818. # save_mean = torch.sum(input / (input.shape[0] * input.shape[2]), dim=reduction_dims)
  819. biased_var, save_mean = torch.var_mean(
  820. input, dim=reduction_dims, unbiased=False
  821. )
  822. save_invstd = 1 / (torch.sqrt(biased_var + eps))
  823. if running_mean is not None:
  824. running_mean.copy_(momentum * save_mean + (1 - momentum) * running_mean)
  825. if running_var is not None:
  826. n = input.numel() / input.shape[1]
  827. # This doesn't strictly match eager's numerics, which accumulates var sum and then directly applies the correction
  828. # But... that would require re-implementing var here, for negligible numerics gain on a tensor whose
  829. # numerics probably don't matter.
  830. unbiased_var = biased_var * (n / (n - 1))
  831. running_var.copy_(momentum * unbiased_var + (1 - momentum) * running_var)
  832. mean = save_mean
  833. invstd = save_invstd
  834. else:
  835. assert running_mean is not None and running_var is not None
  836. mean = running_mean
  837. invstd = 1 / (torch.sqrt(running_var + eps))
  838. # Very annoying inconsistency where CPU and CUDA give different shapes
  839. if input.device.type == "cuda":
  840. save_mean = running_mean
  841. save_invstd = invstd
  842. else:
  843. save_mean = input.new_zeros((0,))
  844. save_invstd = input.new_zeros((0,))
  845. if weight is None:
  846. weight = input.new_ones(())
  847. if bias is None:
  848. bias = input.new_zeros(())
  849. mean = _unsqueeze_to_dim(mean, input.dim() - 1)
  850. invstd = _unsqueeze_to_dim(invstd, input.dim() - 1)
  851. weight = _unsqueeze_to_dim(weight, input.dim() - 1)
  852. bias = _unsqueeze_to_dim(bias, input.dim() - 1)
  853. output = ((input - mean) * invstd) * weight + bias
  854. return output, save_mean, save_invstd
  855. @register_decomposition(aten.clamp_min)
  856. def clamp_min(self: Tensor, min: float):
  857. return torch.clamp(self, min=min)
  858. @register_decomposition(aten.clamp_max)
  859. def clamp_max(self: Tensor, max: float):
  860. return torch.clamp(self, max=max)
  861. @register_decomposition(aten._fused_dropout)
  862. @pw_cast_for_opmath
  863. def _fused_dropout_decomposition(input, p, generator=None):
  864. mask = (torch.rand_like(input) < p).to(dtype=torch.uint8)
  865. res = mask.type_as(input) * input * (1.0 / p)
  866. return (res, mask)
  867. # TODO: these logical decomps are buggy for complex inputs
  868. @register_decomposition(aten.logical_xor)
  869. def logical_xor(self: Tensor, other: Tensor) -> Tensor:
  870. return self.to(dtype=torch.bool) ^ other.to(dtype=torch.bool)
  871. @register_decomposition(aten.logical_not)
  872. def logical_not(self: Tensor) -> Tensor:
  873. return ~self.to(dtype=torch.bool)
  874. @register_decomposition(aten.xlogy.Tensor)
  875. @pw_cast_for_int_to_real
  876. def xlogy(self: Tensor, other: Tensor) -> Tensor:
  877. return aten.where(aten.isnan(self),
  878. self,
  879. aten.where(self == aten.new_zeros(self, ()),
  880. aten.new_zeros(self, ()),
  881. self * aten.log(other)))
  882. @register_decomposition(aten.var.correction)
  883. @reduction_complex_to_real
  884. def var_correction(
  885. x: Tensor,
  886. dims: Optional[List[int]],
  887. correction: Optional[int] = None,
  888. keepdim: bool = False,
  889. ):
  890. if dims is None:
  891. dims = []
  892. if x.is_complex():
  893. # For complex, calculate variance of real and imaginary components
  894. # separately then add to get overall variance.
  895. real_in = x.real
  896. var_real = torch.var(real_in, dims, correction=correction, keepdim=keepdim)
  897. imag_in = x.imag
  898. var_imag = torch.var(imag_in, dims, correction=correction, keepdim=keepdim)
  899. return var_real + var_imag
  900. if correction is None:
  901. correction = 0
  902. if len(dims) == 0:
  903. n = prod(x.shape) # type: ignore[arg-type]
  904. else:
  905. n = 1
  906. for dim in dims:
  907. n *= x.shape[dim]
  908. mean = torch.mean(x, dims, True)
  909. sub = x - mean
  910. sq = sub * sub
  911. sum = torch.sum(sq, dims, keepdim)
  912. if correction:
  913. n = n - correction
  914. return sum / n
  915. @register_decomposition(aten.std.correction)
  916. @reduction_complex_to_real
  917. def std_decomposition(
  918. x: Tensor, dims: List[int], correction: int = 0, keepdim: bool = False
  919. ):
  920. return torch.sqrt(torch.var(x, dims, correction=correction, keepdim=keepdim))
  921. # Questionable decompositions
  922. # This is only valid if we're running the graph without autograd, such as if the backward pass has been traced.
  923. # Note that this decomposition causes issues with in-place ops
  924. @register_decomposition(aten.detach, disable_meta=True)
  925. def detach_decomposition(x):
  926. return x
  927. @register_decomposition(aten.cudnn_batch_norm)
  928. def cudnn_batch_norm(
  929. input: Tensor,
  930. weight: Tensor,
  931. bias: Optional[Tensor],
  932. running_mean: Optional[Tensor],
  933. running_var: Optional[Tensor],
  934. training: bool,
  935. exponential_average_factor: float,
  936. epsilon: float,
  937. ):
  938. a, b, c = aten.native_batch_norm(
  939. input,
  940. weight,
  941. bias,
  942. running_mean,
  943. running_var,
  944. training,
  945. exponential_average_factor,
  946. epsilon,
  947. )
  948. # Cudnn return running mean and variance when training is True
  949. if training:
  950. return (a, b, c, input.new_zeros((0,), dtype=torch.uint8))
  951. return (a, input.new_zeros((0,)), input.new_zeros((0,)), input.new_zeros((0,), dtype=torch.uint8))
  952. @register_decomposition(aten.cudnn_batch_norm_backward)
  953. def cudnn_batch_norm_backward(
  954. input: Tensor,
  955. grad_output: Tensor,
  956. weight: Tensor,
  957. running_mean: Optional[Tensor],
  958. running_var: Optional[Tensor],
  959. save_mean: Optional[Tensor],
  960. save_var: Optional[Tensor],
  961. epsilon: float,
  962. reserveSpace: Tensor,
  963. ):
  964. return aten.native_batch_norm_backward(
  965. grad_output,
  966. input,
  967. weight,
  968. running_mean,
  969. running_var,
  970. save_mean,
  971. save_var,
  972. True,
  973. epsilon,
  974. [True, True, True],
  975. )
  976. @register_decomposition(aten.rot90.default)
  977. def rot90(self: Tensor, k: int = 1, dims: List[int] = [0, 1]) -> Tensor: # noqa: B006
  978. total_dims = self.dim()
  979. total_rot_dims = len(dims)
  980. assert total_rot_dims == 2, f"expected total rotation dims == 2, but got dims = {total_rot_dims}"
  981. assert total_dims >= 2, f"expected total dims >= 2, but got total dims = {total_dims}"
  982. assert dims[0] != dims[1] and abs(dims[0] - dims[1]) != total_dims,\
  983. f"expected rotation dims to be different, but got dim0 = {dims[0]} and dim1 = {dims[1]}"
  984. assert dims[0] < total_dims and dims[0] >= -total_dims, f"Rotation dim0 out of range, dim0 = {dims[0]}"
  985. assert dims[1] < total_dims and dims[1] >= -total_dims, f"Rotation dim1 out of range, dim1 = {dims[1]}"
  986. k = k % 4
  987. if k == 1:
  988. return self.flip(dims[1]).transpose(dims[0], dims[1])
  989. elif k == 2:
  990. return self.flip(dims)
  991. elif k == 3:
  992. return self.flip(dims[0]).transpose(dims[0], dims[1])
  993. else:
  994. return self.clone(memory_format=torch.contiguous_format)
  995. @register_decomposition(aten.transpose.int)
  996. def transpose_int(self: Tensor, dim0: int, dim1: int) -> Tensor:
  997. dim0, dim1 = utils.canonicalize_dims(self.dim(), (dim0, dim1)) # type: ignore[misc]
  998. if self.dim() <= 1:
  999. return self
  1000. if dim0 == dim1:
  1001. return self
  1002. perm = list(range(self.dim()))
  1003. perm[dim0], perm[dim1] = perm[dim1], perm[dim0]
  1004. return torch.permute(self, perm)
  1005. @register_decomposition(aten.t.default)
  1006. def t(self: Tensor) -> Tensor:
  1007. return self.transpose(0, 0 if self.dim() < 2 else 1)
  1008. def check_stack_inputs(tensors: List[Tensor]):
  1009. entry_shape = tensors[0].shape
  1010. for i in range(1, len(tensors)):
  1011. assert tensors[i].shape == entry_shape, (f"stack expects each tensor to be equal size, but got {entry_shape} at entry 0"
  1012. f"and {tensors[i].shape} at entry {i}")
  1013. def get_stack_inputs(tensors: List[Tensor], dim: int):
  1014. check_stack_inputs(tensors)
  1015. return [t.unsqueeze(dim) for t in tensors]
  1016. @register_decomposition(aten.stack.default)
  1017. def stack(tensors: List[Tensor], dim: int = 0) -> Tensor:
  1018. assert len(tensors) > 0, "stack expects a non-empty TensorList"
  1019. wrapped_dim = utils.canonicalize_dim(tensors[0].dim() + 1, dim)
  1020. if wrapped_dim < tensors[0].dim() and not tensors[0].is_sparse:
  1021. check_stack_inputs(tensors)
  1022. result_sizes = list(tensors[0].shape)
  1023. result_sizes.insert(wrapped_dim, len(tensors))
  1024. out = torch.cat(tensors, wrapped_dim)
  1025. return out.view(result_sizes)
  1026. else:
  1027. return torch.cat(get_stack_inputs(tensors, wrapped_dim), dim)
  1028. def _squeeze_multiple(self: Tensor, dims: List[int]) -> Tensor:
  1029. ndim = self.dim()
  1030. wrapped_dims = utils.canonicalize_dims(ndim, dims)
  1031. assert isinstance(wrapped_dims, tuple)
  1032. for idx in range(ndim - 1, -1, -1):
  1033. if idx in wrapped_dims:
  1034. self = self.squeeze(idx)
  1035. return self
  1036. @register_decomposition(aten.logsumexp.default)
  1037. @pw_cast_for_int_to_real
  1038. def logsumexp(self: Tensor, dim: List[int], keepdim: bool = False) -> Tensor:
  1039. if self.numel() == 0:
  1040. return torch.sum(torch.exp(self), dim, keepdim).log()
  1041. maxes = torch.amax(self, dim, keepdim=True)
  1042. maxes_squeezed = maxes if keepdim else _squeeze_multiple(maxes, dim)
  1043. maxes_squeezed = torch.masked_fill(maxes_squeezed, maxes_squeezed.abs() == float('inf'), 0)
  1044. result = torch.sum(torch.exp(self - maxes), dim, keepdim)
  1045. return result.log().add(maxes_squeezed)
  1046. @register_decomposition(aten.trace.default)
  1047. def trace(self: Tensor) -> Tensor:
  1048. return torch.sum(torch.diag(self))
  1049. # nb: Should use acc_t, not op_math
  1050. @register_decomposition(aten.log_sigmoid_forward)
  1051. @out_wrapper_multi('output', 'buffer')
  1052. @pw_cast_for_opmath
  1053. def log_sigmoid_forward(self: Tensor) -> Tuple[Tensor, Tensor]:
  1054. min = torch.minimum(self.new_zeros(()), self)
  1055. z = torch.exp(-torch.abs(self))
  1056. if self.is_cuda:
  1057. buffer = self.new_zeros((0,))
  1058. else:
  1059. buffer = z
  1060. return min - torch.log1p(z), buffer