transform.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. ## @package onnx
  2. #Module caffe2.python.trt.transform
  3. """
  4. TensorRT related transformation
  5. Note that ONNX-TRT enforce an NCHW input!
  6. """
  7. from caffe2.proto import caffe2_pb2
  8. from caffe2.python import workspace
  9. import caffe2.python._import_c_extension as C
  10. import numpy as np
  11. def _dim_values_to_list(dim_values):
  12. return [x.dim_value for x in dim_values]
  13. def _get_output_shapes(output_value_infos):
  14. names = [x.name for x in output_value_infos]
  15. shapes = [_dim_values_to_list(x.type.tensor_type.shape.dim) for x in output_value_infos]
  16. return dict(zip(names, shapes))
  17. def check_gpu_():
  18. try:
  19. C.get_cuda_version()
  20. except Exception as _:
  21. raise Exception("TensorRT related functions require CUDA support")
  22. def convert_onnx_model_to_trt_op(onnx_model,
  23. max_batch_size=64,
  24. max_workspace_size=2*1024*1024,
  25. verbosity=1,
  26. debug_builder=False):
  27. """
  28. Convert the whole ONNX model to a TensorRT C2 op
  29. """
  30. check_gpu_()
  31. trt_str = C.onnx_to_trt_op(onnx_model.SerializeToString(),
  32. _get_output_shapes(onnx_model.graph.output),
  33. max_batch_size,
  34. max_workspace_size,
  35. verbosity,
  36. debug_builder)
  37. op = caffe2_pb2.OperatorDef()
  38. op.ParseFromString(trt_str)
  39. return op
  40. # Assume the workspace is already filled with init weights
  41. def _infer_shapes(pred_net, inputs):
  42. workspace.RunNetOnce(pred_net)
  43. hints = {}
  44. for op in pred_net.op:
  45. for o in op.output:
  46. if o not in hints:
  47. blob = workspace.FetchBlob(o)
  48. if hasattr(blob, 'shape'):
  49. hints[o] = blob.shape
  50. for i in op.input:
  51. if i not in hints:
  52. blob = workspace.FetchBlob(i)
  53. if hasattr(blob, 'shape'):
  54. hints[i] = blob.shape
  55. return hints
  56. def transform_caffe2_net(
  57. pred_net,
  58. input_shapes,
  59. populate_shapes = False,
  60. max_batch_size=64,
  61. max_workspace_size=2*1024*1024,
  62. verbosity=1,
  63. debug_builder=False,
  64. build_serializable_op=True):
  65. """
  66. Transform the caffe2_net by collapsing TRT-runnable nodes into trt c2 ops
  67. """
  68. check_gpu_()
  69. # Hacky way to infer shapes as not all our operators have shape inference function.
  70. # Normally this is not needed
  71. shape_hints = {}
  72. if populate_shapes:
  73. input_data = {}
  74. for k,v in input_shapes.items():
  75. input_data[k] = np.random.randn(*v).astype(np.float32)
  76. shape_hints = _infer_shapes(pred_net, input_data)
  77. for k,v in input_shapes.items():
  78. shape_hints[k] = v
  79. pred_net_str = C.transform_trt(pred_net.SerializeToString(),
  80. shape_hints,
  81. max_batch_size,
  82. max_workspace_size,
  83. verbosity,
  84. debug_builder,
  85. build_serializable_op)
  86. pred_net_cut = caffe2_pb2.NetDef()
  87. pred_net_cut.ParseFromString(pred_net_str)
  88. return pred_net_cut