rewrite_graph.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215
  1. import copy
  2. from caffe2.proto import caffe2_pb2
  3. from caffe2.python import core
  4. def rewrite_init_net_simple(net):
  5. for op in net.op:
  6. op.device_option.device_type = caffe2_pb2.IDEEP
  7. def last_producer(ops, blob):
  8. for (i, op) in reversed(list(enumerate(ops))):
  9. if blob in op.output:
  10. return i
  11. raise ValueError("Failed to find last producer of blob, %s", blob)
  12. def fix_BoxWithNMSLimit(net):
  13. outputs = set()
  14. for op in net.op:
  15. if op.type == 'BoxWithNMSLimit':
  16. outputs.add(op.output[0])
  17. outputs.add(op.output[1])
  18. outputs.add(op.output[2])
  19. for op in net.op:
  20. if op.type == 'CopyIDEEPToCPU':
  21. if op.input[0] in outputs:
  22. print("Chaning CopyIDEEPToCPU to Copy for {}".format(op.input[0]))
  23. op.type = 'Copy'
  24. op.device_option.device_type = caffe2_pb2.CPU
  25. def rewrite_run_net_simple(net):
  26. # Simple rewrite for now - assume entire graph can be executed
  27. # with MKL, so just insert copy ops for external_input[0] and
  28. # external_output[0]
  29. def mkl_tmp(name):
  30. return "{}__MKL__".format(name)
  31. input_blob = net.external_input[0]
  32. if input_blob != net.op[0].input[0]:
  33. raise Exception(
  34. "Input blob: {} is not consumed by first op: {}".format(
  35. input_blob, net.op[0]))
  36. # Modify input/outputs to point to copied MKL blobs.
  37. from_cpu = "CopyCPUToIDEEP"
  38. to_cpu = "CopyIDEEPToCPU"
  39. copy_input_op = core.CreateOperator(
  40. from_cpu, input_blob, mkl_tmp(input_blob))
  41. net.op[0].input[0] = mkl_tmp(input_blob)
  42. copy_output_ops = [
  43. core.CreateOperator(to_cpu, mkl_tmp(output_blob), output_blob)
  44. for output_blob in net.external_output]
  45. for output_blob in net.external_output:
  46. last_producer_idx = last_producer(net.op, output_blob)
  47. renamed_outputs = [blob if blob != output_blob else mkl_tmp(blob)
  48. for blob in net.op[last_producer_idx].output]
  49. net.op[last_producer_idx].output[:] = renamed_outputs
  50. # Rename any subsequent consumers of an output blob.
  51. for op in net.op[last_producer_idx + 1:]:
  52. renamed_input = [blob if blob != output_blob else mkl_tmp(blob)
  53. for blob in op.input]
  54. op.input[:] = renamed_input
  55. ops = [copy_input_op] + net.op[:] + copy_output_ops
  56. del net.op[:]
  57. net.op.extend(ops)
  58. device = caffe2_pb2.IDEEP
  59. for op in net.op:
  60. op.device_option.MergeFrom(
  61. core.DeviceOption(device_type=device))
  62. op.engine = ""
  63. # Temporarily disable conv+relu fusion until we verify further
  64. # net.ParseFromString(
  65. # C.transform_optimizeForMKLDNN(net.SerializeToString()))
  66. fix_BoxWithNMSLimit(net)
  67. def rewrite_run_net_simple_xrayocr_lstm(net):
  68. # For xrayocr model with lstm, only rewrite the non-lstm part of the net to
  69. # enable mkl, then copy the temporary output blob at the break point
  70. # and all external inputs for lstm part to cpu, and execuate rest of the net
  71. # (two lstm) on cpu
  72. # This only works for the xrayocr lstm model which uses the first 'Shape' op
  73. # to decide the break point, and after two lstm it's external_output
  74. # directly so there's no need to copy back to ideep/mkl
  75. def mkl_tmp(name):
  76. return "{}__MKL__".format(name)
  77. def cpu_tmp(name):
  78. return "{}__CPU__".format(name)
  79. input_blob = net.external_input[0]
  80. if input_blob != net.op[0].input[0]:
  81. raise Exception(
  82. "Input blob: {} is not consumed by first op: {}".format(
  83. input_blob, net.op[0]))
  84. # Modify input/outputs to point to copied MKL blobs.
  85. from_cpu = "CopyCPUToIDEEP"
  86. to_cpu = "CopyIDEEPToCPU"
  87. copy_input_op = core.CreateOperator(
  88. from_cpu, input_blob, mkl_tmp(input_blob))
  89. net.op[0].input[0] = mkl_tmp(input_blob)
  90. # the net may contain some external_inputs falsely added during ONNX->Caffe2
  91. # This should be taken care of in early steps during pytorch_to_caffe2,
  92. # but if not it can cause issue in follow up steps, so check here to confirm
  93. for input_blob in net.external_input:
  94. for op in net.op:
  95. # look for if the external_input blob is output of any op in the net
  96. assert input_blob not in op.output
  97. external_output = None
  98. external_inputs_to_cpu = set()
  99. find_first_shape_op = False
  100. cpu_op_start_idx = -1
  101. for op_idx, op in enumerate(net.op):
  102. # the first Shape op mark the starting point of LSTM chunk of the net
  103. if not find_first_shape_op:
  104. if op.type == 'Shape':
  105. external_output = op.input
  106. find_first_shape_op = True
  107. cpu_op_start_idx = op_idx
  108. else:
  109. # any external input in the LSTM part need to be copied to CPU
  110. for in_blob in op.input:
  111. if in_blob in net.external_input:
  112. external_inputs_to_cpu.add(in_blob)
  113. # make sure we found the expected break point of the net
  114. assert external_output is not None
  115. # create op to copy external input blobs used in LSTM part from IDEEP to CPU
  116. copy_extra_input_ops = []
  117. for in_blob in external_inputs_to_cpu:
  118. copy_extra_input_ops.append(core.CreateOperator(to_cpu, in_blob,
  119. cpu_tmp(in_blob)))
  120. # rename input blobs in LSTM part to use the CPU copy
  121. for op in net.op[cpu_op_start_idx:]:
  122. renamed_input = [blob if blob != in_blob else cpu_tmp(in_blob)
  123. for blob in op.input]
  124. op.input[:] = renamed_input
  125. copy_output_ops = [
  126. core.CreateOperator(to_cpu, mkl_tmp(output_blob), output_blob)
  127. for output_blob in external_output]
  128. for output_blob in external_output:
  129. last_producer_idx = last_producer(net.op, output_blob)
  130. renamed_outputs = [blob if blob != output_blob else mkl_tmp(blob)
  131. for blob in net.op[last_producer_idx].output]
  132. net.op[last_producer_idx].output[:] = renamed_outputs
  133. # rearrange all ops in correct order
  134. ops = [copy_input_op] + net.op[:cpu_op_start_idx] \
  135. + copy_output_ops + copy_extra_input_ops + net.op[cpu_op_start_idx:]
  136. del net.op[:]
  137. net.op.extend(ops)
  138. device = caffe2_pb2.IDEEP
  139. for op in net.op:
  140. # the first Shape op mark the starting point of LSTM chunk of the net
  141. if op.type == 'Shape':
  142. # all LSTM ops should run on CPU
  143. device = caffe2_pb2.CPU
  144. op.device_option.MergeFrom(
  145. core.DeviceOption(device_type=device))
  146. op.engine = ""
  147. # RecurrentNetwork has a nested step_net that needs special treatment
  148. if op.type == 'RecurrentNetwork':
  149. for arg in op.arg:
  150. if arg.name == 'step_net':
  151. for nested_op in arg.n.op:
  152. # set device to CPU
  153. nested_op.device_option.MergeFrom(
  154. core.DeviceOption(device_type=device))
  155. nested_op.engine = ""
  156. # rename inputs in op of nested net
  157. renamed_input = []
  158. for blob in nested_op.input:
  159. renamed_input.append(blob
  160. if blob not in external_inputs_to_cpu
  161. else cpu_tmp(blob))
  162. nested_op.input[:] = renamed_input
  163. # rename external inputs of nested net
  164. new_external_input = []
  165. for blob in arg.n.external_input:
  166. new_external_input.append(blob
  167. if blob not in external_inputs_to_cpu
  168. else cpu_tmp(blob))
  169. arg.n.external_input[:] = new_external_input
  170. # Temporarily disable conv+relu fusion until we verify further
  171. # net.ParseFromString(
  172. # C.transform_optimizeForMKLDNN(net.SerializeToString()))
  173. fix_BoxWithNMSLimit(net)
  174. def rewrite_model_helper_simple(model):
  175. model = copy.deepcopy(model)
  176. # All parameter initialization should run on MKL
  177. rewrite_init_net_simple(model.param_init_net.Proto())
  178. rewrite_run_net_simple(model.net.Proto())
  179. return model