frontend.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357
  1. ## @package onnx
  2. # Module caffe2.python.onnx.frontend
  3. """Caffe2 Protobuf to ONNX converter
  4. To run this, you will need to have Caffe2 installed as well.
  5. """
  6. import collections
  7. import itertools
  8. import logging
  9. import re
  10. from caffe2.python import core as caffe2_core
  11. from onnx import (checker, helper, numpy_helper, mapping,
  12. GraphProto, NodeProto, TensorProto, OperatorSetIdProto)
  13. from onnx.helper import make_tensor_value_info, make_model
  14. import numpy as np
  15. from caffe2.python.onnx.helper import c2_native_run_net
  16. import caffe2.python._import_c_extension as C
  17. logging.basicConfig(level=logging.INFO)
  18. logger = logging.getLogger(__name__)
  19. class Caffe2Frontend(object):
  20. # This number controls the semantics of the operators we target. Whenever
  21. # ONNX makes a BC breaking change to semantics of operators, having this set
  22. # to an accurate number will prevent our models form exporting. However,
  23. # we should strive to keep this up-to-date as much as possible.
  24. target_opset_version = 9
  25. _renamed_operators = {
  26. 'SpatialBN': 'BatchNormalization',
  27. 'Conv1D': 'Conv',
  28. 'Conv2D': 'Conv',
  29. 'Conv3D': 'Conv',
  30. 'ConvTranspose1D': 'ConvTranspose',
  31. 'ConvTranspose2D': 'ConvTranspose',
  32. 'ConvTranspose3D': 'ConvTranspose',
  33. 'MaxPool1D': 'MaxPool',
  34. 'MaxPool2D': 'MaxPool',
  35. 'MaxPool3D': 'MaxPool',
  36. 'AveragePool1D': 'AveragePool',
  37. 'AveragePool2D': 'AveragePool',
  38. 'AveragePool3D': 'AveragePool',
  39. }
  40. # caffe2 arguments that are completely removed in onnx
  41. _blocklist_caffe2_args = {
  42. 'order': {b'NCHW'},
  43. 'cudnn_exhaustive_search': {0, 1},
  44. 'exhaustive_search': {0, 1},
  45. 'use_cudnn': {0, 1},
  46. }
  47. _global_renamed_args = {
  48. 'kernels': 'kernel_shape',
  49. }
  50. _per_op_renamed_args = {
  51. 'Squeeze': {'dims': 'axes'},
  52. 'Transpose': {'axes': 'perm'},
  53. }
  54. _special_operators = {}
  55. # Dummy name generator
  56. _dummy_name = C.DummyName()
  57. @classmethod
  58. def dummy_name(cls):
  59. return cls._dummy_name.new_dummy_name()
  60. @classmethod
  61. def _common_caffe2_arg_to_onnx_attr(cls, op_def, arg):
  62. # name
  63. op_type = op_def.type
  64. name = cls._global_renamed_args.get(arg.name, arg.name)
  65. if op_type in cls._per_op_renamed_args:
  66. # Per-op attribute renames override the global attribute renames
  67. name = cls._per_op_renamed_args[op_type].get(arg.name, name)
  68. # value
  69. if arg.HasField('f'):
  70. value = arg.f
  71. elif arg.HasField('i'):
  72. value = arg.i
  73. elif arg.HasField('s'):
  74. value = arg.s
  75. elif arg.floats:
  76. value = arg.floats
  77. elif arg.ints:
  78. value = arg.ints
  79. elif arg.strings:
  80. value = arg.strings
  81. else:
  82. raise ValueError('Could not find data field in arg: {}'.format(arg))
  83. if name in cls._blocklist_caffe2_args:
  84. assert value in cls._blocklist_caffe2_args[arg.name]
  85. return None
  86. return helper.make_attribute(name, value)
  87. @classmethod
  88. def caffe2_arg_to_onnx_attr(cls, op_def, arg):
  89. return cls._common_caffe2_arg_to_onnx_attr(op_def, arg)
  90. @classmethod
  91. def _common_caffe2_op_to_onnx_node(cls, op_def, shapes):
  92. node_def = NodeProto()
  93. node_def.name = op_def.name
  94. node_def.op_type = cls._renamed_operators.get(op_def.type, op_def.type)
  95. node_def.input.extend(op_def.input)
  96. node_def.output.extend(op_def.output)
  97. attrs = filter(None, [cls.caffe2_arg_to_onnx_attr(op_def, arg)
  98. for arg in op_def.arg])
  99. node_def.attribute.extend(attrs)
  100. return node_def
  101. @classmethod
  102. def caffe2_op_to_onnx_node(cls, op_def, shapes):
  103. if C.support_onnx_export(op_def.type):
  104. node_strs, tensor_strs = C.export_to_onnx(cls._dummy_name, op_def.SerializeToString(), shapes)
  105. nodes = []
  106. for s in node_strs:
  107. node = NodeProto()
  108. node.ParseFromString(s)
  109. nodes.append(node)
  110. const_tensors = []
  111. for s in tensor_strs:
  112. tensor = TensorProto()
  113. tensor.ParseFromString(s)
  114. const_tensors.append(tensor)
  115. return nodes, const_tensors
  116. elif op_def.type in cls._special_operators:
  117. translator = getattr(cls, cls._special_operators[op_def.type])
  118. else:
  119. translator = cls._common_caffe2_op_to_onnx_node
  120. nodes = translator(op_def, shapes)
  121. const_tensors = []
  122. if isinstance(nodes, tuple):
  123. nodes, const_tensors = nodes
  124. if not isinstance(nodes, collections.abc.Iterable):
  125. nodes = [nodes]
  126. return nodes, const_tensors
  127. @staticmethod
  128. def _all_names_in_net(net):
  129. if net is None:
  130. return set()
  131. names = set()
  132. names.update(net.external_input)
  133. names.update(net.external_output)
  134. for op in net.op:
  135. names.update(op.input)
  136. names.update(op.output)
  137. return names
  138. @staticmethod
  139. def _extract_value_info(tensor):
  140. return make_tensor_value_info(
  141. name=tensor.name,
  142. elem_type=tensor.data_type,
  143. shape=tensor.dims)
  144. @classmethod
  145. def caffe2_net_to_onnx_graph(cls,
  146. predict_net,
  147. init_net=None,
  148. value_info=None):
  149. if value_info is None:
  150. value_info = {}
  151. if not isinstance(value_info, dict):
  152. raise ValueError('Please pass value_info as a '
  153. 'name -> (type, shape) dictionary')
  154. cls._filter_fake_init(init_net, value_info)
  155. cls._ssa_rewrite(predict_net, init_net, value_info)
  156. if init_net:
  157. initializer = cls.caffe2_init_net_to_initializer(init_net)
  158. value_info.update({init.name: (init.data_type, init.dims)
  159. for init in initializer})
  160. else:
  161. initializer = []
  162. # Check if value_info contains the types/shapes of all the blobs, in
  163. # which case we don't need to infer them by running the net.
  164. run_native_net = False
  165. for op in predict_net.op:
  166. for name in itertools.chain(op.input, op.output):
  167. if name not in value_info:
  168. run_native_net = True
  169. break
  170. # Check whether we have got type shape info of all input
  171. missing = (set(list(predict_net.external_input)) -
  172. set(value_info.keys()))
  173. if missing:
  174. raise RuntimeError('Could not find value info of inputs: {}'.format(
  175. ', '.join(missing)))
  176. ws = None
  177. outputs = None
  178. if run_native_net:
  179. inputs = {}
  180. for name in predict_net.external_input:
  181. elem_type, shape = value_info[name]
  182. inputs[name] = np.random.randn(*shape).astype(
  183. mapping.TENSOR_TYPE_TO_NP_TYPE[elem_type])
  184. ws, outputs = c2_native_run_net(
  185. init_net,
  186. predict_net,
  187. inputs)
  188. for name in predict_net.external_output:
  189. output = outputs[name]
  190. elem_type = mapping.NP_TYPE_TO_TENSOR_TYPE[output.dtype]
  191. shape = output.shape
  192. value_info[name] = (elem_type, shape)
  193. graph_def = GraphProto()
  194. graph_def.name = predict_net.name
  195. graph_def.initializer.extend(initializer)
  196. # This is a mapping from Caffe2 names to ONNX names
  197. graph_def.input.extend(
  198. make_tensor_value_info(
  199. name=name,
  200. elem_type=value_info[name][0],
  201. shape=value_info[name][1])
  202. for name in predict_net.external_input)
  203. cls._dummy_name.reset(cls._all_names_in_net(predict_net) | cls._all_names_in_net(init_net))
  204. for op in predict_net.op:
  205. shapes = {}
  206. for name in itertools.chain(op.input, op.output):
  207. if ws:
  208. blob = ws.FetchBlob(name)
  209. if hasattr(blob, 'shape'):
  210. shapes[name] = blob.shape
  211. else:
  212. shapes[name] = value_info[name][1]
  213. nodes, const_tensors = cls.caffe2_op_to_onnx_node(op, shapes=shapes)
  214. graph_def.node.extend(nodes)
  215. graph_def.initializer.extend(const_tensors)
  216. graph_def.input.extend([cls._extract_value_info(tensor) for tensor in const_tensors])
  217. all_output = set(sum((list(node.output) for node in graph_def.node),
  218. [init.name for init in graph_def.initializer]))
  219. redundant_output = set(vi.name for vi in graph_def.output) - all_output
  220. if redundant_output:
  221. logger.warning(
  222. 'There are graph output not produced by any node or initializer: {}'
  223. '! Will drop them.'.format(', '.join(redundant_output)))
  224. graph_def.output.extend(
  225. make_tensor_value_info(
  226. name=name,
  227. elem_type=value_info[name][0],
  228. shape=value_info[name][1])
  229. for name in predict_net.external_output
  230. if name in all_output)
  231. return graph_def
  232. @classmethod
  233. def caffe2_init_net_to_initializer(cls, init_net):
  234. ws, _ = c2_native_run_net(init_net=None, predict_net=init_net, inputs=[])
  235. output_names = []
  236. for op in init_net.op:
  237. output_names.extend(op.output)
  238. initializer = [numpy_helper.from_array(ws.FetchBlob(name), name=name)
  239. for name in sorted(set(output_names))]
  240. return initializer
  241. @classmethod
  242. def _filter_fake_init(cls, init_net, value_info):
  243. if init_net:
  244. fake_inits = [op for op in init_net.op
  245. if len(op.output) == 1 and op.output[0] in value_info and
  246. re.match('GivenTensor.*Fill|ConstantFill', op.type)]
  247. for fake_init in fake_inits:
  248. init_net.op.remove(fake_init)
  249. del fake_inits[:]
  250. del fake_inits
  251. @classmethod
  252. def ssa_rewrite(cls, net, init_net, value_info):
  253. return cls._ssa_rewrite(net, init_net, value_info)
  254. @classmethod
  255. def _ssa_rewrite(cls, net, init_net, value_info):
  256. def ssa_name(name, version, version_cnt=None):
  257. if version == 0:
  258. return name
  259. if version_cnt and len(version_cnt.get(name, {})) <= 1:
  260. return name
  261. return '{}_{}'.format(name, version)
  262. if init_net:
  263. for op in init_net.op:
  264. assert re.match('GivenTensor.*Fill', op.type), "type is {}, \n{}".format(op.type, op)
  265. assert len(op.output) == 1
  266. ssa, blob_versions = caffe2_core.get_ssa(net)
  267. version_cnt = {}
  268. versioned_blobs = []
  269. for versioned_input, versioned_output in ssa:
  270. versioned_blobs += versioned_input
  271. versioned_blobs += versioned_output
  272. for (name, version) in versioned_blobs:
  273. if name not in version_cnt:
  274. version_cnt[name] = {version}
  275. else:
  276. version_cnt[name].add(version)
  277. assert len(net.op) == len(ssa)
  278. for op, (versioned_inputs, versioned_outputs) in zip(net.op, ssa):
  279. op.input[:] = [ssa_name(name, version, version_cnt)
  280. for name, version in versioned_inputs]
  281. op.output[:] = [ssa_name(name, version, version_cnt)
  282. for name, version in versioned_outputs]
  283. net.external_output[:] = [ssa_name(name, blob_versions[name], version_cnt)
  284. for name in net.external_output]
  285. @classmethod
  286. def caffe2_net_to_onnx_model(cls, *args, **kwargs):
  287. opset_id = OperatorSetIdProto()
  288. opset_id.domain = '' # ONNX default domain
  289. opset_id.version = cls.target_opset_version
  290. model = make_model(cls.caffe2_net_to_onnx_graph(*args, **kwargs),
  291. opset_imports=[opset_id], # current supported opset version
  292. producer_name='onnx-caffe2', # producer name
  293. )
  294. checker.check_model(model)
  295. return model
  296. caffe2_net_to_onnx_graph = Caffe2Frontend.caffe2_net_to_onnx_graph
  297. caffe2_net_to_onnx_model = Caffe2Frontend.caffe2_net_to_onnx_model
  298. caffe2_init_net_to_initializer = Caffe2Frontend.caffe2_init_net_to_initializer
  299. ssa_rewrite = Caffe2Frontend.ssa_rewrite