| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291 |
- import torch
- from torch import Tensor
- from torch._decomp import register_decomposition
- from enum import Enum
- from typing import Tuple, Optional, List, Callable
- import torch.nn.functional as F
- import functools
- from torch.utils._pytree import tree_map, tree_flatten
- import torch._prims.utils as utils
- from torch._prims.wrappers import out_wrapper_multi
- # None of these functions are publicly accessible; get at them
- # from torch._decomps
- __all__: List[str] = []
- aten = torch.ops.aten
- class Reduction(Enum):
- NONE = 0
- MEAN = 1
- SUM = 2
- # This wraps a decomposition and performs various type promotion logic within it, depending on the strategy provided
- # We're currently re-using ELEMENTWISE_TYPE_PROMOTION_KIND, although some of the usages are on non-elementwise ops
- # Will need to validate the non-elementwise uses
- def type_casts(f: Callable, type_promotion: utils.ELEMENTWISE_TYPE_PROMOTION_KIND):
- @functools.wraps(f)
- def inner(*args, **kwargs):
- flat_args = [x for x in tree_flatten((args, kwargs))[0] if isinstance(x, Tensor)]
- computation_dtype, result_dtype = utils.elementwise_dtypes(*flat_args,
- type_promotion_kind=type_promotion)
- # TODO: pretty sure this is not quite right
- def increase_prec(x):
- if isinstance(x, Tensor):
- return x.to(computation_dtype)
- else:
- return x
- def decrease_prec(x):
- if isinstance(x, Tensor):
- return x.to(result_dtype)
- else:
- return x
- r = f(*tree_map(increase_prec, args), **tree_map(increase_prec, kwargs))
- return tree_map(decrease_prec, r)
- return inner
- pw_cast_for_opmath = functools.partial(type_casts, type_promotion=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT)
- reduction_complex_to_real = functools.partial(type_casts, type_promotion=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT)
- pw_cast_for_int_to_real = functools.partial(type_casts, type_promotion=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
- # This expands x until x.dim() == dim. Might be useful as an operator
- def _unsqueeze_to_dim(x: Tensor, dim: int):
- for _ in range(dim - x.dim()):
- x = x.unsqueeze(-1)
- return x
- @register_decomposition(aten.tanh_backward)
- @pw_cast_for_opmath
- def tanh_backward(out_grad: Tensor, y: Tensor):
- return out_grad * (1 - y * y).conj_physical()
- @register_decomposition(aten.sigmoid_backward)
- @pw_cast_for_opmath
- def sigmoid_backward(out_grad: Tensor, y: Tensor):
- return out_grad * (y * (1 - y)).conj_physical()
- @register_decomposition(aten.softplus_backward)
- @pw_cast_for_opmath
- def softplus_backward(out_grad: Tensor, x: Tensor, beta: float, threshold: float):
- z = (x * beta).exp()
- return torch.where((x * beta) > threshold, out_grad, out_grad * z / (z + 1.0))
- @register_decomposition(aten.elu)
- @pw_cast_for_opmath
- def elu(
- self: Tensor, alpha: float = 1, scale: float = 1, input_scale: float = 1
- ) -> Tensor:
- negcoef = alpha * scale
- poscoef = scale
- negiptcoef = input_scale
- return torch.where(
- self > 0, self * poscoef, (torch.exp(self * negiptcoef) - 1) * negcoef
- )
- @register_decomposition(aten.elu_backward)
- @pw_cast_for_opmath
- def elu_backward(
- grad_output: Tensor,
- alpha: float,
- scale: float,
- input_scale: float,
- is_result: bool,
- self_or_result: Tensor,
- ):
- negcoef = alpha * scale
- poscoef = scale
- negiptcoef = input_scale
- if is_result:
- return torch.where(
- self_or_result <= 0,
- grad_output * negiptcoef * (self_or_result + negcoef),
- self_or_result * poscoef,
- )
- else:
- return torch.where(
- self_or_result <= 0,
- grad_output * negiptcoef * negcoef * torch.exp(self_or_result * negiptcoef),
- grad_output * poscoef,
- )
- @register_decomposition(aten.hardsigmoid)
- @pw_cast_for_opmath
- def hardsigmoid(self: Tensor) -> Tensor:
- return torch.clamp(torch.clamp(self + 3, min=0), max=6) / 6
- @register_decomposition(aten.hardsigmoid_backward)
- @pw_cast_for_opmath
- def hardsigmoid_backward(grad_output: Tensor, self: Tensor):
- return torch.where(
- (self > -3.0) & (self < 3.0),
- grad_output * (1.0 / 6.0),
- grad_output.new_zeros(()),
- )
- @register_decomposition(aten.hardtanh)
- @pw_cast_for_opmath
- def hardtanh(self: Tensor, min_val: float = -1, max_val: float = 1) -> Tensor:
- return torch.clamp(self, min_val, max_val)
- @register_decomposition(aten.hardtanh_backward)
- @pw_cast_for_opmath
- def hardtanh_backward(
- grad_output: Tensor, self: Tensor, min_val: float, max_val: float
- ):
- return torch.where(
- (self <= min_val) | (self >= max_val), grad_output.new_zeros(()), grad_output
- )
- @register_decomposition(aten.hardshrink_backward)
- @pw_cast_for_opmath
- def hardshrink_backward(grad_out: Tensor, self: Tensor, lambd: float):
- return torch.where(
- (self >= -lambd) & (self <= lambd), grad_out.new_zeros(()), grad_out
- )
- @register_decomposition(aten.hardswish)
- @pw_cast_for_opmath
- def hardswish(self: Tensor) -> Tensor:
- return self * torch.clamp(torch.clamp(self + 3, min=0), max=6) / 6
- @register_decomposition(aten.hardswish_backward)
- @pw_cast_for_opmath
- def hardswish_backward(grad_output: Tensor, self: Tensor) -> Tensor:
- return torch.where(
- self < -3,
- grad_output.new_zeros(()),
- torch.where(self <= 3, grad_output * ((self / 3) + 0.5), grad_output),
- )
- @register_decomposition(aten.threshold_backward)
- @pw_cast_for_opmath
- def threshold_backward(grad_output: Tensor, self: Tensor, threshold: float):
- return torch.where(self <= threshold, grad_output.new_zeros(()), grad_output)
- @register_decomposition(aten.leaky_relu)
- @pw_cast_for_opmath
- def leaky_relu(self: Tensor, negative_slope: float = 0.01) -> Tensor:
- return torch.where(self > 0, self, self * negative_slope)
- @register_decomposition(aten.leaky_relu_backward)
- @pw_cast_for_opmath
- def leaky_relu_backward(
- grad_output: Tensor, self: Tensor, negative_slope: float, self_is_result: bool
- ):
- return torch.where(self > 0, grad_output, grad_output * negative_slope)
- @register_decomposition(aten.gelu)
- @pw_cast_for_opmath
- def gelu(self: Tensor, approximate: str = 'none') -> Tensor:
- M_SQRT2 = 1.41421356237309504880
- M_SQRT1_2 = 0.70710678118654752440
- M_2_SQRTPI = 1.12837916709551257390
- if approximate == 'tanh':
- kBeta = M_SQRT2 * M_2_SQRTPI * 0.5
- kKappa = 0.044715
- x_cube = self * self * self
- inner = kBeta * (self + kKappa * x_cube)
- return 0.5 * self * (1 + torch.tanh(inner))
- else:
- kAlpha = M_SQRT1_2
- return self * 0.5 * (1 + torch.erf(self * kAlpha))
- @register_decomposition(aten.gelu_backward)
- @pw_cast_for_opmath
- def gelu_backward(grad: Tensor, self: Tensor, approximate: str = "none"):
- M_SQRT2 = 1.41421356237309504880
- M_SQRT1_2 = 0.70710678118654752440
- M_2_SQRTPI = 1.12837916709551257390
- if approximate == 'tanh':
- kBeta = M_SQRT2 * M_2_SQRTPI * 0.5
- kKappa = 0.044715
- x_sq = self * self
- x_cube = x_sq * self
- inner = kBeta * (self + kKappa * x_cube)
- tanh_inner = torch.tanh(inner)
- left = 0.5 * self
- right = 1 + tanh_inner
- left_derivative = 0.5 * right
- tanh_derivative = 1 - tanh_inner * tanh_inner
- inner_derivative = kBeta * (1 + 3 * kKappa * x_sq)
- right_derivative = left * tanh_derivative * inner_derivative
- return grad * (left_derivative + right_derivative)
- else:
- kAlpha = M_SQRT1_2
- kBeta = M_2_SQRTPI * M_SQRT1_2 * 0.5
- cdf = 0.5 * (1 + torch.erf(self * kAlpha))
- pdf = kBeta * torch.exp(self * self * -0.5)
- return grad * (cdf + self * pdf)
- @register_decomposition(aten.mish_backward)
- @pw_cast_for_opmath
- def mish_backward(grad_output: Tensor, input: Tensor):
- input_tanh_softplus = torch.tanh(F.softplus(input))
- input_sigmoid = torch.sigmoid(input)
- out = input * input_sigmoid * (1 - input_tanh_softplus * input_tanh_softplus)
- return grad_output * (input_tanh_softplus + out)
- @register_decomposition(aten.silu)
- @pw_cast_for_opmath
- def silu(self: Tensor) -> Tensor:
- return self * torch.sigmoid(self)
- @register_decomposition(aten.silu_backward)
- @pw_cast_for_opmath
- def silu_backward(grad_output: Tensor, self: Tensor) -> Tensor:
- sigmoid = 1 / (1 + torch.exp(-self))
- return grad_output * sigmoid * (1 + self * (1 - sigmoid))
- @register_decomposition(aten.softshrink_backward)
- def softshrink_backward(grad_output: Tensor, self: Tensor, lambd: float) -> Tensor:
- return torch.where(
- (self >= -lambd) & (self <= lambd), grad_output.new_zeros(()), grad_output
- )
- @register_decomposition(aten.prelu_backward)
- @pw_cast_for_opmath
- def prelu_backward(
- grad_output: Tensor, self: Tensor, weight: Tensor
- ) -> Tuple[Tensor, Tensor]:
- # Logic is more complicated than I would like. Basically, weight can either
- # be a scalar or a vector of size [C], and in the forward pass it's
- # broadcast against [N, C, ...]. So now, we need to do the corresponding
- # reduction, which is harder than we'd like...
- cur_weight = weight
- for _ in range(2, grad_output.dim()):
- cur_weight = cur_weight.unsqueeze(-1)
- input_grad = torch.where(self > 0, grad_output, cur_weight * grad_output)
- weight_grad_collector = torch.where(
- self > 0, grad_output.new_zeros(()), self * grad_output
- )
- out = weight_grad_collector.sum_to_size(cur_weight.shape)
- while out.dim() > weight.dim():
- out = out.squeeze(-1)
- return (input_grad, out)
- @register_decomposition(aten.rrelu_with_noise_backward)
- @pw_cast_for_opmath
- def rrelu_with_noise_backward(
- grad_output: Tensor,
- self: Tensor,
- noise: Tensor,
- lower: float,
- upper: float,
- training: bool,
- self_is_result: bool,
- ) -> Tensor:
- if training and upper - lower > 1e-6:
- return grad_output.mul(noise)
- else:
- negative_slope = (lower + upper) / 2
- return aten.leaky_relu_backward(grad_output, self, negative_slope, self_is_result)
- @register_decomposition(aten.log_sigmoid_backward)
- @pw_cast_for_opmath
- def log_sigmoid_backward(grad_output: Tensor, self: Tensor, buffer: Tensor) -> Tensor:
- in_negative = self < 0
- max_deriv = torch.where(in_negative, 1, 0)
- sign = torch.where(in_negative, 1, -1)
- z = torch.exp(-torch.abs(self))
- return grad_output * (max_deriv - sign * (z / (1 + z)))
- # CPU has a special formula that uses buffer, but disabled for convenience sake
- # return (max_deriv - sign * (buffer / (1 + buffer))) * grad_output
- def apply_loss_reduction(loss: Tensor, reduction: int):
- if reduction == Reduction.MEAN.value:
- return torch.mean(loss)
- elif reduction == Reduction.SUM.value:
- return torch.sum(loss)
- else:
- return loss
- def to_real_dtype(dtype: torch.dtype):
- if dtype == torch.complex32:
- return torch.float16
- elif dtype == torch.complex64:
- return torch.float32
- elif dtype == torch.complex128:
- return torch.float64
- # TODO: None of these loss castings are quite correct, see
- # https://github.com/pytorch/pytorch/issues/76870. Also, the ATen kernels
- # perform the pointwise portion in opmath, but don't maintain it between the
- # pointwise portion and the reduction
- @register_decomposition(aten.l1_loss)
- def l1_loss(
- self: Tensor, target: Tensor, reduction: int = Reduction.MEAN.value
- ) -> Tensor:
- loss = (self - target).abs()
- # PyTorch semantics result in the output of l1_loss having the corresponding
- # real dtype to self. This may not happen without explicit casting if say
- # self: complex64 and target: float64, which results in loss: float64
- float_type = to_real_dtype(self.dtype)
- return apply_loss_reduction(loss, reduction).to(float_type)
- @register_decomposition(aten.l1_loss_backward)
- @pw_cast_for_opmath
- def l1_loss_backward(
- grad_output: Tensor,
- self: Tensor,
- target: Tensor,
- reduction: int = Reduction.MEAN.value,
- ):
- sign = torch.sign(self - target)
- norm = sign / self.numel() if reduction == Reduction.MEAN.value else sign
- return grad_output * norm
- @register_decomposition(aten.mse_loss)
- @pw_cast_for_opmath
- def mse_loss(
- self: Tensor, target: Tensor, reduction: int = Reduction.MEAN.value
- ) -> Tensor:
- loss = (self - target) ** 2
- return apply_loss_reduction(loss, reduction)
- @register_decomposition(aten.mse_loss_backward)
- @pw_cast_for_opmath
- def mse_loss_backward(
- grad_output: Tensor, input: Tensor, target: Tensor, reduction: int
- ):
- norm = 2.0 / input.numel() if reduction == Reduction.MEAN.value else 2.0
- return norm * (input - target) * grad_output
- @register_decomposition(aten.huber_loss)
- @pw_cast_for_opmath
- def huber_loss(
- self: Tensor,
- target: Tensor,
- reduction: int = Reduction.MEAN.value,
- delta: float = 1.0,
- ) -> Tensor:
- assert delta > 0, "huber_loss does not support non-positive values for delta."
- z = (self - target).abs()
- loss = torch.where(z < delta, 0.5 * z * z, delta * (z - 0.5 * delta))
- return apply_loss_reduction(loss, reduction)
- @register_decomposition(aten.huber_loss_backward)
- @pw_cast_for_opmath
- def huber_loss_backward(
- grad_output: Tensor, self: Tensor, target: Tensor, reduction: int, delta: float
- ):
- norm = 1.0 / self.numel() if reduction == Reduction.MEAN.value else 1.0
- x = self - target
- return torch.where(
- x < -delta,
- -norm * grad_output * delta,
- torch.where(x > delta, norm * grad_output * delta, norm * x * grad_output),
- )
- def _nll_loss_backward(
- grad_output: Tensor,
- self: Tensor,
- target: Tensor,
- weight: Optional[Tensor],
- reduction: int,
- ignore_index: int,
- total_weight: Tensor,
- ) -> Tensor:
- channel_dim = 0 if self.dim() < 2 else 1
- if reduction == Reduction.MEAN.value:
- grad_output = grad_output / total_weight
- target = target.unsqueeze(channel_dim)
- grad_input = torch.zeros_like(self)
- grad_input = torch.scatter(grad_input, channel_dim, target, -1.0)
- if grad_input.dim() > grad_output.dim() > 0:
- grad_output = grad_output.unsqueeze(channel_dim)
- if weight is not None:
- new_shape = [1 for _ in range(self.dim())]
- new_shape[channel_dim] = weight.shape[0]
- weight = weight.reshape(new_shape)
- grad_output = grad_output * weight
- has_ignore_index = ignore_index >= 0
- if has_ignore_index:
- ignore_index_mask = target != ignore_index
- grad_output = grad_output * ignore_index_mask
- return grad_input * grad_output
- @register_decomposition(aten.nll_loss_backward)
- def nll_loss_backward(
- grad_output: Tensor,
- self: Tensor,
- target: Tensor,
- weight: Optional[Tensor],
- reduction: int,
- ignore_index: int,
- total_weight: Tensor,
- ) -> Tensor:
- assert 0 <= self.dim() <= 2, "input tensor should be 1D or 2D"
- assert (
- target.dim() <= 1
- ), "0D or 1D target tensor expected, multi-target not supported"
- no_batch_dim = self.dim() == 1 and target.dim() == 0
- assert no_batch_dim or (
- self.shape[0] == target.shape[0]
- ), f"size mismatch (got input: {self.shape}, target: {target.shape})"
- assert total_weight.numel() == 1, (
- "expected total_weight to be a single element tensor, got: ",
- f"{total_weight.shape} ({total_weight.numel()} elements)",
- )
- assert (
- weight is None or weight.numel() == self.shape[-1]
- ), "weight tensor should be defined either for all or no classes"
- if reduction == Reduction.NONE.value and self.dim() == 2:
- assert grad_output.dim() == 1 and grad_output.shape[0] == self.shape[0], (
- f"Expected a tensor of dimension 1 and tensor.size[0] == {self.shape[0]} but "
- f"got: dimension {grad_output.dim()} and tensor.size[0] == {grad_output.shape[0]}"
- )
- else:
- assert (
- grad_output.dim() <= 1 and grad_output.numel() == 1
- ), f"Expected a single element grad_output tensor, but got: {grad_output.shape}"
- return _nll_loss_backward(grad_output, self, target, weight, reduction, ignore_index, total_weight)
- @register_decomposition(aten.nll_loss2d_backward)
- def nll_loss2d_backward(
- grad_output: Tensor,
- self: Tensor,
- target: Tensor,
- weight: Optional[Tensor],
- reduction: int,
- ignore_index: int,
- total_weight: Tensor,
- ) -> Tensor:
- assert (
- self.dim() == 4
- ), f"only batches of spatial inputs supported (4D tensors), but got input of dimension: {self.dim()}"
- assert (
- target.dim() == 3
- ), f"only batches of spatial targets supported (3D tensors) but got targets of dimension: {target.dim()}"
- assert(
- self.shape[0] == target.shape[0] and self.shape[2] == target.shape[1] and self.shape[3] == target.shape[2]
- ), f"size mismatch (got input: {self.shape}, target: {target.shape}"
- assert (
- total_weight.numel() == 1
- ), (
- "expected total_weight to be a single element tensor, "
- f"got: {total_weight.shape} ( {total_weight.numel()}, elements)"
- )
- return _nll_loss_backward(grad_output, self, target, weight, reduction, ignore_index, total_weight)
- @register_decomposition(aten.binary_cross_entropy)
- @pw_cast_for_opmath
- def binary_cross_entropy(
- self: Tensor,
- target: Tensor,
- weight: Optional[Tensor] = None,
- reduction: int = Reduction.MEAN.value,
- ) -> Tensor:
- # We cannot currently model this without introducing data-dependent control flow
- # TORCH_CHECK(
- # (input_val >= 0) && (input_val <= 1),
- # "all elements of input should be between 0 and 1"
- # )
- loss = (target - 1) * torch.maximum(
- torch.log(1 - self), self.new_full((), -100)
- ) - target * torch.maximum(torch.log(self), self.new_full((), -100))
- if weight is not None:
- loss = loss * weight
- return apply_loss_reduction(loss, reduction)
- @register_decomposition(aten.binary_cross_entropy_backward)
- @pw_cast_for_opmath
- def binary_cross_entropy_backward(
- grad_output: Tensor,
- self: Tensor,
- target: Tensor,
- weight: Optional[Tensor] = None,
- reduction: int = Reduction.MEAN.value,
- ) -> Tensor:
- EPSILON = 1e-12
- result = grad_output * (self - target) / torch.clamp(self * (1 - self), min=EPSILON)
- if weight is not None:
- result = result * weight
- if reduction == Reduction.MEAN.value:
- result = result / self.numel()
- return result
- @register_decomposition(aten._euclidean_dist)
- def _euclidean_dist(x1: Tensor, x2: Tensor) -> Tensor:
- x1_norm = x1.pow(2).sum(-1, True)
- x1_pad = torch.ones_like(x1_norm, memory_format=torch.contiguous_format)
- x2_norm = x2.pow(2).sum(-1, True)
- x2_pad = torch.ones_like(x2_norm, memory_format=torch.contiguous_format)
- x1_ = torch.cat([x1.mul(-2), x1_norm, x1_pad], -1)
- x2_ = torch.cat([x2, x2_pad, x2_norm], -1)
- result = x1_.matmul(x2_.mT)
- return result.clamp_min(0).sqrt()
- @register_decomposition(aten.slice_backward)
- def slice_backward(
- grad_output: Tensor,
- input_sizes: List[int],
- dim: int,
- start: int,
- end: int,
- step: int,
- ):
- grad_input = grad_output.new_zeros(input_sizes)
- return torch.slice_scatter(grad_input, grad_output, dim, start, end, step)
- @register_decomposition(aten.select_backward)
- def select_backward(grad_output: Tensor, input_sizes: List[int], dim: int, index: int):
- grad_input = grad_output.new_zeros(input_sizes)
- return torch.select_scatter(grad_input, grad_output, dim, index)
- @register_decomposition(aten.diagonal_backward)
- def diagonal_backward(
- grad_output: Tensor, input_sizes: List[int], offset: int, dim1: int, dim2: int
- ):
- grad_input = grad_output.new_zeros(input_sizes)
- return torch.diagonal_scatter(grad_input, grad_output, offset, dim1, dim2)
- @register_decomposition(aten._softmax_backward_data)
- @pw_cast_for_opmath
- def _softmax_backward_data(
- grad_output: Tensor, output: Tensor, dim: int, input_dtype: int
- ):
- new_grad = grad_output * output
- return new_grad - output * torch.sum(new_grad, dim=dim, keepdim=True)
- @register_decomposition(aten._log_softmax_backward_data)
- @pw_cast_for_opmath
- def _log_softmax_backward_data(
- grad_output: Tensor, output: Tensor, dim: int, input_dtype: int
- ):
- grad_input = grad_output - torch.exp(output) * torch.sum(
- grad_output, dim=dim, keepdim=True
- )
- return grad_input
- # TODO: the type annotations on arguments are not quite right
- @register_decomposition(aten.im2col_backward)
- def im2col_backward(
- grad_output: Tensor,
- input_size: List[int],
- kernel_size: List[int],
- dilation: List[int],
- padding: List[int],
- stride: List[int],
- ) -> Tensor:
- return F.fold(grad_output, input_size, kernel_size, dilation, padding, stride) # type: ignore[arg-type]
- @register_decomposition(aten.col2im_backward)
- def col2im_backward(
- grad_output: Tensor,
- kernel_size: List[int],
- dilation: List[int],
- padding: List[int],
- stride: List[int],
- ) -> Tensor:
- return F.unfold(grad_output, kernel_size, dilation, padding, stride) # type: ignore[arg-type]
- @register_decomposition(aten.masked_fill.Scalar)
- def masked_fill_Scalar(self: Tensor, mask: Tensor, value: float) -> Tensor:
- return torch.where(mask, utils.dtype_to_type(self.dtype)(value), self)
- @register_decomposition(aten.masked_fill.Tensor)
- def masked_fill_Tensor(self: Tensor, mask: Tensor, value: Tensor) -> Tensor:
- return torch.where(mask, value, self)
- @register_decomposition(aten.native_dropout_backward)
- @pw_cast_for_opmath
- def native_dropout_backward(grad_output: Tensor, mask: Tensor, scale: float):
- return grad_output * (mask.type_as(grad_output) * scale)
- @register_decomposition(aten.logit)
- @pw_cast_for_int_to_real
- def logit(self: Tensor, eps: Optional[float] = None) -> Tensor:
- if eps is None:
- eps = -1.0
- lo = eps
- hi = 1 - eps
- self = torch.clamp(self, lo, hi)
- return (self / (1 - self)).log()
- @register_decomposition(aten.logit_backward)
- @pw_cast_for_opmath
- def logit_backward(
- grad_output: Tensor, self: Tensor, eps: Optional[float] = None
- ) -> Tensor:
- if eps is not None:
- lo = eps
- hi = 1.0 - lo
- return torch.where(
- torch.logical_and(self >= lo, self <= hi),
- grad_output / (self * (1.0 - self)),
- self.new_zeros(()),
- )
- else:
- return torch.where(
- torch.logical_and(self >= 0.0, self <= 1.0),
- grad_output / (self * (1.0 - self)),
- self.new_full((), float("nan")),
- )
- @register_decomposition(aten.native_dropout)
- @pw_cast_for_opmath
- def native_dropout(input: Tensor, p: float, train: Optional[bool]):
- if train:
- bool_mask = torch.rand_like(input) < p
- res = bool_mask * input * float(1.0 / p)
- return (res, bool_mask)
- else:
- return (input, torch.ones_like(input, dtype=torch.bool))
- # TODO: Correct the type promotion semantics
- @register_decomposition(aten._softmax)
- @pw_cast_for_opmath
- def _softmax(x: Tensor, dim: int, half_to_float: bool):
- x_max = torch.max(x, dim, keepdim=True)[0]
- unnormalized = torch.exp(x - x_max)
- return unnormalized / torch.sum(unnormalized, dim, keepdim=True)
- # TODO: Correct the type promotion semantics
- @register_decomposition(aten._log_softmax)
- @pw_cast_for_opmath
- def _log_softmax(x: Tensor, dim: int, half_to_float: bool):
- x_max = torch.max(x, dim, keepdim=True)[0]
- shifted = x - x_max
- shifted_logsumexp = torch.log(torch.sum(torch.exp(shifted), dim, keepdim=True))
- return shifted - shifted_logsumexp
- @register_decomposition(aten.addcdiv)
- @pw_cast_for_opmath
- def addcdiv(self: Tensor, tensor1: Tensor, tensor2: Tensor, value: float = 1):
- return self + value * (tensor1 / tensor2)
- # Remove special case when https://github.com/pytorch/pytorch/pull/72949 is landed.
- @register_decomposition(aten.addcmul)
- @pw_cast_for_opmath
- def addcmul(self: Tensor, tensor1: Tensor, tensor2: Tensor, value: float = 1):
- if self.is_floating_point() or self.is_complex():
- return self + value * tensor1 * tensor2
- else:
- return self + int(value) * tensor1 * tensor2
- @register_decomposition(aten.rsub.Tensor)
- def rsub_Tensor(self: Tensor, other: Tensor, alpha: float = 1) -> Tensor:
- return torch.sub(other, self, alpha=alpha)
- @register_decomposition(aten.rsub.Scalar)
- def rsub_Scalar(self: Tensor, other: float, alpha: float = 1) -> Tensor:
- return torch.sub(other, self, alpha=alpha)
- @register_decomposition(aten.embedding)
- def embedding(
- weight: Tensor,
- indices: Tensor,
- padding_idx: int = -1,
- scale_grad_by_freq: bool = False,
- sparse: bool = False,
- ) -> Tensor:
- assert weight.dim() == 2, "'weight' must be 2-D"
- # TODO: Assert not ported over yet
- # auto indices_arg = TensorArg(indices, "indices", 1);
- # checkScalarTypes("embedding", indices_arg, {kLong, kInt});
- if indices.dim() == 1:
- return weight.index_select(0, indices)
- size = list(indices.shape)
- for d in weight.shape[1:]:
- size.append(d)
- return weight.index_select(0, indices.reshape(-1)).view(size)
- # TODO: Correct the type promotion semantics
- @register_decomposition(aten.embedding_dense_backward)
- def embedding_dense_backward(
- grad_output: Tensor,
- indices: Tensor,
- num_weights: int,
- padding_idx: int,
- scale_grad_by_freq: bool,
- ):
- numel = indices.numel()
- grad = grad_output.view(numel, grad_output.size(-1))
- grad_weight = grad_output.new_zeros((num_weights, grad_output.shape[-1]))
- indices_rank1 = indices.view(numel)
- if scale_grad_by_freq:
- counts = indices.new_zeros((num_weights,))
- ones = indices.new_ones((numel,))
- counts = counts.index_put([indices_rank1], ones, accumulate=True)
- grad_weights_scale = counts[indices_rank1]
- grad = grad / grad_weights_scale.unsqueeze(1)
- skip_padding = (indices_rank1 != padding_idx).unsqueeze(1)
- skip_padding = skip_padding.expand_as(grad)
- zero_grad = torch.full_like(grad, 0)
- return grad_weight.index_put(
- [indices_rank1], torch.where(skip_padding, grad, zero_grad), accumulate=True
- )
- def prod(x: List[int]):
- r = 1
- for i in x:
- r *= i
- return r
- @register_decomposition(aten.split_with_sizes)
- def split_with_sizes(
- self: Tensor, split_sizes: List[int], dim: int = 0
- ) -> List[Tensor]:
- num_splits = len(split_sizes)
- splits = []
- start_idx = 0
- for i in range(num_splits):
- length = split_sizes[i]
- splits.append(self.narrow(dim, start_idx, length))
- start_idx += length
- return splits
- @register_decomposition(aten.split.Tensor)
- def split(self: Tensor, split_size: int, dim: int = 0) -> List[Tensor]:
- input_sizes = self.shape
- dim_size = input_sizes[dim]
- if split_size == 0:
- assert dim_size == 0
- return [self]
- chunks = (dim_size + split_size - 1) // split_size
- split_sizes = [split_size for i in range(chunks)]
- split_sizes[chunks - 1] = split_size - (split_size * chunks - dim_size)
- return torch.split(self, split_sizes, dim)
- # TODO: this doesn't appear to have enough precision in bfloat16
- @register_decomposition(aten.addmm)
- @pw_cast_for_opmath
- def addmm(self: Tensor, mat1: Tensor, mat2: Tensor, beta: int = 1, alpha: int = 1):
- if not self.is_floating_point() and not self.is_complex():
- beta = int(beta)
- alpha = int(alpha)
- out = alpha * torch.mm(mat1, mat2)
- if beta == 0:
- return out
- return beta * self + out
- # TODO: Correct the type promotion semantics
- @register_decomposition(aten.native_layer_norm)
- @pw_cast_for_opmath
- def native_layer_norm(
- input: Tensor,
- normalized_shape: List[int],
- weight: Optional[Tensor],
- bias: Optional[Tensor],
- eps: float,
- ) -> Tuple[Tensor, Tensor, Tensor]:
- input_shape = input.shape
- input_ndim = input.dim()
- axis = input_ndim - len(normalized_shape)
- M = prod(input_shape[:axis]) # type: ignore[arg-type]
- # Hmm... not sure how I get around this...
- # Basically, native_batch_norm doesn't support 0-entry tensors, while
- # native_layer_norm does (and is tested by OpInfos!)
- if M > 0:
- input_reshaped = input.view(1, M, -1)
- else:
- return (input, input.new_zeros((0,)), input.new_zeros((0,)))
- # Unlike Batch Normalization, which applies scalar scale and bias for each
- # entire channel/plane with the affine option, Layer Normalization applies
- # per-element scale and bias. E.g. For input {N, C, H, W}, weight for
- # batchnorm has shape {C} while weight for layernorm has shape {H, W} or {W}.
- out, mean, rstd = aten.native_batch_norm(
- input_reshaped,
- weight=None,
- bias=None,
- running_mean=None,
- running_var=None,
- training=True,
- momentum=0.0,
- eps=eps,
- )
- out = out.view(input_shape)
- if weight is not None:
- out = out * weight
- if bias is not None:
- out = out + bias
- stat_shape = list(input_shape[:axis])
- for _ in range(axis, input.dim()):
- stat_shape.append(1)
- mean = mean.view(stat_shape)
- rstd = rstd.view(stat_shape)
- return (out, mean, rstd)
- # TODO: Correct the type promotion semantics
- @register_decomposition(aten.native_layer_norm_backward)
- @pw_cast_for_opmath
- def native_layer_norm_backward(
- grad_out: Tensor,
- input: Tensor,
- normalized_shape: List[int],
- mean: Tensor,
- rstd: Tensor,
- weight: Optional[Tensor],
- bias: Optional[Tensor],
- output_mask: List[bool],
- ) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]:
- input_shape = input.shape
- input_ndim = input.dim()
- axis = input_ndim - len(normalized_shape)
- inner_dims = input_shape[axis:]
- outer_dims = input_shape[:axis]
- inner_dim_indices: List[int] = []
- outer_dim_indices: List[int] = []
- for i in range(input_ndim):
- if i >= axis:
- inner_dim_indices.append(i)
- else:
- outer_dim_indices.append(i)
- N = prod(inner_dims) # type: ignore[arg-type]
- M = prod(outer_dims) # type: ignore[arg-type]
- if M <= 0 or N <= 0:
- return (
- input.new_zeros(input_shape),
- input.new_zeros(input_shape[axis:]),
- input.new_zeros(input_shape[axis:]),
- )
- x_hat = (input - mean) * rstd
- if weight is not None:
- grad_x_hat = grad_out * weight
- else:
- grad_x_hat = grad_out
- a = grad_x_hat * N
- b = torch.sum(grad_x_hat, inner_dim_indices, True)
- c1 = torch.mul(grad_x_hat, x_hat)
- c2 = torch.sum(c1, inner_dim_indices, True)
- c3 = torch.mul(x_hat, c2)
- inner = a - b - c3
- if output_mask[0]:
- d_input: Optional[Tensor] = (rstd / N) * inner
- else:
- d_input = None
- if output_mask[1] and weight is not None:
- if len(outer_dim_indices) > 0:
- d_weight: Optional[Tensor] = torch.sum(
- grad_out * x_hat, outer_dim_indices, False
- )
- else:
- d_weight = grad_out * x_hat
- else:
- d_weight = None
- if output_mask[2] and bias is not None:
- if len(outer_dim_indices) > 0:
- d_bias: Optional[Tensor] = torch.sum(grad_out, outer_dim_indices, False)
- else:
- d_bias = grad_out
- else:
- d_bias = None
- return (d_input, d_weight, d_bias)
- # TODO: Correct the type promotion semantics
- @register_decomposition(aten.native_batch_norm)
- @pw_cast_for_opmath
- def native_batch_norm(
- input: Tensor,
- weight: Optional[Tensor],
- bias: Optional[Tensor],
- running_mean: Optional[Tensor],
- running_var: Optional[Tensor],
- training: bool,
- momentum: float,
- eps: float,
- ) -> Tuple[Tensor, Tensor, Tensor]:
- reduction_dims = [0] + list(range(2, input.dim()))
- if training:
- # save_mean = torch.sum(input / (input.shape[0] * input.shape[2]), dim=reduction_dims)
- biased_var, save_mean = torch.var_mean(
- input, dim=reduction_dims, unbiased=False
- )
- save_invstd = 1 / (torch.sqrt(biased_var + eps))
- if running_mean is not None:
- running_mean.copy_(momentum * save_mean + (1 - momentum) * running_mean)
- if running_var is not None:
- n = input.numel() / input.shape[1]
- # This doesn't strictly match eager's numerics, which accumulates var sum and then directly applies the correction
- # But... that would require re-implementing var here, for negligible numerics gain on a tensor whose
- # numerics probably don't matter.
- unbiased_var = biased_var * (n / (n - 1))
- running_var.copy_(momentum * unbiased_var + (1 - momentum) * running_var)
- mean = save_mean
- invstd = save_invstd
- else:
- assert running_mean is not None and running_var is not None
- mean = running_mean
- invstd = 1 / (torch.sqrt(running_var + eps))
- # Very annoying inconsistency where CPU and CUDA give different shapes
- if input.device.type == "cuda":
- save_mean = running_mean
- save_invstd = invstd
- else:
- save_mean = input.new_zeros((0,))
- save_invstd = input.new_zeros((0,))
- if weight is None:
- weight = input.new_ones(())
- if bias is None:
- bias = input.new_zeros(())
- mean = _unsqueeze_to_dim(mean, input.dim() - 1)
- invstd = _unsqueeze_to_dim(invstd, input.dim() - 1)
- weight = _unsqueeze_to_dim(weight, input.dim() - 1)
- bias = _unsqueeze_to_dim(bias, input.dim() - 1)
- output = ((input - mean) * invstd) * weight + bias
- return output, save_mean, save_invstd
- @register_decomposition(aten.clamp_min)
- def clamp_min(self: Tensor, min: float):
- return torch.clamp(self, min=min)
- @register_decomposition(aten.clamp_max)
- def clamp_max(self: Tensor, max: float):
- return torch.clamp(self, max=max)
- @register_decomposition(aten._fused_dropout)
- @pw_cast_for_opmath
- def _fused_dropout_decomposition(input, p, generator=None):
- mask = (torch.rand_like(input) < p).to(dtype=torch.uint8)
- res = mask.type_as(input) * input * (1.0 / p)
- return (res, mask)
- # TODO: these logical decomps are buggy for complex inputs
- @register_decomposition(aten.logical_xor)
- def logical_xor(self: Tensor, other: Tensor) -> Tensor:
- return self.to(dtype=torch.bool) ^ other.to(dtype=torch.bool)
- @register_decomposition(aten.logical_not)
- def logical_not(self: Tensor) -> Tensor:
- return ~self.to(dtype=torch.bool)
- @register_decomposition(aten.xlogy.Tensor)
- @pw_cast_for_int_to_real
- def xlogy(self: Tensor, other: Tensor) -> Tensor:
- return aten.where(aten.isnan(self),
- self,
- aten.where(self == aten.new_zeros(self, ()),
- aten.new_zeros(self, ()),
- self * aten.log(other)))
- @register_decomposition(aten.var.correction)
- @reduction_complex_to_real
- def var_correction(
- x: Tensor,
- dims: Optional[List[int]],
- correction: Optional[int] = None,
- keepdim: bool = False,
- ):
- if dims is None:
- dims = []
- if x.is_complex():
- # For complex, calculate variance of real and imaginary components
- # separately then add to get overall variance.
- real_in = x.real
- var_real = torch.var(real_in, dims, correction=correction, keepdim=keepdim)
- imag_in = x.imag
- var_imag = torch.var(imag_in, dims, correction=correction, keepdim=keepdim)
- return var_real + var_imag
- if correction is None:
- correction = 0
- if len(dims) == 0:
- n = prod(x.shape) # type: ignore[arg-type]
- else:
- n = 1
- for dim in dims:
- n *= x.shape[dim]
- mean = torch.mean(x, dims, True)
- sub = x - mean
- sq = sub * sub
- sum = torch.sum(sq, dims, keepdim)
- if correction:
- n = n - correction
- return sum / n
- @register_decomposition(aten.std.correction)
- @reduction_complex_to_real
- def std_decomposition(
- x: Tensor, dims: List[int], correction: int = 0, keepdim: bool = False
- ):
- return torch.sqrt(torch.var(x, dims, correction=correction, keepdim=keepdim))
- # Questionable decompositions
- # This is only valid if we're running the graph without autograd, such as if the backward pass has been traced.
- # Note that this decomposition causes issues with in-place ops
- @register_decomposition(aten.detach, disable_meta=True)
- def detach_decomposition(x):
- return x
- @register_decomposition(aten.cudnn_batch_norm)
- def cudnn_batch_norm(
- input: Tensor,
- weight: Tensor,
- bias: Optional[Tensor],
- running_mean: Optional[Tensor],
- running_var: Optional[Tensor],
- training: bool,
- exponential_average_factor: float,
- epsilon: float,
- ):
- a, b, c = aten.native_batch_norm(
- input,
- weight,
- bias,
- running_mean,
- running_var,
- training,
- exponential_average_factor,
- epsilon,
- )
- # Cudnn return running mean and variance when training is True
- if training:
- return (a, b, c, input.new_zeros((0,), dtype=torch.uint8))
- return (a, input.new_zeros((0,)), input.new_zeros((0,)), input.new_zeros((0,), dtype=torch.uint8))
- @register_decomposition(aten.cudnn_batch_norm_backward)
- def cudnn_batch_norm_backward(
- input: Tensor,
- grad_output: Tensor,
- weight: Tensor,
- running_mean: Optional[Tensor],
- running_var: Optional[Tensor],
- save_mean: Optional[Tensor],
- save_var: Optional[Tensor],
- epsilon: float,
- reserveSpace: Tensor,
- ):
- return aten.native_batch_norm_backward(
- grad_output,
- input,
- weight,
- running_mean,
- running_var,
- save_mean,
- save_var,
- True,
- epsilon,
- [True, True, True],
- )
- @register_decomposition(aten.rot90.default)
- def rot90(self: Tensor, k: int = 1, dims: List[int] = [0, 1]) -> Tensor: # noqa: B006
- total_dims = self.dim()
- total_rot_dims = len(dims)
- assert total_rot_dims == 2, f"expected total rotation dims == 2, but got dims = {total_rot_dims}"
- assert total_dims >= 2, f"expected total dims >= 2, but got total dims = {total_dims}"
- assert dims[0] != dims[1] and abs(dims[0] - dims[1]) != total_dims,\
- f"expected rotation dims to be different, but got dim0 = {dims[0]} and dim1 = {dims[1]}"
- assert dims[0] < total_dims and dims[0] >= -total_dims, f"Rotation dim0 out of range, dim0 = {dims[0]}"
- assert dims[1] < total_dims and dims[1] >= -total_dims, f"Rotation dim1 out of range, dim1 = {dims[1]}"
- k = k % 4
- if k == 1:
- return self.flip(dims[1]).transpose(dims[0], dims[1])
- elif k == 2:
- return self.flip(dims)
- elif k == 3:
- return self.flip(dims[0]).transpose(dims[0], dims[1])
- else:
- return self.clone(memory_format=torch.contiguous_format)
- @register_decomposition(aten.transpose.int)
- def transpose_int(self: Tensor, dim0: int, dim1: int) -> Tensor:
- dim0, dim1 = utils.canonicalize_dims(self.dim(), (dim0, dim1)) # type: ignore[misc]
- if self.dim() <= 1:
- return self
- if dim0 == dim1:
- return self
- perm = list(range(self.dim()))
- perm[dim0], perm[dim1] = perm[dim1], perm[dim0]
- return torch.permute(self, perm)
- @register_decomposition(aten.t.default)
- def t(self: Tensor) -> Tensor:
- return self.transpose(0, 0 if self.dim() < 2 else 1)
- def check_stack_inputs(tensors: List[Tensor]):
- entry_shape = tensors[0].shape
- for i in range(1, len(tensors)):
- assert tensors[i].shape == entry_shape, (f"stack expects each tensor to be equal size, but got {entry_shape} at entry 0"
- f"and {tensors[i].shape} at entry {i}")
- def get_stack_inputs(tensors: List[Tensor], dim: int):
- check_stack_inputs(tensors)
- return [t.unsqueeze(dim) for t in tensors]
- @register_decomposition(aten.stack.default)
- def stack(tensors: List[Tensor], dim: int = 0) -> Tensor:
- assert len(tensors) > 0, "stack expects a non-empty TensorList"
- wrapped_dim = utils.canonicalize_dim(tensors[0].dim() + 1, dim)
- if wrapped_dim < tensors[0].dim() and not tensors[0].is_sparse:
- check_stack_inputs(tensors)
- result_sizes = list(tensors[0].shape)
- result_sizes.insert(wrapped_dim, len(tensors))
- out = torch.cat(tensors, wrapped_dim)
- return out.view(result_sizes)
- else:
- return torch.cat(get_stack_inputs(tensors, wrapped_dim), dim)
- def _squeeze_multiple(self: Tensor, dims: List[int]) -> Tensor:
- ndim = self.dim()
- wrapped_dims = utils.canonicalize_dims(ndim, dims)
- assert isinstance(wrapped_dims, tuple)
- for idx in range(ndim - 1, -1, -1):
- if idx in wrapped_dims:
- self = self.squeeze(idx)
- return self
- @register_decomposition(aten.logsumexp.default)
- @pw_cast_for_int_to_real
- def logsumexp(self: Tensor, dim: List[int], keepdim: bool = False) -> Tensor:
- if self.numel() == 0:
- return torch.sum(torch.exp(self), dim, keepdim).log()
- maxes = torch.amax(self, dim, keepdim=True)
- maxes_squeezed = maxes if keepdim else _squeeze_multiple(maxes, dim)
- maxes_squeezed = torch.masked_fill(maxes_squeezed, maxes_squeezed.abs() == float('inf'), 0)
- result = torch.sum(torch.exp(self - maxes), dim, keepdim)
- return result.log().add(maxes_squeezed)
- @register_decomposition(aten.trace.default)
- def trace(self: Tensor) -> Tensor:
- return torch.sum(torch.diag(self))
- # nb: Should use acc_t, not op_math
- @register_decomposition(aten.log_sigmoid_forward)
- @out_wrapper_multi('output', 'buffer')
- @pw_cast_for_opmath
- def log_sigmoid_forward(self: Tensor) -> Tuple[Tensor, Tensor]:
- min = torch.minimum(self.new_zeros(()), self)
- z = torch.exp(-torch.abs(self))
- if self.is_cuda:
- buffer = self.new_zeros((0,))
- else:
- buffer = z
- return min - torch.log1p(z), buffer
|