onnxifi.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. ## @package onnx
  2. #Module caffe2.python.onnx.onnxifi
  3. """
  4. ONNXIFI a Caffe2 net
  5. """
  6. from caffe2.proto import caffe2_pb2
  7. import caffe2.python._import_c_extension as C
  8. def onnxifi_set_option(option_name, option_value):
  9. """
  10. Set onnxifi option
  11. """
  12. return C.onnxifi_set_option(option_name, str(option_value))
  13. def onnxifi_get_option(option_name):
  14. """
  15. Get onnxifi option
  16. """
  17. return C.onnxifi_get_option(option_name)
  18. def onnxifi_caffe2_net(
  19. pred_net,
  20. input_shapes,
  21. max_batch_size=1,
  22. max_seq_size=1,
  23. debug=False,
  24. use_onnx=True,
  25. merge_fp32_inputs_into_fp16=False,
  26. adjust_batch=True,
  27. block_list=None,
  28. weight_names=None,
  29. net_ssa_rewritten=False,
  30. timeout=0):
  31. """
  32. Transform the caffe2_net by collapsing ONNXIFI-runnable nodes into Onnxifi c2 ops
  33. """
  34. shape_hints = caffe2_pb2.TensorBoundShapes()
  35. if type(input_shapes) is caffe2_pb2.TensorBoundShapes:
  36. shape_hints = input_shapes
  37. elif type(input_shapes) is dict:
  38. for k, v in input_shapes.items():
  39. tbs = caffe2_pb2.TensorBoundShape()
  40. tbs.name = k
  41. tbs.shape.dims.extend(v)
  42. tbs.dim_type.extend([caffe2_pb2.TensorBoundShape.CONSTANT] * len(tbs.shape.dims))
  43. tbs.dim_type[0] = caffe2_pb2.TensorBoundShape.BATCH
  44. shape_hints.shapes.extend([tbs])
  45. shape_hints.max_batch_size = max_batch_size
  46. shape_hints.max_feature_len = max_seq_size
  47. pred_net_str = C.onnxifi(pred_net.SerializeToString(),
  48. shape_hints.SerializeToString(),
  49. block_list if block_list else [],
  50. weight_names if weight_names is not None else [],
  51. max_batch_size,
  52. max_seq_size,
  53. timeout,
  54. adjust_batch,
  55. debug,
  56. merge_fp32_inputs_into_fp16,
  57. net_ssa_rewritten,
  58. use_onnx)
  59. pred_net_cut = caffe2_pb2.NetDef()
  60. pred_net_cut.ParseFromString(pred_net_str)
  61. return pred_net_cut