test_trt.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279
  1. from caffe2.proto import caffe2_pb2
  2. from caffe2.python import core, workspace
  3. import onnx
  4. import onnx.defs
  5. from onnx.helper import make_node, make_graph, make_tensor_value_info, make_model
  6. from onnx.backend.base import namedtupledict
  7. from caffe2.python.models.download import ModelDownloader
  8. import caffe2.python.onnx.backend as c2
  9. from caffe2.python.onnx.workspace import Workspace
  10. from caffe2.python.trt.transform import convert_onnx_model_to_trt_op, transform_caffe2_net
  11. from caffe2.python.onnx.tests.test_utils import TestCase
  12. import numpy as np
  13. import os.path
  14. import time
  15. import unittest
  16. import tarfile
  17. import tempfile
  18. import shutil
  19. from six.moves.urllib.request import urlretrieve
  20. def _print_net(net):
  21. for i in net.external_input:
  22. print("Input: {}".format(i))
  23. for i in net.external_output:
  24. print("Output: {}".format(i))
  25. for op in net.op:
  26. print("Op {}".format(op.type))
  27. for x in op.input:
  28. print(" input: {}".format(x))
  29. for y in op.output:
  30. print(" output: {}".format(y))
  31. def _base_url(opset_version):
  32. return 'https://s3.amazonaws.com/download.onnx/models/opset_{}'.format(opset_version)
  33. # TODO: This is copied from https://github.com/onnx/onnx/blob/master/onnx/backend/test/runner/__init__.py. Maybe we should
  34. # expose a model retrival API from ONNX
  35. def _download_onnx_model(model_name, opset_version):
  36. onnx_home = os.path.expanduser(os.getenv('ONNX_HOME', os.path.join('~', '.onnx')))
  37. models_dir = os.getenv('ONNX_MODELS',
  38. os.path.join(onnx_home, 'models'))
  39. model_dir = os.path.join(models_dir, model_name)
  40. if not os.path.exists(os.path.join(model_dir, 'model.onnx')):
  41. if os.path.exists(model_dir):
  42. bi = 0
  43. while True:
  44. dest = '{}.old.{}'.format(model_dir, bi)
  45. if os.path.exists(dest):
  46. bi += 1
  47. continue
  48. shutil.move(model_dir, dest)
  49. break
  50. os.makedirs(model_dir)
  51. # On Windows, NamedTemporaryFile can not be opened for a
  52. # second time
  53. url = '{}/{}.tar.gz'.format(_base_url(opset_version), model_name)
  54. download_file = tempfile.NamedTemporaryFile(delete=False)
  55. try:
  56. download_file.close()
  57. print('Start downloading model {} from {}'.format(
  58. model_name, url))
  59. urlretrieve(url, download_file.name)
  60. print('Done')
  61. with tarfile.open(download_file.name) as t:
  62. t.extractall(models_dir)
  63. except Exception as e:
  64. print('Failed to prepare data for model {}: {}'.format(
  65. model_name, e))
  66. raise
  67. finally:
  68. os.remove(download_file.name)
  69. return model_dir
  70. class TensorRTOpTest(TestCase):
  71. def setUp(self):
  72. self.opset_version = onnx.defs.onnx_opset_version()
  73. def _test_relu_graph(self, X, batch_size, trt_max_batch_size):
  74. node_def = make_node("Relu", ["X"], ["Y"])
  75. Y_c2 = c2.run_node(node_def, {"X": X})
  76. graph_def = make_graph(
  77. [node_def],
  78. name="test",
  79. inputs=[make_tensor_value_info("X", onnx.TensorProto.FLOAT, [batch_size, 1, 3, 2])],
  80. outputs=[make_tensor_value_info("Y", onnx.TensorProto.FLOAT, [batch_size, 1, 3, 2])])
  81. model_def = make_model(graph_def, producer_name='relu-test')
  82. op_outputs = [x.name for x in model_def.graph.output]
  83. op = convert_onnx_model_to_trt_op(model_def, max_batch_size=trt_max_batch_size)
  84. device_option = core.DeviceOption(caffe2_pb2.CUDA, 0)
  85. op.device_option.CopyFrom(device_option)
  86. Y_trt = None
  87. ws = Workspace()
  88. with core.DeviceScope(device_option):
  89. ws.FeedBlob("X", X)
  90. ws.RunOperatorsOnce([op])
  91. output_values = [ws.FetchBlob(name) for name in op_outputs]
  92. Y_trt = namedtupledict('Outputs', op_outputs)(*output_values)
  93. np.testing.assert_almost_equal(Y_c2, Y_trt)
  94. @unittest.skipIf(not workspace.C.use_trt, "No TensortRT support")
  95. def test_relu_graph_simple(self):
  96. X = np.random.randn(1, 1, 3, 2).astype(np.float32)
  97. self._test_relu_graph(X, 1, 50)
  98. @unittest.skipIf(not workspace.C.use_trt, "No TensortRT support")
  99. def test_relu_graph_big_batch(self):
  100. X = np.random.randn(52, 1, 3, 2).astype(np.float32)
  101. self._test_relu_graph(X, 52, 50)
  102. def _test_onnx_importer(self, model_name, data_input_index, opset_version=onnx.defs.onnx_opset_version()):
  103. model_dir = _download_onnx_model(model_name, opset_version)
  104. model_def = onnx.load(os.path.join(model_dir, 'model.onnx'))
  105. input_blob_dims = [int(x.dim_value) for x in model_def.graph.input[data_input_index].type.tensor_type.shape.dim]
  106. op_inputs = [x.name for x in model_def.graph.input]
  107. op_outputs = [x.name for x in model_def.graph.output]
  108. print("{}".format(op_inputs))
  109. data = np.random.randn(*input_blob_dims).astype(np.float32)
  110. Y_c2 = c2.run_model(model_def, {op_inputs[data_input_index]: data})
  111. op = convert_onnx_model_to_trt_op(model_def, verbosity=3)
  112. device_option = core.DeviceOption(caffe2_pb2.CUDA, 0)
  113. op.device_option.CopyFrom(device_option)
  114. Y_trt = None
  115. ws = Workspace()
  116. with core.DeviceScope(device_option):
  117. ws.FeedBlob(op_inputs[data_input_index], data)
  118. if opset_version >= 5:
  119. # Some newer models from ONNX Zoo come with pre-set "data_0" input
  120. ws.FeedBlob("data_0", data)
  121. ws.RunOperatorsOnce([op])
  122. output_values = [ws.FetchBlob(name) for name in op_outputs]
  123. Y_trt = namedtupledict('Outputs', op_outputs)(*output_values)
  124. np.testing.assert_allclose(Y_c2, Y_trt, rtol=1e-3)
  125. @unittest.skipIf(not workspace.C.use_trt, "No TensortRT support")
  126. def test_resnet50(self):
  127. self._test_onnx_importer('resnet50', 0, 9)
  128. @unittest.skipIf(not workspace.C.use_trt, "No TensortRT support")
  129. def test_bvlc_alexnet(self):
  130. self._test_onnx_importer('bvlc_alexnet', 0, 9)
  131. @unittest.skip("Until fixing Unsqueeze op")
  132. def test_densenet121(self):
  133. self._test_onnx_importer('densenet121', -1, 3)
  134. @unittest.skipIf(not workspace.C.use_trt, "No TensortRT support")
  135. def test_inception_v1(self):
  136. self._test_onnx_importer('inception_v1', -3, 9)
  137. @unittest.skip("Until fixing Unsqueeze op")
  138. def test_inception_v2(self):
  139. self._test_onnx_importer('inception_v2', 0, 9)
  140. @unittest.skip('Need to revisit our ChannelShuffle exporter to avoid generating 5D tensor')
  141. def test_shufflenet(self):
  142. self._test_onnx_importer('shufflenet', 0)
  143. @unittest.skipIf(not workspace.C.use_trt, "No TensortRT support")
  144. def test_squeezenet(self):
  145. self._test_onnx_importer('squeezenet', -1, 9)
  146. @unittest.skipIf(not workspace.C.use_trt, "No TensortRT support")
  147. def test_vgg16(self):
  148. self._test_onnx_importer('vgg16', 0, 9)
  149. @unittest.skipIf(not workspace.C.use_trt, "No TensortRT support")
  150. def test_vgg19(self):
  151. self._test_onnx_importer('vgg19', -2, 9)
  152. class TensorRTTransformTest(TestCase):
  153. def setUp(self):
  154. self.model_downloader = ModelDownloader()
  155. def _add_head_tail(self, pred_net, new_head, new_tail):
  156. orig_head = pred_net.external_input[0]
  157. orig_tail = pred_net.external_output[0]
  158. # Add head
  159. head = caffe2_pb2.OperatorDef()
  160. head.type = "Copy"
  161. head.input.append(new_head)
  162. head.output.append(orig_head)
  163. dummy = caffe2_pb2.NetDef()
  164. dummy.op.extend(pred_net.op)
  165. del pred_net.op[:]
  166. pred_net.op.extend([head])
  167. pred_net.op.extend(dummy.op)
  168. pred_net.external_input[0] = new_head
  169. # Add tail
  170. tail = caffe2_pb2.OperatorDef()
  171. tail.type = "Copy"
  172. tail.input.append(orig_tail)
  173. tail.output.append(new_tail)
  174. pred_net.op.extend([tail])
  175. pred_net.external_output[0] = new_tail
  176. @unittest.skipIf(not workspace.C.use_trt, "No TensortRT support")
  177. def test_resnet50_core(self):
  178. N = 2
  179. warmup = 20
  180. repeat = 100
  181. print("Batch size: {}, repeat inference {} times, warmup {} times".format(N, repeat, warmup))
  182. init_net, pred_net, _ = self.model_downloader.get_c2_model('resnet50')
  183. self._add_head_tail(pred_net, 'real_data', 'real_softmax')
  184. input_blob_dims = (N, 3, 224, 224)
  185. input_name = "real_data"
  186. device_option = core.DeviceOption(caffe2_pb2.CUDA, 0)
  187. init_net.device_option.CopyFrom(device_option)
  188. pred_net.device_option.CopyFrom(device_option)
  189. for op in pred_net.op:
  190. op.device_option.CopyFrom(device_option)
  191. op.engine = 'CUDNN'
  192. net_outputs = pred_net.external_output
  193. Y_c2 = None
  194. data = np.random.randn(*input_blob_dims).astype(np.float32)
  195. c2_time = 1
  196. workspace.SwitchWorkspace("gpu_test", True)
  197. with core.DeviceScope(device_option):
  198. workspace.FeedBlob(input_name, data)
  199. workspace.RunNetOnce(init_net)
  200. workspace.CreateNet(pred_net)
  201. for _ in range(warmup):
  202. workspace.RunNet(pred_net.name)
  203. start = time.time()
  204. for _ in range(repeat):
  205. workspace.RunNet(pred_net.name)
  206. end = time.time()
  207. c2_time = end - start
  208. output_values = [workspace.FetchBlob(name) for name in net_outputs]
  209. Y_c2 = namedtupledict('Outputs', net_outputs)(*output_values)
  210. workspace.ResetWorkspace()
  211. # Fill the workspace with the weights
  212. with core.DeviceScope(device_option):
  213. workspace.RunNetOnce(init_net)
  214. # Cut the graph
  215. start = time.time()
  216. pred_net_cut = transform_caffe2_net(pred_net,
  217. {input_name: input_blob_dims},
  218. build_serializable_op=False)
  219. del init_net, pred_net
  220. pred_net_cut.device_option.CopyFrom(device_option)
  221. for op in pred_net_cut.op:
  222. op.device_option.CopyFrom(device_option)
  223. #_print_net(pred_net_cut)
  224. Y_trt = None
  225. input_name = pred_net_cut.external_input[0]
  226. print("C2 runtime: {}s".format(c2_time))
  227. with core.DeviceScope(device_option):
  228. workspace.FeedBlob(input_name, data)
  229. workspace.CreateNet(pred_net_cut)
  230. end = time.time()
  231. print("Conversion time: {:.2f}s".format(end -start))
  232. for _ in range(warmup):
  233. workspace.RunNet(pred_net_cut.name)
  234. start = time.time()
  235. for _ in range(repeat):
  236. workspace.RunNet(pred_net_cut.name)
  237. end = time.time()
  238. trt_time = end - start
  239. print("TRT runtime: {}s, improvement: {}%".format(trt_time, (c2_time-trt_time)/c2_time*100))
  240. output_values = [workspace.FetchBlob(name) for name in net_outputs]
  241. Y_trt = namedtupledict('Outputs', net_outputs)(*output_values)
  242. np.testing.assert_allclose(Y_c2, Y_trt, rtol=1e-3)