helper.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. ## @package onnx
  2. # Module caffe2.python.onnx.helper
  3. from caffe2.proto import caffe2_pb2
  4. from onnx.backend.base import namedtupledict
  5. from caffe2.python.onnx.workspace import Workspace
  6. import logging
  7. import time
  8. log = logging.getLogger(__name__)
  9. def c2_native_run_op(op_def, inputs):
  10. ws = Workspace()
  11. if isinstance(inputs, dict):
  12. for key, value in inputs.items():
  13. ws.FeedBlob(key, value, op_def.device_option)
  14. else:
  15. assert(len(op_def.input) == len(inputs))
  16. for key, value in zip(op_def.input, inputs):
  17. ws.FeedBlob(key, value, op_def.device_option)
  18. ws.RunOperatorOnce(op_def)
  19. output_names = op_def.output
  20. output_values = [ws.FetchBlob(name) for name in output_names]
  21. return ws, namedtupledict('Outputs', output_names)(*output_values)
  22. def c2_native_run_net(init_net, predict_net, inputs, debug_arg=None):
  23. ws = Workspace()
  24. if init_net:
  25. ws.RunNetOnce(init_net)
  26. if isinstance(inputs, dict):
  27. for key, value in inputs.items():
  28. ws.FeedBlob(key, value, predict_net.device_option)
  29. else:
  30. uninitialized = [input_name
  31. for input_name in predict_net.external_input
  32. if not ws.HasBlob(input_name)]
  33. if len(uninitialized) == len(inputs):
  34. for key, value in zip(uninitialized, inputs):
  35. ws.FeedBlob(key, value, predict_net.device_option)
  36. else:
  37. # If everything is initialized,
  38. # we just initialized the first len(inputs) external_input.
  39. # Added some extra logging to help debug sporadic sandcastle fails
  40. if len(inputs) > len(predict_net.external_input):
  41. print("c2_native_run_net assert. len(inputs)=", len(inputs),
  42. "len(predict_net.external_input)=",
  43. len(predict_net.external_input))
  44. print("debug_arg: ", debug_arg)
  45. print("predict_net ", type(predict_net), ":", predict_net)
  46. print("inputs ", type(inputs), ":", inputs)
  47. assert(len(inputs) <= len(predict_net.external_input))
  48. for i in range(len(inputs)):
  49. ws.FeedBlob(predict_net.external_input[i], inputs[i],
  50. predict_net.device_option)
  51. ws.RunNetOnce(predict_net)
  52. output_names = predict_net.external_output
  53. output_values = [ws.FetchBlob(name) for name in output_names]
  54. return ws, namedtupledict('Outputs', output_names)(*output_values)
  55. def load_caffe2_net(file):
  56. net = caffe2_pb2.NetDef()
  57. with open(file, "rb") as f:
  58. net.ParseFromString(f.read())
  59. return net
  60. def save_caffe2_net(net, file, output_txt=False):
  61. with open(file, "wb") as f:
  62. f.write(net.SerializeToString())
  63. if output_txt:
  64. with open(file + "txt", "w") as f:
  65. f.write(str(net))
  66. def benchmark_caffe2_model(init_net, predict_net, warmup_iters=3, main_iters=10, layer_details=True):
  67. '''
  68. Run the benchmark net on the target model.
  69. Return the execution time per iteration (millisecond).
  70. '''
  71. ws = Workspace()
  72. if init_net:
  73. ws.RunNetOnce(init_net)
  74. ws.CreateNet(predict_net)
  75. results = ws.BenchmarkNet(predict_net.name, warmup_iters, main_iters, layer_details)
  76. del ws
  77. return results[0]
  78. def benchmark_pytorch_model(model, inputs, training=False, warmup_iters=3,
  79. main_iters=10, verbose=False):
  80. '''
  81. Run the model several times, and measure the execution time.
  82. Return the execution time per iteration (millisecond).
  83. '''
  84. for _i in range(warmup_iters):
  85. model(*inputs)
  86. total_pytorch_time = 0.0
  87. for _i in range(main_iters):
  88. ts = time.time()
  89. model(*inputs)
  90. te = time.time()
  91. total_pytorch_time += te - ts
  92. log.info("The PyTorch model execution time per iter is {} milliseconds, "
  93. "{} iters per second.".format(total_pytorch_time / main_iters * 1000,
  94. main_iters / total_pytorch_time))
  95. return total_pytorch_time * 1000 / main_iters