"""This file exports ONNX ops for opset 9. Opset 9 is supported by ONNX release 1.4.1 release on 01/23/19 """ import functools import math import sys import warnings from typing import List, Optional, Tuple, Union import torch import torch._C._onnx as _C_onnx import torch.nn.modules.utils import torch.onnx from torch import _C # This import monkey-patches graph manipulation methods on Graph, used for the # ONNX symbolics from torch.onnx import _patch_torch # noqa: F401 from torch.onnx import symbolic_helper from torch.onnx._globals import GLOBALS # EDITING THIS FILE? READ THIS FIRST! # see Note [Edit Symbolic Files] in symbolic_helper.py # Note [Pointwise by scalar] # ~~~~~~~~~~~~~~~~~~~~~~~~~~ # What happens if you add a tensor with a constant (e.g., x + 2)? There are # some moving parts to implementing the ONNX translation in this case: # # - By the time we get the scalar in a symbolic function here, it is no longer # a Python long/float, but a PyTorch tensor with numel == 1 (eventually, we # want it to be a zero dim tensor but this change has not happened yet.) # However, the type of this scalar is *exactly* what the user wrote in # Python, which may not match the tensor it is being added to. PyTorch # will do implicit conversions on scalars; however, ONNX will not, so # we must do the conversion ourselves. This is what _if_scalar_type_as # does. # # - Dispatch to these functions takes advantage an outrageous coincidence # between the tensor and scalar name. When we add two tensors together, # you get the dispatch: # # add(*[self, other], **{"alpha": alpha}) # # When you add a tensor and a scalar, you get the dispatch: # # add(*[self], **{"other": other, "alpha": alpha}) # # By having the argument name line up with the name of the scalar attribute # if it exists, we can write a single function for both overloads. # # used to represent "missing" optional inputs def unused(g): n = g.op("prim::Constant") n.setType(_C.OptionalType.ofTensor()) return n def _shape_as_tensor(g, input): return g.op("Shape", input) def _reshape_from_tensor(g, input, shape): if isinstance(shape, list): shape = g.op("Concat", *shape, axis_i=0) return reshape(g, input, shape) def reshape(g, self, shape): return symbolic_helper._reshape_helper(g, self, shape) def reshape_as(g, self, other): shape = g.op("Shape", other) return reshape(g, self, shape) def add(g, self, other, alpha=None): if symbolic_helper._is_value(self) and symbolic_helper._is_tensor_list(self): return symbolic_helper._onnx_opset_unsupported_detailed( "Add", 9, 11, "Add between list of tensors not supported" ) # default alpha arg is to allow no-alpha add (aten add st overload no alpha) if alpha and symbolic_helper._scalar(symbolic_helper._maybe_get_scalar(alpha)) != 1: return symbolic_helper._unimplemented("add", "alpha != 1") return g.op("Add", self, other) def sub(g, self, other, alpha=None): # default alpha arg is to allow no-alpha sub (aten sub st overload no alpha) if alpha and symbolic_helper._scalar(symbolic_helper._maybe_get_scalar(alpha)) != 1: return symbolic_helper._unimplemented("sub", "alpha != 1") return g.op("Sub", self, other) def rsub(g, self, other, alpha=None): return sub(g, other, self, alpha=alpha) def mul(g, self, other): return g.op("Mul", self, other) def div(g, self, other, *args): if len(args) == 0: return true_divide(g, self, other) else: return _div_rounding_mode(g, self, other, *args) @symbolic_helper.parse_args("v", "v", "v", "f") def addcmul(g, self, tensor1, tensor2, value=1.0): value_tens = g.op("Constant", value_t=torch.tensor([value])) return add(g, self, mul(g, mul(g, tensor1, tensor2), value_tens)) @symbolic_helper.parse_args("v", "v", "s") def _div_rounding_mode(g, self, other, rounding_mode): if rounding_mode is None: return true_divide(g, self, other) elif rounding_mode == "floor": return _floor_divide(g, self, other) elif rounding_mode == "trunc": return _trunc_divide(g, self, other) else: raise RuntimeError( f'Unsupported rounding mode: "{rounding_mode}". Expected None, "floor" or "trunc"' ) def _trunc_divide(g, self, other): out = g.op("Div", self, other) # the correct operation is truncate, which is not supported in ONNX, # we cannot call floor since it will behave differently for negative numbers # (eg. -0.1 should become -0 ) # - if scalar_type information are not available, assume that # we need to call floor (treat as float) out = g.op("Cast", out, to_i=symbolic_helper.cast_pytorch_to_onnx["Long"]) # Matching PyTorch's behavior: # - if self is fp the output's type is self's type # - if self is not fp and other is fp, the output is of type "Float" # - self is not fp and other is not fp, the output's type is self's output type # - the output type defaults to Float scalar_type = self.type().scalarType() if scalar_type is not None: if ( not symbolic_helper._is_fp(self) and other.type().scalarType() is not None and symbolic_helper._is_fp(other) ): out = g.op("Cast", out, to_i=symbolic_helper.cast_pytorch_to_onnx["Float"]) else: out = g.op( "Cast", out, to_i=symbolic_helper.cast_pytorch_to_onnx[scalar_type] ) else: out = g.op("Cast", out, to_i=symbolic_helper.cast_pytorch_to_onnx["Float"]) return out def _floor_divide(g, self, other): if symbolic_helper._is_fp(self) or symbolic_helper._is_fp(other): out = true_divide(g, self, other) return g.op("Floor", out) else: # Integer division does trunction rounding div = g.op("Div", self, other) # Division is negative if: self < 0 != other < 0 zero = g.op("Constant", value_t=torch.tensor(0, dtype=torch.int64)) negative = g.op( "Xor", symbolic_helper._lt_helper(g, self, zero), symbolic_helper._lt_helper(g, other, zero), ) # For negative numbers with self % other != 0, subtract 1 to round down instead of up mod = g.op("Sub", self, g.op("Mul", div, other)) fixup_mask = g.op("And", negative, g.op("Not", g.op("Equal", mod, zero))) one = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)) fixup = g.op("Mul", fixup_mask, one) return g.op("Sub", div, fixup) def floor_divide(g, self, other): # Deprecated behavior, floor_divide actually truncates return _trunc_divide(g, self, other) def floordiv(g, self, other): return floor_divide(g, self, other) def true_divide(g, self, other): """Division where both inputs are cast to floating types If both inputs are floating, performs div as usual If only one input is a floating type, the other input is cast to its type If neither input is a floating type, both inputs are cast to the default scalar type """ # Case 1: either values are floating # Performs div as usual. # Implicit casting will be handled in scalar type analysis pass. if symbolic_helper._is_fp(self) or symbolic_helper._is_fp(other): return g.op("Div", self, other) # Case 2: neither is floating # Casts both inputs to the default scalar type scalar_type = torch.get_default_dtype() onnx_scalar_type = symbolic_helper.cast_pytorch_to_onnx["Float"] assert scalar_type is torch.float or scalar_type is torch.double if torch.get_default_dtype() is torch.double: onnx_scalar_type = symbolic_helper.cast_pytorch_to_onnx["Double"] self = g.op("Cast", self, to_i=onnx_scalar_type) other = g.op("Cast", other, to_i=onnx_scalar_type) return g.op("Div", self, other) def reciprocal(g, self): # torch.reciprocal implicitly casts to float, so we do the same. if not symbolic_helper._is_fp(self): self = g.op("Cast", self, to_i=symbolic_helper.cast_pytorch_to_onnx["Float"]) return g.op("Reciprocal", self) @symbolic_helper.parse_args("v", "i") def cat(g, tensor_list, dim): tensors = symbolic_helper._unpack_list(tensor_list) return g.op("Concat", *tensors, axis_i=dim) @symbolic_helper.parse_args("v", "i") def stack(g, tensor_list, dim): unsqueezed = [ symbolic_helper._unsqueeze_helper(g, t, [dim]) for t in symbolic_helper._unpack_list(tensor_list) ] return g.op("Concat", *unsqueezed, axis_i=dim) def _list(g, self): return self def mm(g, self, other): # Create a dummy C tensor. Only needed for API purposes, the value is # since beta = 0 C = g.op("Constant", value_t=torch.tensor([1])) return g.op("Gemm", self, other, C, beta_f=0.0, alpha_f=1.0) def bmm(g, self, other): return g.op("MatMul", self, other) def matmul(g, self, other): return g.op("MatMul", self, other) @symbolic_helper.parse_args("v", "v", "v", "t", "t") def addmm(g, self, mat1, mat2, beta, alpha): dtype = None self_dtype = symbolic_helper._try_get_scalar_type(self) mat1_dtype = symbolic_helper._try_get_scalar_type(mat1) mat2_dtype = symbolic_helper._try_get_scalar_type(mat2) if self_dtype is not None: dtype = self_dtype elif mat1_dtype is not None: dtype = mat1_dtype elif mat2_dtype is not None: dtype = mat2_dtype mat1_rank = symbolic_helper._get_tensor_rank(mat1) mat2_rank = symbolic_helper._get_tensor_rank(mat2) def isNotNoneAnd(v, u): return v is not None and v != u if dtype is not None and (isNotNoneAnd(mat1_rank, 2) or isNotNoneAnd(mat2_rank, 2)): dtype = symbolic_helper.scalar_type_to_onnx.index( symbolic_helper.cast_pytorch_to_onnx[dtype] ) dtype = symbolic_helper.scalar_type_to_pytorch_type[dtype] res1 = g.op("MatMul", mat1, mat2) res2 = self alpha = symbolic_helper._scalar(alpha) beta = symbolic_helper._scalar(beta) if alpha != 1: alpha = g.op("Constant", value_t=torch.tensor(alpha, dtype=dtype)) res1 = g.op("Mul", res1, alpha) if beta != 1: beta = g.op( "Constant", value_t=torch.tensor(symbolic_helper._scalar(beta), dtype=dtype), ) res2 = g.op("Mul", res2, beta) return g.op("Add", res1, res2) return g.op( "Gemm", mat1, mat2, self, beta_f=symbolic_helper._scalar(beta), alpha_f=symbolic_helper._scalar(alpha), ) def neg(g, self): return g.op("Neg", self) def sqrt(g, self): return g.op("Sqrt", self) def rsqrt(g, self): return g.op( "Div", symbolic_helper._if_scalar_type_as(g, torch.ones(1), self), sqrt(g, self) ) def tanh(g, self): return g.op("Tanh", self) def sin(g, self): return g.op("Sin", self) def cos(g, self): return g.op("Cos", self) def tan(g, self): return g.op("Tan", self) def asin(g, self): return g.op("Asin", self) def acos(g, self): return g.op("Acos", self) def atan(g, self): return g.op("Atan", self) # Fixed scale and zero_point, discovered from aten/src/ATen/native/quantized/cpu/qsigmoid.cpp @symbolic_helper.quantized_args(True, scale=1.0 / 256.0, zero_point=0) def sigmoid(g, self): return g.op("Sigmoid", self) def sign(g, self): return g.op("Sign", self) def _slice(g, input, axes, starts, ends): assert len(starts) == len(ends) if len(starts) == 1 and starts[0] == 0 and ends[0] == 9223372036854775807: return input return g.op("Slice", input, axes_i=axes, starts_i=starts, ends_i=ends) def _maybe_cast_reduce_op_input(g, self): dtype = self.type().scalarType() # This check only covers traced modules where dtype is present if dtype is not None: # pytorch reduce-ops cast all other integral types to int64 if not symbolic_helper._is_fp(self) and not (dtype == "Long"): self = _cast_Long(g, self, False) # type: ignore[name-defined] return self def _reduce_op_symbolic(onnx_op_name, allow_multi_dim_support=True): def symbolic(g, self, dim=None, keepdim=None): self = _maybe_cast_reduce_op_input(g, self) if dim is None: # all-reduce path return symbolic_helper._handle_reduce_dim_none(g, self, onnx_op_name) else: # dim-reduce path desc = "is" if allow_multi_dim_support else "i" dim, keepdim = symbolic_helper._get_const( dim, desc, "dim" ), symbolic_helper._get_const(keepdim, "i", "keepdim") dim_list = dim if allow_multi_dim_support else [dim] return g.op(onnx_op_name, self, axes_i=dim_list, keepdims_i=keepdim) return symbolic def overload_by_arg_count(fn): @functools.wraps(fn) def wrapper(g, *args): overloads = fn(g, *args) last_exception = None for overload in overloads: arg_descriptors = overload._arg_descriptors if len(arg_descriptors) == len(args): return overload(g, *args) raise NotImplementedError("Unknown aten::{} signature".format(fn.__name__)) return wrapper def _reduce_with_dtype(onnx_op, name, allow_multi_dim_support=True): symbolic = _reduce_op_symbolic( onnx_op, allow_multi_dim_support=allow_multi_dim_support ) @overload_by_arg_count def reduce(g, *args, **kwargs): @symbolic_helper.parse_args("v", "none") def reduce_nodim(g, self, dtype): if dtype.node().kind() == "onnx::Constant": dtype = symbolic_helper._get_const(dtype, "i", "dtype") self = g.op( "Cast", self, to_i=symbolic_helper.scalar_type_to_onnx[dtype] ) elif dtype.node().kind() != "prim::Constant": return symbolic_helper._unimplemented(name, "dtype") return symbolic(g, self) dim_desc = "is" if allow_multi_dim_support else "i" @symbolic_helper.parse_args("v", dim_desc, "i", "none") def reduce_dim(g, self, dim, keepdim, dtype): if dtype.node().kind() == "onnx::Constant": dtype = symbolic_helper._get_const(dtype, "i", "dtype") self = g.op( "Cast", self, to_i=symbolic_helper.scalar_type_to_onnx[dtype] ) elif dtype.node().kind() != "prim::Constant": return symbolic_helper._unimplemented(name, "dtype") return symbolic(g, self, dim, keepdim) return reduce_nodim, reduce_dim return reduce sum = _reduce_with_dtype("ReduceSum", "sum") mean = _reduce_with_dtype("ReduceMean", "mean") # torch.prod does not support multidimensional "dim" prod = _reduce_with_dtype("ReduceProd", "prod", allow_multi_dim_support=False) @symbolic_helper.parse_args("v", "i", "none") def cumsum(g, input, dim, dtype): if symbolic_helper.is_caffe2_aten_fallback(): if dtype.node().kind() != "prim::Constant": return symbolic_helper._unimplemented(name, "dtype") return g.at("cumsum", input, dim_i=dim) else: symbolic_helper._onnx_opset_unsupported("cumsum", 9, 11) def _sample_dirichlet(g, self, generator): if symbolic_helper.is_caffe2_aten_fallback(): if not symbolic_helper._is_none(generator): return symbolic_helper._unimplemented( "_sample_dirichlet", "We are not able to export generator" ) return g.at("_sample_dirichlet", self) else: return symbolic_helper._onnx_unsupported("_sample_dirichlet") def _standard_gamma(g, self, generator): if symbolic_helper.is_caffe2_aten_fallback(): if not symbolic_helper._is_none(generator): return symbolic_helper._unimplemented( "_standard_gamma", "We are not able to export generator" ) return g.at("_standard_gamma", self) else: return symbolic_helper._onnx_unsupported("_standard_gamma") def t(g, self): return g.op("Transpose", self, perm_i=(1, 0)) def expand(g, self, size, implicit): size = symbolic_helper._maybe_get_const(size, "is") if not symbolic_helper._is_value(size): size = g.op("Constant", value_t=torch.LongTensor(size)) elif symbolic_helper._is_packed_list(size): # Expand with -1 dim value means dim is unchanged. # Since onnx::expand supports two-way broadcasting, # -1 dim value can be exported to onnx as 1 size = symbolic_helper._reshape_helper( g, stack(g, size, 0), g.op("Constant", value_t=torch.tensor([-1])) ) dtype = symbolic_helper.ScalarType.INT64 ones = ones_like(g, size, dtype) neg_ones = mul(g, ones, g.op("Constant", value_t=torch.tensor(-1))) size = where(g, g.op("Equal", size, neg_ones), ones, size) return g.op("Expand", self, size) def expand_as(g, self, other): self_t = symbolic_helper._maybe_get_const(self, "t") if isinstance(self_t, torch.Tensor): orig_type = self_t.dtype self_t = self_t.to(torch.double) dims = [] for d in range(self_t.dim()): if torch.equal(self_t.mean(d).unsqueeze(d).expand_as(self_t), self_t): dims.append(d) self = g.op("Constant", value_t=self_t.mean(dims).to(orig_type)) shape = g.op("Shape", other) return g.op("Expand", self, shape) @symbolic_helper.parse_args("v", "v", "i", "b", "v") def embedding(g, weight, indices, padding_idx, scale_grad_by_freq, sparse): if scale_grad_by_freq and GLOBALS.training_mode: raise RuntimeError( "Unsupported: ONNX export of embedding with scale_grad_by_freq=True " "for training mode. ONNX does not support scaling the gradients." ) if padding_idx >= 0 and GLOBALS.training_mode: warnings.warn( "Warning: ONNX export of embedding with padding_idx >= 0 " "for training mode. " "ONNX does not support not updating the embedding vector at padding_idx during training." ) return g.op("Gather", weight, indices) @symbolic_helper.parse_args("v", "v", "v", "i", "i", "i", "v", "i", "i") def embedding_bag( g, embedding_matrix, indices, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights, include_last_offset, padding_idx, ): if not symbolic_helper._is_none(per_sample_weights): return symbolic_helper._onnx_unsupported( "embedding_bag with per_sample_weights" ) if symbolic_helper.is_caffe2_aten_fallback(): return g.at( "embedding_bag", embedding_matrix, indices, offsets, outputs=4, scale_grad_by_freq_i=scale_grad_by_freq, mode_i=mode, sparse_i=sparse, include_last_offset_i=include_last_offset, padding_idx_i=padding_idx, ) else: return symbolic_helper._onnx_unsupported("embedding_bag") def size(g, self, dim=None): if dim is None: return g.op("Shape", self) if symbolic_helper._maybe_get_const(dim, "i") < 0: rank = symbolic_helper._get_tensor_rank(self) if rank is not None: dim = symbolic_helper._maybe_get_const(dim, "i") + rank dim = g.op("Constant", value_t=torch.tensor(dim)) return symbolic_helper._size_helper(g, self, dim) @symbolic_helper.parse_args("v", "i", "i") def transpose(g, self, dim0, dim1): if dim0 == dim1: # micro-optimization return self # NB: Transpose in ONNX is actually a Permute rank = symbolic_helper._get_tensor_rank(self) if rank is not None: axes = list(range(rank)) axes[dim0], axes[dim1] = axes[dim1], axes[dim0] return g.op("Transpose", self, perm_i=axes) else: # if we don't have dim information we cannot # output a permute so use ATen instead if symbolic_helper.is_caffe2_aten_fallback(): return g.at( "transpose", self, overload_name="int", dim0_i=dim0, dim1_i=dim1 ) else: raise RuntimeError( "Unsupported: ONNX export of transpose for tensor " "of unknown rank." ) @symbolic_helper.parse_args("v", "is") def permute(g, self, dims): if dims == list(range(0, len(dims))): return self return g.op("Transpose", self, perm_i=dims) def view(g, self, size): return reshape(g, self, size) def view_as(g, self, other): shape = g.op("Shape", other) return reshape(g, self, shape) @symbolic_helper.parse_args("v", "i", "i", "i") def unsafe_chunk(g, self, chunks, dim, _outputs=None): if _outputs is None: return symbolic_helper._onnx_opset_unsupported_detailed( "unsafe_chunk", 9, 11, "Dynamic number of outputs not supported" ) size = symbolic_helper._get_tensor_dim_size(self, dim) if size is None: return symbolic_helper._unimplemented("unsafe_chunk", "unknown dimension size") split_size = (size + chunks - 1) // chunks splits = [split_size] * (size // split_size) leftover = size % split_size if leftover: splits.append(leftover) return g.op("Split", self, split_i=splits, axis_i=dim, outputs=_outputs) @symbolic_helper.parse_args("v", "v", "v", "i") def split(g, self, split_size_or_sizes, dim, _outputs=None): if not symbolic_helper._is_split_static(split_size_or_sizes, _outputs): return symbolic_helper._onnx_opset_unsupported_detailed( "split", 9, 11, "Dynamic number of outputs not supported" ) split_val = split_size_or_sizes.node()["value"] if split_val.dim() > 0: return split_with_sizes(g, self, split_size_or_sizes, dim, _outputs) split_size = symbolic_helper._get_const(split_size_or_sizes, "i", "split_size") dim = symbolic_helper._get_const(dim, "i", "dim") size = symbolic_helper._get_tensor_dim_size(self, dim) if size is None: if _outputs is not None: size = split_size * _outputs else: return symbolic_helper._onnx_opset_unsupported_detailed( "split", 9, 11, "Unknown dimension size not supported" ) splits = [split_size] * (size // split_size) leftover = size % split_size if leftover: splits.append(leftover) return g.op("Split", self, split_i=splits, axis_i=dim, outputs=_outputs) def unsafe_split(g, self, split_size_or_sizes, dim, _outputs=None): return split(g, self, split_size_or_sizes, dim, _outputs) @symbolic_helper.parse_args("v", "is", "i", "i") def split_with_sizes(g, self, split_sizes, dim, _outputs=None): if not symbolic_helper._is_split_static(split_sizes, _outputs): return symbolic_helper._onnx_opset_unsupported_detailed( "split_with_sizes", 9, 11, "Dynamic number of outputs not supported" ) return g.op("Split", self, split_i=split_sizes, axis_i=dim, outputs=_outputs) def unsafe_split_with_sizes(g, self, split_sizes, dim, _outputs=None): return split_with_sizes(g, self, split_sizes, dim, _outputs) @symbolic_helper.parse_args("v", "i", "i") def unbind(g, self, dim=0, _outputs=None): if _outputs is None: return symbolic_helper._onnx_opset_unsupported_detailed( "unbind", 9, 11, "Dynamic number of outputs not supported" ) outputs = g.op("Split", self, split_i=[1] * _outputs, axis_i=dim, outputs=_outputs) outputs = [outputs] if _outputs == 1 else outputs squeezed_outputs = [ symbolic_helper._squeeze_helper(g, out, [dim]) for out in outputs ] return squeezed_outputs @symbolic_helper.parse_args("v", "i", "v") def select(g, self, dim, index): index = symbolic_helper._maybe_get_scalar(index) if (not symbolic_helper._is_value(index)) and (index < 0): if index == -1: end_index = 9223372036854775807 else: end_index = index + 1 slice_node = symbolic_helper._slice_helper( g, self, axes=[dim], starts=[index], ends=[end_index] ) return symbolic_helper._squeeze_helper(g, slice_node, [dim]) else: return g.op("Gather", self, index, axis_i=dim) def square(g, self): return g.op("Mul", self, self) def squeeze(g, self, dim=None): if dim is None: return g.op("Squeeze", self) squeeze_dim = symbolic_helper._get_const(dim, "i", "dim") # Handle negative dims if squeeze_dim < 0: rank = symbolic_helper._get_tensor_rank(self) if rank is not None: warnings.warn( "ONNX export squeeze with negative axis " + str(squeeze_dim) + " might cause the onnx model to be incorrect. " + "Negative axis is not supported in ONNX. " + "Axis is converted to " + str(squeeze_dim + rank) + " based on input shape at export time. " + "Passing an tensor of different rank in execution will be incorrect." ) squeeze_dim += rank else: return symbolic_helper._unimplemented( "squeeze", "negative axis with unknown input rank" ) dim_size = symbolic_helper._get_tensor_dim_size(self, squeeze_dim) if dim_size is None: warnings.warn( "This model contains a squeeze operation on dimension " + str(squeeze_dim) + " on an input " + "with unknown shape. Note that if the size of dimension " + str(squeeze_dim) + " of the input " + "is not 1, the ONNX model will return an error. Opset version 11 supports squeezing on " + "non-singleton dimensions, it is recommended to export this model using opset " + "version 11 or higher." ) return symbolic_helper._squeeze_helper(g, self, axes_i=[squeeze_dim]) if dim_size > 1: warnings.warn( "This model contains a squeeze operation on dimension " + str(squeeze_dim) + ". The size of " + "this dimension in the given input is " + str(dim_size) + ". The model will " + "be exported without the squeeze node. If the model is intended to be used with dynamic " + "input shapes, please use opset version 11 to " + "export the model." ) return self warnings.warn( "This model contains a squeeze operation on dimension " + str(squeeze_dim) + ". If the model is " + "intended to be used with dynamic input shapes, please use opset version 11 to export the model." ) return symbolic_helper._squeeze_helper(g, self, axes_i=[squeeze_dim]) def prelu(g, self, weight): self_rank = symbolic_helper._get_tensor_rank(self) if self_rank is not None: if self_rank > 2: # make weight unidirectional broadcastable weight = symbolic_helper._unsqueeze_helper( g, weight, list(range(1, self_rank - 1)) ) elif self_rank == 0: # weight is always rank 1. torch allows scalar self, and ONNX is ambiguous # about whether this is allowed, but some implementations enforce # rank(self) >= rank(weight), which makes sense. self = symbolic_helper._unsqueeze_helper(g, self, [0]) self_rank = 1 weight_rank = symbolic_helper._get_tensor_rank(weight) if self_rank is not None and weight_rank is not None: assert ( self_rank >= weight_rank ), "rank(x) should be >= rank(slope) but got {} < {}".format( self_rank, weight_rank ) return g.op("PRelu", self, weight) def silu(g, input): return g.op("Mul", input, g.op("Sigmoid", input)) def mish(g, input): return g.op("Mul", input, g.op("Tanh", g.op("Softplus", input))) def op_with_optional_float_cast(g, op_name, *args, **kwargs): """Some PyTorch operators (e.g., Clip/Min/ReLU/Pad) are super set of ONNX in terms of data types. This function maximizes the exportability of PyTorch-ONNX by allowing ONNX-unsupported PyTorch operator data type. For example, `Cast(Clip(Cast(INPUT)))` can be used to mimic `Clip(INPUT)` (opset version < 12). Args: g (torch._C.Graph): graph to write the ONNX representation into. op_name (str): operator name in ONNX. *args (tuple): operands to the operator. **kwargs (dict): attributes to the operator along with "opset_before" (optional, None by default) indicating the smallest opset version to trigger such casting behavior and "target_float_t" (optional, "Float" by default) indicating the data type of internal operator. Returns: Optional[torch._C.Value, Tuple[torch._C.Value, ...]]: output(s) of the operator. """ opset_before = kwargs.pop("opset_before", None) target_float_t = kwargs.pop("target_float_t", "Float") inputs = list(args) dtype_0 = inputs[0].type().scalarType() require_cast = not symbolic_helper._is_fp(inputs[0]) and ( opset_before is None or GLOBALS.export_onnx_opset_version < opset_before ) if require_cast: for input in inputs: if input.isCompleteTensor() and input.type().scalarType() != dtype_0: raise RuntimeError( f"Inputs of {op_name} must have same dtype. Got {dtype_0} and {input.type().scalarType()}" ) for i, input in enumerate(inputs): if input.isCompleteTensor() and not symbolic_helper._is_fp(input): inputs[i] = g.op( "Cast", input, to_i=symbolic_helper.cast_pytorch_to_onnx[target_float_t], ) self = g.op(op_name, *inputs, **kwargs) if require_cast: self = g.op("Cast", self, to_i=symbolic_helper.cast_pytorch_to_onnx[dtype_0]) return self @symbolic_helper.quantized_args(True) def relu(g, input): return op_with_optional_float_cast(g, "Relu", input, opset_before=14) @symbolic_helper.quantized_args(True) def relu6(g, input): relu = op_with_optional_float_cast(g, "Relu", input, opset_before=14) return clamp_max(g, relu, 6) def ceil(g, input): return g.op("Ceil", input) def floor(g, input): return g.op("Floor", input) def _len(g, self): sz_0 = size(g, self, g.op("Constant", value_t=torch.LongTensor([0]))) return symbolic_helper._squeeze_helper(g, sz_0, [0]) @symbolic_helper.parse_args("v", "t", "t") def threshold(g, self, threshold, value): # See Note [Export inplace] if symbolic_helper._scalar(threshold) != 0: return symbolic_helper._unimplemented("threshold", "non-zero threshold") if symbolic_helper._scalar(value) != 0: return symbolic_helper._unimplemented("threshold", "non-zero value") return g.op("Relu", self) def leaky_relu(g, input, negative_slope, inplace=False): negative_slope = symbolic_helper._get_const(negative_slope, "t", "negative_slope") # See Note [Export inplace] # TODO: Talk to ONNX about unconditional cast of scalar to float return g.op("LeakyRelu", input, alpha_f=symbolic_helper._scalar(negative_slope)) @symbolic_helper.parse_args("v", "i") def glu(g, input, dim): dim_size = symbolic_helper._get_tensor_dim_size(input, dim) if dim_size is not None: assert dim_size % 2 == 0 first, second = g.op("Split", input, axis_i=dim, outputs=2) return g.op("Mul", first, g.op("Sigmoid", second)) @symbolic_helper.parse_args("v", "i", "none") def softmax(g, input, dim, dtype=None): # Softmax does normalization at vector level. # PyTorch and ONNX use different strategies to split the input tensor into vectors. # Thus dim and axis have different meanings. # PyTorch slices the input tensor into vectors along the `dim`-th dimension. # ONNX reshapes the input into a 2-D tensor, and `axis` indicates where the input is coerced. # If input is a 2 x 3 tensor: # input = [[1.0, 1.0, 1.0], # [1.0, 1,0, 1,0]] # with dim = 0, the result is: # result = [[0.5, 0.5, 0.5], # [0.5, 0.5, 0.5]] # with axis = 0, the result is: # result = [[0.167, 0.167, 0.167], # [0.167, 0.167, 0.167]] # So only when dim and axis both equal to ndim - 1 (the last dimension), # their semantics are equivalent. # So use softmax when dim and axis both equal to ndim - 1, # otherwise transpose the input to put the vectors to be normalized to the last dimension. # When input rank is not known at export time we compute softmax using a subgraph # with other operators input_dim = symbolic_helper._get_tensor_rank(input) if input_dim is not None: # TODO: remove this as onnx opset 11 spec allows negative axes if dim < 0: dim = input_dim + dim is_transpose_required = input_dim != dim + 1 if is_transpose_required: axes = list(range(input_dim)) axes[dim], axes[-1] = axes[-1], axes[dim] input = g.op("Transpose", input, perm_i=axes) dim = input_dim - 1 softmax = g.op("Softmax", input, axis_i=dim) if dtype and dtype.node().kind() != "prim::Constant": parsed_dtype = symbolic_helper._get_const(dtype, "i", "dtype") softmax = g.op( "Cast", softmax, to_i=symbolic_helper.scalar_type_to_onnx[parsed_dtype] ) if is_transpose_required: softmax = g.op("Transpose", softmax, perm_i=axes) return softmax # Apply max normalization. input = g.op("Sub", input, g.op("ReduceMax", input, axes_i=[dim], keepdims_i=1)) exp = g.op("Exp", input) sum = symbolic_helper._reducesum_helper(g, exp, axes_i=[dim]) softmax = g.op("Div", exp, sum) if dtype and dtype.node().kind() != "prim::Constant": parsed_dtype = symbolic_helper._get_const(dtype, "i", "dtype") softmax = g.op( "Cast", softmax, to_i=symbolic_helper.scalar_type_to_onnx[parsed_dtype] ) return softmax def softplus(g, self, beta, threshold): beta_const = symbolic_helper._maybe_get_const(beta, "f") if beta_const != 1: return g.op("Div", g.op("Softplus", g.op("Mul", self, beta)), beta) return g.op("Softplus", self) def get_pool_ceil_padding(input, kernel_size, stride, padding): sizes = symbolic_helper._get_tensor_sizes(input) dim = sizes[-len(padding) :] if sizes is not None else None if dim is None or any([i is None for i in dim]): return symbolic_helper._unimplemented(name, "input size not accessible") ceiled_output_dim = [ int(math.ceil((dim[i] + 2 * padding[i] - kernel_size[i]) / float(stride[i]))) + 1 for i in range(0, len(padding)) ] # ensure last pooling starts inside ceiled_output_dim = [ ceiled_output_dim[i] - 1 if (((ceiled_output_dim[i] - 1) * stride[i]) >= (dim[i] + padding[i])) else ceiled_output_dim[i] for i in range(0, len(ceiled_output_dim)) ] padding_ceil = [ 0 if (stride[i] == 1) else ( kernel_size[i] - (dim[i] + 2 * padding[i] - ((ceiled_output_dim[i] - 1) * stride[i] + 1)) ) for i in range(0, len(padding)) ] # ensure padding is not > kernel_size padding_ceil = [ ( int(padding_ceil[i]) if padding_ceil[i] < kernel_size[i] - 1 else int(kernel_size[i] - 1) ) if ((padding_ceil[i] + 2 * padding[i]) >= (kernel_size[i])) else int(padding_ceil[i]) for i in range(0, len(padding_ceil)) ] return padding_ceil def _max_pool(name, tuple_fn, ndims, return_indices): @symbolic_helper.quantized_args(True, False, False, False, False, False) @symbolic_helper.parse_args("v", "is", "is", "is", "is", "i") def symbolic_fn(g, input, kernel_size, stride, padding, dilation, ceil_mode): if set(tuple_fn(dilation)) != {1}: return symbolic_helper._unimplemented(name, "dilation") if not stride: stride = kernel_size padding = tuple(tuple_fn(padding)) if ceil_mode: padding_ceil = get_pool_ceil_padding(input, kernel_size, stride, padding) padding = padding + tuple(a + b for (a, b) in zip(padding_ceil, padding)) else: padding = padding * 2 kwargs = { "kernel_shape_i": tuple_fn(kernel_size), "pads_i": padding, "strides_i": tuple_fn(stride), } # easy but hacky way to get flattened indices values # to be used to convert the indices values to non-flattened. # In ONNX the indices are computed as a flatten 1-D tensor, # so the values in indices are in [0, N x C x D1 x ... x Dn). # To convert the indices to the same format used by Pytorch, # we first execute a maxpool with a kernel and stride of 1 on the same input. # This will result in a tensor of indices in which each index will have it's own value. # Using this tensor as a reference, we extract the first index of each axis and substract # it from each index of this axis in the indices to convert. # This step will result in a tensor were each dimension has values of indices within # the dimension it is in. # For more information : # https://github.com/pytorch/pytorch/pull/16455#issuecomment-460776407 if return_indices: r, indices = g.op("MaxPool", input, outputs=2, **kwargs) _, flattened_indices = g.op( "MaxPool", input, outputs=2, kernel_shape_i=[1 for _ in range(ndims)], strides_i=[1 for _ in range(ndims)], ) # convert indices to have non-flattened indices values s = symbolic_helper._slice_helper( g, flattened_indices, axes=[2 + i for i in range(ndims)], starts=tuple_fn(0), ends=tuple_fn(1), ) indices = sub(g, indices, s) return r, indices else: r = g.op("MaxPool", input, outputs=1, **kwargs) return r return symbolic_fn max_pool1d = _max_pool( "max_pool1d", torch.nn.modules.utils._single, 1, return_indices=False ) max_pool2d = _max_pool( "max_pool2d", torch.nn.modules.utils._pair, 2, return_indices=False ) max_pool3d = _max_pool( "max_pool3d", torch.nn.modules.utils._triple, 3, return_indices=False ) max_pool1d_with_indices = _max_pool( "max_pool1d_with_indices", torch.nn.modules.utils._single, 1, return_indices=True, ) max_pool2d_with_indices = _max_pool( "max_pool2d_with_indices", torch.nn.modules.utils._pair, 2, return_indices=True, ) max_pool3d_with_indices = _max_pool( "max_pool3d_with_indices", torch.nn.modules.utils._triple, 3, return_indices=True, ) def _avg_pool(name, tuple_fn): @symbolic_helper.quantized_args(True) @symbolic_helper.parse_args("v", "is", "is", "is", "i", "i", "none") def symbolic_fn( g, input: _C.Value, kernel_size: Tuple[int, ...], stride: Tuple[int, ...], padding: Union[int, Tuple[int, ...]], ceil_mode: int, count_include_pad: int, divisor_override=None, ): if not stride: stride = kernel_size padding = symbolic_helper._avgpool_helper( tuple_fn, padding, kernel_size, stride, divisor_override, name ) adjusted_padding = padding if count_include_pad: input = g.op( "Pad", input, pads_i=((0,) * 2 + padding) * 2, mode_s="constant", value_f=0.0, ) adjusted_padding = (0,) * len(padding) if ceil_mode: padding_ceil = get_pool_ceil_padding(input, kernel_size, stride, padding) adjusted_padding = adjusted_padding + tuple( a + b for (a, b) in zip(padding_ceil, adjusted_padding) ) else: adjusted_padding = adjusted_padding * 2 output = g.op( "AveragePool", input, kernel_shape_i=tuple_fn(kernel_size), strides_i=tuple_fn(stride), pads_i=adjusted_padding, ) return output return symbolic_fn avg_pool1d = _avg_pool("avg_pool1d", torch.nn.modules.utils._single) avg_pool2d = _avg_pool("avg_pool2d", torch.nn.modules.utils._pair) avg_pool3d = _avg_pool("avg_pool3d", torch.nn.modules.utils._triple) def _adaptive_pool(name, type, tuple_fn, fn=None): @symbolic_helper.quantized_args(True, False) def symbolic_fn(g, input, output_size): # _adaptive_pool is supported for cases where output_size is 1 for all dimensions, # by executing a GlobalPool. # It is also supported for cases where the output size is a factor of the input size. # For these cases the stride and kernel size are uniform along all the indices of # the same dimension, which makes it possible to export it to ONNX. # for MaxPool, GlobalMaxPool does not return indices, # so we try using max_poolxd_with_indices, and if it is not possible # (input is not a complete tensor or output size not factor of input size) # then we call GlobalAveragePool and return None for the indices try: output_size = symbolic_helper._parse_arg(output_size, "is") except Exception: return symbolic_helper._onnx_unsupported( "adaptive pooling, since output_size is not constant." ) if output_size == [1] * len(output_size) and type == "AveragePool": return g.op("GlobalAveragePool", input) sizes = symbolic_helper._get_tensor_sizes(input) try: dim = sizes[2:] except Exception: dim = None if dim is None or any([i is None for i in dim]): if output_size == [1] * len(output_size): return g.op("GlobalMaxPool", input), None return symbolic_helper._unimplemented(name, "input size not accessible") # verify if output size % input size = 0 for all dim mod = [dim[i] % output_size[i] for i in range(0, len(dim))] if mod != [0] * len(mod): if output_size == [1] * len(output_size): return g.op("GlobalMaxPool", input), None return symbolic_helper._unimplemented( name, "output size that are not factor of input size" ) k = [int(dim[i] / output_size[i]) for i in range(0, len(dim))] # call max_poolxd_with_indices to get indices in the output if type == "MaxPool": return fn(g, input, k, k, (0,) * len(dim), (1,) * len(dim), False) output = g.op(type, input, kernel_shape_i=tuple_fn(k), strides_i=tuple_fn(k)) return output return symbolic_fn adaptive_avg_pool1d = _adaptive_pool( "adaptive_avg_pool1d", "AveragePool", torch.nn.modules.utils._single ) adaptive_avg_pool2d = _adaptive_pool( "adaptive_avg_pool2d", "AveragePool", torch.nn.modules.utils._pair ) adaptive_avg_pool3d = _adaptive_pool( "adaptive_avg_pool3d", "AveragePool", torch.nn.modules.utils._triple ) adaptive_max_pool1d = _adaptive_pool( "adaptive_max_pool1d", "MaxPool", torch.nn.modules.utils._single, max_pool1d_with_indices, ) adaptive_max_pool2d = _adaptive_pool( "adaptive_max_pool2d", "MaxPool", torch.nn.modules.utils._pair, max_pool2d_with_indices, ) adaptive_max_pool3d = _adaptive_pool( "adaptive_max_pool3d", "MaxPool", torch.nn.modules.utils._triple, max_pool3d_with_indices, ) # Generate paddings in ONNX order based on pad in pytorch. # Args: # dim: the dimension of the tensor. # pad: the paddings in pytorch. # The order is dim_n_begin, dim_n_end, dim_n-1_begin, dim_n-1_end, ... def _prepare_onnx_paddings(dim, pad): assert isinstance(dim, int) # The desired order of paddings is # dim_0_begin, dim_1_begin, ... , dim_0_end, ..., dim_n_end. # n is the dimension of input. # assume zero-dimensions in the beginning paddings = list(pad[:]) + [0] * (dim * 2 - len(pad)) # reverse order and collate first beginnings and then ends paddings = paddings[-2::-2] + paddings[-1::-2] return paddings def _convert_padding_node(padding): padding = symbolic_helper._maybe_get_const(padding, "is") if symbolic_helper._is_value(padding) and symbolic_helper._is_packed_list(padding): input_list = symbolic_helper._unpack_list(padding) try: padding = [ symbolic_helper._get_const(v, "i", "padding") for v in input_list ] except Exception: return symbolic_helper._onnx_opset_unsupported_detailed( "Pad", 9, 11, "The sizes of the padding must be constant" ) return padding def constant_pad_nd(g, input, padding, value): mode = "constant" try: value = symbolic_helper._get_const(value, "f", "value") except Exception: return symbolic_helper._onnx_opset_unsupported_detailed( "Pad", 9, 11, "The value for the padding must be constant" ) padding = _convert_padding_node(padding) paddings = _prepare_onnx_paddings(symbolic_helper._get_tensor_rank(input), padding) return op_with_optional_float_cast( g, "Pad", input, pads_i=paddings, mode_s=mode, value_f=value, opset_before=11 ) def _pad_circular(g, input, pad): padding = _convert_padding_node(pad) assert len(padding) % 2 == 0 ndim = len(padding) // 2 cur = input for idx in range(ndim): pad_l = padding[-(2 * idx + 1)] pad_r = padding[-(2 * idx + 2)] tensors = [] if pad_l > 0: left = symbolic_helper._slice_helper( g, cur, axes=[2 + idx], starts=[-(pad_l + 1)], ends=[-1] ) tensors.append(left) if pad_l < 0 or pad_r < 0: middle = symbolic_helper._slice_helper( g, cur, axes=[2 + idx], starts=[max(0, -pad_l)], ends=[-(1 + max(0, -pad_r))], ) tensors.append(middle) else: tensors.append(cur) if pad_r > 0: right = symbolic_helper._slice_helper( g, cur, axes=[2 + idx], starts=[0], ends=[pad_r] ) tensors.append(right) cur = g.op("Concat", *tensors, axis_i=(2 + idx)) return cur def reflection_pad(g, input, padding): mode = "reflect" padding = _convert_padding_node(padding) paddings = _prepare_onnx_paddings(symbolic_helper._get_tensor_rank(input), padding) return op_with_optional_float_cast( g, "Pad", input, pads_i=paddings, mode_s=mode, opset_before=11 ) def replication_pad(g, input, padding): mode = "edge" padding = _convert_padding_node(padding) paddings = _prepare_onnx_paddings(symbolic_helper._get_tensor_rank(input), padding) return op_with_optional_float_cast( g, "Pad", input, pads_i=paddings, mode_s=mode, opset_before=11 ) reflection_pad1d = reflection_pad reflection_pad2d = reflection_pad reflection_pad3d = reflection_pad replication_pad1d = replication_pad replication_pad2d = replication_pad replication_pad3d = replication_pad def pad(g, input, pad, mode, value): mode = symbolic_helper._parse_arg(mode, "s") if mode == "replicate": return replication_pad(g, input, pad) elif mode == "reflect": return reflection_pad(g, input, pad) elif mode == "constant": return constant_pad_nd(g, input, pad, value) elif mode == "circular": return _pad_circular(g, input, pad) else: raise RuntimeError(f"Unrecognized padding mode {mode}") def _interpolate(name, dim, interpolate_mode): def symbolic_fn(g, input, output_size, *args): scales, align_corners = symbolic_helper._get_interpolate_attributes( g, interpolate_mode, args ) symbolic_helper._interpolate_warning(interpolate_mode) align_corners = symbolic_helper._maybe_get_scalar(align_corners) if align_corners: return symbolic_helper._unimplemented(name, "align_corners == True") if scales is None: scales = symbolic_helper._interpolate_size_to_scales( g, input, output_size, dim ) return g.op("Upsample", input, scales, mode_s=interpolate_mode) return symbolic_fn upsample_nearest1d = _interpolate("upsample_nearest1d", 3, "nearest") upsample_nearest2d = _interpolate("upsample_nearest2d", 4, "nearest") upsample_nearest3d = _interpolate("upsample_nearest3d", 5, "nearest") upsample_linear1d = _interpolate("upsample_linear1d", 3, "linear") upsample_bilinear2d = _interpolate("upsample_bilinear2d", 4, "linear") upsample_trilinear3d = _interpolate("upsample_trilinear3d", 5, "linear") def __interpolate( g, input, size, scale_factor, mode, align_corners, recompute_scale_factor, antialias ): scales, mode = symbolic_helper._interpolate_get_scales_and_mode( g, input, size, scale_factor, mode, align_corners ) return g.op("Upsample", input, scales, mode_s=mode) def bitwise_not(g, inp): if inp.type().scalarType() != "Bool": raise NotImplementedError( "ONNX export does NOT support exporting bitwise Not " + "for non-boolean input values" ) return g.op("Not", inp) def wrap_logical_op_with_cast_to(to_type): def decorator(fn): def wrap_with_cast(g, input, other): return g.op( "Cast", fn(g, input, other), to_i=symbolic_helper.cast_pytorch_to_onnx[to_type], ) return wrap_with_cast return decorator def wrap_logical_op_with_cast_to_and_from(to_type): def decorator(fn): def wrap_with_cast(g, input, other): to_cast_func = globals()["_cast_{}".format(to_type)] from_cast_func = wrap_logical_op_with_cast_to(input.type().scalarType())(fn) return from_cast_func( g, to_cast_func(g, input, False), to_cast_func(g, other, False) ) return wrap_with_cast return decorator def wrap_logical_op_with_negation(func): def wrap_with_not(g, input, other): return g.op("Not", func(g, input, other)) return wrap_with_not def __not_(g, self): if self.type().scalarType() != "Bool": raise NotImplementedError( "ONNX export does NOT support exporting bitwise Not " + "for non-boolean input values" ) return g.op("Not", self) def eq(g, self, other): if isinstance(self.type(), _C.DeviceObjType) and isinstance( other.type(), _C.DeviceObjType ): # ONNX doesn't have devices, so consider them all to be equal. # The no-op check for equality will get constant-folded. return g.op("Constant", value_t=torch.tensor(True, dtype=torch.bool)) return g.op("Equal", self, other) @wrap_logical_op_with_negation def ne(g, self, other): return eq(g, self, other) def gt(g, input, other): return gt_impl(g, input, other) def gt_impl(g, input, other): if ( input.type().scalarType() is not None and input.type().scalarType() == "Bool" and other.type().scalarType() is not None and other.type().scalarType() == "Bool" ): input = g.op("Cast", input, to_i=symbolic_helper.cast_pytorch_to_onnx["Int"]) other = g.op("Cast", other, to_i=symbolic_helper.cast_pytorch_to_onnx["Int"]) return g.op("Greater", input, other) def lt(g, input, other): return lt_impl(g, input, other) def lt_impl(g, input, other): if ( input.type().scalarType() is not None and input.type().scalarType() == "Bool" and other.type().scalarType() is not None and other.type().scalarType() == "Bool" ): input = g.op("Cast", input, to_i=symbolic_helper.cast_pytorch_to_onnx["Int"]) other = g.op("Cast", other, to_i=symbolic_helper.cast_pytorch_to_onnx["Int"]) return g.op("Less", input, other) @wrap_logical_op_with_negation def ge(g, input, other): return lt_impl(g, input, other) @wrap_logical_op_with_negation def le(g, input, other): return gt_impl(g, input, other) def __and_(g, input, other): if input.type().scalarType() == "Bool" and other.type().scalarType() == "Bool": return g.op("And", input, other) else: raise NotImplementedError( "ONNX export does NOT support exporting bitwise AND " + "for non-boolean input values" ) def __or_(g, input, other): if input.type().scalarType() == "Bool" and other.type().scalarType() == "Bool": return g.op("Or", input, other) else: raise NotImplementedError( "ONNX export does NOT support exporting bitwise OR " + "for non-boolean input values" ) def __xor_(g, input, other): if input.type().scalarType() == "Bool" and other.type().scalarType() == "Bool": return g.op("Xor", input, other) else: raise NotImplementedError( "ONNX export does NOT support exporting bitwise XOR " + "for non-boolean input values" ) @wrap_logical_op_with_cast_to_and_from("Bool") def logical_and(g, input, other): return g.op("And", input, other) @wrap_logical_op_with_cast_to_and_from("Bool") def logical_or(g, input, other): return g.op("Or", input, other) @wrap_logical_op_with_cast_to_and_from("Bool") def logical_xor(g, input, other): return g.op("Xor", input, other) def __rshift_(g, self, other): # make sure to cast other to self's type # (when self is long, make sure that other is not float) if other.type().scalarType() != self.type().scalarType(): other = g.op( "Cast", other, to_i=symbolic_helper.cast_pytorch_to_onnx[self.type().scalarType()], ) two = g.op("Constant", value_t=torch.tensor(2, dtype=torch.float32)) # exponent (same type as self) has to be float or double in onnx::Pow if not symbolic_helper._is_fp(self): other = g.op("Cast", other, to_i=symbolic_helper.cast_pytorch_to_onnx["Float"]) two_pow = g.op("Pow", two, other) two_pow = g.op( "Cast", two_pow, to_i=symbolic_helper.cast_pytorch_to_onnx[self.type().scalarType()], ) rshift = g.op("Div", self, two_pow) return rshift def __lshift_(g, self, other): # make sure to cast other to self's type # (when self is long, make sure that other is not float) if other.type().scalarType() != self.type().scalarType(): other = g.op( "Cast", other, to_i=symbolic_helper.cast_pytorch_to_onnx[self.type().scalarType()], ) two = g.op("Constant", value_t=torch.tensor(2, dtype=torch.float32)) # exponent (same type as self) has to be float or double in onnx::Pow if not symbolic_helper._is_fp(self): other = g.op("Cast", other, to_i=symbolic_helper.cast_pytorch_to_onnx["Float"]) two_pow = g.op("Pow", two, other) two_pow = g.op( "Cast", two_pow, to_i=symbolic_helper.cast_pytorch_to_onnx[self.type().scalarType()], ) lshift = g.op("Mul", self, two_pow) return lshift @symbolic_helper.parse_args("v", "v", "v", "i") def where(g, condition, self=None, other=None, _outputs=None): # Assumes that torch.where's first argument takes only Bool and Byte tensors. if condition.type().scalarType() != "Bool": condition = g.op( "Cast", condition, to_i=symbolic_helper.cast_pytorch_to_onnx["Bool"] ) if self is None: condition = nonzero(g, condition) return symbolic_helper._unbind_helper( g, condition, g.op("Constant", value_t=torch.tensor(1)), _outputs ) return g.op("Where", condition, self, other) @symbolic_helper.parse_args("v", "i", "none") def log_softmax(g, input, dim, dtype=None): # PyTorch dim and ONNX axis have different meanings. # See Softmax comment for details. # TODO: remove this as onnx opset 11 spec allows negative axes input_dim = symbolic_helper._get_tensor_rank(input) if input_dim is None: return symbolic_helper._unimplemented( "dim", "ONNX and PyTorch use different strategies to split the input. " "Input rank must be known at export time.", ) if dim < 0: dim = input_dim + dim is_transpose_required = input_dim != dim + 1 # ONNX only supports log_softmax with dim = -1. Transpose must be added before and after log_softmax to support other cases. if is_transpose_required: axes = list(range(input_dim)) axes[dim], axes[-1] = axes[-1], axes[dim] input = g.op("Transpose", input, perm_i=axes) dim = input_dim - 1 return_op = g.op("LogSoftmax", input, axis_i=dim) if dtype and dtype.node().kind() != "prim::Constant": parsed_dtype = symbolic_helper._get_const(dtype, "i", "dtype") return_op = g.op( "Cast", return_op, to_i=symbolic_helper.scalar_type_to_onnx[parsed_dtype] ) if is_transpose_required: return_op = g.op("Transpose", return_op, perm_i=axes) return return_op @symbolic_helper.parse_args( "v", "v", "v", "is", "is", "is", "i", "is", "i", "i", "i", "i", "i" ) def _convolution( g, input, weight, bias, stride, padding, dilation, transposed, output_padding, groups, benchmark, deterministic, cudnn_enabled, allow_tf32=None, ): weight_size = symbolic_helper._get_tensor_sizes(weight) try: kernel_shape = weight_size[2:] except Exception: kernel_shape = None if kernel_shape is None or any([i is None for i in kernel_shape]): raise RuntimeError( "Unsupported: ONNX export of convolution for kernel " "of unknown shape." ) args = [input, weight] # ONNX only supports 1D bias if ( not symbolic_helper._is_none(bias) and symbolic_helper._get_tensor_rank(bias) == 1 ): args.append(bias) kwargs = { "kernel_shape_i": weight_size[2:], "strides_i": stride, # NB: ONNX supports asymmetric padding, whereas PyTorch supports only # symmetric padding "pads_i": padding + padding, "dilations_i": dilation, "group_i": groups, } if any(o != 0 for o in output_padding): # ONNX supports both output_shape and output_padding. they are equivalent expressive. # output_padding is more straightforward, so we use it here. # output_shape = stride * (input_shape - 1) + output_padding + kernel_shape - padding * 2 assert transposed assert len(stride) == len(output_padding) kwargs["output_padding_i"] = output_padding n = g.op("ConvTranspose" if transposed else "Conv", *args, **kwargs) if ( not symbolic_helper._is_none(bias) and symbolic_helper._get_tensor_rank(bias) != 1 ): return g.op("Add", n, bias) else: return n @symbolic_helper.parse_args("v", "v", "v", "is", "is", "is", "i") def conv1d(g, input, weight, bias, stride, padding, dilation, groups): return _convolution( g, input, weight, bias, stride, padding, dilation, False, (), groups, None, None, None, None, ) @symbolic_helper.parse_args("v", "v", "v", "is", "is", "is", "i") def conv2d(g, input, weight, bias, stride, padding, dilation, groups): return _convolution( g, input, weight, bias, stride, padding, dilation, False, (), groups, None, None, None, None, ) @symbolic_helper.parse_args("v", "v", "v", "is", "is", "is", "i") def conv3d(g, input, weight, bias, stride, padding, dilation, groups): return _convolution( g, input, weight, bias, stride, padding, dilation, False, (), groups, None, None, None, None, ) @symbolic_helper.parse_args("v", "v", "v", "is", "is", "is", "i", "is") def conv_transpose1d( g, input, weight, bias, stride, padding, output_padding, groups, dilation ): return _convolution( g, input, weight, bias, stride, padding, dilation, True, output_padding, groups, None, None, None, None, ) @symbolic_helper.parse_args("v", "v", "v", "is", "is", "is", "i", "is") def conv_transpose2d( g, input, weight, bias, stride, padding, output_padding, groups, dilation ): return _convolution( g, input, weight, bias, stride, padding, dilation, True, output_padding, groups, None, None, None, None, ) @symbolic_helper.parse_args("v", "v", "v", "is", "is", "is", "i", "is") def conv_transpose3d( g, input, weight, bias, stride, padding, output_padding, groups, dilation ): return _convolution( g, input, weight, bias, stride, padding, dilation, True, output_padding, groups, None, None, None, None, ) @symbolic_helper.parse_args("v", "v", "v", "v", "v", "i", "f", "f", "i") def batch_norm( g, input, weight, bias, running_mean, running_var, training, momentum, eps, cudnn_enabled, ): symbolic_helper.check_training_mode(training, "batch_norm") if ( torch.is_autocast_enabled() and not symbolic_helper.args_have_same_dtype( [input, weight, bias, running_mean, running_var] ) and GLOBALS.export_onnx_opset_version < 15 ): return symbolic_helper._onnx_opset_unsupported_detailed( "BatchNormalization", 9, 15, "All input tensors must have the same `dtype`." " Turn off Autocast or export using opset version 15.", ) weight, bias, running_mean, running_var = symbolic_helper._batchnorm_helper( g, input, weight, bias, running_mean, running_var ) out = g.op( "BatchNormalization", input, weight, bias, running_mean, running_var, epsilon_f=eps, momentum_f=1 - momentum, outputs=1 if not training else 5, ) if not training: return out else: res, new_running_mean, new_running_var, saved_mean, saved_var = out new_running_mean.setType(running_mean.type()) new_running_var.setType(running_var.type()) saved_mean.setDebugName("batch_norm_dead_output-" + saved_mean.debugName()) saved_var.setDebugName("batch_norm_dead_output-" + saved_var.debugName()) return res @symbolic_helper.parse_args("v", "is", "v", "v", "f", "i") def layer_norm(g, input, normalized_shape, weight, bias, eps, cudnn_enable): if symbolic_helper.is_caffe2_aten_fallback(): return g.at( "layer_norm", input, weight, bias, normalized_shape_i=normalized_shape, eps_f=eps, cudnn_enable_i=cudnn_enable, ) axes = [-i for i in range(len(normalized_shape), 0, -1)] two_cst = symbolic_helper._generate_wrapped_number(g, 2.0) eps_cst = symbolic_helper._generate_wrapped_number(g, eps) mean = g.op("ReduceMean", input, axes_i=axes) numerator = sub(g, input, mean) # variance = e((x - e(x))^2), and (x - e(x)) is the numerator in the layer_norm formula variance = g.op("ReduceMean", pow(g, numerator, two_cst), axes_i=axes) denominator = sqrt(g, add(g, variance, eps_cst)) layer_norm = g.op("Div", numerator, denominator) if not (weight is None or symbolic_helper._is_none(weight)): layer_norm = mul(g, layer_norm, weight) if not (bias is None or symbolic_helper._is_none(bias)): layer_norm = add(g, layer_norm, bias) return layer_norm @symbolic_helper.parse_args("v", "v", "v", "v", "v", "i", "f", "f", "i") def instance_norm( g, input, weight, bias, running_mean, running_var, use_input_stats, momentum, eps, cudnn_enabled, ): symbolic_helper.check_training_mode(use_input_stats, "instance_norm") channel_size = symbolic_helper._get_tensor_dim_size(input, 1) if weight is None or symbolic_helper._is_none(weight): if channel_size is None: raise RuntimeError( "Unsupported: ONNX export of instance_norm for unknown " "channel size." ) weight_value = torch.tensor([1.0] * channel_size).type( "torch." + input.type().scalarType() + "Tensor" ) weight = g.op("Constant", value_t=weight_value) if bias is None or symbolic_helper._is_none(bias): if channel_size is None: raise RuntimeError( "Unsupported: ONNX export of instance_norm for unknown " "channel size." ) bias_value = torch.tensor([0.0] * channel_size).type( "torch." + input.type().scalarType() + "Tensor" ) bias = g.op("Constant", value_t=bias_value) if ( running_mean is None or symbolic_helper._is_none(running_mean) or running_var is None or symbolic_helper._is_none(running_var) ): return g.op("InstanceNormalization", input, weight, bias, epsilon_f=eps) else: input_size = symbolic_helper._get_tensor_sizes(input) # If input shape is [N, C, H, W], reshape to [1, N * C, H, W] and call batch_norm. # For more information instance_norm(): # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Normalization.cpp#L542 input_size_reshape = input_size.copy() n = input_size[0] if n is None: raise RuntimeError( "Unsupported: ONNX export of instance_norm training for unknown " "batch size." ) c = input_size[1] input_size_reshape[0] = 1 input_size_reshape[1] = n * c weight_ = repeat( g, weight, g.op("Constant", value_t=torch.tensor([n], dtype=torch.int64)) ) bias_ = repeat( g, bias, g.op("Constant", value_t=torch.tensor([n], dtype=torch.int64)) ) running_mean_ = repeat( g, running_mean, g.op("Constant", value_t=torch.tensor([n], dtype=torch.int64)), ) running_var_ = repeat( g, running_var, g.op("Constant", value_t=torch.tensor([n], dtype=torch.int64)), ) input_reshaped = g.op( "Reshape", input, g.op("Constant", value_t=torch.LongTensor(input_size_reshape)), ) out = batch_norm( g, input_reshaped, weight_, bias_, running_mean_, running_var_, use_input_stats, momentum, eps, cudnn_enabled, ) return view(g, out, g.op("Constant", value_t=torch.tensor(input_size))) @symbolic_helper.parse_args("v", "i", "i", "i") def unfold(g, input, dimension, size, step): if symbolic_helper.is_caffe2_aten_fallback(): return g.at("unfold", input, dimension_i=dimension, size_i=size, step_i=step) sizes = symbolic_helper._get_tensor_sizes(input) try: sizedim = sizes[dimension] except Exception: sizedim = None if sizedim is not None: low_indices = range(0, sizedim, step) hi_indices = range(size, sizedim + 1, step) stack = [ symbolic_helper._slice_helper( g, input, axes=[dimension], starts=[low], ends=[hi] ) for low, hi in zip(low_indices, hi_indices) ] ndim = len(sizes) perm = list(range(0, ndim)) perm.append(perm.pop(dimension)) unsqueeze = [ symbolic_helper._unsqueeze_helper( g, g.op("Transpose", t, perm_i=perm), [dimension] ) for t in stack ] return g.op("Concat", *unsqueeze, axis_i=dimension) else: return symbolic_helper._unimplemented("Unfold", "input size not accessible") @symbolic_helper.parse_args("v", "t", "t", "t") def elu(g, input, alpha, scale, input_scale): if scale and scale != 1.0: return symbolic_helper._unimplemented("scale", "does not support scale in Elu") if input_scale and input_scale != 1.0: return symbolic_helper._unimplemented( "input_scale", "does not support input_scale in Elu" ) # See Note [Export inplace] return g.op("Elu", input, alpha_f=symbolic_helper._scalar(alpha)) def selu(g, input): return g.op("Selu", input) @symbolic_helper.parse_args("v", "i", "v") def index_select(g, self, dim, index): # In case of a scalar index, index_select returns a tensor with the same rank as the input. # To match this behavior in ONNX, we make index a 1D tensor so that the following gather # also produces a tensor with the same rank as the input. return symbolic_helper._select_helper(g, self, dim, index) def index_put(g, self, indices_list_value, values, accumulate): if symbolic_helper._is_packed_list(indices_list_value): indices_list = symbolic_helper._unpack_list(indices_list_value) else: indices_list = [indices_list_value] if symbolic_helper.is_caffe2_aten_fallback(): args = [self] + indices_list + [values, accumulate] return g.at("index_put", *args) accumulate = symbolic_helper._parse_arg(accumulate, "b") if len(indices_list) == 0: if accumulate: return add(g, self, values) else: return values else: symbolic_helper._onnx_opset_unsupported("index_put", 9, 11) def index_fill(g, self, dim, index, value): dim_value = symbolic_helper._parse_arg(dim, "i") if symbolic_helper.is_caffe2_aten_fallback(): return g.at( "index_fill", self, index, value, overload_name="int_Scalar", dim_i=dim_value, ) expanded_index_shape, expanded_index = symbolic_helper._index_fill_reshape_helper( g, self, dim, index ) value = symbolic_helper._maybe_get_scalar(value) value = symbolic_helper._if_scalar_type_as(g, value, self) expanded_value = expand(g, value, expanded_index_shape, None) return scatter(g, self, dim, expanded_index, expanded_value) def index_copy(g, self, dim, index, source): dim_value = symbolic_helper._parse_arg(dim, "i") if symbolic_helper.is_caffe2_aten_fallback(): return g.at("index_copy", self, index, source, dim_i=dim_value) expanded_index_shape, expanded_index = symbolic_helper._index_fill_reshape_helper( g, self, dim, index ) return scatter(g, self, dim, expanded_index, source) @symbolic_helper.parse_args("v", "v", "b", "b") def bucketize(g, self, boundaries, out_int32=False, right=False): out_type = _C_onnx.TensorProtoDataType.INT64 if out_int32: out_type = _C_onnx.TensorProtoDataType.INT32 # A tensor expanded_boundaries is created such that it # contains a copy of boundaries for each element of self. new_shape = g.op("Concat", g.op("Shape", boundaries), g.op("Shape", self), axis_i=0) # Unsqueeze step is performed to respect ONNX's numpy style broadcasting for comparison ops # https://github.com/onnx/onnx/blob/main/docs/Broadcasting.md unsqueeze_axes = list(range(1, symbolic_helper._get_tensor_rank(self) + 1)) expanded_boundaries = expand( g, symbolic_helper._unsqueeze_helper(g, boundaries, unsqueeze_axes), new_shape, None, ) # Compare each element of self to boundaries to get a tensor # with leading 1s and trailing 0s. # e.g., 4 > [1, 3, 4] = [1, 1, 0] # The index of the last 1 is the bucket where the element should go. if right: cond = ge(g, self, expanded_boundaries) else: cond = gt(g, self, expanded_boundaries) cond_out = g.op("Cast", cond, to_i=out_type) # Sum to get the number of 1s corresponding to each element, # which is the same as the bucket index. # e.g., sum(4 > [1, 3, 4]) = sum([1, 1, 0]) = 2 return symbolic_helper._reducesum_helper(g, cond_out, axes_i=[0], keepdims_i=0) def type_as(g, self, other): self_dtype = symbolic_helper._try_get_scalar_type(self) other_dtype = symbolic_helper._try_get_scalar_type(other) if self_dtype == other_dtype and self_dtype is not None: return self if other_dtype is not None: return g.op( "Cast", self, to_i=symbolic_helper.cast_pytorch_to_onnx[other_dtype] ) else: if symbolic_helper.is_caffe2_aten_fallback(): # We don't know the type of other, bail by emitting ATen return g.at("type_as", self, other) else: raise RuntimeError( "Unsupported: ONNX export of type_as for tensor " "of unknown dtype. Please check if the dtype of the " "parameter passed to the type_as function is correct." ) @symbolic_helper.parse_args("v", "v", "i", "f") def cosine_similarity(g, x1, x2, dim, eps): if symbolic_helper.is_caffe2_aten_fallback(): return g.at("cosine_similarity", x1, x2, dim_i=dim, eps_f=eps) cross = symbolic_helper._reducesum_helper( g, mul(g, x1, x2), axes_i=[dim], keepdims_i=0 ) x1_l2 = symbolic_helper._reducesum_helper( g, mul(g, x1, x1), axes_i=[dim], keepdims_i=0 ) x2_l2 = symbolic_helper._reducesum_helper( g, mul(g, x2, x2), axes_i=[dim], keepdims_i=0 ) div_tens = max( g, sqrt(g, mul(g, x1_l2, x2_l2)), g.op("Constant", value_t=torch.tensor([eps])) ) return div(g, cross, div_tens) def pairwise_distance(g, input1, input2, p, eps, keepdim): if not symbolic_helper._is_value(eps): eps = g.op("Constant", value_t=torch.tensor([eps])) inv_p = div( g, g.op("Constant", value_t=torch.tensor([1], dtype=torch.float)), add(g, p, eps), ) summation = symbolic_helper._reducesum_helper( g, pow(g, sub(g, input1, input2), p), axes_i=[-1], keepdims_i=symbolic_helper._parse_arg(keepdim, "i"), ) return pow(g, summation, inv_p) # ignore clone operators that are inserted by PyTorch autograd def clone(g, input, unused_memory_format): return input def abs(g, self): return g.op("Abs", self) def log(g, self): return g.op("Log", self) def log1p(g, self): return log( g, add(g, symbolic_helper._if_scalar_type_as(g, torch.ones(1), self), self) ) def log10(g, self): _ln10 = 2.30258509299404568401 return g.op("Div", log(g, self), g.op("Constant", value_t=torch.tensor([_ln10]))) def pow(g, self, exponent): f_dtype = self_dtype = self.type().scalarType() if not symbolic_helper._is_fp(self): f_dtype = "Float" self = g.op("Cast", self, to_i=symbolic_helper.cast_pytorch_to_onnx[f_dtype]) if not symbolic_helper._is_fp(exponent): exponent = g.op( "Cast", exponent, to_i=symbolic_helper.cast_pytorch_to_onnx[f_dtype] ) pow = g.op("Pow", self, exponent) return pow def clamp(g, self, min, max): # min or max may be None that we need to dispatch to # Clip separately, as ONNX does not have None syntax if symbolic_helper._is_none(min): return clamp_max(g, self, max) elif symbolic_helper._is_none(max): return clamp_min(g, self, min) else: if symbolic_helper._is_constant(min) and symbolic_helper._is_constant(max): return op_with_optional_float_cast( g, "Clip", self, min_f=symbolic_helper._parse_arg(min, "f"), max_f=symbolic_helper._parse_arg(max, "f"), opset_before=12, ) else: return clamp_max(g, clamp_min(g, self, min), max) @symbolic_helper.parse_args("v", "v") def clamp_min(g, self, min): if symbolic_helper._is_constant(min): return op_with_optional_float_cast( g, "Clip", self, min_f=symbolic_helper._parse_arg(min, "f"), opset_before=12 ) else: dtype = self.type().scalarType() min = g.op("Cast", min, to_i=symbolic_helper.cast_pytorch_to_onnx[dtype]) return op_with_optional_float_cast(g, "Max", self, min, opset_before=12) @symbolic_helper.parse_args("v", "v") def clamp_max(g, self, max): if symbolic_helper._is_constant(max): return op_with_optional_float_cast( g, "Clip", self, max_f=symbolic_helper._parse_arg(max, "f"), opset_before=12 ) else: dtype = self.type().scalarType() max = g.op("Cast", max, to_i=symbolic_helper.cast_pytorch_to_onnx[dtype]) return op_with_optional_float_cast(g, "Min", self, max, opset_before=12) # torch.max (same for torch.min) actually has two interfaces smashed together: # torch.max(x, dim, keepdim) and torch.max(x, y) def max(g, self, dim_or_y=None, keepdim=None): # torch.max(input) if dim_or_y is None and keepdim is None: return g.op("ReduceMax", self, keepdims_i=0) # torch.max(input, other) if keepdim is None: return op_with_optional_float_cast(g, "Max", self, dim_or_y, opset_before=12) # torch.max(input, dim, keepdim) else: dim = symbolic_helper._get_const(dim_or_y, "i", "dim") keepdim = symbolic_helper._get_const(keepdim, "i", "keepdim") max = g.op("ReduceMax", self, axes_i=[dim], keepdims_i=keepdim) indices = g.op("ArgMax", self, axis_i=dim, keepdims_i=keepdim) return max, indices def maximum(g, input, other): return max(g, input, dim_or_y=other) def min(g, self, dim_or_y=None, keepdim=None): # torch.min(input) if dim_or_y is None and keepdim is None: return g.op("ReduceMin", self, keepdims_i=0) # torch.min(input, other) if keepdim is None: return op_with_optional_float_cast(g, "Min", self, dim_or_y, opset_before=12) # torch.min(input, dim, keepdim) else: dim = symbolic_helper._get_const(dim_or_y, "i", "dim") keepdim = symbolic_helper._get_const(keepdim, "i", "keepdim") min = g.op("ReduceMin", self, axes_i=[dim], keepdims_i=keepdim) indices = g.op("ArgMin", self, axis_i=dim, keepdims_i=keepdim) return min, indices def minimum(g, input, other): return min(g, input, dim_or_y=other) @symbolic_helper.parse_args("v", "is", "i") def amax(g, self, dim, keepdim): return g.op("ReduceMax", self, axes_i=dim, keepdims_i=keepdim) @symbolic_helper.parse_args("v", "is", "i") def amin(g, self, dim, keepdim): return g.op("ReduceMin", self, axes_i=dim, keepdims_i=keepdim) @symbolic_helper.parse_args("v", "v", "i") def aminmax(g, self, dim, keepdim): reduce_kwargs = {"keepdims_i": keepdim} if not symbolic_helper._is_none(dim): dim = symbolic_helper._get_const(dim, "i", "dim") reduce_kwargs["axes_i"] = [dim] return g.op("ReduceMin", self, **reduce_kwargs), g.op( "ReduceMax", self, **reduce_kwargs ) def exp(g, self): return g.op("Exp", self) @symbolic_helper.parse_args("v", "f", "i") def dropout(g, input, p, train): symbolic_helper.check_training_mode(train, "dropout") # in eval mode, dropout is non-op - if the node's train param is set to False, dropout is non-op if not train: return input warnings.warn( "Dropout is a training op and should not be exported in inference mode. " "For inference, make sure to call eval() on the model and to export it with param training=False." ) r, _ = g.op("Dropout", input, ratio_f=p, outputs=2) return r def _unsupported_dropout(name): @symbolic_helper.parse_args("v", "f", "i") def feature_dropout(g, input, p, train): # NB: In inference mode, FeatureDropout is exported as an identity op. if train: return symbolic_helper._unimplemented(name, "training mode") return input return feature_dropout feature_dropout = _unsupported_dropout("feature_dropout") alpha_dropout = _unsupported_dropout("alpha_dropout") feature_alpha_dropout = _unsupported_dropout("feature_alpha_dropout") # See Note [Export inplace] dropout_ = dropout feature_dropout_ = feature_dropout alpha_dropout_ = alpha_dropout feature_alpha_dropout_ = feature_alpha_dropout @symbolic_helper.parse_args("v", "t", "is", "i") def norm(g, self, p, dim, keepdim): if p == 1: f = _reduce_op_symbolic("ReduceL1") elif p == 2: f = _reduce_op_symbolic("ReduceL2") else: raise RuntimeError("ONNX export only p-norms with p of 1 or 2") return f(g, self, dim=dim, keepdim=keepdim) @symbolic_helper.parse_args("v", "v", "v", "i") def conv_tbc(g, input, weight, bias, pad): if symbolic_helper.is_caffe2_aten_fallback(): return g.at("conv_tbc", input, weight, bias, pad_i=pad) else: # input must have 3 dimensions, see: # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/ConvolutionTBC.cpp#L8-L10 # input = (time, batch, in_channels) # weight = (kernel_width, in_channels, out_channels) # bias = (out_channels,) input = g.op("Transpose", input, perm_i=[1, 2, 0]) weight = g.op("Transpose", weight, perm_i=[2, 1, 0]) conv = conv1d(g, input, weight, bias, [1], [pad], [1], 1) return g.op("Transpose", conv, perm_i=[2, 0, 1]) @symbolic_helper.parse_args("v", "i", "i") def _unique(g, input, sorted, return_inverse): if symbolic_helper.is_caffe2_aten_fallback(): return g.at( "_unique", input, sorted_i=sorted, return_inverse_i=return_inverse, outputs=2, ) else: return symbolic_helper._onnx_unsupported("_unique") @symbolic_helper.parse_args("v", "i", "i", "i") def _unique2(g, input, sorted, return_inverse, return_counts): if symbolic_helper.is_caffe2_aten_fallback(): return g.at( "_unique2", input, sorted_i=sorted, return_inverse_i=return_inverse, return_counts_i=return_counts, outputs=3, ) else: symbolic_helper._onnx_opset_unsupported("_unique2", 9, 11) # TODO(justinchuby): Clean up this function generation magic by defining the functions # explicitly. for k, v in symbolic_helper.cast_pytorch_to_onnx.items(): name = "_cast_{}".format(k) globals()[name] = symbolic_helper.parse_args("v", "i")( functools.partial(symbolic_helper._cast_func_template, v) ) @symbolic_helper.parse_args("v", "i", "v", "v", "v", "v") def empty(g, sizes, dtype, layout, device, pin_memory=False, memory_format=None): return zeros(g, sizes, dtype, layout, device, pin_memory) @symbolic_helper.parse_args("v", "i", "v", "v", "v", "v") def empty_like( g, input, dtype=None, layout=None, device=None, pin_memory=False, memory_format=None ): return zeros_like(g, input, dtype, layout, device, pin_memory) def new_empty(g, self, sizes, dtype, layout, device, pin_memory=False): self_dtype = symbolic_helper._try_get_scalar_type(self) if dtype is None and self_dtype is not None: dtype = self_dtype dtype = symbolic_helper.scalar_type_to_onnx.index( symbolic_helper.cast_pytorch_to_onnx[dtype] ) return empty(g, sizes, dtype, layout, device, pin_memory) def scalar_tensor(g, scalar, dtype, *options): dtype = symbolic_helper._get_const(dtype, "i", "dtype") if dtype is None: dtype = symbolic_helper.ScalarType.FLOAT scalar = g.op("Cast", scalar, to_i=symbolic_helper.scalar_type_to_onnx[dtype]) return scalar def tensor(g, data, dtype=None, device=None, requires_grad=False): dtype = symbolic_helper._get_const(dtype, "i", "dtype") if symbolic_helper._is_packed_list(data): if dtype is None: dtype = symbolic_helper._unpack_list(data)[0].type().scalarType() # type: ignore[attr-defined] # TODO(justinchuby): Remove type ignore after #81112 is checked in. dtype = symbolic_helper.scalar_type_to_onnx.index( symbolic_helper.cast_pytorch_to_onnx[dtype] ) input_list = list() for t in symbolic_helper._unpack_list(data): shape_reference = g.op("Constant", value_t=torch.LongTensor([1])) t = symbolic_helper._reshape_helper(g, t, shape_reference) t = g.op("Cast", t, to_i=symbolic_helper.scalar_type_to_onnx[dtype]) input_list.append(t) return g.op("Concat", *input_list, axis_i=0) else: if dtype is None: dtype = data.type().scalarType() dtype = symbolic_helper.scalar_type_to_onnx.index( symbolic_helper.cast_pytorch_to_onnx[dtype] ) if symbolic_helper._is_list(data) and ( symbolic_helper._is_tensor_list(data) or symbolic_helper._is_scalar_list(data) ): data = g.op("ConcatFromSequence", data, axis_i=0, new_axis_i=1) return g.op("Cast", data, to_i=symbolic_helper.scalar_type_to_onnx[dtype]) def as_tensor(g, data, dtype=None, device=None): return tensor(g, data, dtype, device) @symbolic_helper.parse_args("v", "i", "v", "v", "v") def zeros(g, sizes, dtype, layout, device, pin_memory=False): # NOTE: no way to set device, layout and pin_memory in ONNX, so we ignore it if dtype is None: dtype = symbolic_helper.ScalarType.FLOAT sizes_ = symbolic_helper._maybe_get_const(sizes, "is") if isinstance(sizes_, list) and len(sizes_) == 0: sizes = g.op("Constant", value_t=torch.tensor([]).to(torch.int64)) return g.op( "ConstantOfShape", sizes, value_t=torch.tensor( [0], dtype=symbolic_helper.scalar_type_to_pytorch_type[dtype] ), ) @symbolic_helper.parse_args("v", "i", "v", "v", "v", "v") def zeros_like( g, input, dtype=None, layout=None, device=None, pin_memory=False, memory_format=None ): shape = g.op("Shape", input) if dtype is None: dtype = symbolic_helper.ScalarType.FLOAT return g.op( "ConstantOfShape", shape, value_t=torch.tensor( [0], dtype=symbolic_helper.scalar_type_to_pytorch_type[dtype] ), ) def new_zeros(g, self, sizes, dtype, layout, device, pin_memory=False): self_dtype = symbolic_helper._try_get_scalar_type(self) if dtype is None and self_dtype is not None: dtype = self_dtype dtype = symbolic_helper.scalar_type_to_onnx.index( symbolic_helper.cast_pytorch_to_onnx[dtype] ) return zeros(g, sizes, dtype, layout, device, pin_memory) @symbolic_helper.parse_args("v", "i", "v", "v", "v") def ones(g, sizes, dtype, layout, device, pin_memory=False): if dtype is None: dtype = symbolic_helper.ScalarType.FLOAT sizes_ = symbolic_helper._maybe_get_const(sizes, "is") if isinstance(sizes_, list) and len(sizes_) == 0: sizes = g.op("Constant", value_t=torch.tensor([]).to(torch.int64)) return g.op( "ConstantOfShape", sizes, value_t=torch.tensor( [1], dtype=symbolic_helper.scalar_type_to_pytorch_type[dtype] ), ) @symbolic_helper.parse_args("v", "i", "v", "v", "v", "v") def ones_like( g, input, dtype=None, layout=None, device=None, pin_memory=False, memory_format=None ): shape = g.op("Shape", input) if dtype is None: dtype = symbolic_helper.ScalarType.FLOAT return g.op( "ConstantOfShape", shape, value_t=torch.tensor( [1], dtype=symbolic_helper.scalar_type_to_pytorch_type[dtype] ), ) def new_ones(g, self, sizes, dtype, layout, device, pin_memory=False): self_dtype = symbolic_helper._try_get_scalar_type(self) if dtype is None and self_dtype is not None: dtype = self_dtype dtype = symbolic_helper.scalar_type_to_onnx.index( symbolic_helper.cast_pytorch_to_onnx[dtype] ) return ones(g, sizes, dtype, layout, device, pin_memory) def full(g, sizes, value, dtype, layout, device, pin_memory=False): const_value = symbolic_helper._maybe_get_const(value, "t") if symbolic_helper._is_value(const_value): dtype = symbolic_helper.ScalarType.FLOAT if dtype is None else dtype tmp = zeros(g, sizes, dtype, layout, device) return add(g, tmp, value, g.op("Constant", value_t=torch.tensor(1))) else: dtype = symbolic_helper._get_const(dtype, "i", "dtype") dtype = symbolic_helper.ScalarType.FLOAT if dtype is None else dtype sizes_ = symbolic_helper._maybe_get_const(sizes, "is") if isinstance(sizes_, list) and len(sizes_) == 0: sizes = g.op("Constant", value_t=torch.tensor([]).to(torch.int64)) return g.op( "ConstantOfShape", sizes, value_t=const_value.view(1).to( symbolic_helper.scalar_type_to_pytorch_type[dtype] ), ) def full_like( g, input, fill_value, dtype=None, layout=None, device=None, pin_memory=False, memory_format=None, ): fill_value = symbolic_helper._maybe_get_const(fill_value, "f") dtype = symbolic_helper._get_const(dtype, "i", "dtype") dtype = symbolic_helper.ScalarType.FLOAT if dtype is None else dtype if symbolic_helper._is_value(fill_value): tmp = zeros_like(g, input, dtype, layout, device) fill_value = g.op( "Cast", fill_value, to_i=symbolic_helper.scalar_type_to_onnx[dtype] ) return add(g, tmp, fill_value, g.op("Constant", value_t=torch.tensor(1))) else: shape = g.op("Shape", input) return g.op( "ConstantOfShape", shape, value_t=torch.tensor([fill_value]).to( symbolic_helper.scalar_type_to_pytorch_type[dtype] ), ) def new_full(g, self, size, fill_value, dtype, layout, device, pin_memory=False): self_dtype = symbolic_helper._try_get_scalar_type(self) if dtype is None and self_dtype is not None: dtype = self_dtype dtype = symbolic_helper.scalar_type_to_onnx.index( symbolic_helper.cast_pytorch_to_onnx[dtype] ) return full(g, size, fill_value, dtype, layout, device, pin_memory) def eye(g, *args): if len(args) == 5: # aten::eye(n, dtype, layout, device, pin_memory) n, dtype, layout, device, pin_memory = args dim_size = symbolic_helper._unsqueeze_helper(g, n, [0]) shape = g.op("Concat", dim_size, dim_size, axis_i=0) tensor = zeros(g, shape, dtype, layout, device) return g.op("EyeLike", tensor) elif len(args) == 6: # aten::eye(n, m, dtype, layout, device, pin_memory) n, m, dtype, layout, device, pin_memory = args shape = g.op( "Concat", symbolic_helper._unsqueeze_helper(g, n, [0]), symbolic_helper._unsqueeze_helper(g, m, [0]), axis_i=0, ) tensor = zeros(g, shape, dtype, layout, device) return g.op("EyeLike", tensor) else: raise NotImplementedError("Unknown aten::eye signature") def slice(g, self, *args): if len(args) == 4: # aten::slice(Tensor self, int dim, int start, int end, int step) -> Tensor dim, start, end, step = args step = symbolic_helper._parse_arg(step, "i") if step != 1: raise RuntimeError("step!=1 is currently not supported") is_start_none = ( start.node().kind() == "prim::Constant" and start.type().kind() == "NoneType" ) is_end_none = ( end.node().kind() == "prim::Constant" and end.type().kind() == "NoneType" ) is_start_onnx_const = start.node().kind() == "onnx::Constant" is_end_onnx_const = end.node().kind() == "onnx::Constant" if ( ((not is_start_none) and (not is_start_onnx_const)) or ((not is_end_none) and (not is_end_onnx_const)) or dim.node().kind() != "onnx::Constant" ): if GLOBALS.operator_export_type == _C_onnx.OperatorExportTypes.ONNX: raise RuntimeError( "Unsupported: ONNX export of Slice with dynamic inputs. DynamicSlice " "is a deprecated experimental op. Please use statically allocated " "variables or export to a higher opset version." ) else: start_unsqueezed = symbolic_helper._unsqueeze_helper(g, start, [0]) end_unsqueezed = symbolic_helper._unsqueeze_helper(g, end, [0]) dim_unsqueezed = symbolic_helper._unsqueeze_helper(g, dim, [0]) return g.op( "DynamicSlice", self, start_unsqueezed, end_unsqueezed, dim_unsqueezed, ) else: start = 0 if is_start_none else symbolic_helper._parse_arg(start, "i") end = ( 9223372036854775807 if is_end_none else symbolic_helper._parse_arg(end, "i") ) dim = symbolic_helper._parse_arg(dim, "i") return symbolic_helper._slice_helper( g, self, axes=[dim], starts=[start], ends=[end] ) elif len(args) == 3: # aten::slice(t[] l, int start, int end, int step) -> t[] start, end, step = args dim = 0 is_start_none = ( start.node().kind() == "prim::Constant" and start.type().kind() == "NoneType" ) is_end_none = ( end.node().kind() == "prim::Constant" and end.type().kind() == "NoneType" ) start = 0 if is_start_none else symbolic_helper._parse_arg(start, "i") end = ( 9223372036854775807 if is_end_none else symbolic_helper._parse_arg(end, "i") ) return symbolic_helper._slice_helper( g, self, axes=[dim], starts=[start], ends=[end] ) else: raise NotImplementedError("Unknown aten::slice signature") @symbolic_helper.parse_args("v", "f", "f") def hardtanh(g, self, min_val, max_val): return op_with_optional_float_cast( g, "Clip", self, min_f=min_val, max_f=max_val, opset_before=12 ) @symbolic_helper.parse_args("v") def hardswish(g, self): hs = hardsigmoid(g, self) return g.op("Mul", self, hs) # Fixed scale and zero_point, discovered from aten/src/ATen/native/quantized/cpu/qhardsigmoid.cpp @symbolic_helper.quantized_args(True, scale=1.0 / 256.0, zero_point=0) @symbolic_helper.parse_args("v") def hardsigmoid(g, self): # Set alpha_f to 1 / 6 to make op equivalent to PyTorch's definition of Hardsigmoid. # See https://pytorch.org/docs/stable/generated/torch.nn.Hardsigmoid.html return g.op("HardSigmoid", self, alpha_f=1 / 6) @symbolic_helper.parse_args("v") def tanhshrink(g, self): return g.op("Sub", self, tanh(g, self)) @symbolic_helper.parse_args("v", "f") def hardshrink(g, self, lambd): lambd_op = g.op("Constant", value_t=torch.FloatTensor([lambd])) cond = logical_or(g, gt(g, self, lambd_op), lt(g, self, neg(g, lambd_op))) return g.op("Where", cond, self, g.op("Constant", value_t=torch.FloatTensor([0]))) @symbolic_helper.parse_args("v", "f") def softshrink(g, self, lambd): lambd_op = g.op("Constant", value_t=torch.FloatTensor([lambd])) gt_cond = gt(g, self, lambd_op) gt_out = g.op( "Where", gt_cond, sub(g, self, lambd_op), g.op("Constant", value_t=torch.FloatTensor([0])), ) lt_cond = lt(g, self, neg(g, lambd_op)) lt_out = g.op( "Where", lt_cond, add(g, self, lambd_op), g.op("Constant", value_t=torch.FloatTensor([0])), ) return add(g, gt_out, lt_out) def alias(g, self): return self @symbolic_helper.parse_args("v", "i") def unsqueeze(g, self, dim): # Handle negative dim if dim < 0: rank = symbolic_helper._get_tensor_rank(self) if rank is not None: warnings.warn( "ONNX export unsqueeze with negative axis " + str(dim) + " might cause the onnx model to be incorrect. " + "Negative axis is not supported in ONNX. " + "Axis is converted to " + str(dim + rank + 1) + " based on input shape at export time. " + "Passing an tensor of different rank in execution will be incorrect." ) dim = dim + rank + 1 else: return symbolic_helper._unimplemented( "unsqueeze", "negative axis with unknown input rank" ) return symbolic_helper._unsqueeze_helper(g, self, axes_i=[dim]) @symbolic_helper.parse_args("v", "i", "i", "none") def sort(g, self, dim, decending, out=None): if out is not None: symbolic_helper._unimplemented( "Sort", "Out parameter is not supported for sort" ) self_sizes = symbolic_helper._get_tensor_sizes(self) try: dim_size = self_sizes[dim] except Exception: dim_size = None if dim_size is None: return symbolic_helper._unimplemented("Sort", "input size not accessible") return g.op("TopK", self, k_i=dim_size, axis_i=dim, outputs=2) def numel(g, self): shape = g.op("Shape", self) return g.op("ReduceProd", shape, keepdims_i=0) @symbolic_helper.parse_args("v", "i", "i", "i", "i", "none") def topk(g, self, k, dim, largest, sorted, out=None): if out is not None: symbolic_helper._unimplemented( "TopK", "Out parameter is not supported for topk" ) if not largest: symbolic_helper._unimplemented("TopK", "Ascending TopK is not supported") return g.op("TopK", self, k_i=k, axis_i=dim, outputs=2) def to(g, self, *args): def is_aten_to_device_only(args): if len(args) == 4: # aten::to(Tensor, Device, bool, bool, memory_format) return ( args[0].node().kind() == "prim::device" or args[0].type().isSubtypeOf(_C.ListType.ofInts()) or isinstance(args[0].type(), _C.DeviceObjType) ) elif len(args) == 5: # aten::to(Tensor, Device, ScalarType, bool, bool, memory_format) # When dtype is None, this is a aten::to(device) call dtype = symbolic_helper._get_const(args[1], "i", "dtype") return dtype is None elif len(args) in (6, 7): # aten::to(Tensor, ScalarType, Layout, Device, bool, bool, memory_format) -> Tensor # aten::to(Tensor, ScalarType, Layout, Device, bool, bool, bool, memory_format) -> Tensor # When dtype is None, this is a aten::to(device) call dtype = symbolic_helper._get_const(args[0], "i", "dtype") return dtype is None return False # ONNX doesn't have a concept of a device, so we ignore device-only casts if is_aten_to_device_only(args): return self if len(args) == 4: # TestONNXRuntime::test_ones_bool shows args[0] of aten::to() can be onnx::Constant[value=]() # In this case, the constant value is a tensor not int, # so symbolic_helper._maybe_get_const(args[0], 'i') would not work. dtype = args[0] if ( symbolic_helper._is_value(args[0]) and args[0].node().kind() == "onnx::Constant" ): tval = args[0].node()["value"] if isinstance(tval, torch.Tensor): if len(tval.shape) == 0: tval = tval.item() dtype = int(tval) else: dtype = tval if symbolic_helper._is_value(dtype) or isinstance(dtype, torch.Tensor): # aten::to(Tensor, Tensor, bool, bool, memory_format) dtype = args[0].type().scalarType() return g.op("Cast", self, to_i=symbolic_helper.cast_pytorch_to_onnx[dtype]) else: # aten::to(Tensor, ScalarType, bool, bool, memory_format) # memory_format is ignored return g.op("Cast", self, to_i=symbolic_helper.scalar_type_to_onnx[dtype]) elif len(args) == 5: # aten::to(Tensor, Device, ScalarType, bool, bool, memory_format) dtype = symbolic_helper._get_const(args[1], "i", "dtype") # memory_format is ignored return g.op("Cast", self, to_i=symbolic_helper.scalar_type_to_onnx[dtype]) elif len(args) == 6: # aten::to(Tensor, ScalarType, Layout, Device, bool, bool, memory_format) -> Tensor dtype = symbolic_helper._get_const(args[0], "i", "dtype") # Layout, device and memory_format are ignored return g.op("Cast", self, to_i=symbolic_helper.scalar_type_to_onnx[dtype]) elif len(args) == 7: # aten::to(Tensor, ScalarType, Layout, Device, bool, bool, bool, memory_format) -> Tensor dtype = symbolic_helper._get_const(args[0], "i", "dtype") # Layout, device and memory_format are ignored return g.op("Cast", self, to_i=symbolic_helper.scalar_type_to_onnx[dtype]) else: return symbolic_helper._onnx_unsupported("Unknown aten::to signature") def repeat(g, self, repeats): dtype = symbolic_helper.ScalarType.INT64 shape_ = ones_like(g, repeats, dtype) self = g.op("Expand", self, shape_) return g.op("Tile", self, repeats) def repeat_interleave(g, self, repeats, dim=None, output_size=None): input = self # if dim is None flatten # By default, use the flattened input array, and return a flat output array if symbolic_helper._is_none(dim): input = symbolic_helper._reshape_helper( g, self, g.op("Constant", value_t=torch.tensor([-1])) ) dim = 0 else: dim = symbolic_helper._maybe_get_scalar(dim) repeats_dim = symbolic_helper._get_tensor_rank(repeats) repeats_sizes = symbolic_helper._get_tensor_sizes(repeats) input_sizes = symbolic_helper._get_tensor_sizes(input) if repeats_dim is None: raise RuntimeError( "Unsupported: ONNX export of repeat_interleave for unknown repeats rank." ) if repeats_sizes is None: raise RuntimeError( "Unsupported: ONNX export of repeat_interleave for unknown repeats size." ) if input_sizes is None: raise RuntimeError( "Unsupported: ONNX export of repeat_interleave for unknown input size." ) input_sizes_temp = input_sizes.copy() for idx, input_size in enumerate(input_sizes): if input_size is None: input_sizes[idx], input_sizes_temp[idx] = 0, -1 # Cases where repeats is an int or single value tensor if repeats_dim == 0 or (repeats_dim == 1 and repeats_sizes[0] == 1): if not symbolic_helper._is_tensor(repeats): repeats = g.op("Constant", value_t=torch.LongTensor(repeats)) if input_sizes[dim] == 0: return symbolic_helper._onnx_opset_unsupported_detailed( "repeat_interleave", 9, 13, "Unsupported along dimension with unknown input size", ) else: reps = input_sizes[dim] repeats = expand( g, repeats, g.op("Constant", value_t=torch.tensor([reps])), None ) # Cases where repeats is a 1 dim Tensor elif repeats_dim == 1: if input_sizes[dim] == 0: return symbolic_helper._onnx_opset_unsupported_detailed( "repeat_interleave", 9, 13, "Unsupported along dimension with unknown input size", ) if repeats_sizes[0] is None: return symbolic_helper._onnx_opset_unsupported_detailed( "repeat_interleave", 9, 13, "Unsupported for cases with dynamic repeats" ) assert ( repeats_sizes[0] == input_sizes[dim] ), "repeats must have the same size as input along dim" reps = repeats_sizes[0] else: raise RuntimeError("repeats must be 0-dim or 1-dim tensor") final_splits = list() r_splits = symbolic_helper._repeat_interleave_split_helper(g, repeats, reps, 0) i_splits = symbolic_helper._repeat_interleave_split_helper(g, input, reps, dim) input_sizes[dim], input_sizes_temp[dim] = -1, 1 for idx, r_split in enumerate(r_splits): i_split = unsqueeze(g, i_splits[idx], dim + 1) r_concat = [ g.op("Constant", value_t=torch.LongTensor(input_sizes_temp[: dim + 1])), r_split, g.op("Constant", value_t=torch.LongTensor(input_sizes_temp[dim + 1 :])), ] r_concat = g.op("Concat", *r_concat, axis_i=0) i_split = expand(g, i_split, r_concat, None) i_split = symbolic_helper._reshape_helper( g, i_split, g.op("Constant", value_t=torch.LongTensor(input_sizes)), allowzero=0, ) final_splits.append(i_split) return g.op("Concat", *final_splits, axis_i=dim) @symbolic_helper.parse_args("v", "i") def pixel_shuffle(g, self, upscale_factor): dims = symbolic_helper._get_tensor_sizes(self) if len(dims) != 4: return symbolic_helper._unimplemented("pixel_shuffle", "only support 4d input") if any(i is None for i in dims[1:]): after_view = symbolic_helper._reshape_helper( g, symbolic_helper._unsqueeze_helper(g, self, [2, 3]), g.op( "Constant", value_t=torch.tensor([0, -1, upscale_factor, upscale_factor, 0, 0]), ), allowzero=0, ) after_transpose = g.op("Transpose", after_view, perm_i=[0, 1, 4, 2, 5, 3]) # For dynamic input shapes, two reshapes are performed reshape_h = symbolic_helper._reshape_helper( g, after_transpose, g.op("Constant", value_t=torch.tensor([0, 0, -1, 1, 0, 0])), allowzero=0, ) reshape_w = symbolic_helper._reshape_helper( g, reshape_h, g.op("Constant", value_t=torch.tensor([0, 0, 0, 0, -1, 1])), allowzero=0, ) return symbolic_helper._squeeze_helper(g, reshape_w, [3, 5]) else: output_channel = dims[1] // upscale_factor // upscale_factor after_view = symbolic_helper._reshape_helper( g, self, g.op( "Constant", value_t=torch.tensor( [ -1, output_channel, upscale_factor, upscale_factor, dims[2], dims[3], ] ), ), allowzero=0, ) after_transpose = g.op("Transpose", after_view, perm_i=[0, 1, 4, 2, 5, 3]) return symbolic_helper._reshape_helper( g, after_transpose, g.op( "Constant", value_t=torch.tensor( [ -1, output_channel, dims[2] * upscale_factor, dims[3] * upscale_factor, ] ), ), allowzero=0, ) @symbolic_helper.parse_args("v", "i") def pixel_unshuffle(g, self, downscale_factor): dims = symbolic_helper._get_tensor_sizes(self) if len(dims) != 4: return symbolic_helper._unimplemented("pixel_shuffle", "only support 4d input") if any(i is None for i in dims[1:]): # For dynamic input shapes, two reshapes are performed reshape_h = symbolic_helper._reshape_helper( g, symbolic_helper._unsqueeze_helper(g, self, [3]), g.op("Constant", value_t=torch.tensor([0, 0, -1, downscale_factor, 0])), allowzero=0, ) reshape_w = symbolic_helper._reshape_helper( g, reshape_h, g.op("Constant", value_t=torch.tensor([0, 0, 0, 0, -1, downscale_factor])), allowzero=0, ) after_transpose = g.op("Transpose", reshape_w, perm_i=[0, 1, 3, 5, 2, 4]) final_reshape = symbolic_helper._reshape_helper( g, after_transpose, g.op("Constant", value_t=torch.tensor([0, -1, 1, 1, 0, 0])), allowzero=0, ) return symbolic_helper._squeeze_helper(g, final_reshape, [2, 3]) else: output_channel = dims[1] * downscale_factor * downscale_factor after_view = symbolic_helper._reshape_helper( g, self, g.op( "Constant", value_t=torch.tensor( [ -1, dims[1], dims[2] // downscale_factor, downscale_factor, dims[3] // downscale_factor, downscale_factor, ] ), ), allowzero=0, ) after_transpose = g.op("Transpose", after_view, perm_i=[0, 1, 3, 5, 2, 4]) return symbolic_helper._reshape_helper( g, after_transpose, g.op( "Constant", value_t=torch.tensor( [ -1, output_channel, dims[2] // downscale_factor, dims[3] // downscale_factor, ] ), ), allowzero=0, ) def _generic_rnn( g, variant, input, initial_states, all_weights, has_biases, num_layers, dropout, train, bidirectional, batch_first=None, batch_sizes=None, ): warnings.warn( "Exporting a model to ONNX with a batch_size other than 1, " + "with a variable length with " + variant + " can cause an error " + "when running the ONNX model with a different batch size. " + "Make sure to save the model with a batch size of 1, " + "or define the initial states (h0/c0) as inputs of the model. " ) onnxActivations = [ "Relu", "Tanh", "Sigmoid", "Affine", "LeakyRelu", "ThresholdedRelu", "ScaledTanh", "HardSigmoid", "Elu", "Softsign", "Softplus", ] variantToOnnxActivationMap = dict( zip([act_fun.lower() for act_fun in onnxActivations], onnxActivations) ) weights_per_layer = 4 if has_biases else 2 # this means that projections are used inside LSTM, so need to tell user that it's not supported if variant == "LSTM" and len(all_weights) != num_layers * weights_per_layer * ( 1 + bidirectional ): return symbolic_helper._unimplemented("LSTM", "LSTMs with projections") assert len(all_weights) == num_layers * weights_per_layer * (1 + bidirectional) layer_weights = [ all_weights[i : i + weights_per_layer] for i in range(0, len(all_weights), weights_per_layer) ] if batch_first: # batch, seq, feat -> seq, batch, feat input = g.op("Transpose", input, perm_i=[1, 0, 2]) if dropout and train: return symbolic_helper._unimplemented( "RNN/GRU/LSTM", "dropout in training mode" ) if variant.startswith("RNN"): nonlinearity = variantToOnnxActivationMap[variant[4:].lower()] variant = "RNN" w_hh = all_weights[1] hidden_size = symbolic_helper._get_tensor_dim_size(w_hh, 1) if hidden_size is None: return symbolic_helper._unimplemented("RNN/GRU/LSTM", "unknown hidden size") unidirectional = not bidirectional prev_output = input h_outs = [] if variant == "RNN" or variant == "GRU": h0 = initial_states elif variant == "LSTM": h0, c0 = initial_states c_outs = [] sequence_lens = unused(g) if batch_sizes is None else batch_sizes if variant == "GRU": # pytorch is reset, input, hidden # onnx is input, reset, hidden reform_permutation = [(1, 2), (0, 1), (2, 3)] elif variant == "LSTM": # pytorch is input, forget, cell, output. # onnx is input, output, forget, cell. reform_permutation = [(0, 1), (3, 4), (1, 3)] def reform_weights(g, w, n, intervals): slices = [ symbolic_helper._slice_helper(g, w, axes=[0], starts=[x * n], ends=[y * n]) for x, y in intervals ] return g.op("Concat", *slices, axis_i=0) def transform_weights_no_bias(layer_index): weights = layer_weights[layer_index] if variant == "RNN": weight_ih, weight_hh = weights elif variant == "GRU" or variant == "LSTM": weight_ih, weight_hh = [ reform_weights(g, w, hidden_size, reform_permutation) for w in weights ] return tuple( symbolic_helper._unsqueeze_helper(g, x, [0]) for x in (weight_ih, weight_hh) ) def transform_weights(layer_index): weights = layer_weights[layer_index] if variant == "RNN": weight_ih, weight_hh, bias_ih, bias_hh = weights elif variant == "GRU" or variant == "LSTM": weight_ih, weight_hh, bias_ih, bias_hh = [ reform_weights(g, w, hidden_size, reform_permutation) for w in weights ] bias_concat = g.op("Concat", bias_ih, bias_hh, axis_i=0) return tuple( symbolic_helper._unsqueeze_helper(g, x, [0]) for x in (weight_ih, weight_hh, bias_concat) ) def retrieve_state(x, start, end): return ( x if num_layers == 1 else symbolic_helper._slice_helper( g, x, axes=[0], starts=[start], ends=[end] ) ) for i in range(num_layers): if unidirectional: if weights_per_layer == 4: weight_ih, weight_hh, bias_concat = transform_weights(i) else: weight_ih, weight_hh = transform_weights_no_bias(i) bias_concat = unused(g) state_indices = i, i + 1 else: if weights_per_layer == 4: weight_ih_f, weight_hh_f, bias_f = transform_weights(2 * i) weight_ih_b, weight_hh_b, bias_b = transform_weights(2 * i + 1) bias_concat = g.op("Concat", bias_f, bias_b, axis_i=0) else: weight_ih_f, weight_hh_f = transform_weights_no_bias(2 * i) weight_ih_b, weight_hh_b = transform_weights_no_bias(2 * i + 1) bias_concat = unused(g) weight_ih = g.op("Concat", weight_ih_f, weight_ih_b, axis_i=0) weight_hh = g.op("Concat", weight_hh_f, weight_hh_b, axis_i=0) state_indices = 2 * i, 2 * i + 2 inputs = [prev_output, weight_ih, weight_hh, bias_concat, sequence_lens] inputs.append(retrieve_state(h0, *state_indices)) if variant == "LSTM": inputs.append(retrieve_state(c0, *state_indices)) extra_kwargs = {} if unidirectional else {"direction_s": "bidirectional"} if variant == "RNN": if bidirectional: activation = [nonlinearity, nonlinearity] else: activation = [nonlinearity] prev_output, h_out = g.op( "RNN", *inputs, outputs=2, hidden_size_i=hidden_size, activations_s=activation, **extra_kwargs, ) elif variant == "GRU": prev_output, h_out = g.op( "GRU", *inputs, outputs=2, hidden_size_i=hidden_size, linear_before_reset_i=1, **extra_kwargs, ) elif variant == "LSTM": prev_output, h_out, c_out = g.op( "LSTM", *inputs, outputs=3, hidden_size_i=hidden_size, **extra_kwargs ) if bidirectional: # The ONNX RNN/GRU/LSTM produce an output of dimensions # seq_len, num_directions, batch, hidden_size # We have to convert to match pytorch's expected # seq_len, batch, num_directions * hidden_size # by first moving num_directions before hidden_size with # Transpose, and then combining it with hidden_size # with Reshape. prev_output = g.op("Transpose", prev_output, perm_i=[0, 2, 1, 3]) prev_output = symbolic_helper._reshape_helper( g, prev_output, g.op("Constant", value_t=torch.LongTensor([0, 0, -1])), allowzero=0, ) else: prev_output = symbolic_helper._squeeze_helper(g, prev_output, [1]) h_outs.append(h_out) if variant == "LSTM": c_outs.append(c_out) if batch_first: # seq, batch, num_directions * hidden_size -> batch, seq, num_directions * hidden_size prev_output = g.op("Transpose", prev_output, perm_i=[1, 0, 2]) h_outs = h_out if num_layers == 1 else g.op("Concat", *h_outs, axis_i=0) if variant == "RNN" or variant == "GRU": return prev_output, h_outs elif variant == "LSTM": c_outs = c_out if num_layers == 1 else g.op("Concat", *c_outs, axis_i=0) return prev_output, h_outs, c_outs @symbolic_helper.parse_args("v", "v", "v", "i", "i", "f", "i", "i", "i") def _lstm_full( g, input, hidden_v, weight_v, has_biases, num_layers, dropout, train, bidirectional, batch_first, ): hidden, weight = symbolic_helper._unpack_list( hidden_v ), symbolic_helper._unpack_list(weight_v) return _generic_rnn( g, "LSTM", input, hidden, weight, has_biases, num_layers, dropout, train, bidirectional, batch_first, ) @symbolic_helper.parse_args("v", "v", "v", "v", "i", "i", "f", "i", "i") def _lstm_packed( g, input, batch_sizes, hidden_v, weight_v, has_biases, num_layers, dropout, train, bidirectional, ): hidden, weight = symbolic_helper._unpack_list( hidden_v ), symbolic_helper._unpack_list(weight_v) return _generic_rnn( g, "LSTM", input, hidden, weight, has_biases, num_layers, dropout, train, bidirectional, batch_sizes=batch_sizes, ) def lstm(g, *args): if symbolic_helper._is_tensor_list(args[3]): return _lstm_packed(g, *args) else: return _lstm_full(g, *args) def lstm_cell(g, self, hidden, w_ih, w_hh, b_ih, b_hh): input = symbolic_helper._unsqueeze_helper(g, self, [0]) hidden = symbolic_helper._unpack_list(hidden) hidden = [symbolic_helper._unsqueeze_helper(g, x, [0]) for x in hidden] weight = ( (w_ih, w_hh, b_ih, b_hh) if symbolic_helper._is_tensor(b_ih) else (w_ih, w_hh) ) has_biases = True if symbolic_helper._is_tensor(b_ih) else False _, h_outs, c_outs = _generic_rnn( g, "LSTM", input, hidden, weight, has_biases, num_layers=1, dropout=0, train=0, bidirectional=False, batch_first=False, ) return symbolic_helper._squeeze_helper( g, h_outs, [0] ), symbolic_helper._squeeze_helper(g, c_outs, [0]) def _one_hidden_rnn(kind): @symbolic_helper.parse_args("v", "v", "v", "i", "i", "f", "i", "i", "i") def _rnn_full( g, input, hidden, weight_v, has_biases, num_layers, dropout, train, bidirectional, batch_first, ): weight = symbolic_helper._unpack_list(weight_v) return _generic_rnn( g, kind, input, hidden, weight, has_biases, num_layers, dropout, train, bidirectional, batch_first, ) @symbolic_helper.parse_args("v", "v", "v", "v", "i", "i", "f", "i", "i") def _rnn_packed( g, input, batch_sizes, hidden, weight_v, has_biases, num_layers, dropout, train, bidirectional, ): weight = symbolic_helper._unpack_list(weight_v) return _generic_rnn( g, kind, input, hidden, weight, has_biases, num_layers, dropout, train, bidirectional, batch_sizes=batch_sizes, ) def symbolic(g, *args): if symbolic_helper._is_tensor_list(args[3]): return _rnn_packed(g, *args) else: return _rnn_full(g, *args) return symbolic gru = _one_hidden_rnn("GRU") rnn_tanh = _one_hidden_rnn("RNN_TANH") rnn_relu = _one_hidden_rnn("RNN_RELU") @symbolic_helper.parse_args("v", "i") def _dim_arange(g, like, dim): like_shape = g.op("Shape", like) stop = g.op( "Gather", like_shape, g.op("Constant", value_t=torch.tensor(dim)), axis_i=0 ) if symbolic_helper.is_caffe2_aten_fallback(): return g.op("_caffe2::Range", stop) else: # aten::arange(Scalar end, ScalarType dtype, Layout, Device, bool pin_memory) return arange(g, stop, 4, None, None, None) def detach(g, input): # Erase aten::detach nodes because ONNX is inference only return input @symbolic_helper.parse_args("v", "i") def contiguous(g, input, memory_format): if memory_format > 2: # allower values are any, preserve and contiguous_format raise RuntimeError("onnx memory_format support is not implemented") return input @symbolic_helper.parse_args("v", "v", "i") def _pack_padded_sequence(g, input, lengths, batch_first): # Currently there is no PackPadded operator in ONNX. We rely on an # optimization pass to remove this later. It is an error if all # PackPadded operators cannot be optimized out. if batch_first: input = g.op("Transpose", input, perm_i=[1, 0, 2]) if not lengths.type().isSubtypeOf(torch._C.TensorType.get()): raise RuntimeError("Lengths must be a Tensor for ONNX export") # We know it's a TensorType so this check is now safe. # It's really only necessary because those operators expand to something that # only works with int32 types in Caffe2... if lengths.type().scalarType() != "Int": lengths = _cast_Int(g, lengths, False) # type: ignore[name-defined] return g.op("prim::PackPadded", input, lengths, outputs=2) @symbolic_helper.parse_args("v", "v", "i", "t", "v") def _pad_packed_sequence( g, data, batch_sizes, batch_first, padding_value, total_length ): # Ignore total_length as it is not supported in _symbolic_pad_packed_sequence # It is only useful/used when training using data_parallel model, so # It shouldn't be relevant for ONNX anyway data, lengths = g.op("prim::PadPacked", data, batch_sizes, outputs=2) if batch_first: data = g.op("Transpose", data, perm_i=[1, 0, 2]) return data, lengths def randn(g, shapes, dtype, *options): dtype = symbolic_helper._get_const(dtype, "i", "dtype") if dtype is None: dtype = symbolic_helper.ScalarType.FLOAT shape = symbolic_helper._maybe_get_const(shapes, "is") if symbolic_helper._is_value(shape): shape_const = g.op( "ConstantOfShape", shapes, value_t=torch.tensor( [0], dtype=symbolic_helper.scalar_type_to_pytorch_type[6] ), ) return g.op( "RandomNormalLike", shape_const, dtype_i=symbolic_helper.scalar_type_to_onnx[dtype], ) return g.op("RandomNormal", shape_i=shape) def rand(g, shapes, dtype, *options): dtype = symbolic_helper._get_const(dtype, "i", "dtype") if dtype is None: dtype = symbolic_helper.ScalarType.FLOAT shape = symbolic_helper._maybe_get_const(shapes, "is") if symbolic_helper._is_value(shape): shape_const = g.op( "ConstantOfShape", shapes, value_t=torch.tensor( [0], dtype=symbolic_helper.scalar_type_to_pytorch_type[6] ), ) return g.op( "RandomUniformLike", shape_const, dtype_i=symbolic_helper.scalar_type_to_onnx[dtype], ) return g.op("RandomUniform", shape_i=shape) def randn_like( g, self, dtype, layout=None, device=None, pin_memory=False, memory_format=None ): dtype = symbolic_helper._get_const(dtype, "i", "dtype") if dtype is None: dtype = symbolic_helper.ScalarType.FLOAT return g.op( "RandomNormalLike", self, dtype_i=symbolic_helper.scalar_type_to_onnx[dtype] ) def rand_like( g, self, dtype, layout=None, device=None, pin_memory=False, memory_format=None ): dtype = symbolic_helper._get_const(dtype, "i", "dtype") if dtype is None: dtype = symbolic_helper.ScalarType.FLOAT return g.op( "RandomUniformLike", self, dtype_i=symbolic_helper.scalar_type_to_onnx[dtype] ) @symbolic_helper.parse_args("v", "f", "f", "i", "none") def rrelu(g, input, lower, upper, training, generator): p = g.op("RandomUniformLike", input, high_f=upper, low_f=lower) return g.op("PRelu", input, p) def bernoulli(g, input, generator=None, out=None): if out is not None: symbolic_helper._unimplemented( "Bernoulli", "out parameter is not supported for bernoulli" ) if generator is not None and not symbolic_helper._is_none(generator): symbolic_helper._unimplemented( "Bernoulli", "generator is not supported for bernoulli" ) dtype = symbolic_helper._try_get_scalar_type(input) if dtype is None: return symbolic_helper._unimplemented("Bernoulli", "input dtype not accessible") p = g.op( "RandomUniformLike", input, high_f=1.0, low_f=0.0, dtype_i=symbolic_helper.cast_pytorch_to_onnx[dtype], ) output = g.op("Less", p, input) return g.op("Cast", output, to_i=symbolic_helper.cast_pytorch_to_onnx[dtype]) @symbolic_helper.parse_args("v") def log_sigmoid(g, input): p = g.op("Sigmoid", input) return g.op("Log", p) @symbolic_helper.parse_args("v") def erf(g, input): return g.op("Erf", input) @symbolic_helper.quantized_args(True, False, False) @symbolic_helper.parse_args("v", "i", "i") def flatten(g, input, start_dim, end_dim): dim = symbolic_helper._get_tensor_rank(input) if dim is None: return symbolic_helper._unimplemented( "dim", "ONNX and PyTorch use different strategies to split the input. " "Input rank must be known at export time.", ) # TODO: remove this as onnx opset 11 spec allows negative axes if end_dim < 0: end_dim = dim + end_dim # use ONNX's Flatten operator for cases where the output shape is 2D if start_dim == 1 and end_dim == dim - 1: return g.op("Flatten", input, axis_i=start_dim) if start_dim == 0 and end_dim == dim - 2: return g.op("Flatten", input, axis_i=end_dim + 1) return symbolic_helper._flatten_helper(g, input, start_dim, end_dim, dim) @symbolic_helper.parse_args("v") def nonzero(g, input): """Emitted from `torch.nonzero(x, as_tuple=False)`""" return t(g, g.op("NonZero", input)) # Emitted from `torch.nonzero(x, as_tuple=True)` def nonzero_numpy(g, input, _outputs=None): return unbind(g, nonzero(g, input), 1, _outputs=_outputs) @symbolic_helper.parse_args("v") def isnan(g, input): output = g.op("IsNaN", input) return output def _any(g, *args): # aten::any(Tensor self) if len(args) == 1: input = args[0] dim, keepdim = None, 0 # aten::any(Tensor self, int dim, bool keepdim) else: input, dim, keepdim = args dim = [symbolic_helper._parse_arg(dim, "i")] keepdim = symbolic_helper._parse_arg(keepdim, "i") input = _cast_Long(g, input, False) # type: ignore[name-defined] input_sum = symbolic_helper._reducesum_helper( g, input, axes_i=dim, keepdims_i=keepdim ) return gt(g, input_sum, g.op("Constant", value_t=torch.LongTensor([0]))) def _all(g, *args): input = g.op("Not", args[0]) # aten::all(Tensor self) if len(args) == 1: return g.op("Not", _any(g, input)) # aten::all(Tensor self, int dim, bool keepdim) else: return g.op("Not", _any(g, input, args[1], args[2])) @symbolic_helper.parse_args("v", "i", "i", "i") def narrow(g, input, dim, start, length): return symbolic_helper._slice_helper( g, input, axes=[dim], starts=[start], ends=[start + length] ) def argmax(g, input, dim, keepdim): if symbolic_helper._is_none(dim): flattened = symbolic_helper._reshape_helper( g, input, g.op("Constant", value_t=torch.tensor([-1])) ) return g.op("ArgMax", flattened, axis_i=0, keepdims_i=False) else: dim = symbolic_helper._parse_arg(dim, "i") keepdim = symbolic_helper._parse_arg(keepdim, "i") return g.op("ArgMax", input, axis_i=dim, keepdims_i=keepdim) def argmin(g, input, dim, keepdim): if symbolic_helper._is_none(dim): flattened = symbolic_helper._reshape_helper( g, input, g.op("Constant", value_t=torch.tensor([-1])) ) return g.op("ArgMin", flattened, axis_i=0, keepdims_i=False) else: dim = symbolic_helper._parse_arg(dim, "i") keepdim = symbolic_helper._parse_arg(keepdim, "i") return g.op("ArgMin", input, axis_i=dim, keepdims_i=keepdim) @symbolic_helper.parse_args("v", "i", "v", "v") def scatter(g, self, dim, index, src): src_type = src.type().scalarType() src = symbolic_helper._maybe_get_scalar(src) if symbolic_helper._is_value(src): return g.op("Scatter", self, index, src, axis_i=dim) else: # Check if scalar "src" has same type as self (PyTorch allows different # type for scalar src (but not when src is tensor)). If not, insert Cast node. if self.type().scalarType() != src_type: src = g.op( "Cast", src, to_i=symbolic_helper.cast_pytorch_to_onnx[self.type().scalarType()], ) return g.op("Scatter", self, index, expand_as(g, src, index), axis_i=dim) @symbolic_helper.parse_args("v", "i", "v", "v") def scatter_add(g, self, dim, index, src): dtype = symbolic_helper._try_get_scalar_type(self) if dtype is None: return symbolic_helper._unimplemented( "scatter_add", "input dtype not accessible" ) dtype = symbolic_helper.scalar_type_to_onnx.index( symbolic_helper.cast_pytorch_to_onnx[dtype] ) dtype = symbolic_helper.scalar_type_to_pytorch_type[dtype] sizes = symbolic_helper._get_tensor_sizes(self, allow_nonstatic=False) if sizes: to_add = g.op("Constant", value_t=torch.zeros(sizes, dtype=dtype)) else: dtype = symbolic_helper.scalar_type_to_pytorch_type.index(dtype) to_add = zeros_like(g, self, dtype) to_add = symbolic_helper._scatter_helper(g, to_add, dim, index, src) return add(g, self, to_add) def log2(g, self): _ln2 = 0.693147180559945309 return g.op("Div", log(g, self), g.op("Constant", value_t=torch.tensor([_ln2]))) def is_floating_point(g, self): if symbolic_helper._is_fp(self): return g.op("Constant", value_t=torch.BoolTensor([1])) return g.op("Constant", value_t=torch.BoolTensor([0])) def __is_(g, self, other): if symbolic_helper._is_none(other): if symbolic_helper._is_none(self): return g.op("Constant", value_t=torch.BoolTensor([1])) return g.op("Constant", value_t=torch.BoolTensor([0])) return eq(g, self, other) @wrap_logical_op_with_negation def __isnot_(g, self, other): return __is_(g, self, other) def one_hot(g, self, num_classes): values = g.op("Constant", value_t=torch.LongTensor([0, 1])) # onnxruntime supports limited type combinations for OneHot. if num_classes.type().scalarType() in ("Byte", "Char", "Int", "Short"): num_classes = g.op( "Cast", num_classes, to_i=symbolic_helper.cast_pytorch_to_onnx["Long"] ) return g.op("OneHot", self, num_classes, values, axis_i=-1) @symbolic_helper.parse_args("v", "i", "v", "v") def gather(g, self, dim, index, sparse_grad=False): if symbolic_helper._maybe_get_const(sparse_grad, "i"): return symbolic_helper._unimplemented("gather", "sparse_grad == True") # NOTE: This workaround is needed since GatherElement is only supported # since opset 11, and Gather in ONNX is not the same as torch.gather. dtype = self.type().scalarType() values = g.op("Constant", value_t=torch.LongTensor([0, 1])) depth = size(g, self, g.op("Constant", value_t=torch.LongTensor([dim]))) index = g.op( "Cast", g.op("OneHot", index, depth, values, axis_i=dim), to_i=symbolic_helper.cast_pytorch_to_onnx[dtype], ) mul = g.op("Mul", symbolic_helper._unsqueeze_helper(g, self, [dim + 1]), index) return symbolic_helper._reducesum_helper(g, mul, axes_i=[dim], keepdims_i=0) @symbolic_helper.parse_args("v", "is", "i", "i") def _var_mean(g, input, dim, correction, keepdim): if dim is None: mean = g.op("ReduceMean", input, keepdims_i=0) t_mean = mean num_elements = numel(g, input) else: mean = g.op("ReduceMean", input, axes_i=dim, keepdims_i=keepdim) t_mean = g.op("ReduceMean", input, axes_i=dim, keepdims_i=1) redudced_dims = g.op("Shape", input) # dim could contain one or multiple dimensions redudced_dims = g.op( "Gather", redudced_dims, g.op("Constant", value_t=torch.tensor(dim)), axis_i=0, ) num_elements = g.op("ReduceProd", redudced_dims, keepdims_i=0) sub_v = g.op("Sub", input, t_mean) sqr_sub = g.op("Mul", sub_v, sub_v) keepdim_mean = 0 if dim is None else keepdim var = g.op("ReduceMean", sqr_sub, axes_i=dim, keepdims_i=keepdim_mean) # Correct bias in calculating variance, by dividing it over (N - correction) instead on N if correction is None: correction = 1 if correction != 0: num_elements = g.op( "Cast", num_elements, to_i=symbolic_helper.cast_pytorch_to_onnx["Float"] ) one = g.op("Constant", value_t=torch.tensor(correction, dtype=torch.float)) mul = g.op("Mul", var, num_elements) var = g.op("Div", mul, g.op("Sub", num_elements, one)) return var, mean def std(g, input, *args): var, _ = var_mean(g, input, *args) return g.op("Sqrt", var) def var(g, input, *args): var, _ = var_mean(g, input, *args) return var # var_mean (and all variance-related functions) has multiple signatures, so need to manually figure # out the correct arguments: # aten::var_mean(Tensor self, bool unbiased) # aten::var_mean(Tensor self, int[1] dim, bool unbiased, bool keepdim=False) # aten::var_mean(Tensor self, int[1]? dim=None, *, int? correction=None, bool keepdim=False) def var_mean(g, input, *args): if len(args) == 1: return _var_mean(g, input, None, args[0], None) else: return _var_mean(g, input, *args) def std_mean(g, input, *args): var, mean = var_mean(g, input, *args) return g.op("Sqrt", var), mean @symbolic_helper.parse_args("v", "is", "i") def logsumexp(g, input, dim, keepdim): return g.op("ReduceLogSumExp", input, axes_i=dim, keepdims_i=keepdim) def arange(g, *args): if symbolic_helper.is_caffe2_aten_fallback(): return g.at("arange", *args) def _get_arange_dtype(dtype): dtype = symbolic_helper._maybe_get_const(dtype, "i") return dtype def _float_step_convert(range_tensor): if symbolic_helper._is_fp(range_tensor): range_tensor = g.op( "Cast", g.op("Ceil", range_tensor), to_i=symbolic_helper.scalar_type_to_onnx[4], ) return range_tensor if len(args) == 2 or len(args) == 5: if len(args) == 2: # aten::arange(Scalar end, Tensor out) dtype = None else: # aten::arange(Scalar end, ScalarType dtype, Layout, Device, bool pin_memory) dtype = _get_arange_dtype(args[1]) dtype, end, start, step = symbolic_helper._arange_cast_helper( g, end=args[0], dtype=dtype ) end = symbolic_helper._unsqueeze_helper(g, end, [0]) range_tensor = _float_step_convert(end) arange_tensor = symbolic_helper._squeeze_helper( g, nonzero(g, ones(g, range_tensor, dtype, None, None)), [1] ) return g.op( "Cast", arange_tensor, to_i=symbolic_helper.scalar_type_to_onnx[dtype] ) elif len(args) == 4 or len(args) == 7: if len(args) == 4: # aten::arange(Scalar start, Scalar end, Scalar step, Tensor out) dtype = None else: # aten::arange(Scalar start, Scalar end, Scalar step, ScalarType dtype, Layout, Device, bool pin_memory) dtype = _get_arange_dtype(args[3]) dtype, end, start, step = symbolic_helper._arange_cast_helper( g, start=args[0], end=args[1], step=args[2], dtype=dtype ) step = symbolic_helper._unsqueeze_helper(g, step, [0]) end = symbolic_helper._unsqueeze_helper(g, end, [0]) start = symbolic_helper._unsqueeze_helper(g, start, [0]) range_tensor = _float_step_convert(g.op("Div", g.op("Sub", end, start), step)) arange_tensor = symbolic_helper._squeeze_helper( g, nonzero(g, ones(g, range_tensor, None, None, None)), [1] ) arange_tensor = g.op("Add", g.op("Mul", arange_tensor, step), start) return g.op( "Cast", arange_tensor, to_i=symbolic_helper.scalar_type_to_onnx[dtype] ) elif len(args) == 6: # aten::arange(Scalar start, Scalar end, ScalarType dtype, Layout, Device, bool pin_memory) dtype = _get_arange_dtype(args[2]) dtype, end, start, step = symbolic_helper._arange_cast_helper( g, start=args[0], end=args[1], dtype=dtype ) end = symbolic_helper._unsqueeze_helper(g, end, [0]) start = symbolic_helper._unsqueeze_helper(g, start, [0]) range_tensor = _float_step_convert(g.op("Sub", end, start)) arange_tensor = g.op( "Add", symbolic_helper._squeeze_helper( g, nonzero(g, ones(g, range_tensor, dtype, *(args[3:]))), [1] ), start, ) return g.op( "Cast", arange_tensor, to_i=symbolic_helper.scalar_type_to_onnx[dtype] ) else: raise NotImplementedError( "Unknown aten::arange signature taking " + str(len(args)) + " arguments." ) def linspace(g, start, end, steps, dtype, layout, device, pin_memory): range_tensor = symbolic_helper._arange_helper(g, steps, None) step = div( g, sub(g, end, start), sub(g, steps, g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64))), ) return add(g, mul(g, range_tensor, step), start) def lift(g, self): # at::lift() is a no-op from the perspective of tracing for onnx return self def masked_fill(g, self, mask, value): mask = _cast_Bool(g, mask, False) # type: ignore[name-defined] value = symbolic_helper._maybe_get_scalar(value) return g.op("Where", mask, symbolic_helper._if_scalar_type_as(g, value, self), self) def index(g, self, index): if symbolic_helper.is_caffe2_aten_fallback(): return g.at("index", self, index, overload_name="Tensor") if symbolic_helper._is_packed_list(index): indices = symbolic_helper._unpack_list(index) else: indices = [index] def try_mask_to_index(index): if not symbolic_helper._is_none(index) and ( index.type().scalarType() == "Byte" or index.type().scalarType() == "Bool" ): if GLOBALS.export_onnx_opset_version < 9: raise RuntimeError( "Exporting masked indices are only supported after ONNX opset 9." ) warnings.warn( "Exporting aten::index operator with indices of type Byte. " "Only 1-D indices are supported. In any other case, " "this will produce an incorrect ONNX graph." ) index = symbolic_helper._squeeze_helper(g, nonzero(g, index), [1]) return index indices = [try_mask_to_index(idx) for idx in indices] if len(indices) == 1: return symbolic_helper._select_helper( g, self, 0, indices[0], apply_reshape=False ) else: # Multiple tensors as indices. Each tensor could either be # 1. prim::Constant() # representing ":" in python indexing. E.g. tensor[:, :] # 2. prim::Constant[value=...] or tensor output # representing advanced indexing. E.g. tensor[[0, 1], [2, 0]]. # For more info on advanced indexing, # check https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#advanced-indexing # Consider a general case of # t: [x_1, y_1, y_2, ..., x_m, ..., y_n] # where t is a tensor of rank m+n, {x_i} are axes where tensor index is provided, and {y_i} are axes for ":". # Same results can be achieved through transposing t into # t: [x_1, x_2, ..., x_m, y_1, y_2, ..., y_n] # and use gatherND. However ONNX does not have gatherND, to use 1d gather we'll need to flatten t # and process the tensor indices. # t: [x_1 * x_2 * ... * x_m, y_1 * y_2 * ... * y_n] # tensor index = \sum_{i=1}^m (ind_i * \prod_{j=i+1}^m (x_j)) # After gather, reshape and transpose back. adv_idx_indices = [ i for i, idx in enumerate(indices) if not symbolic_helper._is_none(idx) ] if len(adv_idx_indices) == 0: return self elif len(adv_idx_indices) == 1: return index_select( g, self, adv_idx_indices[0], indices[adv_idx_indices[0]] ) else: rank = symbolic_helper._get_tensor_rank(self) if rank is None: raise NotImplementedError( "Unsupported aten::index operator of advanced indexing on tensor of unknown rank, " + "try turning on shape and type propagate during export: " + "torch.onnx._export(..., propagate=True)." ) # TODO: If indexing is supported natively in ONNX in future opsets, # update the warning to recommend exporting with higher opset version. warnings.warn( "Exporting aten::index operator of advanced indexing in opset " + str(GLOBALS.export_onnx_opset_version) + " is achieved by combination of multiple ONNX operators, " + "including Reshape, Transpose, Concat, and Gather. " + "If indices include negative values, the exported graph will produce incorrect results." ) adv_idx_count = len(adv_idx_indices) shape_tensor = _shape_as_tensor(g, self) dim_tensor_list = [ g.op( "Gather", shape_tensor, g.op("Constant", value_t=torch.LongTensor([dim])), axis_i=0, ) for dim in range(rank) ] self = g.op( "Transpose", self, perm_i=adv_idx_indices + [i for i in range(rank) if i not in adv_idx_indices], ) self = g.op("Flatten", self, axis_i=adv_idx_count) # Note that tensor indices will be broadcasted while accumulating. Thus we get the final subarray shape as well. cum_adv_index = indices[adv_idx_indices[-1]] multiplier = dim_tensor_list[adv_idx_indices[-1]] for i in range(adv_idx_count - 2, -1, -1): adv_index = g.op("Mul", indices[adv_idx_indices[i]], multiplier) cum_adv_index = g.op("Add", cum_adv_index, adv_index) multiplier = g.op( "Mul", multiplier, dim_tensor_list[adv_idx_indices[i]] ) # perform gather self = index_select(g, self, 0, cum_adv_index) cum_adv_index_shape_tensor = _shape_as_tensor(g, cum_adv_index) # check if all advanced indices are consecutive. # Refer to https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#combining-advanced-and-basic-indexing # to understand how the subarray position is decided. if adv_idx_indices == list( range(adv_idx_indices[0], adv_idx_indices[-1] + 1) ): # unfold regular index axes folded_adv_idx_shape_list = [ g.op("Constant", value_t=torch.LongTensor([-1])) ] + [ dim_tensor_list[i] for i in range(rank) if i not in adv_idx_indices ] folded_adv_idx_shape = g.op( "Concat", *folded_adv_idx_shape_list, axis_i=0 ) self = symbolic_helper._reshape_helper(g, self, folded_adv_idx_shape) # Transpose folded advanced indexed axis to its original location. adv_idx_permute = ( list(range(1, adv_idx_indices[0] + 1)) + [0] + list(range(adv_idx_indices[0] + 1, rank - adv_idx_count + 1)) ) self = g.op("Transpose", self, perm_i=adv_idx_permute) # unfold advanced index axes final_shape_list = ( [dim_tensor_list[i] for i in range(adv_idx_indices[0])] + [cum_adv_index_shape_tensor] + [ dim_tensor_list[i] for i in range(adv_idx_indices[0], rank) if i not in adv_idx_indices ] ) final_shape = g.op("Concat", *final_shape_list, axis_i=0) else: final_shape = g.op( "Concat", cum_adv_index_shape_tensor, *[ dim_tensor_list[i] for i in range(rank) if i not in adv_idx_indices ], axis_i=0, ) return symbolic_helper._reshape_helper(g, self, final_shape) @symbolic_helper.parse_args("v", "v", "is", "i", "v") def linalg_norm(g, self, ord, dim, keepdim, dtype): # Conditions based on https://pytorch.org/docs/stable/generated/torch.linalg.norm.html ord_value = None if dim is None: if symbolic_helper._is_none(ord): self = symbolic_helper._reshape_helper(g, self, [-1]) ord = g.op("Constant", value_t=torch.LongTensor([2])) self_dim = symbolic_helper._get_tensor_rank(self) if self_dim is None: return symbolic_helper._unimplemented( "dim", "Input rank must be known at export time." ) if self_dim == 1: ord_value = symbolic_helper._parse_arg(ord, "f") else: dim = [0, 1] else: if len(dim) == 1: if symbolic_helper._is_none(ord): ord = g.op("Constant", value_t=torch.LongTensor([2])) ord_value = symbolic_helper._parse_arg(ord, "f") if ord_value: return linalg_vector_norm(g, self, ord_value, dim, keepdim, dtype) return linalg_matrix_norm(g, self, ord, dim, keepdim, dtype) @symbolic_helper.parse_args("v", "f", "is", "i", "v") def linalg_vector_norm(g, self, ord, dim, keepdim, dtype): # Conditions based on https://pytorch.org/docs/stable/generated/torch.linalg.vector_norm.html if dim is None: self = symbolic_helper._reshape_helper(g, self, [-1]) keepdim = None if ord == math.inf: result = g.op("ReduceMax", g.op("Abs", self), axes_i=dim, keepdims_i=keepdim) elif ord == -math.inf: result = g.op("ReduceMin", g.op("Abs", self), axes_i=dim, keepdims_i=keepdim) elif ord == 0: return symbolic_helper._onnx_opset_unsupported_detailed( "linalg_vector_norm", 9, 11, "ord=0 not supported" ) else: ord_op = g.op("Constant", value_t=torch.FloatTensor([ord])) result = symbolic_helper._reducesum_helper( g, g.op("Pow", g.op("Abs", self), ord_op), axes_i=dim, keepdims_i=keepdim ) result = g.op( "Pow", result, g.op("Div", g.op("Constant", value_t=torch.FloatTensor([1])), ord_op), ) return result @symbolic_helper.parse_args("v", "v", "is", "i", "v") def linalg_matrix_norm(g, self, ord, dim, keepdim, dtype): # Conditions based on https://pytorch.org/docs/stable/generated/torch.linalg.matrix_norm.html ord_value = symbolic_helper._parse_arg(ord, "s") if ord_value == "fro": return frobenius_norm(g, self, dim, keepdim) elif ord_value == "nuc": return symbolic_helper._unimplemented("linalg.matrix_norm", "ord==nuc") else: ord_value = symbolic_helper._parse_arg(ord, "f") if ord_value is None: return frobenius_norm(g, self, dim, keepdim) if ord_value == 2 or ord_value == -2: # ord = 2/-2 unimplemented due to lack of operators # used to calculate singular values return symbolic_helper._unimplemented("linalg.matrix_norm", "ord==2") # Wrap the dim vector to handle neagtive dim values self_dim = symbolic_helper._get_tensor_rank(self) if self_dim is None: return symbolic_helper._unimplemented( "linalg.matrix_norm", "Input rank must be known at export time." ) # Common implementation for cases with # ord = 1/-1 and ord = inf/-inf if dim[0] < 0: dim[0] += self_dim if dim[1] < 0: dim[1] += self_dim if ord_value == math.inf or ord_value == -math.inf: dim[0], dim[1] = dim[1], dim[0] if dim[1] > dim[0] and not keepdim: dim[1] -= 1 sum = symbolic_helper._reducesum_helper( g, g.op("Abs", self), axes_i=[dim[0]], keepdims_i=keepdim ) if ord_value > 0: result, indices = max( g, sum, dim_or_y=g.op("Constant", value_t=torch.LongTensor([dim[1]])), keepdim=keepdim, ) else: result, indices = min( g, sum, dim_or_y=g.op("Constant", value_t=torch.LongTensor([dim[1]])), keepdim=keepdim, ) return result @symbolic_helper.parse_args("v", "v", "i") def linalg_cross(g, input, other, dim=-1): return cross(g, input, other, dim) @symbolic_helper.parse_args("v", "is", "i") def frobenius_norm(g, self, dim=None, keepdim=False): sqr = g.op("Mul", self, self) sumsqr = symbolic_helper._reducesum_helper(g, sqr, axes_i=dim, keepdims_i=keepdim) return g.op("Sqrt", sumsqr) @symbolic_helper.parse_args("v", "i", "b", "v") def multinomial(g, input, num_samples, replacement=False, generator=None): if generator is not None and not symbolic_helper._is_none(generator): symbolic_helper._unimplemented( "Multinomial", "generator is not supported for multinomial" ) if not replacement and num_samples > 1: symbolic_helper._unimplemented( "Multinomial", "replacement=False when num_samples > 1 is not supported for multinomial", ) log_input = log(g, input) return g.op( "Multinomial", log_input, dtype_i=symbolic_helper.cast_pytorch_to_onnx["Long"], sample_size_i=num_samples, ) def baddbmm(g, self, batch1, batch2, beta, alpha): dtype = self.type().scalarType() batch_mul = matmul(g, batch1, batch2) mul_a = mul( g, batch_mul, g.op("Cast", alpha, to_i=symbolic_helper.cast_pytorch_to_onnx[dtype]), ) mul_b = mul( g, self, g.op("Cast", beta, to_i=symbolic_helper.cast_pytorch_to_onnx[dtype]) ) return add(g, mul_a, mul_b) @symbolic_helper.parse_args("v", "s") def meshgrid(g, tensor_list, indexing: Optional[str] = None): if indexing is None: indexing = "ij" elif indexing not in {"ij", "xy"}: raise ValueError(f"Unsupported indexing: {indexing}") if indexing == "xy": tensor_list[0], tensor_list[1] = tensor_list[1], tensor_list[0] tensors = [ symbolic_helper._reshape_helper( g, t, g.op("Constant", value_t=torch.LongTensor([-1])) ) for t in symbolic_helper._unpack_list(tensor_list) ] tensors_shape = [g.op("Shape", t) for t in tensors] out_shape = g.op("Concat", *tensors_shape, axis_i=0) out = [] for i, t in enumerate(tensors): shape_i = [g.op("Constant", value_t=torch.ones(1, dtype=torch.int64))] * len( tensors ) shape_i[i] = tensors_shape[i] t_reshaped = _reshape_from_tensor(g, t, g.op("Concat", *shape_i, axis_i=0)) out.append(g.op("Expand", t_reshaped, out_shape)) if indexing == "xy": out[0], out[1] = out[1], out[0] return g.op("prim::ListConstruct", *out) def remainder(g, input, other): div = _floor_divide(g, input, other) quo = g.op("Mul", div, other) return g.op("Sub", input, quo) @symbolic_helper.parse_args("v", "s") def gelu(g, self: torch._C.Value, approximate: str = "none"): if approximate == "tanh": kBeta = math.sqrt(2 / math.pi) kKappa = 0.044715 beta = torch.tensor(kBeta, dtype=torch.double) kappa = torch.tensor(kKappa, dtype=torch.double) one = torch.tensor(1.0, dtype=torch.double) half = torch.tensor(0.5, dtype=torch.double) self_cube = mul(g, self, mul(g, self, self)) inner = mul(g, beta, add(g, self, mul(g, kappa, self_cube))) return mul(g, half, mul(g, self, add(g, one, g.op("Tanh", inner)))) else: _sqrt2 = 1.4142135623730951 erf = g.op("Erf", g.op("Div", self, torch.tensor(_sqrt2, dtype=torch.double))) erf_plusone = add( g, erf, g.op("Constant", value_t=torch.tensor(1, dtype=torch.double)) ) return mul( g, mul(g, self, erf_plusone), g.op("Constant", value_t=torch.tensor(0.5, dtype=torch.double)), ) @symbolic_helper.parse_args("v", "i", "v", "v", "f", "i") def group_norm(g, input, num_groups, weight, bias, eps, cudnn_enabled): if symbolic_helper.is_caffe2_aten_fallback(): return g.at( "group_norm", input, weight, bias, num_groups_i=num_groups, eps_f=eps, cudnn_enabled_i=cudnn_enabled, ) channel_size = symbolic_helper._get_tensor_dim_size(input, 1) if channel_size is not None: assert channel_size % num_groups == 0 input_rank = symbolic_helper._get_tensor_rank(input) if input_rank is None: return symbolic_helper._unimplemented("group_norm", "unknown input rank") # 0 in the shape list keeps dimension value unchanged. shape = [0, num_groups, -1] input_reshaped = symbolic_helper._reshape_helper( g, input, g.op("Constant", value_t=torch.LongTensor(shape)) ) # C is always divisible by num_groups # Due to shape difference. we need to apply weight and bias after # instance norm computation and reshape weight_ = g.op( "Constant", value_t=torch.tensor([1.0] * num_groups).type( "torch." + input.type().scalarType() + "Tensor" ), ) bias_ = g.op( "Constant", value_t=torch.tensor([0.0] * num_groups).type( "torch." + input.type().scalarType() + "Tensor" ), ) norm_reshaped = g.op( "InstanceNormalization", input_reshaped, weight_, bias_, epsilon_f=eps ) norm = symbolic_helper._reshape_helper(g, norm_reshaped, g.op("Shape", input)) if weight is None or weight.node().mustBeNone(): weight_value = torch.tensor([1.0]).type( "torch." + input.type().scalarType() + "Tensor" ) weight = g.op("Constant", value_t=weight_value) if bias is None or bias.node().mustBeNone(): bias_value = torch.tensor([0.0]).type( "torch." + input.type().scalarType() + "Tensor" ) bias = g.op("Constant", value_t=bias_value) # Norm has shape [N, C, *] so we reshape weight and bias to [C, *] axes = list(range(1, input_rank - 1)) return add( g, mul(g, norm, symbolic_helper._unsqueeze_helper(g, weight, axes)), symbolic_helper._unsqueeze_helper(g, bias, axes), ) @symbolic_helper.parse_args("v", "v", "i") def _weight_norm(g, weight_v, weight_g, dim): rank = symbolic_helper._get_tensor_rank(weight_v) if rank is not None: # W = g * ((v) / ||v||) # Compute norm_except_dim for l2 norm. dim = None means over all dims # torch's weight_norm module sets dim = -1 if it's None. # This conflicts the logic for negative axes to access dims backwards # TODO: Might need a fix in torch group_norm module axes = list(range(rank)) if dim is not None: if dim < -1: dim += rank if dim != -1: axes.remove(dim) norm_v = norm(g, weight_v, 2, axes, 1) div = g.op("Div", weight_v, norm_v) return g.op("Mul", div, weight_g) elif symbolic_helper.is_caffe2_aten_fallback(): return g.at("_weight_norm", weight_v, weight_g, dim_i=dim) else: raise RuntimeError( "Unsupported: ONNX export of _weight_norm for tensor " "of unknown rank." ) def dim(g, self): """Implement the dim functionality available for a pytorch tensor in ONNX""" # ONNX does not support dim directly in this opset so we can use 2 ops to get the info shape = g.op("Shape", self) return g.op("Size", shape) def __getitem_(g, self, i): return select(g, self, g.op("Constant", value_t=torch.tensor([0])), i) def item(g, self): return self def take(g, self, index): self_flattened = symbolic_helper._reshape_helper( g, self, g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64)) ) out = index_select(g, self_flattened, 0, index) out = reshape_as(g, out, index) return out def _kl_div_log_target_impl(g, input, target): diff_ = sub(g, target, input) exp_ = exp(g, target) output = mul(g, exp_, diff_) return output def _kl_div_non_log_target_impl(g, input, target): log_ = log(g, target) diff_ = sub(g, log_, input) output_pos = mul(g, target, diff_) zeros_ = zeros_like(g, output_pos) mask_ = gt(g, target, g.op("Constant", value_t=torch.tensor(0))) output = where(g, mask_, output_pos, zeros_) return output @symbolic_helper.parse_args("v", "v", "i", "b") def kl_div(g, input, target, reduction, log_target): if log_target: output = _kl_div_log_target_impl(g, input, target) else: output = _kl_div_non_log_target_impl(g, input, target) if reduction == 0: return output elif reduction == 1: return g.op("ReduceMean", output, keepdims_i=0) elif reduction == 2: return symbolic_helper._reducesum_helper(g, output, keepdims_i=0) else: return symbolic_helper._onnx_unsupported( "kl_div with reduction other than none, mean, or sum." ) @symbolic_helper.parse_args("v", "v", "is", "i") def as_strided(g, self, sizes, strides, offset=None): sizes = symbolic_helper._maybe_get_const(sizes, "is") rank = len(strides) self_1d = symbolic_helper._reshape_helper( g, self, g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64)) ) ind: Optional[torch.Tensor] if not symbolic_helper._is_value(sizes): ind = torch.tensor([0], dtype=torch.long) for i, (size, stride) in enumerate(zip(sizes, strides)): r_size = [1] * rank r_size[i] = -1 ind = ind + torch.arange(size).view(r_size) * stride if offset: ind = ind + offset return g.op("Gather", self_1d, g.op("Constant", value_t=ind)) else: ind = None for i, stride in enumerate(strides): r_size = [1] * rank r_size[i] = -1 size = select( g, sizes, g.op("Constant", value_t=torch.tensor([0])), g.op("Constant", value_t=torch.tensor(i)), ) tmp_ind = symbolic_helper._reshape_helper( g, arange(g, size, 4, None, None, None), g.op("Constant", value_t=torch.tensor(r_size)), ) tmp_ind = g.op( "Mul", tmp_ind, g.op("Constant", value_t=torch.tensor([stride])) ) if ind is None: ind = tmp_ind else: ind = g.op("Add", ind, tmp_ind) if offset: ind = g.op("Add", ind, g.op("Constant", torch.tensor([offset]))) return g.op("Gather", self_1d, ind) def __derive_index(g, index, start, step): return g.op("Add", start, g.op("Mul", index, step)) # Source code for aten op can be found here: pytorch/torch/csrc/jit/runtime/register_prim_ops.cpp # if (step > 0 && lo < hi) { # push(stack, 1 + (hi - 1 - lo) / step); # } else if (step < 0 && lo > hi) { # push(stack, 1 + (lo - 1 - hi) / (0 - step)); # } else { # push(stack, 0); # } def __range_length(g, lo, hi, step): sub = g.op("Sub", hi, lo) div = g.op("Ceil", true_divide(g, sub, step)) return g.op("Cast", div, to_i=symbolic_helper.cast_pytorch_to_onnx["Long"]) def linear(g, input, weight, bias): rank = symbolic_helper._get_tensor_rank(input) weight = t(g, weight) if rank == 2 and not bias.node().mustBeNone(): alpha = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)) beta = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)) output = addmm(g, bias, input, weight, alpha, beta) else: output = matmul(g, input, weight) if not bias.node().mustBeNone(): output = add(g, bias, output) return output @symbolic_helper.parse_args("v", "b", "i", "v", "v", "v", "v") def hann_window( g, window_length, periodic=True, dtype=None, layout=None, device=None, pin_memory=None, requires_grad=False, ): if dtype is None: dtype = torch.get_default_dtype() if not dtype or not dtype.is_floating_point: dtype = torch.float dtype = symbolic_helper.scalar_type_to_pytorch_type.index(dtype) n_array = arange(g, window_length, 4, None, None, None) output = g.op("Cast", n_array, to_i=symbolic_helper.cast_pytorch_to_onnx["Float"]) output = mul( g, g.op("Constant", value_t=torch.tensor(math.pi, dtype=torch.float)), output ) if periodic is False: window_length = sub( g, window_length, g.op("Constant", value_t=torch.tensor(1, dtype=torch.int)) ) output = div(g, output, window_length) output = g.op( "Cast", square(g, sin(g, output)), to_i=symbolic_helper.scalar_type_to_onnx[dtype], ) return output def mv(g, self, vec): return matmul(g, self, vec) def dot(g, self, other): return matmul(g, self, other) @symbolic_helper.parse_args("v", "v") def fill(g, self, value): dtype = self.type().scalarType() if dtype is None: dtype = symbolic_helper.ScalarType.FLOAT else: dtype = symbolic_helper.scalar_type_to_onnx.index( symbolic_helper.cast_pytorch_to_onnx[dtype] ) return full_like(g, self, value, dtype) def index_add(g, self, dim, index, other, alpha=None): warnings.warn( "Warning: ONNX export does not support duplicated values in 'index' field, " + "this will cause the ONNX model to be incorrect." ) # ONNX does not support "alpha" argument, unlike aten index_add # See: https://github.com/pytorch/pytorch/pull/65993#issuecomment-953151102 for more context if alpha and symbolic_helper._scalar(symbolic_helper._maybe_get_scalar(alpha)) != 1: return symbolic_helper._unimplemented("index_add", "alpha != 1") dim = symbolic_helper._maybe_get_const(dim, "i") if dim is None: raise NotImplementedError( "ONNX export does NOT support exporting 'index_add_()' function with " + "unknown 'dim' value." ) self_dim_rank = symbolic_helper._get_tensor_rank(self) other_dim_rank = symbolic_helper._get_tensor_rank(other) if self_dim_rank is None or other_dim_rank is None: raise NotImplementedError( "ONNX export does NOT support exporting 'index_add_()' function while " + "the rank of self tensor or tensor to be added is unknown." ) if other_dim_rank != self_dim_rank: delta = self_dim_rank - other_dim_rank for i in range(delta): other = symbolic_helper._unsqueeze_helper( g, other, [symbolic_helper._get_tensor_rank(other)] ) other_dim_size = symbolic_helper._get_tensor_dim_size(other, dim) self_dim_size = symbolic_helper._get_tensor_dim_size(self, dim) if (other_dim_size is not None) and (self_dim_size is not None): if other_dim_size > self_dim_size: raise NotImplementedError( "ONNX export does NOT support exporting 'index_add_()' function with " + "duplicated values in 'index' parameter yet." ) # Construct a new shape. It's almost as same as self except the size of the 'dim' # dimension is 1, so that we can expand other dimensions as expected. new_shape_axes = list(range(self_dim_rank)) new_shape_starts = [0 for i in range(self_dim_rank)] new_shape_ends = [sys.maxsize if (i != dim) else 1 for i in range(self_dim_rank)] new_shape = symbolic_helper._slice_helper( g, self, axes=new_shape_axes, starts=new_shape_starts, ends=new_shape_ends ) other = expand_as(g, other, new_shape) for i in range(dim): index = symbolic_helper._unsqueeze_helper(g, index, [0]) for i in range(self_dim_rank - dim - 1): index = symbolic_helper._unsqueeze_helper( g, index, [symbolic_helper._get_tensor_rank(index)] ) return scatter_add(g, self, dim, expand_as(g, index, other), other) @symbolic_helper.parse_args("v", "is", "is") def roll(g, self, shifts, dims): assert len(shifts) == len(dims) result = self for i in range(len(shifts)): shapes = [] shape = symbolic_helper._slice_helper( g, result, axes=[dims[i]], starts=[-shifts[i]], ends=[sys.maxsize] ) shapes.append(shape) shape = symbolic_helper._slice_helper( g, result, axes=[dims[i]], starts=[0], ends=[-shifts[i]] ) shapes.append(shape) result = g.op("Concat", *shapes, axis_i=dims[i]) return result @symbolic_helper.parse_args("v", "v", "i") def cross(g, input, other, dim=None): dim = symbolic_helper._get_dim_for_cross(input, dim) # If we have two tensors such that # A = [a, b, c], B = [d, e, f], we permute the tensor such that we have # After first roll, # A' = [b, c, a], B' = [f, d, e], so that we calculate (b*f, c*d, a*e) roll_x_1 = roll(g, input, [2], [dim]) roll_y_1 = roll(g, other, [1], [dim]) # After second roll, # A' = [c, a, b], B' = [e, f, d], so that we calculate (c*e, a*f, b*d) roll_x_2 = roll(g, input, [1], [dim]) roll_y_2 = roll(g, other, [2], [dim]) # cross product is calculated as # result = [(b*f - c*e), (c*d - a*f), (a*e - b*d)] return sub(g, mul(g, roll_x_1, roll_y_1), mul(g, roll_x_2, roll_y_2)) def cdist(g, x1, x2, p=2.0, compute_mode="use_mm_for_euclid_dist_if_necessary"): # X1.shape = (B * P * D), X2.shape = (B * R * D) # In order to respect numpy style broadcasting as demonstrated in # https://github.com/onnx/onnx/blob/main/docs/Broadcasting.md # we unsqueeze both input tensors # Currently we ignore the 'compute_mode' variable as we use default to # using matrix multiplication to calculate the euclidean distance rank = symbolic_helper._get_tensor_rank(x1) broadcasted_x1 = symbolic_helper._unsqueeze_helper(g, x1, [rank - 1]) broadcasted_x2 = symbolic_helper._unsqueeze_helper(g, x2, [rank - 2]) return pairwise_distance( g, broadcasted_x1, broadcasted_x2, p, eps=1e-06, keepdim=False ) def broadcast_tensors(g, self): all_tensors = symbolic_helper._unpack_list(self) t_with_final_shape = zeros_like(g, all_tensors[0]) # Add operator supports multidirectional broadcasting. So we leverage this function # to infer the final shape generated by the broadcast. for t in all_tensors: t_with_final_shape = add(g, t_with_final_shape, t) t_list = [expand_as(g, t, t_with_final_shape) for t in all_tensors] return g.op("prim::ListConstruct", *t_list) class Prim: domain = "prim" @staticmethod def ConstantSplit(g, self, split_size, dim): size = symbolic_helper._get_tensor_dim_size(self, dim) if size is None: return symbolic_helper._unimplemented( "prim::ConstantSplit", "unknown dimension size" ) splits = [split_size] * (size // split_size) leftover = size % split_size if leftover: splits.append(leftover) return g.op("Split", self, split_i=splits, axis_i=dim, outputs=len(splits)) # TODO: It would be better to export this as a chunk directly, as this is # less sensitive to changes in input size. # TODO: Once we have proper scoping, stop reimplementing chunk, delete this # method, and use the desugared version @staticmethod def ConstantChunk(g, self, chunks, dim): dim_size = symbolic_helper._get_tensor_dim_size(self, dim) if dim_size is None: return symbolic_helper._unimplemented( "prim::ConstantChunk", "unknown dimension size" ) split_size = (dim_size + chunks - 1) // chunks return Prim.ConstantSplit(g, self, split_size, dim) @staticmethod def shape(g, self): return g.op("Shape", self) @staticmethod def max(g, self, other): return op_with_optional_float_cast(g, "Max", self, other, opset_before=12) @staticmethod def min(g, self, other=None): if not other: if symbolic_helper._is_packed_list(self): self = stack(g, self, g.op("Constant", value_t=torch.tensor([0]))) return min(g, self) return min(g, self, other) @staticmethod def data(g, self): return self @staticmethod def ListConstruct(g, *inputs, **kwargs): return None @staticmethod def ListUnpack(g, *inputs, **kwargs) -> Optional[List[_C.Value]]: if len(inputs) == 1 and inputs[0].node().kind() == "prim::ListConstruct": # Cancel the previous node if it is ListConstruct by returning its inputs # TODO(justinchuby): Use a public method in the helper module return symbolic_helper._unpack_list(inputs[0]) return None @staticmethod def TupleConstruct(g, *inputs, **kwargs): return None @staticmethod def Uninitialized(g, *inputs, **kwargs): return None # exists to refine the type of the Value # if x is an optional Tensor, unchecked_cast will cast # x to Tensor, so the rest of the graph knows that x is a Tensor # this doesn't do anything in runtime and is a noop in ONNX @staticmethod def unchecked_cast(g, self): return self @staticmethod def dtype(g, self): dtype = symbolic_helper._try_get_scalar_type(self) if dtype is None: dtype = "Float" dtype = symbolic_helper.scalar_type_to_onnx.index( symbolic_helper.cast_pytorch_to_onnx[dtype] ) return g.op("Constant", value_t=torch.tensor(dtype)) # tolist is currently supported only for 1D input tensors. # dim_val and elem_ty_val represent dimension and type annotations # that need to match dimension and type of the input tensor. @staticmethod def tolist(g, input, dim_val, elem_ty_val): dim = symbolic_helper._maybe_get_const(dim_val, "i") if dim > 1: return symbolic_helper._unimplemented("prim::tolist", "dim_val > 1") return input # ----------------------------------------------------------------------------- # Symbolic functions that need extra context # ----------------------------------------------------------------------------- @staticmethod def device(ctx: torch.onnx.SymbolicContext, g, *inputs, **kwargs): n = ctx.cur_node if n.output().type().kind() == "DeviceObjType": return None return symbolic_helper._unimplemented( "prim::device", "output type is not `DeviceObjType`." ) @staticmethod def Loop(ctx: torch.onnx.SymbolicContext, g, *inputs, **attrs): n = ctx.cur_node env = ctx.env params_dict = ctx.params_dict operator_export_type = GLOBALS.operator_export_type opset_version = GLOBALS.export_onnx_opset_version new_op_outputs = g.op("Loop", *inputs, outputs=n.outputsSize()) new_node = ( new_op_outputs[0].node() if n.outputsSize() > 1 else new_op_outputs.node() ) for b in n.blocks(): new_block = new_node.addBlock() # Copy input metadata to subblock # # prim::Loop(iter, cond, input_1, ..., input_n) # block0(iter, input_1, ..., input_n) # # For `Loop` node, copy metadata for `iter`, `input_1`, ..., `input_n`. for i, b_in in enumerate(b.inputs()): if i == 0 and i < len(inputs): b_in.setType(inputs[i].type()) # For optional block inputs, they may switch between None not-None inside # the loop body, so if the loop input is not optional, the block input may # still need to be optional. if ( i > 0 and (i + 1) < len(inputs) and not isinstance(b_in.type(), _C.OptionalType) ): b_in.setType(inputs[i + 1].type()) torch._C._jit_pass_onnx_block( b, new_block, operator_export_type, env, False # type:ignore[arg-type] ) new_op_outputs = torch._C._jit_pass_fixup_onnx_controlflow_node( new_node, opset_version ) # Run shape type inference for Loop after subblock is converted. if GLOBALS.onnx_shape_inference: torch._C._jit_pass_onnx_node_shape_type_inference( new_node, params_dict, opset_version ) return new_op_outputs @staticmethod def If(ctx: torch.onnx.SymbolicContext, g, *inputs, **attrs): n = ctx.cur_node block = ctx.onnx_block env = ctx.env params_dict = ctx.params_dict operator_export_type = GLOBALS.operator_export_type opset_version = GLOBALS.export_onnx_opset_version static_if = inputs[0].node().kind() == "onnx::Constant" if static_if: # Fold static if # # The torch IR # graph(%embedding_matrix.1 : Float(10, 15, strides=[15, 1], requires_grad=0, device=cpu), # %input.1 : Long(6, strides=[1], requires_grad=0, device=cpu), ... # %65 : Bool(requires_grad=0, device=cpu) = prim::Constant[value={0}]() # %21 : Long(device=cpu) = aten::eq(%20, %64) # %22 : Long(device=cpu) = prim::If(%21) # block0(): # %23 : Long(device=cpu) = aten::is_floating_point(%input.1) # -> (%23) # block1(): # -> (%65) # %input.53 : Tensor, %weight : Tensor = prim::If(%22) # block0(): # -> (%embedding_matrix.1, %input.1) # block1(): # -> (%input.1, %embedding_matrix.1) # %26 : int[] = aten::size(%input.53) # # The converted ONNX graph # %10 : Bool(device=cpu) = onnx::Constant[value={0}]() # %14 : Bool(device=cpu) = onnx::Equal(%13, %8) # %15 : Bool(requires_grad=0, device=cpu) = onnx::Constant[value={0}]() # %16 : Long(1, strides=[1], device=cpu) = onnx::Shape(%input.1) input_flag = inputs[0].node()["value"].tolist() const_value = ( all(input_flag) if isinstance(input_flag, list) else bool(input_flag) ) block_idx = 0 if const_value else 1 current_b = list(n.blocks())[block_idx] env = torch._C._jit_pass_onnx_block( current_b, block, operator_export_type, # type:ignore[arg-type] env, # type:ignore[arg-type] True, ) if_output_list = list(n.outputs()) current_b_list = list(current_b.outputs()) final_b_list = [] for idx in range(len(if_output_list)): if current_b_list[idx] not in env: raise RuntimeError( "The sub block ATen output {}" " is not in env.".format(current_b_list[idx]) ) # type:ignore[operator] onnx_b = env[current_b_list[idx]] final_b_list.append(onnx_b) return final_b_list else: new_op_outputs = g.op("If", *inputs, outputs=n.outputsSize()) new_node = ( new_op_outputs[0].node() if n.outputsSize() > 1 else new_op_outputs.node() ) for b in n.blocks(): new_block = new_node.addBlock() torch._C._jit_pass_onnx_block( b, new_block, operator_export_type, # type:ignore[arg-type] env, False, ) new_op_outputs = torch._C._jit_pass_fixup_onnx_controlflow_node( new_node, opset_version ) # Run shape type inference for If after subblock is converted. if GLOBALS.onnx_shape_inference: torch._C._jit_pass_onnx_node_shape_type_inference( new_node, params_dict, opset_version ) return new_op_outputs @staticmethod def Constant(ctx: torch.onnx.SymbolicContext, g, *inputs, **attrs): n = ctx.cur_node if n.mustBeNone(): return None # This must go before checking for string values, because some device constants # have string values, but we want to keep them as unconverted Device types so # that eq() can work on them. if isinstance(n.output().type(), _C.DeviceObjType): return None if n.kindOf("value") == "t": return g.op("Constant", value_t=n["value"]) if n.kindOf("value") == "s": return g.op("Constant", value_s=n["value"]) elif n.output().type().isSubtypeOf( _C.ListType.ofInts() ) or n.output().type().isSubtypeOf(_C.ListType.ofFloats()): return g.op("Constant", value_t=torch.tensor(n["value"])) else: raise RuntimeError( "Unsupported prim::Constant kind: `{}`. Send a bug report.".format( n.kindOf("value") ) ) class Onnx: domain = "onnx" # ----------------------------------------------------------------------------- # Symbolic functions that need extra context # ----------------------------------------------------------------------------- @staticmethod def Placeholder(ctx: torch.onnx.SymbolicContext, g, *inputs, **attrs): n = ctx.cur_node block = ctx.onnx_block env = ctx.env return torch._C._jit_onnx_convert_pattern_from_subblock(block, n, env)