backend.py 41 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967
  1. ## @package onnx
  2. # Module caffe2.python.onnx.backend
  3. """Backend for running ONNX on Caffe2
  4. To run this, you will need to have Caffe2 installed as well.
  5. """
  6. import collections
  7. import sys
  8. import zipfile
  9. import itertools
  10. # When onnx is built against a version of protobuf that is older than
  11. # that which is vendored with caffe2, onnx will crash if caffe2's
  12. # vendored protobuf is loaded first. We can work around this by
  13. # importing onnx first, which will cause it to go out and pick up the
  14. # system protobuf.
  15. import onnx.backend
  16. from caffe2.python import core, workspace, rnn_cell, gru_cell
  17. from caffe2.python.model_helper import ModelHelper
  18. from caffe2.proto import caffe2_pb2
  19. import caffe2.python.utils
  20. import numpy as np
  21. import onnx
  22. from onnx import TensorProto
  23. import onnx.numpy_helper
  24. import onnx.defs
  25. import onnx.shape_inference
  26. import onnx.utils
  27. from onnx.backend.base import Backend, Device, DeviceType, namedtupledict
  28. from caffe2.python.onnx.workspace import Workspace
  29. from caffe2.python.onnx.backend_rep import Caffe2Rep
  30. import caffe2.python._import_c_extension as C
  31. import warnings
  32. def force_unicode(s):
  33. try:
  34. return s.decode('utf-8')
  35. except AttributeError:
  36. return s
  37. def get_device_option(device):
  38. m = {DeviceType.CPU: caffe2_pb2.CPU,
  39. DeviceType.CUDA: workspace.GpuDeviceType}
  40. return core.DeviceOption(m[device.type], device.device_id)
  41. class OnnxAttributes(dict):
  42. """
  43. This is a more convenient way to work with ONNX/Caffe2 attributes
  44. that is not the protobuf representation.
  45. """
  46. @staticmethod
  47. def from_onnx(args):
  48. d = OnnxAttributes()
  49. for arg in args:
  50. d[arg.name] = convertAttributeProto(arg)
  51. return d
  52. def caffe2(self, kmap=lambda k: k):
  53. for k, v in self.items():
  54. if kmap(k) != '':
  55. yield caffe2.python.utils.MakeArgument(kmap(k), v)
  56. # TODO: Move this into ONNX main library
  57. def convertAttributeProto(onnx_arg):
  58. """
  59. Convert an ONNX AttributeProto into an appropriate Python object
  60. for the type.
  61. NB: Tensor attribute gets returned as the straight proto.
  62. """
  63. if onnx_arg.HasField('f'):
  64. return onnx_arg.f
  65. elif onnx_arg.HasField('i'):
  66. return onnx_arg.i
  67. elif onnx_arg.HasField('s'):
  68. return onnx_arg.s
  69. elif onnx_arg.HasField('t'):
  70. return onnx_arg.t # this is a proto!
  71. elif onnx_arg.HasField('g'):
  72. return Caffe2Backend._graph_to_net(onnx_arg.g, Caffe2Backend._known_opset_version)
  73. elif len(onnx_arg.floats):
  74. return list(onnx_arg.floats)
  75. elif len(onnx_arg.ints):
  76. return list(onnx_arg.ints)
  77. elif len(onnx_arg.strings):
  78. return list(onnx_arg.strings)
  79. elif len(onnx_arg.graphs):
  80. retval = []
  81. # TODO: this doesn't work with RNN ops
  82. for g in onnx_arg.graphs:
  83. retval.append(Caffe2Backend._graph_to_net(g, Caffe2Backend._known_opset_version))
  84. return retval
  85. else:
  86. raise ValueError("Unsupported ONNX attribute: {}".format(onnx_arg))
  87. # TODO: Move this into ONNX main library
  88. class OnnxNode(object):
  89. """
  90. Reimplementation of NodeProto from ONNX, but in a form
  91. more convenient to work with from Python.
  92. We may temporarily edit these nodes to get them into Caffe2 form,
  93. before actually translating into the Caffe2 protobuf, since this
  94. is easier than decomposing everything, and putting it back together
  95. when we're ready.
  96. """
  97. def __init__(self, node):
  98. self.name = str(node.name)
  99. self.op_type = str(node.op_type)
  100. self.attrs = OnnxAttributes.from_onnx(node.attribute)
  101. self.inputs = list(node.input)
  102. self.outputs = list(node.output)
  103. Caffe2Ops = collections.namedtuple('Caffe2Ops', ['ops', 'init_ops', 'interface_blobs'])
  104. class Caffe2Backend(Backend):
  105. # The greatest version of the ONNX operator set which we are aware of.
  106. # Models whose version is larger than this will cause us to emit a warning
  107. # that we are attempting to translate on a "best effort" basis.
  108. #
  109. # If you increase this, make SURE you cross-reference all BC-breaking
  110. # changes from one version to the next, and any that you did not
  111. # implement, mark as broken in _broken_operators
  112. _known_opset_version = 9
  113. # This dictionary will record operators which are KNOWN to be
  114. # broken, so we give a good error message rather than do something
  115. # bogus and then fail.
  116. _broken_operators = {
  117. # 'BrokenOp': version_it_was_broken_in
  118. }
  119. # Operators that are different between Caffe2 and
  120. # ONNX but only in their name.
  121. # In most cases, this should be empty - as the effort of ONNX is
  122. # to unify the operator definitions.
  123. _renamed_operators = {
  124. 'GlobalMaxPool': 'MaxPool',
  125. 'GlobalAveragePool': 'AveragePool',
  126. 'Pad': 'PadImage',
  127. 'Neg': 'Negative',
  128. 'BatchNormalization': 'SpatialBN',
  129. 'InstanceNormalization': 'InstanceNorm',
  130. 'MatMul': 'BatchMatMul',
  131. 'Upsample': 'ResizeNearest',
  132. 'Identity': 'Copy',
  133. 'InstanceNormalization': 'InstanceNorm',
  134. 'Equal': 'EQ',
  135. 'Less': 'LT',
  136. 'Greater': 'GT',
  137. 'Unsqueeze': 'ExpandDims',
  138. 'Loop': 'ONNXWhile',
  139. 'Tile': 'NumpyTile',
  140. 'RandomNormal': 'GaussianFill',
  141. 'RandomUniform': 'UniformFill',
  142. }
  143. _global_renamed_attrs = {'kernel_shape': 'kernels'}
  144. _per_op_renamed_attrs = {
  145. 'Squeeze': {'axes': 'dims'},
  146. 'Unsqueeze': {'axes': 'dims'},
  147. 'Transpose': {'perm': 'axes'},
  148. 'Upsample': {'mode': '',
  149. 'scales': ''},
  150. 'ConvTranspose': {'output_padding': 'adjs'},
  151. 'Selu': {'gamma': 'scale'},
  152. 'If': {'then_branch': 'then_net',
  153. 'else_branch': 'else_net'},
  154. 'RandomUniform': {'low': 'min',
  155. 'high': 'max'}
  156. }
  157. # operators whose behavior is different beyond renaming
  158. # the value is an attribute of this class that is a
  159. # function from ToffeIR node_def to caffe2 op_def
  160. _special_operators = {
  161. 'LSTM': '_create_rnn_variant',
  162. 'GRU': '_create_rnn_variant',
  163. 'RNN': '_create_rnn_variant',
  164. 'Loop': '_create_loop',
  165. 'If': '_create_if',
  166. 'Upsample': '_create_upsample',
  167. 'RandomNormal': '_create_gaussian_fill'
  168. }
  169. # Dummy name generator
  170. _dummy_name = C.DummyName()
  171. @classmethod
  172. def dummy_name(cls):
  173. return cls._dummy_name.new_dummy_name()
  174. # NB: By default, you will use the LATEST definition of the operator,
  175. # so this interface MAY make BC-breaking changes. Specify an
  176. # opset_version if you don't want this to version.
  177. @classmethod
  178. def run_node(cls, node, inputs, device='CPU', opset_version=_known_opset_version, outputs_info=None):
  179. super(Caffe2Backend, cls).run_node(node, inputs, device=device,
  180. outputs_info=outputs_info, opset_version=opset_version)
  181. value_infos = []
  182. device_option = get_device_option(Device(device))
  183. ws = Workspace()
  184. with core.DeviceScope(device_option): # temporary!
  185. if isinstance(inputs, dict):
  186. for key, value in inputs.items():
  187. ws.FeedBlob(key, value)
  188. value_infos.append(onnx.helper.make_tensor_value_info(
  189. name=key,
  190. elem_type=onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[value.dtype],
  191. shape=value.shape).SerializeToString())
  192. else:
  193. assert len(node.input) == len(inputs), "{}: expected {} but got {}".format(
  194. node.op_type, len(node.input), len(inputs))
  195. for key, value in zip(node.input, inputs):
  196. ws.FeedBlob(key, value)
  197. value_infos.append(onnx.helper.make_tensor_value_info(
  198. name=key,
  199. elem_type=onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[value.dtype],
  200. shape=value.shape).SerializeToString())
  201. ops = []
  202. cbackend = C.Caffe2Backend(cls._dummy_name)
  203. ops_str = cbackend.convert_node(node.SerializeToString(), value_infos, opset_version)
  204. for s in ops_str[0] + ops_str[1]:
  205. op = caffe2_pb2.OperatorDef()
  206. op.ParseFromString(s)
  207. op.device_option.CopyFrom(device_option)
  208. ops.append(op)
  209. ws.RunOperatorsOnce(ops)
  210. output_values = [ws.FetchBlob(name) for name in node.output]
  211. return namedtupledict('Outputs', node.output)(*output_values)
  212. @classmethod
  213. def _create_tensor_filling_op(cls, onnx_tensor, name=None):
  214. """
  215. Given an Onnx TensorProto, translate it into a Caffe2 operator
  216. which produces the given tensor filling op.
  217. """
  218. assert name or onnx_tensor.name
  219. name = name or onnx_tensor.name
  220. c2_op = caffe2_pb2.OperatorDef()
  221. c2_values = c2_op.arg.add()
  222. c2_values.name = "values"
  223. def tensor2list(onnx_tensor):
  224. # Use the onnx.numpy_helper because the data may be raw
  225. return onnx.numpy_helper.to_array(onnx_tensor).flatten().tolist()
  226. if onnx_tensor.data_type in [TensorProto.FLOAT]:
  227. c2_op.type = 'GivenTensorFill'
  228. c2_values.floats.extend(tensor2list(onnx_tensor))
  229. elif onnx_tensor.data_type in [TensorProto.DOUBLE]:
  230. c2_op.type = 'GivenTensorDoubleFill'
  231. c2_values.floats.extend(tensor2list(onnx_tensor))
  232. elif onnx_tensor.data_type in [TensorProto.INT64,
  233. TensorProto.UINT32]:
  234. c2_op.type = 'GivenTensorInt64Fill'
  235. c2_values.ints.extend(tensor2list(onnx_tensor))
  236. elif onnx_tensor.data_type in [TensorProto.UINT8,
  237. TensorProto.INT8,
  238. TensorProto.UINT16,
  239. TensorProto.INT16,
  240. TensorProto.INT32]:
  241. c2_op.type = 'GivenTensorIntFill'
  242. c2_values.ints.extend(tensor2list(onnx_tensor))
  243. elif onnx_tensor.data_type == TensorProto.BOOL:
  244. c2_op.type = 'GivenTensorBoolFill'
  245. c2_values.ints.extend(tensor2list(onnx_tensor))
  246. elif onnx_tensor.data_type == TensorProto.STRING:
  247. c2_op.type = 'GivenTensorStringFill'
  248. c2_values.strings.extend(onnx_tensor.string_data)
  249. else:
  250. raise RuntimeError(
  251. "unrecognized tensor type {}".format(onnx_tensor.data_type))
  252. c2_shape = c2_op.arg.add()
  253. c2_shape.name = "shape"
  254. c2_shape.ints.extend(onnx_tensor.dims)
  255. c2_op.output.append(name)
  256. return c2_op
  257. @classmethod
  258. def _rnn_reform_weights(cls, reforms, name, hidden_size, init_net, gates, reorder_indices):
  259. for name_from, name_to, do_concat, extra_dims in reforms:
  260. gate_blobs = ['%s/%s_%s' % (name, prefix, name_to) for prefix in gates]
  261. for i, x in enumerate(gate_blobs):
  262. dim0 = i * hidden_size, (i+1) * hidden_size
  263. starts, ends = zip(dim0, *extra_dims)
  264. init_net.Slice(name_from, x, starts=starts, ends=ends)
  265. if do_concat:
  266. reordered_gate_blobs = [gate_blobs[i] for i in reorder_indices]
  267. init_net.Concat(reordered_gate_blobs, ['%s/%s' % (name, name_to), cls.dummy_name()], axis=0)
  268. @classmethod
  269. def _make_rnn_direction(cls, input_blob, B, W, R, initial_states_and_names, sequence_lens,
  270. pred_mh, init_net,
  271. input_size, hidden_size, num_gates, direction_offset,
  272. Bi, Br, W_, R_,
  273. reform, make_cell, keep_outputs):
  274. name = cls.dummy_name()
  275. # input and recurrence biases are squashed together in onnx
  276. # but not in caffe2
  277. gates_hidden_size = num_gates * hidden_size
  278. bias_offset = 2 * direction_offset * gates_hidden_size
  279. weight_offset = direction_offset * gates_hidden_size
  280. Bi = init_net.Slice(B, name + Bi,
  281. starts=[bias_offset + 0 * gates_hidden_size],
  282. ends =[bias_offset + 1 * gates_hidden_size])
  283. Br = init_net.Slice(B, name + Br,
  284. starts=[bias_offset + 1 * gates_hidden_size],
  285. ends =[bias_offset + 2 * gates_hidden_size])
  286. W_ = init_net.Slice(W, name + W_,
  287. starts=[weight_offset + 0 * gates_hidden_size, 0],
  288. ends =[weight_offset + 1 * gates_hidden_size,-1])
  289. R_ = init_net.Slice(R, name + R_,
  290. starts=[weight_offset + 0 * gates_hidden_size, 0],
  291. ends =[weight_offset + 1 * gates_hidden_size,-1])
  292. initial_states_sliced = []
  293. for initial_state, name_suffix in initial_states_and_names:
  294. initial_states_sliced.append(
  295. pred_mh.net.Slice(initial_state, name + name_suffix,
  296. starts=[direction_offset + 0, 0, 0],
  297. ends =[direction_offset + 1,-1,-1]))
  298. if direction_offset == 1:
  299. if sequence_lens is not None:
  300. seq_lens_for_reverse = sequence_lens
  301. else:
  302. input_shape = pred_mh.net.Shape(input_blob, name + '/input_shape')
  303. batch_size = pred_mh.net.Slice(input_shape, name + '/batch_size_slice', starts=[1], ends=[2])
  304. seq_len = pred_mh.net.Slice(input_shape, name + '/seq_len_slice', starts=[0], ends=[1])
  305. dummy_sequence_lens = pred_mh.net.Tile([seq_len, batch_size], name + '/dummy_sequence_lens', axis=0)
  306. pred_mh.net.Reshape(dummy_sequence_lens, [dummy_sequence_lens, cls.dummy_name()], shape=[-1])
  307. seq_lens_for_reverse = pred_mh.net.Cast(dummy_sequence_lens, name + '/seq_lens_for_reverse', to=core.DataType.INT32)
  308. reform(Bi, Br, W_, R_, name, hidden_size, init_net)
  309. if direction_offset == 1:
  310. input = pred_mh.net.ReversePackedSegs(
  311. [input_blob, seq_lens_for_reverse], name + "/input-reversed")
  312. else:
  313. input = input_blob
  314. outputs = keep_outputs(list(make_cell(
  315. pred_mh,
  316. input,
  317. sequence_lens,
  318. initial_states_sliced,
  319. input_size,
  320. hidden_size,
  321. name,
  322. drop_states=False,
  323. forward_only=True,
  324. )))
  325. if direction_offset == 1:
  326. outputs[0] = pred_mh.net.ReversePackedSegs(
  327. [outputs[0], seq_lens_for_reverse], name + "/output-reversed")
  328. return outputs
  329. @classmethod
  330. def _create_rnn_variant(cls, init_model, pred_model, n, opset_version):
  331. assert init_model is not None, "cannot convert RNNs without access to the full model"
  332. assert pred_model is not None, "cannot convert RNNs without access to the full model"
  333. attrs = dict(n.attrs) # make a copy, which is safe to mutate
  334. hidden_size = attrs.pop('hidden_size')
  335. direction = force_unicode(attrs.pop('direction', 'forward'))
  336. if n.op_type == 'RNN':
  337. activation = force_unicode(attrs.pop('activations', ('tanh',))[0].lower())
  338. elif n.op_type == 'GRU':
  339. linear_before_reset = attrs.pop('linear_before_reset', 0)
  340. assert not attrs, "unsupported RNN attributes: " + str(attrs.keys())
  341. assert direction in ['forward', 'bidirectional'], "unsupported backwards RNN/GRU/LSTM"
  342. if n.op_type in ['RNN', 'GRU']:
  343. input_blob, W, R, B, sequence_lens, initial_h = n.inputs
  344. elif n.op_type == 'LSTM':
  345. input_blob, W, R, B, sequence_lens, initial_h, initial_c = n.inputs
  346. if sequence_lens == "":
  347. sequence_lens = None
  348. for x in itertools.chain(init_model.graph.input,
  349. init_model.graph.value_info,
  350. pred_model.graph.input,
  351. pred_model.graph.value_info):
  352. if x.name == W:
  353. input_size = x.type.tensor_type.shape.dim[2].dim_value
  354. break
  355. else:
  356. raise RuntimeError("best-effort shape inference for RNN/GRU/LSTM failed")
  357. pred_mh = ModelHelper()
  358. init_net = core.Net("init-net")
  359. init_net.Reshape(W, [W, cls.dummy_name()], shape=[1,-1,0])
  360. init_net.Squeeze(W, W, dims=[0])
  361. init_net.Reshape(R, [R, cls.dummy_name()], shape=[1,-1,0])
  362. init_net.Squeeze(R, R, dims=[0])
  363. init_net.Reshape(B, [B, cls.dummy_name()], shape=[1,-1])
  364. init_net.Squeeze(B, B, dims=[0])
  365. if n.op_type == 'RNN':
  366. def reform(*args):
  367. pass
  368. def make_cell(*args, **kwargs):
  369. return rnn_cell.BasicRNN(*args, activation=activation, **kwargs)
  370. def make_rnn(direction_offset):
  371. return cls._make_rnn_direction(
  372. input_blob, B, W, R, [(initial_h, '/initial_h')], sequence_lens,
  373. pred_mh, init_net, input_size, hidden_size, 1, direction_offset,
  374. "/i2h_b", "/gates_t_b", "/i2h_w", "/gates_t_w",
  375. reform, make_cell, lambda x: x)
  376. elif n.op_type == 'GRU':
  377. def reform(Bi, Br, W_, R_, name, hidden_size, init_net):
  378. # caffe2 has a different order from onnx. We need to rearrange
  379. # z r h -> r z h
  380. reforms = ((W_, 'i2h_w', True, [(0,-1)]),
  381. (R_, 'gate_t_w', False, [(0,-1)]),
  382. (Bi, 'i2h_b', True, []),
  383. (Br, 'gate_t_b', False, []))
  384. cls._rnn_reform_weights(reforms, name, hidden_size, init_net,
  385. ['update', 'reset', 'output'], [1, 0, 2])
  386. def make_cell(*args, **kwargs):
  387. return gru_cell.GRU(*args, linear_before_reset=linear_before_reset, **kwargs)
  388. def make_rnn(direction_offset):
  389. return cls._make_rnn_direction(
  390. input_blob, B, W, R, [(initial_h, '/initial_h')], sequence_lens,
  391. pred_mh, init_net, input_size, hidden_size, 3, direction_offset,
  392. "_bias_i2h", "_bias_gates", "/i2h_w_pre", "/gates_t_w_pre",
  393. reform, make_cell, lambda x: x)
  394. elif n.op_type == 'LSTM':
  395. def reform(Bi, Br, W_, R_, name, hidden_size, init_net):
  396. # caffe2 has a different order from onnx. We need to rearrange
  397. # i o f c -> i f o c
  398. reforms = ((W_, 'i2h_w', True, [(0, -1)]),
  399. (R_, 'gates_t_w', True, [(0, -1)]),
  400. (Bi, 'i2h_b' , True, []),
  401. (Br, 'gates_t_b', True, []))
  402. cls._rnn_reform_weights(reforms, name, hidden_size, init_net,
  403. ['input', 'output', 'forget', 'cell'], [0, 2, 1, 3])
  404. def make_cell(*args, **kwargs):
  405. return rnn_cell.LSTM(*args, **kwargs)
  406. def make_rnn(direction_offset):
  407. return cls._make_rnn_direction(
  408. input_blob, B, W, R, [(initial_h, '/initial_h'), (initial_c, '/initial_c')], sequence_lens,
  409. pred_mh, init_net, input_size, hidden_size, 4, direction_offset,
  410. "/i2h_b", "/gates_t_b", "/i2h_w", "/gates_t_w",
  411. reform, make_cell, lambda x: [x[0], x[1], x[3]])
  412. if direction == 'forward':
  413. outputs = make_rnn(0)
  414. # in the forward case, storage is shared between the
  415. # last outputs. We need to decouple them so that the
  416. # VariableLengthSequencePadding only mutates
  417. # n.outputs[0]
  418. for i in range(1, len(outputs)):
  419. pred_mh.net.Copy(outputs[i], n.outputs[i])
  420. if sequence_lens is not None:
  421. pred_mh.net.VariableLengthSequencePadding(
  422. [outputs[0], sequence_lens], [outputs[0]])
  423. pred_mh.net.ExpandDims([outputs[0]], [n.outputs[0]], dims=[1])
  424. elif direction == 'bidirectional':
  425. outputs_f = make_rnn(0)
  426. outputs_b = make_rnn(1)
  427. concatted_output, _ = pred_mh.net.Concat(
  428. [outputs_f[0], outputs_b[0]], [cls.dummy_name(), cls.dummy_name()], axis=2)
  429. if sequence_lens is not None:
  430. pred_mh.net.VariableLengthSequencePadding(
  431. [concatted_output, sequence_lens], [concatted_output])
  432. reshaped_output, _ = pred_mh.net.Reshape(concatted_output, [cls.dummy_name(), cls.dummy_name()], shape=[0,0,-1,2])
  433. pred_mh.net.Transpose(reshaped_output, n.outputs[0], axes=[0,2,1,3])
  434. for i in range(1, len(n.outputs)):
  435. pred_mh.net.Concat([outputs_f[i], outputs_b[i]],
  436. [n.outputs[i], cls.dummy_name()], axis=0)
  437. # We want to decide whether to put all of our weight-reshaping
  438. # operators in the init net or the predict net. We can put
  439. # them in the init net iff the inputs to those operators are
  440. # already available, either as graph initializers, or as the
  441. # output of other operators in the init net. The latter case
  442. # occurs, for example, when exporting from pytorch to onnx.
  443. # In most production use, we expect has_initializers to be
  444. # true.
  445. initializers = {i.name for i in init_model.graph.initializer}
  446. outputs = {output for node in init_model.graph.node for output in node.output}
  447. has_initializers = all(x in initializers or x in outputs for x in (W, R, B))
  448. pred_ops = []
  449. init_ops = []
  450. (init_ops if has_initializers else pred_ops).extend(init_net.Proto().op)
  451. pred_ops.extend(pred_mh.Proto().op)
  452. return Caffe2Ops(pred_ops, init_ops, list(pred_mh.Proto().external_input))
  453. @classmethod
  454. def _create_control_op(cls, init_model, pred_model, n, opset_version):
  455. control_inputs = []
  456. if '__control_inputs' in n.attrs:
  457. control_inputs.extend(n.attrs['__control_inputs'])
  458. node = cls._common_onnx_node_to_caffe2_op(init_model, pred_model, n, opset_version)
  459. node.control_input.extend(control_inputs)
  460. return Caffe2Ops([node], [], [])
  461. @classmethod
  462. def _remove_ssa(cls, net, remap_dict):
  463. for op in net.op:
  464. for i, name in enumerate(op.output):
  465. if name in remap_dict:
  466. op.output[i] = remap_dict[name]
  467. for i, out in enumerate(net.external_output):
  468. if out in remap_dict:
  469. net.external_output[i] = remap_dict[out]
  470. @classmethod
  471. def _create_if(cls, init_model, pred_model, n, opset_version):
  472. ops = cls._create_control_op(init_model, pred_model, n, opset_version)
  473. assert ops[0][0].type == 'If'
  474. if_op = ops[0][0]
  475. then_net = else_net = None
  476. control_inputs = []
  477. for arg in if_op.arg:
  478. if arg.name == 'then_net':
  479. then_net = arg.n
  480. if arg.name == 'else_net':
  481. else_net = arg.n
  482. if arg.name == '__control_inputs':
  483. control_inputs = arg.strings
  484. assert then_net and else_net
  485. then_net_outs = then_net.external_output
  486. else_net_outs = else_net.external_output
  487. op_outputs = if_op.output
  488. assert len(then_net_outs) == len(else_net_outs)
  489. assert len(else_net_outs) == len(op_outputs)
  490. for arg in if_op.arg:
  491. if arg.name == 'then_net':
  492. arg.n.external_input.extend(control_inputs)
  493. if arg.name == 'else_net':
  494. arg.n.external_input.extend(control_inputs)
  495. return ops
  496. @classmethod
  497. def _create_loop(cls, init_model, pred_model, n, opset_version):
  498. ops = cls._create_control_op(init_model, pred_model, n, opset_version)
  499. assert ops[0][0].type == 'ONNXWhile'
  500. while_op = ops[0][0]
  501. while_op.arg.extend([caffe2.python.utils.MakeArgument('has_trip_count', True)])
  502. while_op.arg.extend([caffe2.python.utils.MakeArgument('has_cond', True)])
  503. while_op.arg.extend([caffe2.python.utils.MakeArgument('disable_scopes', True)])
  504. control_inputs = []
  505. for arg in while_op.arg:
  506. if arg.name == '__control_inputs':
  507. control_inputs = arg.strings
  508. num_loop_carried_deps = 0
  509. for arg in while_op.arg:
  510. if arg.name == 'body':
  511. num_loop_carried_deps = len(arg.n.external_input) - 2
  512. arg.n.external_input.extend(control_inputs)
  513. while_op.arg.extend([
  514. caffe2.python.utils.MakeArgument('num_loop_carried_deps',
  515. num_loop_carried_deps)
  516. ])
  517. return ops
  518. @classmethod
  519. def _substitute_raw_value(cls, tp, raw_values_dict):
  520. if tp.HasField('raw_data') and tp.raw_data == bytes(b'__EXTERNAL'):
  521. if tp.name not in raw_values_dict:
  522. raise RuntimeError('TensorProto for value {} referenced raw data but it was not found!'.format(tp.name))
  523. else:
  524. tp.raw_data = raw_values_dict[tp.name]
  525. @classmethod
  526. def _visit_and_substitute_raw_values(cls, nodes, raw_values_dict):
  527. for node in nodes:
  528. for attr in node.attribute:
  529. if attr.HasField('t'):
  530. cls._substitute_raw_value(attr.t, raw_values_dict)
  531. for t in attr.tensors:
  532. cls._substitute_raw_value(t, raw_values_dict)
  533. if attr.HasField('g'):
  534. cls._visit_and_substitute_raw_values(attr.g.node, raw_values_dict)
  535. for g in attr.graphs:
  536. cls._visit_and_substitute_raw_values(g.node, raw_values_dict)
  537. @classmethod
  538. def _external_value_resolution_pass(cls, model, raw_values_dict):
  539. for init in model.graph.initializer:
  540. cls._substitute_raw_value(init, raw_values_dict)
  541. cls._visit_and_substitute_raw_values(model.graph.node, raw_values_dict)
  542. @classmethod
  543. def _direct_initialize_parameters(cls, initializer, ws, device_option):
  544. for tp in initializer:
  545. ws.FeedBlob(tp.name, onnx.numpy_helper.to_array(tp), device_option)
  546. @classmethod
  547. def _direct_initialize_inputs(cls, inputs, initialized, ws, device_option):
  548. for value_info in inputs:
  549. if value_info.name in initialized:
  550. continue
  551. shape = list(d.dim_value for d in value_info.type.tensor_type.shape.dim)
  552. ws.FeedBlob(
  553. value_info.name,
  554. np.ones(shape, dtype=onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[value_info.type.tensor_type.elem_type]),
  555. device_option)
  556. @staticmethod
  557. def optimize_onnx(input, init=False, predict=False):
  558. passes = ['fuse_consecutive_transposes',
  559. 'eliminate_nop_transpose',
  560. 'fuse_transpose_into_gemm',
  561. 'lift_lexical_references']
  562. if init:
  563. passes.append('split_init')
  564. if predict:
  565. passes.append('split_predict')
  566. try:
  567. out = onnx.optimizer.optimize(input, passes)
  568. except AttributeError:
  569. warnings.warn("OptimizerWarning: optimizer module not found in ONNX version {}".format(onnx.__version__))
  570. # ONNX does no ship onnx.optimizer since version 1.9+
  571. import onnxoptimizer
  572. out = onnxoptimizer.optimize(input, passes)
  573. return out
  574. @classmethod
  575. def prepare_zip_archive(cls, file, device='CPU', **kwargs):
  576. with zipfile.ZipFile(file, mode='r') as z:
  577. with z.open('__MODEL_PROTO', 'r') as f:
  578. model = onnx.load(f);
  579. blob_names = set(z.namelist()) - set('__MODEL_PROTO')
  580. # TODO: make this more efficient
  581. raw_values_dict = {}
  582. for name in blob_names:
  583. with z.open(name, 'r') as blob_file:
  584. raw_values_dict[name] = blob_file.read()
  585. return cls.prepare(model, device, raw_values_dict=raw_values_dict, **kwargs)
  586. @classmethod
  587. def prepare(cls, model, device='CPU', raw_values_dict=None, **kwargs):
  588. '''
  589. For Onnx Caffe2Backend, we require that init_graph don't initialize the actual input of the predict_graph,
  590. for example, if "img" is the input blob for the predict_net, we require that in init_graph and in
  591. initializer of the predict_graph, "img" is not initalized. We don't have a check for this, since
  592. there is no way we can know which blob is the input of the predict_graph.
  593. '''
  594. if not kwargs.pop('no_check_UNSAFE', False):
  595. super(Caffe2Backend, cls).prepare(model, device, **kwargs)
  596. opset_version = None
  597. for imp in model.opset_import:
  598. if not imp.HasField("domain") or imp.domain == "":
  599. opset_version = imp.version
  600. if imp.version > cls._known_opset_version:
  601. warnings.warn("This version of onnx-caffe2 targets ONNX operator set version {}, but the model we are trying to import uses version {}. We will try to import it anyway, but if the model uses operators which had BC-breaking changes in the intervening versions, import will fail.".format(cls._known_opset_version, imp.version))
  602. else:
  603. warnings.warn("Unrecognized operator set {}".format(imp.domain))
  604. if opset_version is None:
  605. if model.ir_version >= 0x00000003:
  606. raise RuntimeError("Model with IR version >= 3 did not specify ONNX operator set version (onnx-caffe2 requires it)")
  607. else:
  608. opset_version = 1
  609. # Prior to onnx version update to onnx-1.8.0, errors caused by failures in
  610. # in the onnx shape inference call were being supressed. Hence a try-catch block
  611. # is added around the infer_shapes call to avoid these failures and preserve status
  612. try:
  613. model = onnx.shape_inference.infer_shapes(model)
  614. except RuntimeError:
  615. warnings.warn("ShapeInferenceWarning: Inferred shape and existing shape differ in rank")
  616. ws = Workspace()
  617. device_option = get_device_option(Device(device))
  618. init_net, predict_net = cls._onnx_model_to_caffe2_net(model, device, opset_version, False)
  619. if raw_values_dict:
  620. cls._external_value_resolution_pass(model, raw_values_dict)
  621. # Directly load initializer data into blobs in workspace
  622. cls._direct_initialize_parameters(
  623. model.graph.initializer,
  624. ws,
  625. device_option,
  626. )
  627. initialized = {init.name for init in model.graph.initializer}
  628. cls._direct_initialize_inputs(
  629. model.graph.input,
  630. initialized,
  631. ws,
  632. device_option,
  633. )
  634. uninitialized = [value_info.name for value_info in model.graph.input if value_info.name not in initialized]
  635. retval = Caffe2Rep(init_net, predict_net, ws, uninitialized)
  636. return retval
  637. @classmethod
  638. # TODO: This method needs a refactor for clarity
  639. def _onnx_node_to_caffe2_op(cls, init_model, pred_model, node_def, opset_version):
  640. cbackend = C.Caffe2Backend(cls._dummy_name)
  641. if cbackend.support_onnx_import(node_def.op_type):
  642. # extract value infos from pred model (value infos of
  643. # node's inputs that are in init model should be all
  644. # available in pred model)
  645. value_infos = []
  646. for name in node_def.input:
  647. if pred_model is not None:
  648. for vi in itertools.chain(pred_model.graph.input,
  649. pred_model.graph.output,
  650. pred_model.graph.value_info):
  651. if vi.name == name:
  652. value_infos.append(vi.SerializeToString())
  653. op_strs = cbackend.convert_node(node_def.SerializeToString(), value_infos, opset_version)
  654. init_ops = []
  655. for s in op_strs[0]:
  656. op = caffe2_pb2.OperatorDef()
  657. op.ParseFromString(s)
  658. init_ops.append(op)
  659. ops = []
  660. for s in op_strs[1]:
  661. op = caffe2_pb2.OperatorDef()
  662. op.ParseFromString(s)
  663. ops.append(op)
  664. return Caffe2Ops(ops, init_ops, [])
  665. if node_def.op_type in cls._special_operators:
  666. translator = getattr(cls, cls._special_operators[node_def.op_type])
  667. else:
  668. translator = cls._common_onnx_node_to_caffe2_op
  669. ops = translator(init_model, pred_model, OnnxNode(node_def), opset_version)
  670. if isinstance(ops, Caffe2Ops):
  671. return ops
  672. if not isinstance(ops, collections.abc.Iterable):
  673. ops = [ops]
  674. return Caffe2Ops(ops, [], [])
  675. _broadcast_operators = {
  676. 'Add',
  677. 'Sub',
  678. }
  679. @classmethod
  680. def _common_onnx_node_to_caffe2_op(cls, init_model, pred_model, onnx_node, opset_version):
  681. """
  682. This translator performs the basic translation of ONNX nodes into
  683. Caffe2 operators. Besides doing a straightforward marshalling from
  684. one format to another, it also does these extra things:
  685. - Renames operators based on '_renamed_operators'
  686. - Renames attributes based on '_global_renamed_attrs' and
  687. '_per_op_renamed_attrs'
  688. If you're writing a custom translator, consider calling this first,
  689. and then fixing things up further.
  690. """
  691. c2_op = caffe2_pb2.OperatorDef()
  692. c2_op.input.extend(onnx_node.inputs)
  693. c2_op.output.extend(onnx_node.outputs)
  694. c2_op.name = onnx_node.name
  695. onnx_op_type = onnx_node.op_type
  696. broken_version = cls._broken_operators.get(onnx_op_type, float('Inf'))
  697. if broken_version <= opset_version:
  698. raise ValueError(
  699. "Don't know how to translate op {} in ONNX operator set v{} (I only support prior to v{})".format(onnx_op_type, opset_version, broken_version))
  700. c2_op.type = cls._renamed_operators.get(onnx_op_type, onnx_op_type)
  701. if not core.IsOperator(c2_op.type):
  702. raise ValueError(
  703. "Don't know how to translate op {}".format(onnx_op_type))
  704. def kmap(k):
  705. if (onnx_op_type in cls._per_op_renamed_attrs and
  706. k in cls._per_op_renamed_attrs[onnx_op_type]):
  707. return cls._per_op_renamed_attrs[onnx_op_type][k]
  708. if k in cls._global_renamed_attrs:
  709. return cls._global_renamed_attrs[k]
  710. return k
  711. c2_op.arg.extend(onnx_node.attrs.caffe2(kmap=kmap))
  712. if opset_version < 7:
  713. # onnx opset 7 and newest caffe2 have adopted full onnx broadcast semantics
  714. # so we don't need this hack anymore
  715. if c2_op.type in cls._broadcast_operators:
  716. already_broadcast = False
  717. for arg in c2_op.arg:
  718. if arg.name == 'broadcast':
  719. already_broadcast = True
  720. if not already_broadcast:
  721. c2_op.arg.extend([caffe2.python.utils.MakeArgument('broadcast', 1)])
  722. return c2_op
  723. @staticmethod
  724. def _all_names_in_graph(graph):
  725. if graph is None:
  726. return set()
  727. names = set()
  728. names.update(value_info.name for value_info in graph.input)
  729. names.update(value_info.name for value_info in graph.output)
  730. for node in graph.node:
  731. names.update(node.input)
  732. names.update(node.output)
  733. return names
  734. @classmethod
  735. def _graph_to_net(cls, onnx_graph, opset_version):
  736. net = caffe2_pb2.NetDef()
  737. for node in onnx_graph.node:
  738. try:
  739. c2ops = cls._onnx_node_to_caffe2_op(
  740. None, None, node, opset_version)
  741. except Exception as e:
  742. print('ONNX FATAL:', e)
  743. continue
  744. net.op.extend(c2ops.init_ops)
  745. net.op.extend(c2ops.ops)
  746. net.external_input.extend(c2ops.interface_blobs)
  747. net.external_output.extend(
  748. value_info.name for value_info in onnx_graph.output)
  749. net.external_input.extend(
  750. value_info.name for value_info in onnx_graph.input)
  751. return net
  752. @classmethod
  753. def _onnx_model_to_caffe2_net(cls, onnx_model, device, opset_version, include_initializers):
  754. device_option = get_device_option(Device(device))
  755. # Prior to onnx version update to onnx-1.8.0, errors caused by failures in
  756. # in the onnx shape inference call were being supressed. Hence a try-catch block
  757. # is added around the infer_shapes call to avoid these failures and preserve status
  758. try:
  759. onnx_model = onnx.utils.polish_model(onnx_model)
  760. except RuntimeError:
  761. warnings.warn("ShapeInferenceWarning: Inferred shape and existing shape differ in rank")
  762. except AttributeError:
  763. warnings.warn("ShapeInferenceWarning: utils module not found in ONNX version {}".format(onnx.__version__))
  764. # Optimizer module has been removed in ONNX-1.9 or later, warn caller if that is the case
  765. try:
  766. init_model = cls.optimize_onnx(onnx_model, init=True)
  767. pred_model = cls.optimize_onnx(onnx_model, predict=True)
  768. except ModuleNotFoundError:
  769. warnings.warn("OptimizerWarning: onnxoptimizer module not installed. "
  770. "init_model and pred_model models will not be splitted, which can cause a runtime error")
  771. init_model = onnx_model
  772. pred_model = onnx_model
  773. init_net = caffe2_pb2.NetDef()
  774. pred_net = caffe2_pb2.NetDef()
  775. init_net.name = onnx_model.graph.name + '_init'
  776. pred_net.name = onnx_model.graph.name + '_predict'
  777. if include_initializers:
  778. init_net.op.extend(cls._create_tensor_filling_op(tp) for tp in onnx_model.graph.initializer)
  779. cls._dummy_name.reset(cls._all_names_in_graph(init_model.graph) | cls._all_names_in_graph(pred_model.graph))
  780. errors = []
  781. for net, model in ( (init_net, init_model), (pred_net, pred_model) ):
  782. net.device_option.CopyFrom(device_option)
  783. for node in model.graph.node:
  784. try:
  785. c2ops = cls._onnx_node_to_caffe2_op(
  786. init_model, pred_model, node, opset_version)
  787. except Exception as e:
  788. msg = 'Error while processing node: {}. Exception: {}'.format(node, e)
  789. errors.append(msg)
  790. print('ONNX FATAL:', msg, file=sys.stderr)
  791. continue
  792. init_net.op.extend(c2ops.init_ops)
  793. net.op.extend(c2ops.ops)
  794. net.external_input.extend(c2ops.interface_blobs)
  795. net.external_output.extend(
  796. value_info.name for value_info in model.graph.output)
  797. net.external_input.extend(
  798. value_info.name for value_info in model.graph.input)
  799. if len(errors) > 0:
  800. raise RuntimeError(
  801. "ONNX conversion failed, encountered {} errors:\n\n{}".format(
  802. len(errors), "\n\n".join(errors)))
  803. return init_net, pred_net
  804. # wrapper for backwards compatibility
  805. @classmethod
  806. def onnx_graph_to_caffe2_net(cls, model, device="CPU", opset_version=_known_opset_version):
  807. return cls._onnx_model_to_caffe2_net(model, device=device, opset_version=opset_version, include_initializers=True)
  808. @classmethod
  809. def supports_device(cls, device_str):
  810. device = Device(device_str)
  811. if device.type == DeviceType.CPU:
  812. return True
  813. elif core.IsGPUDeviceType(device.type):
  814. return workspace.has_gpu_support
  815. return False
  816. @classmethod
  817. def is_compatible(cls, model, device='CPU', **kwargs):
  818. if hasattr(super(Caffe2Backend, cls), 'is_compatible') \
  819. and callable(super(Caffe2Backend, cls).is_compatible):
  820. if not super(Caffe2Backend, cls).is_compatible(model, device, **kwargs):
  821. return False
  822. # TODO: should have an unspported list of operators, be optimistic for now
  823. return True
  824. prepare = Caffe2Backend.prepare
  825. prepare_zip_archive = Caffe2Backend.prepare_zip_archive
  826. run_node = Caffe2Backend.run_node
  827. run_model = Caffe2Backend.run_model
  828. supports_device = Caffe2Backend.supports_device # noqa
  829. is_compatible = Caffe2Backend.is_compatible