benchmark_generator.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. #!/usr/bin/env python3
  2. import string
  3. import argparse
  4. import numpy as np
  5. from caffe2.python.model_helper import ModelHelper
  6. from caffe2.python.predictor import mobile_exporter
  7. from caffe2.python import core, workspace, brew, utils
  8. def parse_kwarg(kwarg_str):
  9. key, value = map(string.strip, kwarg_str.split("=", 1))
  10. try:
  11. value = int(value)
  12. except ValueError:
  13. try:
  14. value = float(value)
  15. except ValueError:
  16. pass
  17. return key, value
  18. def main(args):
  19. # User defined keyword arguments
  20. kwargs = {"order": "NCHW"}
  21. kwargs.update(dict(args.kwargs))
  22. model = ModelHelper(name=args.benchmark_name)
  23. op_type = args.operator # assumes a brew type op name
  24. input_name = args.input_name
  25. output_name = args.output_name
  26. iters = int(args.iters)
  27. for i in range(iters):
  28. input_blob_name = input_name + (str(i) if i > 0 and args.chain else '')
  29. output_blob_name = output_name + str(i + 1)
  30. add_op = getattr(brew, op_type)
  31. add_op(model, input_blob_name, output_blob_name, **kwargs)
  32. if args.chain:
  33. input_name, output_name = output_name, input_name
  34. workspace.RunNetOnce(model.param_init_net)
  35. extra_init_net_ops = []
  36. def make_blob_on_context(blob_name, blob_data, context):
  37. if context.upper() != "CPU":
  38. blob_name_modified = "{}_CPU".format(blob_name)
  39. else: # CPU case is simple
  40. blob_name_modified = blob_name
  41. fill_op = core.CreateOperator(
  42. "GivenTensorFill", [], [blob_name_modified],
  43. arg=[
  44. utils.MakeArgument("shape", blob_data.shape),
  45. utils.MakeArgument("values", blob_data)
  46. ]
  47. )
  48. extra_init_net_ops.append(fill_op)
  49. # We need to create CPU blobs and add some copy operations in
  50. # the init_net
  51. if context.upper() == "OPENGL":
  52. copy_op = core.CreateOperator("CopyToOpenGL", [blob_name_modified],
  53. [blob_name])
  54. extra_init_net_ops.append(copy_op)
  55. for unparsed_blob in args.blob:
  56. name, unparsed_dims = unparsed_blob.split('=')
  57. dims = [int(d) for d in unparsed_dims.split(',')]
  58. np_input = np.random.rand(*dims).astype(np.float32)
  59. make_blob_on_context(name, np_input, args.context)
  60. init_net, predict_net = mobile_exporter.Export(
  61. workspace, model.net, model.params
  62. )
  63. init_net.op.extend(extra_init_net_ops)
  64. # Handle manual rewrite
  65. if args.context.upper() == "OPENGL":
  66. old_ops = [op for op in predict_net.op]
  67. del predict_net.op[:]
  68. for op in old_ops:
  69. op.type = 'OpenGL{}'.format(op.type)
  70. predict_net.op.extend(old_ops)
  71. if args.debug:
  72. print("init_net:")
  73. for op in init_net.op:
  74. print(" ", op.type, op.input, "-->", op.output)
  75. print("predict_net:")
  76. for op in predict_net.op:
  77. print(" ", op.type, op.input, "-->", op.output)
  78. with open(args.predict_net, 'wb') as f:
  79. f.write(predict_net.SerializeToString())
  80. with open(args.init_net, 'wb') as f:
  81. f.write(init_net.SerializeToString())
  82. if __name__ == "__main__":
  83. parser = argparse.ArgumentParser(
  84. description="Utility to generate Caffe2 benchmark models.")
  85. parser.add_argument("operator", help="Caffe2 operator to benchmark.")
  86. parser.add_argument("-b", "--blob",
  87. help="Instantiate a blob --blob name=dim1,dim2,dim3",
  88. action='append')
  89. parser.add_argument("--context", help="Context to run on.", default="CPU")
  90. parser.add_argument("--kwargs", help="kwargs to pass to operator.",
  91. nargs="*", type=parse_kwarg, default=[])
  92. parser.add_argument("--init_net", help="Output initialization net.",
  93. default="init_net.pb")
  94. parser.add_argument("--predict_net", help="Output prediction net.",
  95. default="predict_net.pb")
  96. parser.add_argument("--benchmark_name",
  97. help="Name of the benchmark network",
  98. default="benchmark")
  99. parser.add_argument("--input_name", help="Name of the input blob.",
  100. default="data")
  101. parser.add_argument("--output_name", help="Name of the output blob.",
  102. default="output")
  103. parser.add_argument("--iters",
  104. help="Number of iterations to run the operator.",
  105. default="1")
  106. parser.add_argument("-d", "--debug", help="Print debug information.",
  107. action='store_true')
  108. parser.add_argument("-c", "--chain",
  109. help="Chain ops together (create data dependencies)",
  110. action='store_true')
  111. args = parser.parse_args()
  112. main(args)