import sys import warnings import torch from torch.onnx import symbolic_helper from torch.onnx import symbolic_opset9 as opset9 from torch.onnx import utils # EDITING THIS FILE? READ THIS FIRST! # see Note [Edit Symbolic Files] in symbolic_helper.py # This file exports ONNX ops for opset 12 def einsum_helper(g, equation, tensors): if not tensors: raise RuntimeError("Einsum inputs are empty.") # ONNX does not support bool for Einsum inputs. if tensors[0].type().scalarType() == "Bool": tensors = [ g.op("Cast", tensor, to_i=symbolic_helper.cast_pytorch_to_onnx["Long"]) for tensor in tensors ] return g.op( "Cast", g.op("Einsum", *tensors, equation_s=equation), to_i=symbolic_helper.cast_pytorch_to_onnx["Bool"], ) else: return g.op("Einsum", *tensors, equation_s=equation) @symbolic_helper.parse_args("s", "v") def einsum(g, equation, tensor_list): tensors = symbolic_helper._unpack_list(tensor_list) return einsum_helper(g, equation, tensors) @symbolic_helper.parse_args("v", "v") def outer(g, input, other): # make sure to cast other to self's type if other.type().scalarType() != input.type().scalarType(): other = g.op( "Cast", other, to_i=symbolic_helper.cast_pytorch_to_onnx[input.type().scalarType()], ) return einsum_helper(g, "i,j->ij", [input, other]) @symbolic_helper.parse_args("v", "f", "i") def dropout(g, input, p, train): symbolic_helper.check_training_mode(train, "dropout") # in eval mode, dropout is non-op - if the node's train param is set to False, dropout is non-op if not train: return input warnings.warn( "Dropout is a training op and should not be exported in inference mode. " "For inference, make sure to call eval() on the model and to export it with param training=False." ) p = g.op("Constant", value_t=torch.tensor(p)) t = g.op("Constant", value_t=torch.tensor(True)) r, _ = g.op("Dropout", input, p, t, outputs=2) return r def nll_loss(g, self, target, weight, reduction, ignore_index): # none reduction : onnx::Constant[value={0}] # mean reduction : onnx::Constant[value={1}] # sum reduction : onnx::Constant[value={2}] reduction = symbolic_helper._maybe_get_const(reduction, "i") reduction_vals = ["none", "mean", "sum"] reduction = reduction_vals[reduction] # in onnx NegativeLogLikelihoodLoss specification, ignore_index is optional without default value. # therefore we need to set ignore_index attribute even if it is not specified (e.g. ignore_index=-100). ignore_index = symbolic_helper._maybe_get_const(ignore_index, "i") if weight.node().mustBeNone(): nllloss = g.op( "NegativeLogLikelihoodLoss", self, target, reduction_s=reduction, ignore_index_i=ignore_index, ) else: nllloss = g.op( "NegativeLogLikelihoodLoss", self, target, weight, reduction_s=reduction, ignore_index_i=ignore_index, ) return nllloss def nll_loss2d(g, self, target, weight, reduction, ignore_index): return nll_loss(g, self, target, weight, reduction, ignore_index) def nll_loss_nd(g, self, target, weight, reduction, ignore_index): return nll_loss(g, self, target, weight, reduction, ignore_index) def cross_entropy_loss( g, self, target, weight, reduction, ignore_index, label_smoothing ): # none reduction : onnx::Constant[value={0}] # mean reduction : onnx::Constant[value={1}] # sum reduction : onnx::Constant[value={2}] reduction = symbolic_helper._maybe_get_const(reduction, "i") reduction_vals = ["none", "mean", "sum"] reduction = reduction_vals[reduction] label_smoothing = symbolic_helper._maybe_get_const(label_smoothing, "f") if label_smoothing > 0.0: raise RuntimeError("Unsupported: ONNX does not support label_smoothing") # in onnx SoftmaxCrossEntropyLoss specification, ignore_index is optional without default value. # therefore we need to set ignore_index attribute even if it is not specified (e.g. ignore_index=-100). ignore_index = symbolic_helper._maybe_get_const(ignore_index, "i") if weight.node().mustBeNone(): celoss = g.op( "SoftmaxCrossEntropyLoss", self, target, reduction_s=reduction, ignore_index_i=ignore_index, ) else: celoss = g.op( "SoftmaxCrossEntropyLoss", self, target, weight, reduction_s=reduction, ignore_index_i=ignore_index, ) return celoss @symbolic_helper.parse_args("v", "v", "v", "v", "i") def binary_cross_entropy_with_logits(g, input, target, weight, pos_weight, reduction): p = g.op("Constant", value_t=torch.tensor([1])) sig_x = opset9.sigmoid(g, input) log_sig_x = opset9.log(g, sig_x) sub_1_x = opset9.sub(g, p, sig_x) sub_1_y = opset9.sub(g, p, target) log_1_x = opset9.log(g, sub_1_x) if pos_weight is None or symbolic_helper._is_none(pos_weight): output = opset9.neg( g, opset9.add( g, opset9.mul(g, target, log_sig_x), opset9.mul(g, sub_1_y, log_1_x) ), ) else: output = opset9.neg( g, opset9.add( g, opset9.mul(g, opset9.mul(g, target, log_sig_x), pos_weight), opset9.mul(g, sub_1_y, log_1_x), ), ) if weight is not None and not symbolic_helper._is_none(weight): output = opset9.mul(g, weight, output) reduction = symbolic_helper._maybe_get_const(reduction, "i") if reduction == 0: return output elif reduction == 1: return g.op("ReduceMean", output, keepdims_i=0) elif reduction == 2: return g.op("ReduceSum", output, keepdims_i=0) else: return symbolic_helper._onnx_unsupported( "binary_cross_entropy_with_logits with reduction other than none, mean, or sum" ) def celu(g, self, alpha): alpha = symbolic_helper._maybe_get_const(alpha, "f") # if the input is of type double cast it to float if self.type().scalarType() == "Double": self = g.op("Cast", self, to_i=symbolic_helper.cast_pytorch_to_onnx["Float"]) out = g.op("Celu", self, alpha_f=alpha) return g.op("Cast", out, to_i=symbolic_helper.cast_pytorch_to_onnx["Double"]) return g.op("Celu", self, alpha_f=alpha) def argmax(g, input, dim, keepdim): if symbolic_helper._is_none(dim): flattened = symbolic_helper._reshape_helper( g, input, g.op("Constant", value_t=torch.tensor([-1])) ) return g.op( "ArgMax", flattened, axis_i=0, keepdims_i=False, select_last_index_i=False ) else: dim = symbolic_helper._parse_arg(dim, "i") keepdim = symbolic_helper._parse_arg(keepdim, "i") return g.op( "ArgMax", input, axis_i=dim, keepdims_i=keepdim, select_last_index_i=False ) def argmin(g, input, dim, keepdim): if symbolic_helper._is_none(dim): flattened = symbolic_helper._reshape_helper( g, input, g.op("Constant", value_t=torch.tensor([-1])) ) return g.op( "ArgMin", flattened, axis_i=0, keepdims_i=False, select_last_index_i=False ) else: dim = symbolic_helper._parse_arg(dim, "i") keepdim = symbolic_helper._parse_arg(keepdim, "i") return g.op( "ArgMin", input, axis_i=dim, keepdims_i=keepdim, select_last_index_i=False ) def pow(g, self, exponent): return g.op("Pow", self, exponent) def ge(g, input, other): return g.op("GreaterOrEqual", input, other) def le(g, input, other): return g.op("LessOrEqual", input, other) @symbolic_helper.parse_args("v", "i", "v", "v") def unfold(g, input, dimension, size, step): const_size = symbolic_helper._maybe_get_const(size, "i") const_step = symbolic_helper._maybe_get_const(step, "i") if not symbolic_helper._is_value(const_size) and not symbolic_helper._is_value( const_step ): return opset9.unfold(g, input, dimension, const_size, const_step) if symbolic_helper.is_caffe2_aten_fallback(): return g.at("unfold", input, dimension_i=dimension, size_i=size, step_i=step) sizedim = symbolic_helper._get_tensor_dim_size(input, dimension) if sizedim is not None: low_start = g.op("Constant", value_t=torch.tensor(0)) low_end = g.op("Constant", value_t=torch.tensor(sizedim)) hi_end = g.op("Constant", value_t=torch.tensor(sizedim + 1)) low_indices = g.op("Range", low_start, low_end, step) hi_indices = g.op("Range", size, hi_end, step) low_size = symbolic_helper._size_helper( g, low_indices, g.op("Constant", value_t=torch.tensor(0)) ) hi_size = symbolic_helper._size_helper( g, hi_indices, g.op("Constant", value_t=torch.tensor(0)) ) ndim = symbolic_helper._get_tensor_rank(input) perm = list(range(0, ndim)) perm.append(perm.pop(dimension)) unsqueeze_list = [] loop_condition = g.op("Constant", value_t=torch.tensor(1)) loop_condition = g.op("Cast", loop_condition, to_i=9) loop_len = g.op("Min", low_size, hi_size) loop = g.op("Loop", loop_len, loop_condition) loop_block = utils._add_block(loop.node()) block_input_iter = utils._add_input_to_block(loop_block) cond = utils._add_input_to_block(loop_block) starts = loop_block.op("Gather", low_indices, block_input_iter) ends = loop_block.op("Gather", hi_indices, block_input_iter) axes = loop_block.op("Constant", value_t=torch.tensor([2])) starts = symbolic_helper._unsqueeze_helper(loop_block, starts, [0]) ends = symbolic_helper._unsqueeze_helper(loop_block, ends, [0]) stack = loop_block.op("Slice", input, starts, ends, axes) unsqueeze = symbolic_helper._unsqueeze_helper( loop_block, loop_block.op("Transpose", stack, perm_i=perm), [dimension] ) unsqueeze_list.append(unsqueeze) concat = loop_block.op("Concat", *unsqueeze_list, axis_i=0) cond_out = loop_block.op("Cast", loop_condition, to_i=9) utils._add_output_to_block(loop_block, cond_out) utils._add_output_to_block(loop_block, concat) loop_output = loop.node().output() perm = [0, 1, 2, 3, 4] perm[0], perm[dimension + 1] = perm[dimension + 1], perm[0] transpose = g.op("Transpose", loop_output, perm_i=perm) squeeze = symbolic_helper._squeeze_helper(g, transpose, [0]) return squeeze else: return symbolic_helper._unimplemented("Unfold", "input size not accessible") @symbolic_helper.parse_args("v", "v", "is", "is", "v") def tensordot(g, input_a, input_b, dims_a, dims_b, out=None): if out is not None: symbolic_helper._unimplemented( "Tensordot", "Out parameter is not supported for tensordot." ) dim_count_a = symbolic_helper._get_tensor_rank(input_a) if dim_count_a is None: raise RuntimeError( "Unsupported: ONNX export of tensordot for tensor(input_a) of unknown rank." ) dim_count_b = symbolic_helper._get_tensor_rank(input_b) if dim_count_b is None: raise RuntimeError( "Unsupported: ONNX export of tensordot for tensor(input_b) of unknown rank." ) dims_a = [ (dims_a[i] + dim_count_a) if (dims_a[i] < 0) else dims_a[i] for i in range(len(dims_a)) ] dims_b = [ (dims_b[i] + dim_count_b) if (dims_b[i] < 0) else dims_b[i] for i in range(len(dims_b)) ] left_dims_a = [i for i in range(dim_count_a) if (i not in dims_a)] left_dims_b = [i for i in range(dim_count_b) if (i not in dims_b)] new_input_a = opset9.permute(g, input_a, left_dims_a + dims_a) new_input_b = opset9.permute(g, input_b, dims_b + left_dims_b) input_shape = g.op("Shape", new_input_a) left_sizes_a = symbolic_helper._slice_helper( g, input_shape, axes=[0], starts=[0], ends=[len(left_dims_a)] ) shape_sizes = [ left_sizes_a, g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)), ] output_a = opset9._reshape_from_tensor(g, new_input_a, shape_sizes) input_shape = g.op("Shape", output_a) slices = symbolic_helper._slice_helper( g, input_shape, axes=[0], starts=[-1], ends=[sys.maxsize] ) shape_sizes = [ g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)), slices, ] output_a = opset9._reshape_from_tensor(g, new_input_a, shape_sizes) input_shape = g.op("Shape", new_input_b) left_sizes_b = symbolic_helper._slice_helper( g, input_shape, axes=[0], starts=[len(dims_b)], ends=[sys.maxsize] ) slices = symbolic_helper._slice_helper( g, input_shape, axes=[0], starts=[0], ends=[len(dims_b)] ) shape_sizes = [ slices, g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)), ] output_b = opset9._reshape_from_tensor(g, new_input_b, shape_sizes) input_shape = g.op("Shape", output_b) slices = symbolic_helper._slice_helper( g, input_shape, axes=[0], starts=[-1], ends=[sys.maxsize] ) shape_sizes = [ g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)), slices, ] output_b = opset9._reshape_from_tensor(g, new_input_b, shape_sizes) output = einsum(g, "ij,jk->ik", g.op("prim::ListConstruct", *[output_a, output_b])) shape_sizes = [left_sizes_a, left_sizes_b] return opset9._reshape_from_tensor(g, output, shape_sizes)